cloud-game/pkg/network/websocket/websocket.go
2024-11-17 22:27:34 +03:00

258 lines
5.9 KiB
Go

package websocket
import (
"crypto/tls"
"net"
"net/http"
"net/url"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
maxMessageSize = 10 * 1024
pingTime = pongTime * 9 / 10
pongTime = 7 * time.Second
writeWait = 5 * time.Second
)
type Client struct {
Dialer *websocket.Dialer
}
type Server struct {
Upgrader *Upgrader
}
type Connection struct {
alive bool
callback MessageHandler
conn deadlineConn
done chan struct{}
errorHandler ErrorHandler
once sync.Once
pingPong bool
send chan []byte
messSize int64
}
type deadlineConn struct {
*websocket.Conn
wt time.Duration
mu sync.Mutex // needed for concurrent writes of Gorilla
}
type MessageHandler func([]byte, error)
type ErrorHandler func(err error)
type Upgrader struct {
websocket.Upgrader
Origin string
}
var DefaultDialer = websocket.DefaultDialer
var DefaultUpgrader = Upgrader{Upgrader: websocket.Upgrader{
ReadBufferSize: 2048,
WriteBufferSize: 2048,
WriteBufferPool: &sync.Pool{},
EnableCompression: true,
}}
func NewUpgrader(origin string) *Upgrader {
u := DefaultUpgrader
switch {
case origin == "*":
u.CheckOrigin = func(r *http.Request) bool { return true }
case origin != "":
u.CheckOrigin = func(r *http.Request) bool { return r.Header.Get("Origin") == origin }
}
return &u
}
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) {
if u.Origin != "" {
w.Header().Set("Access-Control-Allow-Origin", u.Origin)
}
return u.Upgrader.Upgrade(w, r, responseHeader)
}
func (s *Server) Connect(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Connection, error) {
u := s.Upgrader
if u == nil {
u = &DefaultUpgrader
}
conn, err := u.Upgrade(w, r, responseHeader)
if err != nil {
return nil, err
}
return newSocket(conn, true), nil
}
func (c *Client) Connect(address url.URL) (*Connection, error) {
dialer := c.Dialer
if dialer == nil {
dialer = DefaultDialer
}
if address.Scheme == "wss" {
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
}
conn, _, err := dialer.Dial(address.String(), nil)
if err != nil {
return nil, err
}
return newSocket(conn, false), nil
}
func (conn *deadlineConn) write(t int, mess []byte) error {
conn.mu.Lock()
defer conn.mu.Unlock()
if err := conn.SetWriteDeadline(time.Now().Add(conn.wt)); err != nil {
return err
}
return conn.WriteMessage(t, mess)
}
func (conn *deadlineConn) writeControl(messageType int, data []byte, deadline time.Time) error {
conn.mu.Lock()
defer conn.mu.Unlock()
return conn.Conn.WriteControl(messageType, data, deadline)
}
// reader pumps messages from the websocket connection to the SetMessageHandler callback.
// Blocking, must be called as goroutine. Serializes all websocket reads.
func (c *Connection) reader() {
defer func() {
close(c.send)
c.close()
}()
var s int64 = maxMessageSize
if c.messSize > 0 {
s = c.messSize
}
c.conn.SetReadLimit(s)
_ = c.conn.SetReadDeadline(time.Now().Add(pongTime))
if c.pingPong {
c.conn.SetPongHandler(func(string) error { _ = c.conn.SetReadDeadline(time.Now().Add(pongTime)); return nil })
} else {
c.conn.SetPingHandler(func(string) error {
_ = c.conn.SetReadDeadline(time.Now().Add(pongTime))
err := c.conn.writeControl(websocket.PongMessage, nil, time.Now().Add(writeWait))
if err == websocket.ErrCloseSent {
return nil
} else if e, ok := err.(net.Error); ok && e.Timeout() {
return nil
}
return err
})
}
for {
_, message, err := c.conn.ReadMessage()
if err != nil {
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
if c.errorHandler != nil {
c.errorHandler(err)
}
} else {
c.callback(message, err)
}
break
}
c.callback(message, err)
}
}
// writer pumps messages from the send channel to the websocket connection.
// Blocking, must be called as goroutine. Serializes all websocket writes.
func (c *Connection) writer() {
defer c.close()
if c.pingPong {
ticker := time.NewTicker(pingTime)
defer ticker.Stop()
for {
select {
case message, ok := <-c.send:
if !c.handleMessage(message, ok) {
return
}
case <-ticker.C:
if err := c.conn.write(websocket.PingMessage, nil); err != nil {
return
}
}
}
} else {
for message := range c.send {
if !c.handleMessage(message, true) {
return
}
}
}
}
func (c *Connection) handleMessage(message []byte, ok bool) bool {
if !ok {
_ = c.conn.write(websocket.CloseMessage, nil)
return false
}
if err := c.conn.write(websocket.TextMessage, message); err != nil {
return false
}
return true
}
func (c *Connection) close() {
c.once.Do(func() {
c.alive = false
_ = c.conn.Close()
close(c.done)
})
}
func newSocket(conn *websocket.Conn, pingPong bool) *Connection {
return &Connection{
callback: func(message []byte, err error) {},
conn: deadlineConn{Conn: conn, wt: writeWait},
done: make(chan struct{}, 1),
once: sync.Once{},
pingPong: pingPong,
send: make(chan []byte),
}
}
// IsServer returns true if the connection has server capabilities and not just a client.
// For now, we assume every connection with ping/pong handler is a server.
func (c *Connection) IsServer() bool { return c.pingPong }
func (c *Connection) SetMessageHandler(fn MessageHandler) { c.callback = fn }
func (c *Connection) SetErrorHandler(fn ErrorHandler) { c.errorHandler = fn }
func (c *Connection) SetMaxMessageSize(s int64) { c.messSize = s }
func (c *Connection) Listen() chan struct{} {
if c.alive {
return c.done
}
c.alive = true
go c.writer()
go c.reader()
return c.done
}
func (c *Connection) Write(data []byte) {
if c.alive {
c.send <- data
}
}
func (c *Connection) Close() {
if c.alive {
_ = c.conn.write(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
}
}