diff --git a/main.go b/main.go index 344d073..fc5b9a8 100644 --- a/main.go +++ b/main.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "io/ioutil" + "math/rand" "net/http" "os" "path/filepath" @@ -20,49 +21,8 @@ import ( "github.com/vincent-petithory/dataurl" ) -type server struct { - publicURL string - port string - - // connections stored as map of domain -> connections - conn map[string][]*Connection - sync.Mutex -} - -// Connection determine what can be held -type Connection struct { - ID int - Joined time.Time - Domain string - LastGet string - ws *WebsocketConn -} - -type WebsocketConn struct { - ws *websocket.Conn - sync.Mutex -} - -func NewWebsocket(ws *websocket.Conn) *WebsocketConn { - return &WebsocketConn{ - ws: ws, - } -} - -func (ws *WebsocketConn) Send(p Payload) (err error) { - ws.Lock() - defer ws.Unlock() - log.Tracef("sending %+v", p) - err = ws.ws.WriteJSON(p) - return -} - -func (ws *WebsocketConn) Receive() (p Payload, err error) { - ws.Lock() - defer ws.Unlock() - err = ws.ws.ReadJSON(&p) - log.Tracef("recv %+v", p) - return +func init() { + rand.Seed(time.Now().UTC().UnixNano()) } var Version string @@ -119,47 +79,15 @@ func main() { if err != nil { log.Debug(err) } - // var debug, flagServer bool - // var flagPort, flagPublicURL string - // flag.StringVar(&flagPort, "port", "8001", "port") - // flag.StringVar(&flagPublicURL, "url", "", "public url to use") - // flag.BoolVar(&debug, "debug", false, "debug mode") - // flag.BoolVar(&flagServer, "serve", false, "serve files") - // flag.Parse() - - // if debug { - // log.SetLevel("debug") - // } else { - // log.SetLevel("info") - // } - - // if flagServer { - // log.SetLevel("trace") - // client := &Client{ - // WebsocketURL: "wss://omni.schollz.com/ws", - // } - // client.Run() - - // os.Exit(1) - // } - - // if flagPublicURL == "" { - // flagPublicURL = "localhost:" + flagPort - // } - // if !strings.HasPrefix(flagPublicURL, "http") { - // flagPublicURL = "http://" + flagPublicURL - // } - - // s := &server{ - // port: flagPort, - // conn: make(map[string][]*Connection), - // publicURL: flagPublicURL, - // } - - // s.serve() } func host(c *cli.Context) (err error) { + if c.GlobalBool("debug") { + log.SetLevel("debug") + } else { + log.SetLevel("info") + } + return } @@ -186,6 +114,85 @@ func relay(c *cli.Context) (err error) { return s.serve() } +// +// websocket implementation +// + +var wsupgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +// Connection determine what can be held +type Connection struct { + ID int + Joined time.Time + Domain string + LastGet string + ws *WebsocketConn +} + +// Payload lists the data exchanged +type Payload struct { + Success bool `json:"success"` + Type string `json:"type,omitempty"` + Message string `json:"message,omitempty"` + IPAddress string `json:"ip,omitempty"` + Key string `json:"key,omitempty"` +} + +func (p Payload) String() string { + b, _ := json.Marshal(p) + return string(b) +} + +// WebsocketConn provides convenience functions for sending +// and receiving data with websockets, using mutex to +// make sure only one writer/reader +type WebsocketConn struct { + ws *websocket.Conn + sync.Mutex +} + +// NewWebsocket returns a new websocket +func NewWebsocket(ws *websocket.Conn) *WebsocketConn { + return &WebsocketConn{ + ws: ws, + } +} + +func (ws *WebsocketConn) Send(p Payload) (err error) { + ws.Lock() + defer ws.Unlock() + log.Tracef("sending %+v", p) + err = ws.ws.WriteJSON(p) + return +} + +func (ws *WebsocketConn) Receive() (p Payload, err error) { + ws.Lock() + defer ws.Unlock() + err = ws.ws.ReadJSON(&p) + log.Tracef("recv %+v", p) + return +} + +// +// server implementation +// + +type server struct { + publicURL string + port string + + // connections stored as map of domain -> connections + conn map[string][]*Connection + sync.Mutex +} + func (s *server) serve() (err error) { log.Infof("listening on :%s", s.port) http.HandleFunc("/", s.handler) @@ -204,6 +211,7 @@ func (s *server) handler(w http.ResponseWriter, r *http.Request) { func (s *server) handle(w http.ResponseWriter, r *http.Request) (err error) { log.Debugf("URL: %s, Referer: %s", r.URL.Path, r.Referer()) + // very special paths if r.URL.Path == "/robots.txt" { // special path @@ -229,49 +237,56 @@ Disallow: /`)) } return t.Execute(w, view{PublicURL: s.publicURL}) } else { - log.Debugf("attempting to find %s", r.URL.Path) - - pathToFile := r.URL.Path[1:] - domain := strings.Split(r.URL.Path[1:], "/")[0] - // check to make sure it has domain prepended - piecesOfReferer := strings.Split(r.Referer(), "/") - if len(piecesOfReferer) > 4 && strings.HasPrefix(r.Referer(), s.publicURL) { - domain = piecesOfReferer[3] - } - - // prefix the domain if it doesn't exist - if !strings.HasPrefix(pathToFile, domain) { - pathToFile = domain + "/" + pathToFile - http.Redirect(w, r, "/"+pathToFile, 302) - return - } - - // add index.html if it doesn't exist - if filepath.Ext(pathToFile) == "" { - if string(pathToFile[len(pathToFile)-1]) != "/" { - pathToFile += "/" - } - pathToFile += "index.html" - http.Redirect(w, r, "/"+pathToFile, 302) - return - } - + // get IP address var ipAddress string ipAddress, err = GetClientIPHelper(r) if err != nil { log.Debugf("could not determine ip: %s", err.Error()) } - var data string - data, err = s.get(pathToFile, ipAddress) - if err != nil { + log.Debugf("attempting to find %s", r.URL.Path) + + // determine file path and the domain + pathToFile := r.URL.Path[1:] + domain := strings.Split(r.URL.Path[1:], "/")[0] + // if there is a referer, try to obtain the domain from referer + piecesOfReferer := strings.Split(r.Referer(), "/") + if len(piecesOfReferer) > 4 && strings.HasPrefix(r.Referer(), s.publicURL) { + domain = piecesOfReferer[3] + } + // prefix the domain if it doesn't exist + if !strings.HasPrefix(pathToFile, domain) { + pathToFile = domain + "/" + pathToFile + http.Redirect(w, r, "/"+pathToFile, 302) return } + // trim prefix to get the path to file + pathToFile = strings.TrimPrefix(pathToFile, domain+"/") + + // send GET request to websockets + var data string + data, err = s.get(domain, pathToFile, ipAddress) + if err != nil { + // try index.html if it doesn't exist + if string(pathToFile[len(pathToFile)-1]) != "/" { + pathToFile += "/" + } + pathToFile += "index.html" + data, err = s.get(pathToFile, ipAddress) + if err != nil { + return + } + } + + // decode the data URI var dataURL *dataurl.DataURL dataURL, err = dataurl.DecodeString(data) if err != nil { + log.Errorf("problem decoding '%s': %s", data, err.Error()) return } + + // determine the content type contentType := dataURL.MediaType.ContentType() if contentType == "application/octet-stream" || contentType == "" { pathToFileExt := filepath.Ext(pathToFile) @@ -288,6 +303,8 @@ Disallow: /`)) } } } + + // write the data to the requester w.Header().Set("Content-Type", contentType) w.Write(dataURL.Data) return @@ -295,27 +312,6 @@ Disallow: /`)) return } -var wsupgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return true - }, -} - -type Payload struct { - // message meta - Type string `json:"type"` - Success bool `json:"success"` - Message string `json:"message"` - IPAddress string `json:"ip"` -} - -func (p Payload) String() string { - b, _ := json.Marshal(p) - return string(b) -} - func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err error) { // handle websockets on this page c, errUpgrade := wsupgrader.Upgrade(w, r, nil) @@ -336,7 +332,7 @@ func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err er } log.Debugf("recv: %s", p) - if !(p.Type == "domain" && p.Message != "") { + if !(p.Type == "domain" && p.Message != "" && p.Key != "") { err = fmt.Errorf("got wrong type/domain: %s/%s", p.Type, p.Message) log.Debug(err) c.Close() @@ -355,6 +351,7 @@ func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err er ID: len(s.conn[domain]), Domain: domain, Joined: time.Now(), + Key: p.Key, ws: NewWebsocket(c), } s.conn[domain] = append(s.conn[domain], conn) @@ -370,9 +367,7 @@ func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err er return } -func (s *server) get(filePath, ipAddress string) (payload string, err error) { - log.Debugf("requesting %s", filePath) - domain := strings.Split(filePath, "/")[0] +func (s *server) get(domain, filePath, ipAddress string) (payload string, err error) { var connections []*Connection s.Lock() @@ -382,11 +377,16 @@ func (s *server) get(filePath, ipAddress string) (payload string, err error) { s.Unlock() if connections == nil { err = fmt.Errorf("no connections available for domain %s", domain) + log.Debug(err) return } + log.Debugf("requesting %s/%s from %d connections", domain, filePath, len(connections)) - // loop through connections and try to get one to serve the file - for _, conn := range connections { + // any connection that initated with this key is viable + key := connections[0].Key + + // loop through connections randomly and try to get one to serve the file + for _, i := range rand.Perm(len(connections)) { var p Payload p, err = func() (p Payload, err error) { err = conn.ws.Send(Payload{ @@ -397,7 +397,7 @@ func (s *server) get(filePath, ipAddress string) (payload string, err error) { if err != nil { return } - p, err = conn.ws.Receive() + p, err = connections[i].ws.Receive() return }() if err != nil { @@ -405,20 +405,20 @@ func (s *server) get(filePath, ipAddress string) (payload string, err error) { s.DumpConnection(domain, conn.ID) continue } - if p.Type == "get" { + if len(p.Message) > 10 { + p.Message = p.Message[:10] + "..." + } + log.Debugf("recv: %+v", p) + if p.Type == "get" && p.Key == key { payload = p.Message if !p.Success { err = fmt.Errorf(payload) } return } - if len(p.Message) > 10 { - p.Message = p.Message[:10] + "..." - } - log.Debugf("recv: %+v", p) - err = fmt.Errorf("invalid response") - break + log.Debugf("no good data from %d", i) } + err = fmt.Errorf("invalid response") return }