reorganize code

This commit is contained in:
Zack Scholl 2019-07-10 16:24:00 -07:00
parent c5109f51ca
commit 8a1b684a03

292
main.go
View file

@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"os"
"path/filepath"
@ -20,49 +21,8 @@ import (
"github.com/vincent-petithory/dataurl"
)
type server struct {
publicURL string
port string
// connections stored as map of domain -> connections
conn map[string][]*Connection
sync.Mutex
}
// Connection determine what can be held
type Connection struct {
ID int
Joined time.Time
Domain string
LastGet string
ws *WebsocketConn
}
type WebsocketConn struct {
ws *websocket.Conn
sync.Mutex
}
func NewWebsocket(ws *websocket.Conn) *WebsocketConn {
return &WebsocketConn{
ws: ws,
}
}
func (ws *WebsocketConn) Send(p Payload) (err error) {
ws.Lock()
defer ws.Unlock()
log.Tracef("sending %+v", p)
err = ws.ws.WriteJSON(p)
return
}
func (ws *WebsocketConn) Receive() (p Payload, err error) {
ws.Lock()
defer ws.Unlock()
err = ws.ws.ReadJSON(&p)
log.Tracef("recv %+v", p)
return
func init() {
rand.Seed(time.Now().UTC().UnixNano())
}
var Version string
@ -119,47 +79,15 @@ func main() {
if err != nil {
log.Debug(err)
}
// var debug, flagServer bool
// var flagPort, flagPublicURL string
// flag.StringVar(&flagPort, "port", "8001", "port")
// flag.StringVar(&flagPublicURL, "url", "", "public url to use")
// flag.BoolVar(&debug, "debug", false, "debug mode")
// flag.BoolVar(&flagServer, "serve", false, "serve files")
// flag.Parse()
// if debug {
// log.SetLevel("debug")
// } else {
// log.SetLevel("info")
// }
// if flagServer {
// log.SetLevel("trace")
// client := &Client{
// WebsocketURL: "wss://omni.schollz.com/ws",
// }
// client.Run()
// os.Exit(1)
// }
// if flagPublicURL == "" {
// flagPublicURL = "localhost:" + flagPort
// }
// if !strings.HasPrefix(flagPublicURL, "http") {
// flagPublicURL = "http://" + flagPublicURL
// }
// s := &server{
// port: flagPort,
// conn: make(map[string][]*Connection),
// publicURL: flagPublicURL,
// }
// s.serve()
}
func host(c *cli.Context) (err error) {
if c.GlobalBool("debug") {
log.SetLevel("debug")
} else {
log.SetLevel("info")
}
return
}
@ -186,6 +114,85 @@ func relay(c *cli.Context) (err error) {
return s.serve()
}
//
// websocket implementation
//
var wsupgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// Connection determine what can be held
type Connection struct {
ID int
Joined time.Time
Domain string
LastGet string
ws *WebsocketConn
}
// Payload lists the data exchanged
type Payload struct {
Success bool `json:"success"`
Type string `json:"type,omitempty"`
Message string `json:"message,omitempty"`
IPAddress string `json:"ip,omitempty"`
Key string `json:"key,omitempty"`
}
func (p Payload) String() string {
b, _ := json.Marshal(p)
return string(b)
}
// WebsocketConn provides convenience functions for sending
// and receiving data with websockets, using mutex to
// make sure only one writer/reader
type WebsocketConn struct {
ws *websocket.Conn
sync.Mutex
}
// NewWebsocket returns a new websocket
func NewWebsocket(ws *websocket.Conn) *WebsocketConn {
return &WebsocketConn{
ws: ws,
}
}
func (ws *WebsocketConn) Send(p Payload) (err error) {
ws.Lock()
defer ws.Unlock()
log.Tracef("sending %+v", p)
err = ws.ws.WriteJSON(p)
return
}
func (ws *WebsocketConn) Receive() (p Payload, err error) {
ws.Lock()
defer ws.Unlock()
err = ws.ws.ReadJSON(&p)
log.Tracef("recv %+v", p)
return
}
//
// server implementation
//
type server struct {
publicURL string
port string
// connections stored as map of domain -> connections
conn map[string][]*Connection
sync.Mutex
}
func (s *server) serve() (err error) {
log.Infof("listening on :%s", s.port)
http.HandleFunc("/", s.handler)
@ -204,6 +211,7 @@ func (s *server) handler(w http.ResponseWriter, r *http.Request) {
func (s *server) handle(w http.ResponseWriter, r *http.Request) (err error) {
log.Debugf("URL: %s, Referer: %s", r.URL.Path, r.Referer())
// very special paths
if r.URL.Path == "/robots.txt" {
// special path
@ -229,49 +237,56 @@ Disallow: /`))
}
return t.Execute(w, view{PublicURL: s.publicURL})
} else {
log.Debugf("attempting to find %s", r.URL.Path)
pathToFile := r.URL.Path[1:]
domain := strings.Split(r.URL.Path[1:], "/")[0]
// check to make sure it has domain prepended
piecesOfReferer := strings.Split(r.Referer(), "/")
if len(piecesOfReferer) > 4 && strings.HasPrefix(r.Referer(), s.publicURL) {
domain = piecesOfReferer[3]
}
// prefix the domain if it doesn't exist
if !strings.HasPrefix(pathToFile, domain) {
pathToFile = domain + "/" + pathToFile
http.Redirect(w, r, "/"+pathToFile, 302)
return
}
// add index.html if it doesn't exist
if filepath.Ext(pathToFile) == "" {
if string(pathToFile[len(pathToFile)-1]) != "/" {
pathToFile += "/"
}
pathToFile += "index.html"
http.Redirect(w, r, "/"+pathToFile, 302)
return
}
// get IP address
var ipAddress string
ipAddress, err = GetClientIPHelper(r)
if err != nil {
log.Debugf("could not determine ip: %s", err.Error())
}
var data string
data, err = s.get(pathToFile, ipAddress)
if err != nil {
log.Debugf("attempting to find %s", r.URL.Path)
// determine file path and the domain
pathToFile := r.URL.Path[1:]
domain := strings.Split(r.URL.Path[1:], "/")[0]
// if there is a referer, try to obtain the domain from referer
piecesOfReferer := strings.Split(r.Referer(), "/")
if len(piecesOfReferer) > 4 && strings.HasPrefix(r.Referer(), s.publicURL) {
domain = piecesOfReferer[3]
}
// prefix the domain if it doesn't exist
if !strings.HasPrefix(pathToFile, domain) {
pathToFile = domain + "/" + pathToFile
http.Redirect(w, r, "/"+pathToFile, 302)
return
}
// trim prefix to get the path to file
pathToFile = strings.TrimPrefix(pathToFile, domain+"/")
// send GET request to websockets
var data string
data, err = s.get(domain, pathToFile, ipAddress)
if err != nil {
// try index.html if it doesn't exist
if string(pathToFile[len(pathToFile)-1]) != "/" {
pathToFile += "/"
}
pathToFile += "index.html"
data, err = s.get(pathToFile, ipAddress)
if err != nil {
return
}
}
// decode the data URI
var dataURL *dataurl.DataURL
dataURL, err = dataurl.DecodeString(data)
if err != nil {
log.Errorf("problem decoding '%s': %s", data, err.Error())
return
}
// determine the content type
contentType := dataURL.MediaType.ContentType()
if contentType == "application/octet-stream" || contentType == "" {
pathToFileExt := filepath.Ext(pathToFile)
@ -288,6 +303,8 @@ Disallow: /`))
}
}
}
// write the data to the requester
w.Header().Set("Content-Type", contentType)
w.Write(dataURL.Data)
return
@ -295,27 +312,6 @@ Disallow: /`))
return
}
var wsupgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
type Payload struct {
// message meta
Type string `json:"type"`
Success bool `json:"success"`
Message string `json:"message"`
IPAddress string `json:"ip"`
}
func (p Payload) String() string {
b, _ := json.Marshal(p)
return string(b)
}
func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err error) {
// handle websockets on this page
c, errUpgrade := wsupgrader.Upgrade(w, r, nil)
@ -336,7 +332,7 @@ func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err er
}
log.Debugf("recv: %s", p)
if !(p.Type == "domain" && p.Message != "") {
if !(p.Type == "domain" && p.Message != "" && p.Key != "") {
err = fmt.Errorf("got wrong type/domain: %s/%s", p.Type, p.Message)
log.Debug(err)
c.Close()
@ -355,6 +351,7 @@ func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err er
ID: len(s.conn[domain]),
Domain: domain,
Joined: time.Now(),
Key: p.Key,
ws: NewWebsocket(c),
}
s.conn[domain] = append(s.conn[domain], conn)
@ -370,9 +367,7 @@ func (s *server) handleWebsocket(w http.ResponseWriter, r *http.Request) (err er
return
}
func (s *server) get(filePath, ipAddress string) (payload string, err error) {
log.Debugf("requesting %s", filePath)
domain := strings.Split(filePath, "/")[0]
func (s *server) get(domain, filePath, ipAddress string) (payload string, err error) {
var connections []*Connection
s.Lock()
@ -382,11 +377,16 @@ func (s *server) get(filePath, ipAddress string) (payload string, err error) {
s.Unlock()
if connections == nil {
err = fmt.Errorf("no connections available for domain %s", domain)
log.Debug(err)
return
}
log.Debugf("requesting %s/%s from %d connections", domain, filePath, len(connections))
// loop through connections and try to get one to serve the file
for _, conn := range connections {
// any connection that initated with this key is viable
key := connections[0].Key
// loop through connections randomly and try to get one to serve the file
for _, i := range rand.Perm(len(connections)) {
var p Payload
p, err = func() (p Payload, err error) {
err = conn.ws.Send(Payload{
@ -397,7 +397,7 @@ func (s *server) get(filePath, ipAddress string) (payload string, err error) {
if err != nil {
return
}
p, err = conn.ws.Receive()
p, err = connections[i].ws.Receive()
return
}()
if err != nil {
@ -405,20 +405,20 @@ func (s *server) get(filePath, ipAddress string) (payload string, err error) {
s.DumpConnection(domain, conn.ID)
continue
}
if p.Type == "get" {
if len(p.Message) > 10 {
p.Message = p.Message[:10] + "..."
}
log.Debugf("recv: %+v", p)
if p.Type == "get" && p.Key == key {
payload = p.Message
if !p.Success {
err = fmt.Errorf(payload)
}
return
}
if len(p.Message) > 10 {
p.Message = p.Message[:10] + "..."
}
log.Debugf("recv: %+v", p)
err = fmt.Errorf("invalid response")
break
log.Debugf("no good data from %d", i)
}
err = fmt.Errorf("invalid response")
return
}