mirror of
https://github.com/giongto35/cloud-game.git
synced 2026-01-23 10:35:44 +00:00
258 lines
5.9 KiB
Go
258 lines
5.9 KiB
Go
package websocket
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"net"
|
|
"net/http"
|
|
"net/url"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
maxMessageSize = 10 * 1024
|
|
pingTime = pongTime * 9 / 10
|
|
pongTime = 7 * time.Second
|
|
writeWait = 5 * time.Second
|
|
)
|
|
|
|
type Client struct {
|
|
Dialer *websocket.Dialer
|
|
}
|
|
|
|
type Server struct {
|
|
Upgrader *Upgrader
|
|
}
|
|
|
|
type Connection struct {
|
|
alive bool
|
|
callback MessageHandler
|
|
conn deadlineConn
|
|
done chan struct{}
|
|
errorHandler ErrorHandler
|
|
once sync.Once
|
|
pingPong bool
|
|
send chan []byte
|
|
messSize int64
|
|
}
|
|
|
|
type deadlineConn struct {
|
|
*websocket.Conn
|
|
wt time.Duration
|
|
mu sync.Mutex // needed for concurrent writes of Gorilla
|
|
}
|
|
|
|
type MessageHandler func([]byte, error)
|
|
type ErrorHandler func(err error)
|
|
|
|
type Upgrader struct {
|
|
websocket.Upgrader
|
|
Origin string
|
|
}
|
|
|
|
var DefaultDialer = websocket.DefaultDialer
|
|
var DefaultUpgrader = Upgrader{Upgrader: websocket.Upgrader{
|
|
ReadBufferSize: 2048,
|
|
WriteBufferSize: 2048,
|
|
WriteBufferPool: &sync.Pool{},
|
|
EnableCompression: true,
|
|
}}
|
|
|
|
func NewUpgrader(origin string) *Upgrader {
|
|
u := DefaultUpgrader
|
|
switch {
|
|
case origin == "*":
|
|
u.CheckOrigin = func(r *http.Request) bool { return true }
|
|
case origin != "":
|
|
u.CheckOrigin = func(r *http.Request) bool { return r.Header.Get("Origin") == origin }
|
|
}
|
|
return &u
|
|
}
|
|
|
|
func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) {
|
|
if u.Origin != "" {
|
|
w.Header().Set("Access-Control-Allow-Origin", u.Origin)
|
|
}
|
|
return u.Upgrader.Upgrade(w, r, responseHeader)
|
|
}
|
|
|
|
func (s *Server) Connect(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Connection, error) {
|
|
u := s.Upgrader
|
|
if u == nil {
|
|
u = &DefaultUpgrader
|
|
}
|
|
conn, err := u.Upgrade(w, r, responseHeader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newSocket(conn, true), nil
|
|
}
|
|
|
|
func (c *Client) Connect(address url.URL) (*Connection, error) {
|
|
dialer := c.Dialer
|
|
if dialer == nil {
|
|
dialer = DefaultDialer
|
|
}
|
|
if address.Scheme == "wss" {
|
|
dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
|
|
}
|
|
conn, _, err := dialer.Dial(address.String(), nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return newSocket(conn, false), nil
|
|
}
|
|
|
|
func (conn *deadlineConn) write(t int, mess []byte) error {
|
|
conn.mu.Lock()
|
|
defer conn.mu.Unlock()
|
|
if err := conn.SetWriteDeadline(time.Now().Add(conn.wt)); err != nil {
|
|
return err
|
|
}
|
|
return conn.WriteMessage(t, mess)
|
|
}
|
|
|
|
func (conn *deadlineConn) writeControl(messageType int, data []byte, deadline time.Time) error {
|
|
conn.mu.Lock()
|
|
defer conn.mu.Unlock()
|
|
return conn.Conn.WriteControl(messageType, data, deadline)
|
|
}
|
|
|
|
// reader pumps messages from the websocket connection to the SetMessageHandler callback.
|
|
// Blocking, must be called as goroutine. Serializes all websocket reads.
|
|
func (c *Connection) reader() {
|
|
defer func() {
|
|
close(c.send)
|
|
c.close()
|
|
}()
|
|
|
|
var s int64 = maxMessageSize
|
|
if c.messSize > 0 {
|
|
s = c.messSize
|
|
}
|
|
c.conn.SetReadLimit(s)
|
|
|
|
_ = c.conn.SetReadDeadline(time.Now().Add(pongTime))
|
|
if c.pingPong {
|
|
c.conn.SetPongHandler(func(string) error { _ = c.conn.SetReadDeadline(time.Now().Add(pongTime)); return nil })
|
|
} else {
|
|
c.conn.SetPingHandler(func(string) error {
|
|
_ = c.conn.SetReadDeadline(time.Now().Add(pongTime))
|
|
err := c.conn.writeControl(websocket.PongMessage, nil, time.Now().Add(writeWait))
|
|
if err == websocket.ErrCloseSent {
|
|
return nil
|
|
} else if e, ok := err.(net.Error); ok && e.Timeout() {
|
|
return nil
|
|
}
|
|
return err
|
|
})
|
|
}
|
|
for {
|
|
_, message, err := c.conn.ReadMessage()
|
|
if err != nil {
|
|
if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) {
|
|
if c.errorHandler != nil {
|
|
c.errorHandler(err)
|
|
}
|
|
} else {
|
|
c.callback(message, err)
|
|
}
|
|
break
|
|
}
|
|
c.callback(message, err)
|
|
}
|
|
}
|
|
|
|
// writer pumps messages from the send channel to the websocket connection.
|
|
// Blocking, must be called as goroutine. Serializes all websocket writes.
|
|
func (c *Connection) writer() {
|
|
defer c.close()
|
|
|
|
if c.pingPong {
|
|
ticker := time.NewTicker(pingTime)
|
|
defer ticker.Stop()
|
|
|
|
for {
|
|
select {
|
|
case message, ok := <-c.send:
|
|
if !c.handleMessage(message, ok) {
|
|
return
|
|
}
|
|
case <-ticker.C:
|
|
if err := c.conn.write(websocket.PingMessage, nil); err != nil {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
} else {
|
|
for message := range c.send {
|
|
if !c.handleMessage(message, true) {
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (c *Connection) handleMessage(message []byte, ok bool) bool {
|
|
if !ok {
|
|
_ = c.conn.write(websocket.CloseMessage, nil)
|
|
return false
|
|
}
|
|
if err := c.conn.write(websocket.TextMessage, message); err != nil {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
func (c *Connection) close() {
|
|
c.once.Do(func() {
|
|
c.alive = false
|
|
_ = c.conn.Close()
|
|
close(c.done)
|
|
})
|
|
}
|
|
|
|
func newSocket(conn *websocket.Conn, pingPong bool) *Connection {
|
|
return &Connection{
|
|
callback: func(message []byte, err error) {},
|
|
conn: deadlineConn{Conn: conn, wt: writeWait},
|
|
done: make(chan struct{}, 1),
|
|
once: sync.Once{},
|
|
pingPong: pingPong,
|
|
send: make(chan []byte),
|
|
}
|
|
}
|
|
|
|
// IsServer returns true if the connection has server capabilities and not just a client.
|
|
// For now, we assume every connection with ping/pong handler is a server.
|
|
func (c *Connection) IsServer() bool { return c.pingPong }
|
|
|
|
func (c *Connection) SetMessageHandler(fn MessageHandler) { c.callback = fn }
|
|
|
|
func (c *Connection) SetErrorHandler(fn ErrorHandler) { c.errorHandler = fn }
|
|
|
|
func (c *Connection) SetMaxMessageSize(s int64) { c.messSize = s }
|
|
|
|
func (c *Connection) Listen() chan struct{} {
|
|
if c.alive {
|
|
return c.done
|
|
}
|
|
c.alive = true
|
|
go c.writer()
|
|
go c.reader()
|
|
return c.done
|
|
}
|
|
|
|
func (c *Connection) Write(data []byte) {
|
|
if c.alive {
|
|
c.send <- data
|
|
}
|
|
}
|
|
|
|
func (c *Connection) Close() {
|
|
if c.alive {
|
|
_ = c.conn.write(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""))
|
|
}
|
|
}
|