From 3bb54fdad4619cc364ae7d9ef66e77fd4d38410f Mon Sep 17 00:00:00 2001 From: sergystepanov Date: Sat, 18 Mar 2023 20:24:06 +0300 Subject: [PATCH] Clean API (#391) Remove hard coupling between api (all the API data structures) and com (app clients communication protocol and logic). --- pkg/api/api.go | 92 +++++----- pkg/api/coordinator.go | 56 ++---- pkg/api/user.go | 4 - pkg/api/worker.go | 73 ++++---- pkg/com/com.go | 152 ++++++++-------- pkg/com/map.go | 109 ++++-------- pkg/com/map_test.go | 82 ++++++--- pkg/com/net.go | 262 ++++++++++++++-------------- pkg/com/net_test.go | 107 ++++++------ pkg/coordinator/coordinator.go | 4 +- pkg/coordinator/hub.go | 235 +++++++++++++------------ pkg/coordinator/user.go | 52 ++++-- pkg/coordinator/userapi.go | 8 +- pkg/coordinator/userhandlers.go | 29 ++-- pkg/coordinator/worker.go | 57 ++++-- pkg/coordinator/workerapi.go | 63 +++---- pkg/coordinator/workerhandlers.go | 7 +- pkg/logger/logger.go | 20 ++- pkg/network/uid.go | 22 --- pkg/network/websocket/websocket.go | 267 +++++++++++++++-------------- pkg/worker/coordinator.go | 100 ++++++----- pkg/worker/coordinatorhandlers.go | 172 ++++++++++--------- pkg/worker/room.go | 6 +- pkg/worker/router.go | 22 ++- pkg/worker/worker.go | 11 +- 25 files changed, 1042 insertions(+), 970 deletions(-) delete mode 100644 pkg/network/uid.go diff --git a/pkg/api/api.go b/pkg/api/api.go index bbb34f96..cfdafe0a 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,37 +1,64 @@ +// Package api defines the general API for both coordinator and worker applications. +// +// Each API call (request and response) is a JSON-encoded "packet" of the following structure: +// +// id - (optional) a globally unique packet id; +// t - (required) one of the predefined unique packet types; +// p - (optional) packet payload with arbitrary data. +// +// The basic idea behind this API is that the packets differentiate by their predefined types +// with which it is possible to unwrap the payload into distinct request/response data structures. +// And the id field is used for tracking packets through a chain of different network points (apps, devices), +// for example, passing a packet from a browser forward to a worker and back through a coordinator. +// +// Example: +// +// {"t":4,"p":{"ice":[{"urls":"stun:stun.l.google.com:19302"}],"games":["Sushi The Cat"],"wid":"cfv68irdrc3ifu3jn6bg"}} package api import ( - "encoding/base64" + "encoding/json" "fmt" - - "github.com/giongto35/cloud-game/v3/pkg/network" - "github.com/goccy/go-json" ) type ( - Stateful struct { - Id network.Uid `json:"id"` + Id interface { + String() string + } + Stateful[T Id] struct { + Id T `json:"id"` } Room struct { Rid string `json:"room_id"` // room id } - StatefulRoom struct { - Stateful + StatefulRoom[T Id] struct { + Stateful[T] Room } PT uint8 ) -type ( - RoomInterface interface { - GetRoom() string - } -) - -func StateRoom(id network.Uid, rid string) StatefulRoom { - return StatefulRoom{Stateful: Stateful{id}, Room: Room{rid}} +type In[I Id] struct { + Id I `json:"id,omitempty"` + T PT `json:"t"` + Payload json.RawMessage `json:"p,omitempty"` // should be json.RawMessage for 2-pass unmarshal } -func (sr StatefulRoom) GetRoom() string { return sr.Rid } + +func (i In[I]) GetId() I { return i.Id } +func (i In[I]) GetPayload() []byte { return i.Payload } +func (i In[I]) GetType() PT { return i.T } + +type Out struct { + Id string `json:"id,omitempty"` // string because omitempty won't work as intended with arrays + T uint8 `json:"t"` + Payload any `json:"p,omitempty"` +} + +func (o *Out) SetId(s string) { o.Id = s } +func (o *Out) SetType(u uint8) { o.T = u } +func (o *Out) SetPayload(a any) { o.Payload = a } +func (o *Out) SetGetId(s fmt.Stringer) { o.Id = s.String() } +func (o *Out) GetPayload() any { return o.Payload } // Packet codes: // @@ -110,6 +137,12 @@ var ( ErrMalformed = fmt.Errorf("malformed") ) +var ( + EmptyPacket = Out{Payload: ""} + ErrPacket = Out{Payload: "err"} + OkPacket = Out{Payload: "ok"} +) + func Unwrap[T any](data []byte) *T { out := new(T) if err := json.Unmarshal(data, out); err != nil { @@ -124,28 +157,3 @@ func UnwrapChecked[T any](bytes []byte, err error) (*T, error) { } return Unwrap[T](bytes), nil } - -// ToBase64Json encodes data to a URL-encoded Base64+JSON string. -func ToBase64Json(data any) (string, error) { - if data == nil { - return "", nil - } - b, err := json.Marshal(data) - if err != nil { - return "", err - } - return base64.URLEncoding.EncodeToString(b), nil -} - -// FromBase64Json decodes data from a URL-encoded Base64+JSON string. -func FromBase64Json(data string, obj any) error { - b, err := base64.URLEncoding.DecodeString(data) - if err != nil { - return err - } - err = json.Unmarshal(b, obj) - if err != nil { - return err - } - return nil -} diff --git a/pkg/api/coordinator.go b/pkg/api/coordinator.go index 9c23bfb4..6c79bc8b 100644 --- a/pkg/api/coordinator.go +++ b/pkg/api/coordinator.go @@ -1,24 +1,16 @@ package api -import ( - "encoding/base64" - "fmt" - - "github.com/giongto35/cloud-game/v3/pkg/network" -) - type ( - CloseRoomRequest string - ConnectionRequest struct { + CloseRoomRequest string + ConnectionRequest[T Id] struct { Addr string `json:"addr,omitempty"` - Id string `json:"id,omitempty"` + Id T `json:"id,omitempty"` IsHTTPS bool `json:"is_https,omitempty"` PingURL string `json:"ping_url,omitempty"` Port string `json:"port,omitempty"` Tag string `json:"tag,omitempty"` Zone string `json:"zone,omitempty"` } - GetWorkerListRequest struct{} GetWorkerListResponse struct { Servers []Server `json:"servers"` } @@ -32,40 +24,18 @@ const ( WorkerIdParam = "wid" ) -func RequestToHandshake(data string) (*ConnectionRequest, error) { - if data == "" { - return nil, ErrMalformed - } - handshake, err := UnwrapChecked[ConnectionRequest](base64.URLEncoding.DecodeString(data)) - if err != nil || handshake == nil { - return nil, fmt.Errorf("%v (%v)", err, handshake) - } - return handshake, nil -} - -func (c ConnectionRequest) HasUID() (bool, network.Uid) { - hid := network.Uid(c.Id) - if !(c.Id == "" || !network.ValidUid(hid)) { - return true, hid - } - return false, "" -} - // Server contains a list of server groups. // Server is a separate machine that may contain // multiple sub-processes. type Server struct { - Addr string `json:"addr,omitempty"` - Id network.Uid `json:"id,omitempty"` - IsBusy bool `json:"is_busy,omitempty"` - InGroup bool `json:"in_group,omitempty"` - PingURL string `json:"ping_url"` - Port string `json:"port,omitempty"` - Replicas uint32 `json:"replicas,omitempty"` - Tag string `json:"tag,omitempty"` - Zone string `json:"zone,omitempty"` -} - -type HasServerInfo interface { - GetServerList() []Server + Addr string `json:"addr,omitempty"` + Id Id `json:"id,omitempty"` + IsBusy bool `json:"is_busy,omitempty"` + InGroup bool `json:"in_group,omitempty"` + Machine string `json:"machine,omitempty"` + PingURL string `json:"ping_url"` + Port string `json:"port,omitempty"` + Replicas uint32 `json:"replicas,omitempty"` + Tag string `json:"tag,omitempty"` + Zone string `json:"zone,omitempty"` } diff --git a/pkg/api/user.go b/pkg/api/user.go index 998b3320..df7f9e22 100644 --- a/pkg/api/user.go +++ b/pkg/api/user.go @@ -24,7 +24,3 @@ type ( WebrtcAnswerUserRequest string WebrtcUserIceCandidate string ) - -func InitSessionResult(ice []IceServer, games []string, wid string) (PT, InitSessionUserResponse) { - return InitSession, InitSessionUserResponse{Ice: ice, Games: games, Wid: wid} -} diff --git a/pkg/api/worker.go b/pkg/api/worker.go index 2a90225f..c1b4f14d 100644 --- a/pkg/api/worker.go +++ b/pkg/api/worker.go @@ -1,68 +1,61 @@ package api -import "github.com/giongto35/cloud-game/v3/pkg/network" - -type GameInfo struct { - Name string `json:"name"` - Base string `json:"base"` - Path string `json:"path"` - Type string `json:"type"` -} - type ( - ChangePlayerRequest = struct { - StatefulRoom + ChangePlayerRequest[T Id] struct { + StatefulRoom[T] Index int `json:"index"` } - ChangePlayerResponse int - GameQuitRequest struct { - StatefulRoom + ChangePlayerResponse int + GameQuitRequest[T Id] struct { + StatefulRoom[T] } - LoadGameRequest struct { - StatefulRoom + LoadGameRequest[T Id] struct { + StatefulRoom[T] } - LoadGameResponse string - SaveGameRequest struct { - StatefulRoom + LoadGameResponse string + SaveGameRequest[T Id] struct { + StatefulRoom[T] } - SaveGameResponse string - StartGameRequest struct { - StatefulRoom + SaveGameResponse string + StartGameRequest[T Id] struct { + StatefulRoom[T] Record bool RecordUser string Game GameInfo `json:"game"` PlayerIndex int `json:"player_index"` } + GameInfo struct { + Name string `json:"name"` + Base string `json:"base"` + Path string `json:"path"` + Type string `json:"type"` + } StartGameResponse struct { Room Record bool } - RecordGameRequest struct { - StatefulRoom + RecordGameRequest[T Id] struct { + StatefulRoom[T] Active bool `json:"active"` User string `json:"user"` } - RecordGameResponse string - TerminateSessionRequest struct { - Stateful + RecordGameResponse string + TerminateSessionRequest[T Id] struct { + Stateful[T] } - ToggleMultitapRequest struct { - StatefulRoom + ToggleMultitapRequest[T Id] struct { + StatefulRoom[T] } - WebrtcAnswerRequest struct { - Stateful + WebrtcAnswerRequest[T Id] struct { + Stateful[T] Sdp string `json:"sdp"` } - WebrtcIceCandidateRequest struct { - Stateful - Candidate string `json:"candidate"` + WebrtcIceCandidateRequest[T Id] struct { + Stateful[T] + Candidate string `json:"candidate"` // Base64-encoded ICE candidate } - WebrtcInitRequest struct { - Stateful + WebrtcInitRequest[T Id] struct { + Stateful[T] } WebrtcInitResponse string ) - -func NewWebrtcIceCandidateRequest(id network.Uid, can string) (PT, any) { - return WebrtcIce, WebrtcIceCandidateRequest{Stateful: Stateful{id}, Candidate: can} -} diff --git a/pkg/com/com.go b/pkg/com/com.go index 256dccd0..87247046 100644 --- a/pkg/com/com.go +++ b/pkg/com/com.go @@ -1,93 +1,101 @@ package com -import ( - "encoding/json" +import "github.com/giongto35/cloud-game/v3/pkg/logger" - "github.com/giongto35/cloud-game/v3/pkg/api" - "github.com/giongto35/cloud-game/v3/pkg/logger" - "github.com/giongto35/cloud-game/v3/pkg/network" -) - -type ( - In struct { - Id network.Uid `json:"id,omitempty"` - T api.PT `json:"t"` - Payload json.RawMessage `json:"p,omitempty"` - } - Out struct { - Id network.Uid `json:"id,omitempty"` - T api.PT `json:"t"` - Payload any `json:"p,omitempty"` - } -) - -var ( - EmptyPacket = Out{Payload: ""} - ErrPacket = Out{Payload: "err"} - OkPacket = Out{Payload: "ok"} -) - -type ( - NetClient interface { - Close() - Id() network.Uid - } - RegionalClient interface { - In(region string) bool - } -) - -type SocketClient struct { - NetClient - - id network.Uid - wire *Client - Tag string - Log *logger.Logger +type NetClient interface { + Disconnect() + Id() Uid } -func New(conn *Client, tag string, id network.Uid, log *logger.Logger) SocketClient { - l := log.Extend(log.With().Str("cid", id.Short())) - dir := "→" +type NetMap[T NetClient] struct{ Map[Uid, T] } + +func NewNetMap[T NetClient]() NetMap[T] { return NetMap[T]{Map: Map[Uid, T]{m: make(map[Uid]T, 10)}} } + +func (m *NetMap[T]) Add(client T) { m.Put(client.Id(), client) } +func (m *NetMap[T]) Remove(client T) { m.Map.Remove(client.Id()) } +func (m *NetMap[T]) RemoveDisconnect(client T) { client.Disconnect(); m.Remove(client) } + +type SocketClient[T ~uint8, P Packet[T], X any, P2 Packet2[X]] struct { + id Uid + rpc *RPC[T, P] + sock *Connection + log *logger.Logger // a special logger for showing x -> y directions +} + +func NewConnection[T ~uint8, P Packet[T], X any, P2 Packet2[X]](conn *Connection, id Uid, log *logger.Logger) *SocketClient[T, P, X, P2] { + if id.IsNil() { + id = NewUid() + } + dir := logger.MarkOut if conn.IsServer() { - dir = "←" + dir = logger.MarkIn } - l.Debug().Str("c", tag).Str("d", dir).Msg("Connect") - return SocketClient{id: id, wire: conn, Tag: tag, Log: l} + dirClLog := log.Extend(log.With(). + Str("cid", id.Short()). + Str(logger.DirectionField, dir), + ) + dirClLog.Debug().Msg("Connect") + return &SocketClient[T, P, X, P2]{sock: conn, id: id, log: dirClLog} } -func (c *SocketClient) SetId(id network.Uid) { c.id = id } - -func (c *SocketClient) OnPacket(fn func(p In) error) { - logFn := func(p In) { - c.Log.Debug().Str("c", c.Tag).Str("d", "←").Msgf("%s", p.T) - if err := fn(p); err != nil { - c.Log.Error().Err(err).Send() +func (c *SocketClient[T, P, _, _]) ProcessPackets(fn func(in P) error) chan struct{} { + c.rpc = NewRPC[T, P]() + c.rpc.Handler = func(p P) { + c.log.Debug().Str(logger.DirectionField, logger.MarkIn).Msgf("%v", p.GetType()) + if err := fn(p); err != nil { // 3rd handler + c.log.Error().Err(err).Send() } } - c.wire.OnPacket(logFn) + c.sock.conn.SetMessageHandler(c.handleMessage) // 1st handler + return c.sock.conn.Listen() +} + +func (c *SocketClient[_, _, _, _]) handleMessage(message []byte, err error) { + if err != nil { + c.log.Error().Err(err).Send() + return + } + if err = c.rpc.handleMessage(message); err != nil { // 2nd handler + c.log.Error().Err(err).Send() + return + } +} + +func (c *SocketClient[_, P, X, P2]) Route(in P, out P2) { + rq := P2(new(X)) + rq.SetId(in.GetId().String()) + rq.SetType(uint8(in.GetType())) + rq.SetPayload(out.GetPayload()) + if err := c.rpc.Send(c.sock.conn, rq); err != nil { + c.log.Error().Err(err).Msgf("message route fail") + } } // Send makes a blocking call. -func (c *SocketClient) Send(t api.PT, data any) ([]byte, error) { - c.Log.Debug().Str("c", c.Tag).Str("d", "→").Msgf("ᵇ%s", t) - return c.wire.Call(t, data) +func (c *SocketClient[T, P, X, P2]) Send(t T, data any) ([]byte, error) { + c.log.Debug().Str(logger.DirectionField, logger.MarkOut).Msgf("ᵇ%v", t) + rq := P2(new(X)) + rq.SetType(uint8(t)) + rq.SetPayload(data) + return c.rpc.Call(c.sock.conn, rq) } // Notify just sends a message and goes further. -func (c *SocketClient) Notify(t api.PT, data any) { - c.Log.Debug().Str("c", c.Tag).Str("d", "→").Msgf("%s", t) - _ = c.wire.Send(t, data) +func (c *SocketClient[T, P, X, P2]) Notify(t T, data any) { + c.log.Debug().Str(logger.DirectionField, logger.MarkOut).Msgf("%v", t) + rq := P2(new(X)) + rq.SetType(uint8(t)) + rq.SetPayload(data) + if err := c.rpc.Send(c.sock.conn, rq); err != nil { + c.log.Error().Err(err).Msgf("notify fail") + } } -func (c *SocketClient) Close() { - c.wire.Close() - c.Log.Debug().Str("c", c.Tag).Str("d", "x").Msg("Close") +func (c *SocketClient[_, _, _, _]) Disconnect() { + c.sock.conn.Close() + c.rpc.Cleanup() + c.log.Debug().Str(logger.DirectionField, logger.MarkCross).Msg("Close") } -func (c *SocketClient) Id() network.Uid { return c.id } -func (c *SocketClient) Listen() { c.ProcessMessages(); <-c.Done() } -func (c *SocketClient) ProcessMessages() { c.wire.Listen() } -func (c *SocketClient) Route(in In, out Out) { _ = c.wire.Route(in, out) } -func (c *SocketClient) String() string { return c.Tag + ":" + string(c.Id()) } -func (c *SocketClient) Done() chan struct{} { return c.wire.Wait() } +func (c *SocketClient[_, _, _, _]) Id() Uid { return c.id } +func (c *SocketClient[_, _, _, _]) String() string { return c.Id().String() } diff --git a/pkg/com/map.go b/pkg/com/map.go index 0d59a52e..e7e4cf83 100644 --- a/pkg/com/map.go +++ b/pkg/com/map.go @@ -1,98 +1,53 @@ package com -import ( - "errors" - "sync" +import "sync" - "github.com/giongto35/cloud-game/v3/pkg/network" -) - -// NetMap defines a thread-safe NetClient list. -type NetMap[T NetClient] struct { - m map[string]T +// Map defines a concurrent-safe map structure. +// Keep in mind that the underlying map structure will grow indefinitely. +type Map[K comparable, V any] struct { + m map[K]V mu sync.Mutex } -// ErrNotFound is returned by NetMap when some value is not present. -var ErrNotFound = errors.New("not found") - -func NewNetMap[T NetClient]() NetMap[T] { return NetMap[T]{m: make(map[string]T, 10)} } - -// Add adds a new NetClient value with its id value as the key. -func (m *NetMap[T]) Add(client T) { m.Put(string(client.Id()), client) } - -// Put adds a new NetClient value with a custom key value. -func (m *NetMap[T]) Put(key string, client T) { - m.mu.Lock() - m.m[key] = client - m.mu.Unlock() -} - -// Remove removes NetClient from the map if present. -func (m *NetMap[T]) Remove(client T) { m.RemoveByKey(string(client.Id())) } - -// RemoveByKey removes NetClient from the map by a specified key value. -func (m *NetMap[T]) RemoveByKey(key string) { +func (m *Map[K, _]) Has(key K) bool { _, ok := m.Find(key); return ok } +func (m *Map[_, _]) Len() int { m.mu.Lock(); defer m.mu.Unlock(); return len(m.m) } +func (m *Map[K, V]) Pop(key K) V { m.mu.Lock() + v := m.m[key] delete(m.m, key) m.mu.Unlock() + return v } +func (m *Map[K, V]) Put(key K, v V) { m.mu.Lock(); m.m[key] = v; m.mu.Unlock() } +func (m *Map[K, _]) Remove(key K) { m.mu.Lock(); delete(m.m, key); m.mu.Unlock() } -// RemoveAll removes all occurrences of specified NetClient. -func (m *NetMap[T]) RemoveAll(client T) { +// Find returns the first value found and a boolean flag if its found or not. +func (m *Map[K, V]) Find(key K) (v V, ok bool) { m.mu.Lock() defer m.mu.Unlock() - for k, c := range m.m { - if c.Id() == client.Id() { - delete(m.m, k) + if vv, ok := m.m[key]; ok { + return vv, true + } + return v, false +} + +// FindBy searches the first key-value with the provided predicate function. +func (m *Map[K, V]) FindBy(fn func(v V) bool) (v V, ok bool) { + m.mu.Lock() + defer m.mu.Unlock() + for _, vv := range m.m { + if fn(vv) { + return vv, true } } + return v, false } -func (m *NetMap[T]) IsEmpty() bool { +// ForEach processes every element with the provided callback function. +func (m *Map[K, V]) ForEach(fn func(v V)) { m.mu.Lock() defer m.mu.Unlock() - return len(m.m) == 0 -} - -// List returns the current NetClient map. -func (m *NetMap[T]) List() map[string]T { return m.m } - -func (m *NetMap[T]) Has(id network.Uid) bool { - _, err := m.Find(string(id)) - return err == nil -} - -// Find searches the first NetClient by a specified key value. -func (m *NetMap[T]) Find(key string) (client T, err error) { - if key == "" { - return client, ErrNotFound - } - m.mu.Lock() - defer m.mu.Unlock() - if c, ok := m.m[key]; ok { - return c, nil - } - return client, ErrNotFound -} - -// FindBy searches the first NetClient with the provided predicate function. -func (m *NetMap[T]) FindBy(fn func(c T) bool) (client T, err error) { - m.mu.Lock() - defer m.mu.Unlock() - for _, w := range m.m { - if fn(w) { - return w, nil - } - } - return client, ErrNotFound -} - -// ForEach processes every NetClient with the provided callback function. -func (m *NetMap[T]) ForEach(fn func(c T)) { - m.mu.Lock() - defer m.mu.Unlock() - for _, w := range m.m { - fn(w) + for _, v := range m.m { + fn(v) } } diff --git a/pkg/com/map_test.go b/pkg/com/map_test.go index f107896d..3a4d89bf 100644 --- a/pkg/com/map_test.go +++ b/pkg/com/map_test.go @@ -1,32 +1,62 @@ package com -import ( - "fmt" - "sync/atomic" - "testing" +import "testing" - "github.com/giongto35/cloud-game/v3/pkg/network" -) +func TestMap_Base(t *testing.T) { + // map map + m := Map[int, int]{m: make(map[int]int)} -type testClient struct { - NetClient - id int - c int32 -} - -func (t *testClient) Id() network.Uid { return network.Uid(fmt.Sprintf("%v", t.id)) } -func (t *testClient) change(n int) { atomic.AddInt32(&t.c, int32(n)) } - -func TestPointerValue(t *testing.T) { - m := NewNetMap[*testClient]() - c := testClient{id: 1} - m.Add(&c) - fc, _ := m.FindBy(func(c *testClient) bool { return c.id == 1 }) - c.change(100) - fc2, _ := m.Find(fc.Id().String()) - - expected := c.c == fc.c && c.c == fc2.c - if !expected { - t.Errorf("not expected change, o: %v != %v != %v", c.c, fc.c, fc2.c) + if m.Len() > 0 { + t.Errorf("should be empty, %v %v", m.Len(), m.m) + } + k := 0 + m.Put(k, 0) + if m.Len() == 0 { + t.Errorf("should not be empty, %v", m.m) + } + if !m.Has(k) { + t.Errorf("should have the key %v, %v", k, m.m) + } + v, ok := m.Find(k) + if v != 0 && !ok { + t.Errorf("should have the key %v and ok, %v %v", k, ok, m.m) + } + v, ok = m.Find(k + 1) + if ok { + t.Errorf("should not find anything, %v %v", ok, m.m) + } + m.Put(1, 1) + v, ok = m.FindBy(func(v int) bool { return v == 1 }) + if v != 1 && !ok { + t.Errorf("should have the key %v and ok, %v %v", 1, ok, m.m) + } + sum := 0 + m.ForEach(func(v int) { sum += v }) + if sum != 1 { + t.Errorf("shoud have exact sum of 1, but have %v", sum) + } + m.Remove(1) + if !m.Has(0) || m.Len() > 1 { + t.Errorf("should remove only one element, but has %v", m.m) + } + m.Put(3, 3) + v = m.Pop(3) + if v != 3 { + t.Errorf("should have value %v, but has %v %v", 3, v, m.m) + } + m.Remove(3) + m.Remove(0) + if m.Len() != 0 { + t.Errorf("should be completely empty, but %v", m.m) + } +} + +func TestMap_Concurrency(t *testing.T) { + m := Map[int, int]{m: make(map[int]int)} + for i := 0; i < 100; i++ { + i := i + go m.Put(i, i) + go m.Has(i) + go m.Pop(i) } } diff --git a/pkg/com/net.go b/pkg/com/net.go index 674774f6..2d0c90af 100644 --- a/pkg/com/net.go +++ b/pkg/com/net.go @@ -2,187 +2,175 @@ package com import ( "errors" + "fmt" "net/http" "net/url" - "sync" "time" - "github.com/giongto35/cloud-game/v3/pkg/api" - "github.com/giongto35/cloud-game/v3/pkg/logger" - "github.com/giongto35/cloud-game/v3/pkg/network" "github.com/giongto35/cloud-game/v3/pkg/network/websocket" "github.com/goccy/go-json" + "github.com/rs/xid" ) +type Uid struct { + xid.ID +} + +var NilUid = Uid{xid.NilID()} + +func NewUid() Uid { return Uid{xid.New()} } + +func UidFromString(id string) (Uid, error) { + x, err := xid.FromString(id) + if err != nil { + return NilUid, err + } + return Uid{x}, nil +} + +func (u Uid) Short() string { return u.String()[:3] + "." + u.String()[len(u.String())-3:] } + +type HasCallId interface { + SetGetId(fmt.Stringer) +} + +type Writer interface { + Write([]byte) +} + +type Packet[T ~uint8] interface { + GetId() Uid + GetType() T + GetPayload() []byte +} + +type Packet2[T any] interface { + SetId(string) + SetType(uint8) + SetPayload(any) + SetGetId(fmt.Stringer) + GetPayload() any + *T // non-interface type constraint element +} + +type Transport interface { + SetMessageHandler(func([]byte, error)) +} + +type RPC[T ~uint8, P Packet[T]] struct { + CallTimeout time.Duration + Handler func(P) + Transport Transport + + calls Map[Uid, *request] +} + +type request struct { + done chan struct{} + err error + response []byte +} + +const DefaultCallTimeout = 5 * time.Second + +var errCanceled = errors.New("canceled") +var errTimeout = errors.New("timeout") + type ( - Connector struct { - tag string - wu *websocket.Upgrader - } Client struct { - conn *websocket.WS - queue map[network.Uid]*call - onPacket func(packet In) - mu sync.Mutex + websocket.Client } - call struct { - done chan struct{} - err error - Response In + Server struct { + websocket.Server + } + Connection struct { + conn *websocket.Connection } - Option = func(c *Connector) ) -var ( - errConnClosed = errors.New("connection closed") - errTimeout = errors.New("timeout") -) -var outPool = sync.Pool{New: func() any { o := Out{}; return &o }} +func (c *Client) Connect(addr url.URL) (*Connection, error) { return connect(c.Client.Connect(addr)) } -func WithOrigin(url string) Option { return func(c *Connector) { c.wu = websocket.NewUpgrader(url) } } -func WithTag(tag string) Option { return func(c *Connector) { c.tag = tag } } +func (s *Server) Origin(host string) { s.Upgrader = websocket.NewUpgrader(host) } -const callTimeout = 5 * time.Second - -func NewConnector(opts ...Option) *Connector { - c := &Connector{} - for _, opt := range opts { - opt(c) - } - if c.wu == nil { - c.wu = &websocket.DefaultUpgrader - } - return c +func (s *Server) Connect(w http.ResponseWriter, r *http.Request) (*Connection, error) { + return connect(s.Server.Connect(w, r, nil)) } -func (co *Connector) NewClientServer(w http.ResponseWriter, r *http.Request, log *logger.Logger) (*SocketClient, error) { - ws, err := co.wu.Upgrade(w, r, nil) +func (c Connection) IsServer() bool { return c.conn.IsServer() } + +func connect(conn *websocket.Connection, err error) (*Connection, error) { if err != nil { return nil, err } - conn, err := connect(websocket.NewServerWithConn(ws, log)) - if err != nil { - return nil, err - } - c := New(conn, co.tag, network.NewUid(), log) - return &c, nil + return &Connection{conn: conn}, nil } -func (co *Connector) NewClient(address url.URL, log *logger.Logger) (*Client, error) { - return connect(websocket.NewClient(address, log)) +func NewRPC[T ~uint8, P Packet[T]]() *RPC[T, P] { + return &RPC[T, P]{calls: Map[Uid, *request]{m: make(map[Uid]*request, 10)}} } -func connect(conn *websocket.WS, err error) (*Client, error) { - if err != nil { - return nil, err - } - client := &Client{conn: conn, queue: make(map[network.Uid]*call, 1)} - client.conn.OnMessage = client.handleMessage - return client, nil -} - -func (c *Client) IsServer() bool { return c.conn.IsServer() } - -func (c *Client) OnPacket(fn func(packet In)) { c.mu.Lock(); c.onPacket = fn; c.mu.Unlock() } - -func (c *Client) Listen() { c.mu.Lock(); c.conn.Listen(); c.mu.Unlock() } - -func (c *Client) Close() { - // !to handle error - c.conn.Close() - c.drain(errConnClosed) -} - -func (c *Client) Call(type_ api.PT, payload any) ([]byte, error) { - // !to expose channel instead of results - rq := outPool.Get().(*Out) - id := network.NewUid() - rq.Id, rq.T, rq.Payload = id, type_, payload - r, err := json.Marshal(rq) - outPool.Put(rq) - if err != nil { - //delete(c.queue, id) - return nil, err - } - - task := &call{done: make(chan struct{})} - c.mu.Lock() - c.queue[id] = task - c.conn.Write(r) - c.mu.Unlock() - select { - case <-task.done: - case <-time.After(callTimeout): - task.err = errTimeout - } - return task.Response.Payload, task.err -} - -func (c *Client) Send(type_ api.PT, pl any) error { - rq := outPool.Get().(*Out) - rq.Id, rq.T, rq.Payload = "", type_, pl - defer outPool.Put(rq) - return c.SendPacket(rq) -} - -func (c *Client) Route(p In, pl Out) error { - rq := outPool.Get().(*Out) - rq.Id, rq.T, rq.Payload = p.Id, p.T, pl.Payload - defer outPool.Put(rq) - return c.SendPacket(rq) -} - -func (c *Client) SendPacket(packet *Out) error { +func (t *RPC[_, _]) Send(w Writer, packet any) error { r, err := json.Marshal(packet) if err != nil { return err } - c.mu.Lock() - c.conn.Write(r) - c.mu.Unlock() + w.Write(r) return nil } -func (c *Client) Wait() chan struct{} { return c.conn.Done } +func (t *RPC[_, _]) Call(w Writer, rq HasCallId) ([]byte, error) { + id := NewUid() + // set new request id for the external request structure as string + rq.SetGetId(id) -func (c *Client) handleMessage(message []byte, err error) { + r, err := json.Marshal(rq) if err != nil { - return + return nil, err } - - var res In - if err = json.Unmarshal(message, &res); err != nil { - return + task := &request{done: make(chan struct{})} + t.calls.Put(id, task) + w.Write(r) + select { + case <-task.done: + case <-time.After(t.callTimeout()): + task.err = errTimeout } + return task.response, task.err +} - // empty id implies that we won't track (wait) the response - if !res.Id.Empty() { - if task := c.pop(res.Id); task != nil { - task.Response = res - close(task.done) - return +func (t *RPC[_, P]) handleMessage(message []byte) error { + res := *new(P) + if err := json.Unmarshal(message, &res); err != nil { + return err + } + // if we have an id, then unblock blocking call with that id + id := res.GetId() + if id != NilUid { + if blocked := t.calls.Pop(id); blocked != nil { + blocked.response = res.GetPayload() + close(blocked.done) + return nil } } - c.onPacket(res) + if t.Handler != nil { + t.Handler(res) + } + return nil } -// pop extracts and removes a task from the queue by its id. -func (c *Client) pop(id network.Uid) *call { - c.mu.Lock() - task := c.queue[id] - delete(c.queue, id) - c.mu.Unlock() - return task +func (t *RPC[_, _]) callTimeout() time.Duration { + if t.CallTimeout > 0 { + return t.CallTimeout + } + return DefaultCallTimeout } -// drain cancels all what's left in the task queue. -func (c *Client) drain(err error) { - c.mu.Lock() - for _, task := range c.queue { +func (t *RPC[_, _]) Cleanup() { + // drain cancels all what's left in the task queue. + t.calls.ForEach(func(task *request) { if task.err == nil { - task.err = err + task.err = errCanceled } close(task.done) - } - c.mu.Unlock() + }) } diff --git a/pkg/com/net_test.go b/pkg/com/net_test.go index 3ace306e..44c21805 100644 --- a/pkg/com/net_test.go +++ b/pkg/com/net_test.go @@ -10,22 +10,32 @@ import ( "testing" "time" - "github.com/giongto35/cloud-game/v3/pkg/api" "github.com/giongto35/cloud-game/v3/pkg/logger" "github.com/giongto35/cloud-game/v3/pkg/network/websocket" ) -var log = logger.Default() - -func TestPackets(t *testing.T) { - r, err := json.Marshal(Out{Payload: "asd"}) - if err != nil { - t.Fatalf("can't marshal packet") - } - - t.Logf("PACKET: %v", string(r)) +type TestIn struct { + Id Uid + T uint8 + Payload json.RawMessage } +func (i TestIn) GetId() Uid { return i.Id } +func (i TestIn) GetType() uint8 { return i.T } +func (i TestIn) GetPayload() []byte { return i.Payload } + +type TestOut struct { + Id string + T uint8 + Payload any +} + +func (o *TestOut) SetId(s string) { o.Id = s } +func (o *TestOut) SetType(u uint8) { o.T = u } +func (o *TestOut) SetPayload(a any) { o.Payload = a } +func (o *TestOut) SetGetId(stringer fmt.Stringer) { o.Id = stringer.String() } +func (o *TestOut) GetPayload() any { return o.Payload } + func TestWebsocket(t *testing.T) { testCases := []struct { name string @@ -39,37 +49,29 @@ func TestWebsocket(t *testing.T) { } func testWebsocket(t *testing.T) { - var wg sync.WaitGroup - sh := newServer(t) + server := newServer(t) client := newClient(t, url.URL{Scheme: "ws", Host: "localhost:8080", Path: "/ws"}) - client.OnPacket(func(packet In) { - // nop - }) - client.Listen() - wg.Wait() + clDone := client.ProcessPackets(func(in TestIn) error { return nil }) - server := sh.s - - if server == nil { + if server.conn == nil { t.Fatalf("couldn't make new socket") } calls := []struct { - typ api.PT - payload any + packet TestOut concurrent bool value any }{ - {typ: 10, payload: "test", value: "test", concurrent: true}, - {typ: 10, payload: "test2", value: "test2"}, - {typ: 11, payload: "test3", value: "test3"}, - {typ: 99, payload: "", value: ""}, - {typ: 0}, - {typ: 12, payload: 123, value: 123}, - {typ: 10, payload: false, value: false}, - {typ: 10, payload: true, value: true}, - {typ: 11, payload: []string{"test", "test", "test"}, value: []string{"test", "test", "test"}}, - {typ: 22, payload: []string{}, value: []string{}}, + {packet: TestOut{T: 10, Payload: "test"}, value: "test", concurrent: true}, + {packet: TestOut{T: 10, Payload: "test2"}, value: "test2"}, + {packet: TestOut{T: 11, Payload: "test3"}, value: "test3"}, + {packet: TestOut{T: 99, Payload: ""}, value: ""}, + {packet: TestOut{T: 0}}, + {packet: TestOut{T: 12, Payload: 123}, value: 123}, + {packet: TestOut{T: 10, Payload: false}, value: false}, + {packet: TestOut{T: 10, Payload: true}, value: true}, + {packet: TestOut{T: 11, Payload: []string{"test", "test", "test"}}, value: []string{"test", "test", "test"}}, + {packet: TestOut{T: 22, Payload: []string{}}, value: []string{}}, } const n = 42 @@ -82,10 +84,11 @@ func testWebsocket(t *testing.T) { if call.concurrent { rand.New(rand.NewSource(time.Now().UnixNano())) for i := 0; i < n; i++ { + packet := call.packet go func() { defer wait.Done() time.Sleep(time.Duration(rand.Intn(200-100)+100) * time.Millisecond) - vv, err := client.Call(call.typ, call.payload) + vv, err := client.rpc.Call(client.sock.conn, &packet) err = checkCall(vv, err, call.value) if err != nil { t.Errorf("%v", err) @@ -95,7 +98,8 @@ func testWebsocket(t *testing.T) { } } else { for i := 0; i < n; i++ { - vv, err := client.Call(call.typ, call.payload) + packet := call.packet + vv, err := client.rpc.Call(client.sock.conn, &packet) err = checkCall(vv, err, call.value) if err != nil { wait.Done() @@ -108,18 +112,22 @@ func testWebsocket(t *testing.T) { } wait.Wait() - client.Close() - <-client.conn.Done - server.Close() - <-server.Done + client.sock.conn.Close() + client.rpc.Cleanup() + <-clDone + server.conn.Close() + <-server.done } -func newClient(t *testing.T, addr url.URL) *Client { - conn, err := NewConnector().NewClient(addr, log) +func newClient(t *testing.T, addr url.URL) *SocketClient[uint8, TestIn, TestOut, *TestOut] { + connector := Client{} + conn, err := connector.Connect(addr) if err != nil { t.Fatalf("error: couldn't connect to %v because of %v", addr.String(), err) } - return conn + rpc := new(RPC[uint8, TestIn]) + rpc.calls = Map[Uid, *request]{m: make(map[Uid]*request, 10)} + return &SocketClient[uint8, TestIn, TestOut, *TestOut]{sock: conn, log: logger.Default(), rpc: rpc} } func checkCall(v []byte, err error, need any) error { @@ -164,22 +172,21 @@ func checkCall(v []byte, err error, need any) error { } type serverHandler struct { - s *websocket.WS // ws server reference made dynamically on HTTP request + conn *websocket.Connection // ws server reference made dynamically on HTTP request + done chan struct{} } func (s *serverHandler) serve(t *testing.T) func(w http.ResponseWriter, r *http.Request) { + connector := Server{} + return func(w http.ResponseWriter, r *http.Request) { - conn, err := websocket.DefaultUpgrader.Upgrade(w, r, nil) - if err != nil { - t.Fatalf("no socket, %v", err) - } - sock, err := websocket.NewServerWithConn(conn, log) + sock, err := connector.Server.Connect(w, r, nil) if err != nil { t.Fatalf("couldn't init socket server") } - s.s = sock - s.s.OnMessage = func(m []byte, err error) { s.s.Write(m) } // echo - s.s.Listen() + s.conn = sock + s.conn.SetMessageHandler(func(m []byte, err error) { s.conn.Write(m) }) // echo + s.done = s.conn.Listen() } } diff --git a/pkg/coordinator/coordinator.go b/pkg/coordinator/coordinator.go index bae91242..82e41e0e 100644 --- a/pkg/coordinator/coordinator.go +++ b/pkg/coordinator/coordinator.go @@ -18,8 +18,8 @@ func New(conf coordinator.Config, log *logger.Logger) (services service.Group) { lib.Scan() hub := NewHub(conf, lib, log) h, err := NewHTTPServer(conf, log, func(mux *httpx.Mux) *httpx.Mux { - mux.HandleFunc("/ws", hub.handleUserConnection) - mux.HandleFunc("/wso", hub.handleWorkerConnection) + mux.HandleFunc("/ws", hub.handleUserConnection()) + mux.HandleFunc("/wso", hub.handleWorkerConnection()) return mux }) if err != nil { diff --git a/pkg/coordinator/hub.go b/pkg/coordinator/hub.go index d3dba89e..5487b2b3 100644 --- a/pkg/coordinator/hub.go +++ b/pkg/coordinator/hub.go @@ -2,6 +2,8 @@ package coordinator import ( "bytes" + "encoding/base64" + "fmt" "net/http" "net/url" @@ -10,20 +12,23 @@ import ( "github.com/giongto35/cloud-game/v3/pkg/config/coordinator" "github.com/giongto35/cloud-game/v3/pkg/games" "github.com/giongto35/cloud-game/v3/pkg/logger" - "github.com/giongto35/cloud-game/v3/pkg/service" - "github.com/rs/xid" ) -type Hub struct { - service.Service +type Connection interface { + Disconnect() + Id() com.Uid + ProcessPackets(func(api.In[com.Uid]) error) chan struct{} + Send(api.PT, any) ([]byte, error) + Notify(api.PT, any) +} + +type Hub struct { conf coordinator.Config launcher games.Launcher + log *logger.Logger users com.NetMap[*User] workers com.NetMap[*Worker] - log *logger.Logger - - wConn, uConn *com.Connector } func NewHub(conf coordinator.Config, lib games.GameLibrary, log *logger.Logger) *Hub { @@ -33,135 +38,146 @@ func NewHub(conf coordinator.Config, lib games.GameLibrary, log *logger.Logger) workers: com.NewNetMap[*Worker](), launcher: games.NewGameLauncher(lib), log: log, - wConn: com.NewConnector( - com.WithOrigin(conf.Coordinator.Origin.WorkerWs), - com.WithTag("w"), - ), - uConn: com.NewConnector( - com.WithOrigin(conf.Coordinator.Origin.UserWs), - com.WithTag("u"), - ), } } // handleUserConnection handles all connections from user/frontend. -func (h *Hub) handleUserConnection(w http.ResponseWriter, r *http.Request) { - h.log.Debug().Str("c", "u").Str("d", "←").Msgf("Handshake %v", r.Host) - conn, err := h.uConn.NewClientServer(w, r, h.log) - if err != nil { - h.log.Error().Err(err).Msg("couldn't init user connection") - } - usr := NewUserConnection(conn) - defer func() { - if usr != nil { - usr.Disconnect() - h.users.Remove(usr) +func (h *Hub) handleUserConnection() http.HandlerFunc { + var connector com.Server + connector.Origin(h.conf.Coordinator.Origin.UserWs) + + log := h.log.Extend(h.log.With(). + Str(logger.ClientField, "u"). + Str(logger.DirectionField, logger.MarkIn), + ) + + return func(w http.ResponseWriter, r *http.Request) { + h.log.Debug().Msgf("Handshake %v", r.Host) + + conn, err := connector.Connect(w, r) + if err != nil { + h.log.Error().Err(err).Msg("user connection fail") + return } - }() - usr.HandleRequests(h, h.launcher, h.conf) - wkr := h.findWorkerFor(usr, r.URL.Query()) - if wkr == nil { - usr.Log.Info().Msg("no free workers") - return + user := NewUser(conn, log) + defer h.users.RemoveDisconnect(user) + done := user.HandleRequests(h, h.launcher, h.conf) + params := r.URL.Query() + worker := h.findWorkerFor(user, params, h.log.Extend(h.log.With().Str("cid", user.Id().Short()))) + if worker == nil { + h.log.Info().Msg("no free workers") + return + } + user.Bind(worker) + h.users.Add(user) + user.InitSession(worker.Id().String(), h.conf.Webrtc.IceServers, h.launcher.GetAppNames()) + log.Info().Str(logger.DirectionField, logger.MarkPlus).Msgf("user %s", user.Id()) + <-done } +} - usr.SetWorker(wkr) - h.users.Add(usr) - usr.InitSession(wkr.Id().String(), h.conf.Webrtc.IceServers, h.launcher.GetAppNames()) - <-usr.Done() +func RequestToHandshake(data string) (*api.ConnectionRequest[com.Uid], error) { + if data == "" { + return nil, api.ErrMalformed + } + handshake, err := api.UnwrapChecked[api.ConnectionRequest[com.Uid]](base64.URLEncoding.DecodeString(data)) + if err != nil || handshake == nil { + return nil, fmt.Errorf("%v (%v)", err, handshake) + } + return handshake, nil } // handleWorkerConnection handles all connections from a new worker to coordinator. -func (h *Hub) handleWorkerConnection(w http.ResponseWriter, r *http.Request) { - h.log.Debug().Str("c", "w").Str("d", "←").Msgf("Handshake %v", r.Host) +func (h *Hub) handleWorkerConnection() http.HandlerFunc { + var connector com.Server + connector.Origin(h.conf.Coordinator.Origin.WorkerWs) - handshake, err := api.RequestToHandshake(r.URL.Query().Get(api.DataQueryParam)) - if err != nil { - h.log.Error().Err(err).Msg("handshake fail") - return - } + log := h.log.Extend(h.log.With(). + Str(logger.ClientField, "w"). + Str(logger.DirectionField, logger.MarkIn), + ) - if handshake.PingURL == "" { - h.log.Warn().Msg("Ping address is not set") - } + return func(w http.ResponseWriter, r *http.Request) { + h.log.Debug().Msgf("Handshake %v", r.Host) - if h.conf.Coordinator.Server.Https && !handshake.IsHTTPS { - h.log.Warn().Msg("Unsecure worker connection. Unsecure to secure may be bad.") - } - - conn, err := h.wConn.NewClientServer(w, r, h.log) - if err != nil { - h.log.Error().Err(err).Msg("couldn't init worker connection") - return - } - - worker := &Worker{ - SocketClient: *conn, - Addr: handshake.Addr, - PingServer: handshake.PingURL, - Port: handshake.Port, - Tag: handshake.Tag, - Zone: handshake.Zone, - } - // set connection uid from the handshake - if ok, uid := handshake.HasUID(); ok { - worker.Log.Debug().Msgf("connection id will be changed %s->%s", worker.Id(), uid) - worker.SetId(uid) - } - defer func() { - if worker != nil { - worker.Disconnect() - h.workers.Remove(worker) + handshake, err := RequestToHandshake(r.URL.Query().Get(api.DataQueryParam)) + if err != nil { + h.log.Error().Err(err).Msg("handshake fail") + return } - }() - worker.HandleRequests(&h.users) - h.workers.Add(worker) - h.log.Info().Msgf("> [+] worker %s", worker.PrintInfo()) - worker.Listen() + if handshake.PingURL == "" { + h.log.Warn().Msg("Ping address is not set") + } + + if h.conf.Coordinator.Server.Https && !handshake.IsHTTPS { + h.log.Warn().Msg("Unsecure worker connection. Unsecure to secure may be bad.") + } + + // set connection uid from the handshake + if handshake.Id != com.NilUid { + h.log.Debug().Msgf("Worker uid will be set to %v", handshake.Id) + } + + conn, err := connector.Connect(w, r) + if err != nil { + log.Error().Err(err).Msg("worker connection fail") + return + } + + worker := NewWorker(conn, *handshake, log) + defer h.workers.RemoveDisconnect(worker) + done := worker.HandleRequests(&h.users) + h.workers.Add(worker) + log.Info(). + Str(logger.DirectionField, logger.MarkPlus). + Msgf("worker %s", worker.PrintInfo()) + <-done + } } func (h *Hub) GetServerList() (r []api.Server) { - for _, w := range h.workers.List() { + h.workers.ForEach(func(w *Worker) { r = append(r, api.Server{ Addr: w.Addr, Id: w.Id(), IsBusy: !w.HasSlot(), + Machine: string(w.Id().Machine()), PingURL: w.PingServer, Port: w.Port, Tag: w.Tag, Zone: w.Zone, }) - } + }) return } // findWorkerFor searches a free worker for the user depending on // various conditions. -func (h *Hub) findWorkerFor(usr *User, q url.Values) *Worker { - usr.Log.Debug().Msg("Search available workers") +func (h *Hub) findWorkerFor(usr *User, q url.Values, log *logger.Logger) *Worker { + log.Debug().Msg("Search available workers") roomId := q.Get(api.RoomIdQueryParam) zone := q.Get(api.ZoneQueryParam) wid := q.Get(api.WorkerIdParam) var worker *Worker if worker = h.findWorkerByRoom(roomId, zone); worker != nil { - usr.Log.Debug().Str("room", roomId).Msg("An existing worker has been found") + log.Debug().Str("room", roomId).Msg("An existing worker has been found") } else if worker = h.findWorkerById(wid, h.conf.Coordinator.Debug); worker != nil { - usr.Log.Debug().Msgf("Worker with id: %v has been found", wid) + log.Debug().Msgf("Worker with id: %v has been found", wid) } else { switch h.conf.Coordinator.Selector { case coordinator.SelectByPing: - usr.Log.Debug().Msgf("Searching fastest free worker...") + log.Debug().Msgf("Searching fastest free worker...") if worker = h.findFastestWorker(zone, func(servers []string) (map[string]int64, error) { return usr.CheckLatency(servers) }); worker != nil { - usr.Log.Debug().Msg("The fastest worker has been found") + log.Debug().Msg("The fastest worker has been found") } default: - usr.Log.Debug().Msgf("Searching any free worker...") + log.Debug().Msgf("Searching any free worker...") if worker = h.find1stFreeWorker(zone); worker != nil { - usr.Log.Debug().Msgf("Found next free worker") + log.Debug().Msgf("Found next free worker") } } } @@ -240,28 +256,31 @@ func (h *Hub) findFastestWorker(region string, fn func(addresses []string) (map[ return bestWorker } -func (h *Hub) findWorkerById(workerId string, useAllWorkers bool) *Worker { - // when we select one particular worker - if workerId != "" { - if xid_, err := xid.FromString(workerId); err == nil { - if useAllWorkers { - for _, w := range h.getAvailableWorkers("") { - if xid_.String() == w.Id().String() { - return w - } - } - } else { - for _, w := range h.getAvailableWorkers("") { - xid__, err := xid.FromString(workerId) - if err != nil { - continue - } - if bytes.Equal(xid_.Machine(), xid__.Machine()) { - return w - } - } +func (h *Hub) findWorkerById(id string, useAllWorkers bool) *Worker { + if id == "" { + return nil + } + + uid, err := com.UidFromString(id) + if err != nil { + return nil + } + + for _, w := range h.getAvailableWorkers("") { + if w.Id() == com.NilUid { + continue + } + if useAllWorkers { + if uid == w.Id() { + return w + } + } else { + // select any worker on the same machine when workers are grouped on the client + if bytes.Equal(uid.Machine(), w.Id().Machine()) { + return w } } } + return nil } diff --git a/pkg/coordinator/user.go b/pkg/coordinator/user.go index 0b2f7db6..c6841cab 100644 --- a/pkg/coordinator/user.go +++ b/pkg/coordinator/user.go @@ -5,55 +5,71 @@ import ( "github.com/giongto35/cloud-game/v3/pkg/com" "github.com/giongto35/cloud-game/v3/pkg/config/coordinator" "github.com/giongto35/cloud-game/v3/pkg/games" + "github.com/giongto35/cloud-game/v3/pkg/logger" ) type User struct { - com.SocketClient - w *Worker // linked worker + Connection + w *Worker // linked worker + log *logger.Logger } -// NewUserConnection supposed to be a bidirectional one. -func NewUserConnection(conn *com.SocketClient) *User { return &User{SocketClient: *conn} } +type HasServerInfo interface { + GetServerList() []api.Server +} -func (u *User) SetWorker(w *Worker) { u.w = w; u.w.Reserve() } +func NewUser(sock *com.Connection, log *logger.Logger) *User { + conn := com.NewConnection[api.PT, api.In[com.Uid], api.Out](sock, com.NewUid(), log) + return &User{ + Connection: conn, + log: log.Extend(log.With(). + Str(logger.ClientField, logger.MarkNone). + Str(logger.DirectionField, logger.MarkNone). + Str("cid", conn.Id().Short())), + } +} + +func (u *User) Bind(w *Worker) { + u.w = w + u.w.Reserve() +} func (u *User) Disconnect() { - u.SocketClient.Close() + u.Connection.Disconnect() if u.w != nil { u.w.UnReserve() u.w.TerminateSession(u.Id()) } } -func (u *User) HandleRequests(info api.HasServerInfo, launcher games.Launcher, conf coordinator.Config) { - u.ProcessMessages() - u.OnPacket(func(x com.In) error { - // !to use proper channels - switch x.T { +func (u *User) HandleRequests(info HasServerInfo, launcher games.Launcher, conf coordinator.Config) chan struct{} { + return u.ProcessPackets(func(x api.In[com.Uid]) error { + payload := x.GetPayload() + switch x.GetType() { case api.WebrtcInit: if u.w != nil { u.HandleWebrtcInit() } case api.WebrtcAnswer: - rq := api.Unwrap[api.WebrtcAnswerUserRequest](x.Payload) + rq := api.Unwrap[api.WebrtcAnswerUserRequest](payload) if rq == nil { return api.ErrMalformed } u.HandleWebrtcAnswer(*rq) case api.WebrtcIce: - rq := api.Unwrap[api.WebrtcUserIceCandidate](x.Payload) + rq := api.Unwrap[api.WebrtcUserIceCandidate](payload) if rq == nil { return api.ErrMalformed } u.HandleWebrtcIceCandidate(*rq) case api.StartGame: - rq := api.Unwrap[api.GameStartUserRequest](x.Payload) + rq := api.Unwrap[api.GameStartUserRequest](payload) if rq == nil { return api.ErrMalformed } u.HandleStartGame(*rq, launcher, conf) case api.QuitGame: - rq := api.Unwrap[api.GameQuitRequest](x.Payload) + rq := api.Unwrap[api.GameQuitRequest[com.Uid]](payload) if rq == nil { return api.ErrMalformed } @@ -63,7 +79,7 @@ func (u *User) HandleRequests(info api.HasServerInfo, launcher games.Launcher, c case api.LoadGame: return u.HandleLoadGame() case api.ChangePlayer: - rq := api.Unwrap[api.ChangePlayerUserRequest](x.Payload) + rq := api.Unwrap[api.ChangePlayerUserRequest](payload) if rq == nil { return api.ErrMalformed } @@ -74,7 +90,7 @@ func (u *User) HandleRequests(info api.HasServerInfo, launcher games.Launcher, c if !conf.Recording.Enabled { return api.ErrForbidden } - rq := api.Unwrap[api.RecordGameRequest](x.Payload) + rq := api.Unwrap[api.RecordGameRequest[com.Uid]](payload) if rq == nil { return api.ErrMalformed } @@ -82,7 +98,7 @@ func (u *User) HandleRequests(info api.HasServerInfo, launcher games.Launcher, c case api.GetWorkerList: u.handleGetWorkerList(conf.Coordinator.Debug, info) default: - u.Log.Warn().Msgf("Unknown packet: %+v", x) + u.log.Warn().Msgf("Unknown packet: %+v", x) } return nil }) diff --git a/pkg/coordinator/userapi.go b/pkg/coordinator/userapi.go index b6e68f83..8a02ba7d 100644 --- a/pkg/coordinator/userapi.go +++ b/pkg/coordinator/userapi.go @@ -23,8 +23,12 @@ func (u *User) CheckLatency(req api.CheckLatencyUserResponse) (api.CheckLatencyU // InitSession signals the user that the app is ready to go. func (u *User) InitSession(wid string, ice []webrtc.IceServer, games []string) { - // don't do this at home - u.Notify(api.InitSessionResult(*(*[]api.IceServer)(unsafe.Pointer(&ice)), games, wid)) + u.Notify(api.InitSession, api.InitSessionUserResponse{ + // don't do this at home + Ice: *(*[]api.IceServer)(unsafe.Pointer(&ice)), + Games: games, + Wid: wid, + }) } // SendWebrtcOffer sends SDP offer back to the user. diff --git a/pkg/coordinator/userhandlers.go b/pkg/coordinator/userhandlers.go index a20021f4..43fa11b0 100644 --- a/pkg/coordinator/userhandlers.go +++ b/pkg/coordinator/userhandlers.go @@ -4,6 +4,7 @@ import ( "sort" "github.com/giongto35/cloud-game/v3/pkg/api" + "github.com/giongto35/cloud-game/v3/pkg/com" "github.com/giongto35/cloud-game/v3/pkg/config/coordinator" "github.com/giongto35/cloud-game/v3/pkg/games" ) @@ -11,7 +12,7 @@ import ( func (u *User) HandleWebrtcInit() { resp, err := u.w.WebrtcInit(u.Id()) if err != nil || resp == nil || *resp == api.EMPTY { - u.Log.Error().Err(err).Msg("malformed WebRTC init response") + u.log.Error().Err(err).Msg("malformed WebRTC init response") return } u.SendWebrtcOffer(string(*resp)) @@ -33,7 +34,7 @@ func (u *User) HandleStartGame(rq api.GameStartUserRequest, launcher games.Launc if rq.RoomId != "" { name := launcher.ExtractAppNameFromUrl(rq.RoomId) if name == "" { - u.Log.Warn().Msg("couldn't decode game name from the room id") + u.log.Warn().Msg("couldn't decode game name from the room id") return } game = name @@ -41,20 +42,20 @@ func (u *User) HandleStartGame(rq api.GameStartUserRequest, launcher games.Launc gameInfo, err := launcher.FindAppByName(game) if err != nil { - u.Log.Error().Err(err).Str("game", game).Msg("couldn't find game info") + u.log.Error().Err(err).Send() return } startGameResp, err := u.w.StartGame(u.Id(), gameInfo, rq) if err != nil || startGameResp == nil { - u.Log.Error().Err(err).Msg("malformed game start response") + u.log.Error().Err(err).Msg("malformed game start response") return } if startGameResp.Rid == "" { - u.Log.Error().Msg("there is no room") + u.log.Error().Msg("there is no room") return } - u.Log.Info().Str("id", startGameResp.Rid).Msg("Received room response from worker") + u.log.Info().Str("id", startGameResp.Rid).Msg("Received room response from worker") u.StartGame() // send back recording status @@ -63,7 +64,7 @@ func (u *User) HandleStartGame(rq api.GameStartUserRequest, launcher games.Launc } } -func (u *User) HandleQuitGame(rq api.GameQuitRequest) { +func (u *User) HandleQuitGame(rq api.GameQuitRequest[com.Uid]) { if rq.Room.Rid == u.w.RoomId { u.w.QuitGame(u.Id()) } @@ -91,7 +92,7 @@ func (u *User) HandleChangePlayer(rq api.ChangePlayerUserRequest) { resp, err := u.w.ChangePlayer(u.Id(), int(rq)) // !to make it a little less convoluted if err != nil || resp == nil || *resp == -1 { - u.Log.Error().Err(err).Msg("player switch failed for some reason") + u.log.Error().Err(err).Msg("player switch failed for some reason") return } u.Notify(api.ChangePlayer, rq) @@ -99,27 +100,27 @@ func (u *User) HandleChangePlayer(rq api.ChangePlayerUserRequest) { func (u *User) HandleToggleMultitap() { u.w.ToggleMultitap(u.Id()) } -func (u *User) HandleRecordGame(rq api.RecordGameRequest) { +func (u *User) HandleRecordGame(rq api.RecordGameRequest[com.Uid]) { if u.w == nil { return } - u.Log.Debug().Msgf("??? room: %v, rec: %v user: %v", u.w.RoomId, rq.Active, rq.User) + u.log.Debug().Msgf("??? room: %v, rec: %v user: %v", u.w.RoomId, rq.Active, rq.User) if u.w.RoomId == "" { - u.Log.Error().Msg("Recording in the empty room is not allowed!") + u.log.Error().Msg("Recording in the empty room is not allowed!") return } resp, err := u.w.RecordGame(u.Id(), rq.Active, rq.User) if err != nil { - u.Log.Error().Err(err).Msg("malformed game record request") + u.log.Error().Err(err).Msg("malformed game record request") return } u.Notify(api.RecordGame, resp) } -func (u *User) handleGetWorkerList(debug bool, info api.HasServerInfo) { +func (u *User) handleGetWorkerList(debug bool, info HasServerInfo) { response := api.GetWorkerListResponse{} servers := info.GetServerList() @@ -129,7 +130,7 @@ func (u *User) handleGetWorkerList(debug bool, info api.HasServerInfo) { // not sure if []byte to string always reversible :/ unique := map[string]*api.Server{} for _, s := range servers { - mid := s.Id.Machine() + mid := s.Machine if _, ok := unique[mid]; !ok { unique[mid] = &api.Server{Addr: s.Addr, PingURL: s.PingURL, Id: s.Id, InGroup: true} } diff --git a/pkg/coordinator/worker.go b/pkg/coordinator/worker.go index 50a00054..7fcc93bc 100644 --- a/pkg/coordinator/worker.go +++ b/pkg/coordinator/worker.go @@ -6,11 +6,12 @@ import ( "github.com/giongto35/cloud-game/v3/pkg/api" "github.com/giongto35/cloud-game/v3/pkg/com" + "github.com/giongto35/cloud-game/v3/pkg/logger" ) type Worker struct { - com.SocketClient - com.RegionalClient + Connection + RegionalClient slotted Addr string @@ -19,33 +20,63 @@ type Worker struct { RoomId string // room reference Tag string Zone string + + log *logger.Logger } -func (w *Worker) HandleRequests(users *com.NetMap[*User]) { - // !to make a proper multithreading abstraction - w.OnPacket(func(p com.In) error { - switch p.T { +type RegionalClient interface { + In(region string) bool +} + +type HasUserRegistry interface { + Find(com.Uid) (*User, bool) +} + +func NewWorker(sock *com.Connection, handshake api.ConnectionRequest[com.Uid], log *logger.Logger) *Worker { + conn := com.NewConnection[api.PT, api.In[com.Uid], api.Out](sock, handshake.Id, log) + return &Worker{ + Connection: conn, + Addr: handshake.Addr, + PingServer: handshake.PingURL, + Port: handshake.Port, + Tag: handshake.Tag, + Zone: handshake.Zone, + log: log.Extend(log.With(). + Str(logger.ClientField, logger.MarkNone). + Str(logger.DirectionField, logger.MarkNone). + Str("cid", conn.Id().Short())), + } +} + +func (w *Worker) HandleRequests(users HasUserRegistry) chan struct{} { + return w.ProcessPackets(func(p api.In[com.Uid]) error { + payload := p.GetPayload() + switch p.GetType() { case api.RegisterRoom: - rq := api.Unwrap[api.RegisterRoomRequest](p.Payload) + rq := api.Unwrap[api.RegisterRoomRequest](payload) if rq == nil { return api.ErrMalformed } - w.Log.Info().Msgf("set room [%v] = %v", w.Id(), *rq) + w.log.Info().Msgf("set room [%v] = %v", w.Id(), *rq) w.HandleRegisterRoom(*rq) case api.CloseRoom: - rq := api.Unwrap[api.CloseRoomRequest](p.Payload) + rq := api.Unwrap[api.CloseRoomRequest](payload) if rq == nil { return api.ErrMalformed } w.HandleCloseRoom(*rq) case api.IceCandidate: - rq := api.Unwrap[api.WebrtcIceCandidateRequest](p.Payload) + rq := api.Unwrap[api.WebrtcIceCandidateRequest[com.Uid]](payload) if rq == nil { return api.ErrMalformed } - w.HandleIceCandidate(*rq, users) + err := w.HandleIceCandidate(*rq, users) + if err != nil { + w.log.Error().Err(err).Send() + return api.ErrMalformed + } default: - w.Log.Warn().Msgf("Unknown packet: %+v", p) + w.log.Warn().Msgf("Unknown packet: %+v", p) } return nil }) @@ -76,7 +107,7 @@ func (s *slotted) UnReserve() { func (s *slotted) FreeSlots() { atomic.StoreInt32((*int32)(s), 0) } func (w *Worker) Disconnect() { - w.SocketClient.Close() + w.Connection.Disconnect() w.RoomId = "" w.FreeSlots() } diff --git a/pkg/coordinator/workerapi.go b/pkg/coordinator/workerapi.go index 034df4dc..0943d666 100644 --- a/pkg/coordinator/workerapi.go +++ b/pkg/coordinator/workerapi.go @@ -2,61 +2,66 @@ package coordinator import ( "github.com/giongto35/cloud-game/v3/pkg/api" + "github.com/giongto35/cloud-game/v3/pkg/com" "github.com/giongto35/cloud-game/v3/pkg/games" - "github.com/giongto35/cloud-game/v3/pkg/network" ) -func (w *Worker) WebrtcInit(id network.Uid) (*api.WebrtcInitResponse, error) { +func (w *Worker) WebrtcInit(id com.Uid) (*api.WebrtcInitResponse, error) { return api.UnwrapChecked[api.WebrtcInitResponse]( - w.Send(api.WebrtcInit, api.WebrtcInitRequest{Stateful: api.Stateful{Id: id}})) + w.Send(api.WebrtcInit, api.WebrtcInitRequest[com.Uid]{Stateful: api.Stateful[com.Uid]{Id: id}})) } -func (w *Worker) WebrtcAnswer(id network.Uid, sdp string) { - w.Notify(api.WebrtcAnswer, api.WebrtcAnswerRequest{Stateful: api.Stateful{Id: id}, Sdp: sdp}) +func (w *Worker) WebrtcAnswer(id com.Uid, sdp string) { + w.Notify(api.WebrtcAnswer, api.WebrtcAnswerRequest[com.Uid]{Stateful: api.Stateful[com.Uid]{Id: id}, Sdp: sdp}) } -func (w *Worker) WebrtcIceCandidate(id network.Uid, can string) { - w.Notify(api.NewWebrtcIceCandidateRequest(id, can)) +func (w *Worker) WebrtcIceCandidate(id com.Uid, can string) { + w.Notify(api.WebrtcIce, api.WebrtcIceCandidateRequest[com.Uid]{Stateful: api.Stateful[com.Uid]{Id: id}, Candidate: can}) } -func (w *Worker) StartGame(id network.Uid, app games.AppMeta, req api.GameStartUserRequest) (*api.StartGameResponse, error) { - return api.UnwrapChecked[api.StartGameResponse](w.Send(api.StartGame, api.StartGameRequest{ - StatefulRoom: api.StateRoom(id, req.RoomId), - Game: api.GameInfo{Name: app.Name, Base: app.Base, Path: app.Path, Type: app.Type}, - PlayerIndex: req.PlayerIndex, - Record: req.Record, - RecordUser: req.RecordUser, - })) +func (w *Worker) StartGame(id com.Uid, app games.AppMeta, req api.GameStartUserRequest) (*api.StartGameResponse, error) { + return api.UnwrapChecked[api.StartGameResponse]( + w.Send(api.StartGame, api.StartGameRequest[com.Uid]{ + StatefulRoom: StateRoom(id, req.RoomId), + Game: api.GameInfo{Name: app.Name, Base: app.Base, Path: app.Path, Type: app.Type}, + PlayerIndex: req.PlayerIndex, + Record: req.Record, + RecordUser: req.RecordUser, + })) } -func (w *Worker) QuitGame(id network.Uid) { - w.Notify(api.QuitGame, api.GameQuitRequest{StatefulRoom: api.StateRoom(id, w.RoomId)}) +func (w *Worker) QuitGame(id com.Uid) { + w.Notify(api.QuitGame, api.GameQuitRequest[com.Uid]{StatefulRoom: StateRoom(id, w.RoomId)}) } -func (w *Worker) SaveGame(id network.Uid) (*api.SaveGameResponse, error) { +func (w *Worker) SaveGame(id com.Uid) (*api.SaveGameResponse, error) { return api.UnwrapChecked[api.SaveGameResponse]( - w.Send(api.SaveGame, api.SaveGameRequest{StatefulRoom: api.StateRoom(id, w.RoomId)})) + w.Send(api.SaveGame, api.SaveGameRequest[com.Uid]{StatefulRoom: StateRoom(id, w.RoomId)})) } -func (w *Worker) LoadGame(id network.Uid) (*api.LoadGameResponse, error) { +func (w *Worker) LoadGame(id com.Uid) (*api.LoadGameResponse, error) { return api.UnwrapChecked[api.LoadGameResponse]( - w.Send(api.LoadGame, api.LoadGameRequest{StatefulRoom: api.StateRoom(id, w.RoomId)})) + w.Send(api.LoadGame, api.LoadGameRequest[com.Uid]{StatefulRoom: StateRoom(id, w.RoomId)})) } -func (w *Worker) ChangePlayer(id network.Uid, index int) (*api.ChangePlayerResponse, error) { +func (w *Worker) ChangePlayer(id com.Uid, index int) (*api.ChangePlayerResponse, error) { return api.UnwrapChecked[api.ChangePlayerResponse]( - w.Send(api.ChangePlayer, api.ChangePlayerRequest{StatefulRoom: api.StateRoom(id, w.RoomId), Index: index})) + w.Send(api.ChangePlayer, api.ChangePlayerRequest[com.Uid]{StatefulRoom: StateRoom(id, w.RoomId), Index: index})) } -func (w *Worker) ToggleMultitap(id network.Uid) { - _, _ = w.Send(api.ToggleMultitap, api.ToggleMultitapRequest{StatefulRoom: api.StateRoom(id, w.RoomId)}) +func (w *Worker) ToggleMultitap(id com.Uid) { + _, _ = w.Send(api.ToggleMultitap, api.ToggleMultitapRequest[com.Uid]{StatefulRoom: StateRoom(id, w.RoomId)}) } -func (w *Worker) RecordGame(id network.Uid, rec bool, recUser string) (*api.RecordGameResponse, error) { +func (w *Worker) RecordGame(id com.Uid, rec bool, recUser string) (*api.RecordGameResponse, error) { return api.UnwrapChecked[api.RecordGameResponse]( - w.Send(api.RecordGame, api.RecordGameRequest{StatefulRoom: api.StateRoom(id, w.RoomId), Active: rec, User: recUser})) + w.Send(api.RecordGame, api.RecordGameRequest[com.Uid]{StatefulRoom: StateRoom(id, w.RoomId), Active: rec, User: recUser})) } -func (w *Worker) TerminateSession(id network.Uid) { - _, _ = w.Send(api.TerminateSession, api.TerminateSessionRequest{Stateful: api.Stateful{Id: id}}) +func (w *Worker) TerminateSession(id com.Uid) { + _, _ = w.Send(api.TerminateSession, api.TerminateSessionRequest[com.Uid]{Stateful: api.Stateful[com.Uid]{Id: id}}) +} + +func StateRoom[T api.Id](id T, rid string) api.StatefulRoom[T] { + return api.StatefulRoom[T]{Stateful: api.Stateful[T]{Id: id}, Room: api.Room{Rid: rid}} } diff --git a/pkg/coordinator/workerhandlers.go b/pkg/coordinator/workerhandlers.go index bbd5b17b..97a7839d 100644 --- a/pkg/coordinator/workerhandlers.go +++ b/pkg/coordinator/workerhandlers.go @@ -13,10 +13,11 @@ func (w *Worker) HandleCloseRoom(rq api.CloseRoomRequest) { } } -func (w *Worker) HandleIceCandidate(rq api.WebrtcIceCandidateRequest, users *com.NetMap[*User]) { - if usr, err := users.Find(string(rq.Id)); err == nil { +func (w *Worker) HandleIceCandidate(rq api.WebrtcIceCandidateRequest[com.Uid], users HasUserRegistry) error { + if usr, ok := users.Find(rq.Id); ok { usr.SendWebrtcIceCandidate(rq.Candidate) } else { - w.Log.Warn().Str("id", rq.Id.String()).Msg("unknown session") + w.log.Warn().Str("id", rq.Id.String()).Msg("unknown session") } + return nil } diff --git a/pkg/logger/logger.go b/pkg/logger/logger.go index 02945982..53c3a2d2 100644 --- a/pkg/logger/logger.go +++ b/pkg/logger/logger.go @@ -27,6 +27,16 @@ const ( // Values less than TraceLevel are handled as numbers. ) +const ( + ClientField = "c" + DirectionField = "d" + MarkNone = " " + MarkIn = "←" + MarkOut = "→" + MarkPlus = "+" + MarkCross = "x" +) + func (l Level) String() string { switch l { case TraceLevel: @@ -81,12 +91,12 @@ func NewConsole(isDebug bool, tag string, noColor bool) *Logger { zerolog.LevelFieldName, zerolog.CallerFieldName, "s", - "d", - "c", + DirectionField, + ClientField, "m", zerolog.MessageFieldName, }, - FieldsExclude: []string{"s", "c", "d", "m", "pid"}, + FieldsExclude: []string{"s", ClientField, DirectionField, "m", "pid"}, } if output.NoColor { @@ -103,8 +113,8 @@ func NewConsole(isDebug bool, tag string, noColor bool) *Logger { Str("pid", fmt.Sprintf("%4x", pid)). Str("s", tag). Str("m", ""). - Str("d", " "). - Str("c", " "). + Str(DirectionField, MarkNone). + Str(ClientField, MarkNone). // Str("tag", tag). use when a file writer Timestamp().Logger() return &Logger{logger: &logger} diff --git a/pkg/network/uid.go b/pkg/network/uid.go deleted file mode 100644 index 5d5b464a..00000000 --- a/pkg/network/uid.go +++ /dev/null @@ -1,22 +0,0 @@ -package network - -import "github.com/rs/xid" - -type Uid string - -func NewUid() Uid { return Uid(xid.New().String()) } - -func ValidUid(u Uid) bool { - _, err := xid.FromString(string(u)) - return err == nil -} -func (u Uid) Empty() bool { return u == "" } -func (u Uid) Short() string { return string(u)[:3] + "." + string(u)[len(u)-3:] } -func (u Uid) String() string { return string(u) } -func (u Uid) Machine() string { - id, err := xid.FromString(string(u)) - if err != nil { - return "" - } - return string(id.Machine()) -} diff --git a/pkg/network/websocket/websocket.go b/pkg/network/websocket/websocket.go index c020ecd7..456558f9 100644 --- a/pkg/network/websocket/websocket.go +++ b/pkg/network/websocket/websocket.go @@ -2,14 +2,12 @@ package websocket import ( "crypto/tls" - "errors" "net" "net/http" "net/url" "sync" "time" - "github.com/giongto35/cloud-game/v3/pkg/logger" "github.com/gorilla/websocket" ) @@ -20,29 +18,89 @@ const ( writeWait = 1 * time.Second ) -type ( - WS struct { - conn deadlineConn - send chan []byte - OnMessage WSMessageHandler - pingPong bool - once sync.Once - Done chan struct{} - alive bool - log *logger.Logger - server bool +type Client struct { + Dialer *websocket.Dialer +} + +type Server struct { + Upgrader *Upgrader +} + +type Connection struct { + alive bool + callback MessageHandler + conn deadlineConn + done chan struct{} + once sync.Once + pingPong bool + send chan []byte +} + +type deadlineConn struct { + *websocket.Conn + wt time.Duration + mu sync.Mutex // needed for concurrent writes of Gorilla +} + +type MessageHandler func([]byte, error) + +type Upgrader struct { + websocket.Upgrader + Origin string +} + +var DefaultDialer = websocket.DefaultDialer +var DefaultUpgrader = Upgrader{Upgrader: websocket.Upgrader{ + ReadBufferSize: 2048, + WriteBufferSize: 2048, + WriteBufferPool: &sync.Pool{}, + EnableCompression: true, +}} + +func NewUpgrader(origin string) *Upgrader { + u := DefaultUpgrader + switch { + case origin == "*": + u.CheckOrigin = func(r *http.Request) bool { return true } + case origin != "": + u.CheckOrigin = func(r *http.Request) bool { return r.Header.Get("Origin") == origin } } - WSMessageHandler func(message []byte, err error) - Upgrader struct { - websocket.Upgrader - origin string + return &u +} + +func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) { + if u.Origin != "" { + w.Header().Set("Access-Control-Allow-Origin", u.Origin) } - deadlineConn struct { - *websocket.Conn - wt time.Duration - mu sync.Mutex // needed for concurrent writes of Gorilla + return u.Upgrader.Upgrade(w, r, responseHeader) +} + +func (s *Server) Connect(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Connection, error) { + u := s.Upgrader + if u == nil { + u = &DefaultUpgrader } -) + conn, err := u.Upgrade(w, r, responseHeader) + if err != nil { + return nil, err + } + return newSocket(conn, true), nil +} + +func (c *Client) Connect(address url.URL) (*Connection, error) { + dialer := c.Dialer + if dialer == nil { + dialer = DefaultDialer + } + if address.Scheme == "wss" { + dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} + } + conn, _, err := dialer.Dial(address.String(), nil) + if err != nil { + return nil, err + } + return newSocket(conn, false), nil +} func (conn *deadlineConn) write(t int, mess []byte) error { conn.mu.Lock() @@ -59,72 +117,22 @@ func (conn *deadlineConn) writeControl(messageType int, data []byte, deadline ti return conn.Conn.WriteControl(messageType, data, deadline) } -var DefaultUpgrader = Upgrader{ - Upgrader: websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - WriteBufferPool: &sync.Pool{}, - EnableCompression: true, - }, -} - -var ErrNilConnection = errors.New("nil connection") - -func NewUpgrader(origin string) *Upgrader { - u := DefaultUpgrader - switch { - case origin == "*": - u.CheckOrigin = func(r *http.Request) bool { return true } - case origin != "": - u.CheckOrigin = func(r *http.Request) bool { return r.Header.Get("Origin") == origin } - } - return &u -} - -func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*websocket.Conn, error) { - if u.origin != "" { - w.Header().Set("Access-Control-Allow-Origin", u.origin) - } - return u.Upgrader.Upgrade(w, r, responseHeader) -} - -func NewServerWithConn(conn *websocket.Conn, log *logger.Logger) (*WS, error) { - if conn == nil { - return nil, ErrNilConnection - } - return newSocket(conn, true, true, log), nil -} - -func NewClient(address url.URL, log *logger.Logger) (*WS, error) { - dialer := websocket.DefaultDialer - if address.Scheme == "wss" { - dialer.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} - } - conn, _, err := dialer.Dial(address.String(), nil) - if err != nil { - return nil, err - } - return newSocket(conn, false, false, log), nil -} - -func (ws *WS) IsServer() bool { return ws.server } - -// reader pumps messages from the websocket connection to the OnMessage callback. +// reader pumps messages from the websocket connection to the SetMessageHandler callback. // Blocking, must be called as goroutine. Serializes all websocket reads. -func (ws *WS) reader() { +func (c *Connection) reader() { defer func() { - close(ws.send) - ws.shutdown() + close(c.send) + c.close() }() - ws.conn.SetReadLimit(maxMessageSize) - _ = ws.conn.SetReadDeadline(time.Now().Add(pongTime)) - if ws.pingPong { - ws.conn.SetPongHandler(func(string) error { _ = ws.conn.SetReadDeadline(time.Now().Add(pongTime)); return nil }) + c.conn.SetReadLimit(maxMessageSize) + _ = c.conn.SetReadDeadline(time.Now().Add(pongTime)) + if c.pingPong { + c.conn.SetPongHandler(func(string) error { _ = c.conn.SetReadDeadline(time.Now().Add(pongTime)); return nil }) } else { - ws.conn.SetPingHandler(func(string) error { - _ = ws.conn.SetReadDeadline(time.Now().Add(pongTime)) - err := ws.conn.writeControl(websocket.PongMessage, nil, time.Now().Add(writeWait)) + c.conn.SetPingHandler(func(string) error { + _ = c.conn.SetReadDeadline(time.Now().Add(pongTime)) + err := c.conn.writeControl(websocket.PongMessage, nil, time.Now().Add(writeWait)) if err == websocket.ErrCloseSent { return nil } else if e, ok := err.(net.Error); ok && e.Timeout() { @@ -134,94 +142,101 @@ func (ws *WS) reader() { }) } for { - _, message, err := ws.conn.ReadMessage() + _, message, err := c.conn.ReadMessage() if err != nil { if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure, websocket.CloseGoingAway) { - ws.log.Error().Err(err).Msg("WebSocket read fail") + c.callback(message, err) } break } - ws.OnMessage(message, err) + c.callback(message, err) } } // writer pumps messages from the send channel to the websocket connection. // Blocking, must be called as goroutine. Serializes all websocket writes. -func (ws *WS) writer() { - defer ws.shutdown() +func (c *Connection) writer() { + defer c.close() - if ws.pingPong { + if c.pingPong { ticker := time.NewTicker(pingTime) defer ticker.Stop() for { select { - case message, ok := <-ws.send: - if !ws.handleMessage(message, ok) { + case message, ok := <-c.send: + if !c.handleMessage(message, ok) { return } case <-ticker.C: - if err := ws.conn.write(websocket.PingMessage, nil); err != nil { + if err := c.conn.write(websocket.PingMessage, nil); err != nil { return } } } } else { - for message := range ws.send { - if !ws.handleMessage(message, true) { + for message := range c.send { + if !c.handleMessage(message, true) { return } } } } -func (ws *WS) handleMessage(message []byte, ok bool) bool { +func (c *Connection) handleMessage(message []byte, ok bool) bool { if !ok { - _ = ws.conn.write(websocket.CloseMessage, nil) + _ = c.conn.write(websocket.CloseMessage, nil) return false } - if err := ws.conn.write(websocket.TextMessage, message); err != nil { + if err := c.conn.write(websocket.TextMessage, message); err != nil { return false } return true } -func newSocket(conn *websocket.Conn, pingPong bool, server bool, log *logger.Logger) *WS { - return &WS{ - conn: deadlineConn{Conn: conn, wt: writeWait}, - send: make(chan []byte), - once: sync.Once{}, - Done: make(chan struct{}, 1), - pingPong: pingPong, - server: server, - OnMessage: func(message []byte, err error) {}, - log: log, +func (c *Connection) close() { + c.once.Do(func() { + c.alive = false + _ = c.conn.Close() + close(c.done) + }) +} + +func newSocket(conn *websocket.Conn, pingPong bool) *Connection { + return &Connection{ + callback: func(message []byte, err error) {}, + conn: deadlineConn{Conn: conn, wt: writeWait}, + done: make(chan struct{}, 1), + once: sync.Once{}, + pingPong: pingPong, + send: make(chan []byte), } } -func (ws *WS) Listen() { - ws.alive = true - go ws.writer() - go ws.reader() +// IsServer returns true if the connection has server capabilities and not just a client. +// For now, we assume every connection with ping/pong handler is a server. +func (c *Connection) IsServer() bool { return c.pingPong } + +func (c *Connection) SetMessageHandler(fn MessageHandler) { c.callback = fn } + +func (c *Connection) Listen() chan struct{} { + if c.alive { + return c.done + } + c.alive = true + go c.writer() + go c.reader() + return c.done } -func (ws *WS) Write(data []byte) { - if ws.alive { - ws.send <- data +func (c *Connection) Write(data []byte) { + if c.alive { + c.send <- data } } -func (ws *WS) Close() { - if ws.alive { - _ = ws.conn.write(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) +func (c *Connection) Close() { + if c.alive { + _ = c.conn.write(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) } } - -func (ws *WS) shutdown() { ws.once.Do(ws.close) } - -func (ws *WS) close() { - ws.alive = false - _ = ws.conn.Close() - close(ws.Done) - ws.log.Debug().Msg("WebSocket should be closed now") -} diff --git a/pkg/worker/coordinator.go b/pkg/worker/coordinator.go index 647cdf09..7df238a6 100644 --- a/pkg/worker/coordinator.go +++ b/pkg/worker/coordinator.go @@ -7,120 +7,142 @@ import ( "github.com/giongto35/cloud-game/v3/pkg/com" "github.com/giongto35/cloud-game/v3/pkg/config/worker" "github.com/giongto35/cloud-game/v3/pkg/logger" - "github.com/giongto35/cloud-game/v3/pkg/network" "github.com/giongto35/cloud-game/v3/pkg/network/webrtc" ) -type coordinator struct { - com.SocketClient +type Connection interface { + Disconnect() + Id() com.Uid + ProcessPackets(func(api.In[com.Uid]) error) chan struct{} + + Send(api.PT, any) ([]byte, error) + Notify(api.PT, any) + Route(api.In[com.Uid], *api.Out) } -var connector = com.NewConnector() +type coordinator struct { + Connection + log *logger.Logger +} -// connect to a coordinator. -func connect(host string, conf worker.Worker, addr string, log *logger.Logger) (*coordinator, error) { +var connector com.Client + +func newCoordinatorConnection(host string, conf worker.Worker, addr string, log *logger.Logger) (*coordinator, error) { scheme := "ws" if conf.Network.Secure { scheme = "wss" } address := url.URL{Scheme: scheme, Host: host, Path: conf.Network.Endpoint} - log.Debug().Str("c", "c").Str("d", "→").Msgf("Handshake %s", address.String()) + log.Debug(). + Str(logger.ClientField, "c"). + Str(logger.DirectionField, logger.MarkOut). + Msgf("Handshake %s", address.String()) - id := network.NewUid() + id := com.NewUid() req, err := buildConnQuery(id, conf, addr) if req != "" && err == nil { address.RawQuery = "data=" + req + } else { + return nil, err } - conn, err := connector.NewClient(address, log) + + conn, err := connector.Connect(address) if err != nil { return nil, err } - return &coordinator{com.New(conn, "c", id, log)}, nil + + client := com.NewConnection[api.PT, api.In[com.Uid], api.Out]( + conn, id, + log.Extend(log.With().Str(logger.ClientField, "c")), + ) + return &coordinator{ + Connection: client, + log: log.Extend(log.With().Str("cid", client.Id().Short())), + }, nil } -func (c *coordinator) HandleRequests(w *Worker) { - ap, err := webrtc.NewApiFactory(w.conf.Webrtc, c.Log, nil) +func (c *coordinator) HandleRequests(w *Worker) chan struct{} { + ap, err := webrtc.NewApiFactory(w.conf.Webrtc, c.log, nil) if err != nil { - c.Log.Panic().Err(err).Msg("WebRTC API creation has been failed") + c.log.Panic().Err(err).Msg("WebRTC API creation has been failed") } - c.ProcessMessages() - skipped := com.Out{} + skipped := api.Out{} - c.OnPacket(func(x com.In) (err error) { - var out com.Out + return c.ProcessPackets(func(x api.In[com.Uid]) (err error) { + var out api.Out switch x.T { case api.WebrtcInit: - if dat := api.Unwrap[api.WebrtcInitRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.WebrtcInitRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { out = c.HandleWebrtcInit(*dat, w, ap) } case api.WebrtcAnswer: - dat := api.Unwrap[api.WebrtcAnswerRequest](x.Payload) + dat := api.Unwrap[api.WebrtcAnswerRequest[com.Uid]](x.Payload) if dat == nil { return api.ErrMalformed } c.HandleWebrtcAnswer(*dat, w) case api.WebrtcIce: - dat := api.Unwrap[api.WebrtcIceCandidateRequest](x.Payload) + dat := api.Unwrap[api.WebrtcIceCandidateRequest[com.Uid]](x.Payload) if dat == nil { return api.ErrMalformed } c.HandleWebrtcIceCandidate(*dat, w) case api.StartGame: - if dat := api.Unwrap[api.StartGameRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.StartGameRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { out = c.HandleGameStart(*dat, w) } case api.TerminateSession: - dat := api.Unwrap[api.TerminateSessionRequest](x.Payload) + dat := api.Unwrap[api.TerminateSessionRequest[com.Uid]](x.Payload) if dat == nil { return api.ErrMalformed } c.HandleTerminateSession(*dat, w) case api.QuitGame: - dat := api.Unwrap[api.GameQuitRequest](x.Payload) + dat := api.Unwrap[api.GameQuitRequest[com.Uid]](x.Payload) if dat == nil { return api.ErrMalformed } c.HandleQuitGame(*dat, w) case api.SaveGame: - if dat := api.Unwrap[api.SaveGameRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.SaveGameRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { out = c.HandleSaveGame(*dat, w) } case api.LoadGame: - if dat := api.Unwrap[api.LoadGameRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.LoadGameRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { out = c.HandleLoadGame(*dat, w) } case api.ChangePlayer: - if dat := api.Unwrap[api.ChangePlayerRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.ChangePlayerRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { out = c.HandleChangePlayer(*dat, w) } case api.ToggleMultitap: - if dat := api.Unwrap[api.ToggleMultitapRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.ToggleMultitapRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { c.HandleToggleMultitap(*dat, w) } case api.RecordGame: - if dat := api.Unwrap[api.RecordGameRequest](x.Payload); dat == nil { - err, out = api.ErrMalformed, com.EmptyPacket + if dat := api.Unwrap[api.RecordGameRequest[com.Uid]](x.Payload); dat == nil { + err, out = api.ErrMalformed, api.EmptyPacket } else { out = c.HandleRecordGame(*dat, w) } default: - c.Log.Warn().Msgf("unhandled packet type %v", x.T) + c.log.Warn().Msgf("unhandled packet type %v", x.T) } if out != skipped { - w.cord.Route(x, out) + w.cord.Route(x, &out) } return err }) @@ -130,6 +152,6 @@ func (c *coordinator) RegisterRoom(id string) { c.Notify(api.RegisterRoom, id) } // CloseRoom sends a signal to coordinator which will remove that room from its list. func (c *coordinator) CloseRoom(id string) { c.Notify(api.CloseRoom, id) } -func (c *coordinator) IceCandidate(candidate string, sessionId network.Uid) { - c.Notify(api.NewWebrtcIceCandidateRequest(sessionId, candidate)) +func (c *coordinator) IceCandidate(candidate string, sessionId com.Uid) { + c.Notify(api.WebrtcIce, api.WebrtcIceCandidateRequest[com.Uid]{Stateful: api.Stateful[com.Uid]{Id: sessionId}, Candidate: candidate}) } diff --git a/pkg/worker/coordinatorhandlers.go b/pkg/worker/coordinatorhandlers.go index d9381d2a..ad6b611a 100644 --- a/pkg/worker/coordinatorhandlers.go +++ b/pkg/worker/coordinatorhandlers.go @@ -1,21 +1,22 @@ package worker import ( - "fmt" + "encoding/base64" "github.com/giongto35/cloud-game/v3/pkg/api" "github.com/giongto35/cloud-game/v3/pkg/com" "github.com/giongto35/cloud-game/v3/pkg/config/worker" "github.com/giongto35/cloud-game/v3/pkg/games" "github.com/giongto35/cloud-game/v3/pkg/network/webrtc" + "github.com/goccy/go-json" ) // buildConnQuery builds initial connection data query to a coordinator. -func buildConnQuery[S fmt.Stringer](id S, conf worker.Worker, address string) (string, error) { +func buildConnQuery(id com.Uid, conf worker.Worker, address string) (string, error) { addr := conf.GetPingAddr(address) - return api.ToBase64Json(api.ConnectionRequest{ + return toBase64Json(api.ConnectionRequest[com.Uid]{ Addr: addr.Hostname(), - Id: id.String(), + Id: id, IsHTTPS: conf.Server.Https, PingURL: addr.String(), Port: conf.GetPort(address), @@ -24,55 +25,55 @@ func buildConnQuery[S fmt.Stringer](id S, conf worker.Worker, address string) (s }) } -func (c *coordinator) HandleWebrtcInit(rq api.WebrtcInitRequest, w *Worker, connApi *webrtc.ApiFactory) com.Out { - peer := webrtc.New(c.Log, connApi) +func (c *coordinator) HandleWebrtcInit(rq api.WebrtcInitRequest[com.Uid], w *Worker, connApi *webrtc.ApiFactory) api.Out { + peer := webrtc.New(c.log, connApi) localSDP, err := peer.NewCall(w.conf.Encoder.Video.Codec, audioCodec, func(data any) { - candidate, err := api.ToBase64Json(data) + candidate, err := toBase64Json(data) if err != nil { - c.Log.Error().Err(err).Msgf("ICE candidate encode fail for [%v]", data) + c.log.Error().Err(err).Msgf("ICE candidate encode fail for [%v]", data) return } c.IceCandidate(candidate, rq.Id) }) if err != nil { - c.Log.Error().Err(err).Msg("cannot create new webrtc session") - return com.EmptyPacket + c.log.Error().Err(err).Msg("cannot create new webrtc session") + return api.EmptyPacket } - sdp, err := api.ToBase64Json(localSDP) + sdp, err := toBase64Json(localSDP) if err != nil { - c.Log.Error().Err(err).Msgf("SDP encode fail fro [%v]", localSDP) - return com.EmptyPacket + c.log.Error().Err(err).Msgf("SDP encode fail fro [%v]", localSDP) + return api.EmptyPacket } // use user uid from the coordinator user := NewSession(peer, rq.Id) w.router.AddUser(user) - c.Log.Info().Str("id", string(rq.Id)).Msgf("Peer connection (uid:%s)", user.Id()) + c.log.Info().Str("id", rq.Id.String()).Msgf("Peer connection (uid:%s)", user.Id()) - return com.Out{Payload: sdp} + return api.Out{Payload: sdp} } -func (c *coordinator) HandleWebrtcAnswer(rq api.WebrtcAnswerRequest, w *Worker) { +func (c *coordinator) HandleWebrtcAnswer(rq api.WebrtcAnswerRequest[com.Uid], w *Worker) { if user := w.router.GetUser(rq.Id); user != nil { - if err := user.GetPeerConn().SetRemoteSDP(rq.Sdp, api.FromBase64Json); err != nil { - c.Log.Error().Err(err).Msgf("cannot set remote SDP of client [%v]", rq.Id) + if err := user.GetPeerConn().SetRemoteSDP(rq.Sdp, fromBase64Json); err != nil { + c.log.Error().Err(err).Msgf("cannot set remote SDP of client [%v]", rq.Id) } } } -func (c *coordinator) HandleWebrtcIceCandidate(rs api.WebrtcIceCandidateRequest, w *Worker) { +func (c *coordinator) HandleWebrtcIceCandidate(rs api.WebrtcIceCandidateRequest[com.Uid], w *Worker) { if user := w.router.GetUser(rs.Id); user != nil { - if err := user.GetPeerConn().AddCandidate(rs.Candidate, api.FromBase64Json); err != nil { - c.Log.Error().Err(err).Msgf("cannot add ICE candidate of the client [%v]", rs.Id) + if err := user.GetPeerConn().AddCandidate(rs.Candidate, fromBase64Json); err != nil { + c.log.Error().Err(err).Msgf("cannot add ICE candidate of the client [%v]", rs.Id) } } } -func (c *coordinator) HandleGameStart(rq api.StartGameRequest, w *Worker) com.Out { +func (c *coordinator) HandleGameStart(rq api.StartGameRequest[com.Uid], w *Worker) api.Out { user := w.router.GetUser(rq.Id) if user == nil { - c.Log.Error().Msgf("no user [%v]", rq.Id) - return com.EmptyPacket + c.log.Error().Msgf("no user [%v]", rq.Id) + return api.EmptyPacket } w.log.Info().Msgf("Starting game: %v", rq.Game.Name) @@ -82,7 +83,7 @@ func (c *coordinator) HandleGameStart(rq api.StartGameRequest, w *Worker) com.Ou rq.Room.Rid, games.GameMetadata{Name: rq.Game.Name, Base: rq.Game.Base, Type: rq.Game.Type, Path: rq.Game.Path}, func(room *Room) { - w.router.RemoveRoom() + w.router.SetRoom(nil) c.CloseRoom(room.id) w.log.Debug().Msgf("Room close has been called %v", room.id) }, @@ -108,8 +109,8 @@ func (c *coordinator) HandleGameStart(rq api.StartGameRequest, w *Worker) com.Ou } if room == nil { - c.Log.Error().Msgf("couldn't create a room [%v]", rq.Id) - return com.EmptyPacket + c.log.Error().Msgf("couldn't create a room [%v]", rq.Id) + return api.EmptyPacket } if !room.HasUser(user) { @@ -120,13 +121,13 @@ func (c *coordinator) HandleGameStart(rq api.StartGameRequest, w *Worker) com.Ou c.RegisterRoom(room.GetId()) - return com.Out{Payload: api.StartGameResponse{Room: api.Room{Rid: room.GetId()}, Record: w.conf.Recording.Enabled}} + return api.Out{Payload: api.StartGameResponse{Room: api.Room{Rid: room.GetId()}, Record: w.conf.Recording.Enabled}} } // HandleTerminateSession handles cases when a user has been disconnected from the websocket of coordinator. -func (c *coordinator) HandleTerminateSession(rq api.TerminateSessionRequest, w *Worker) { +func (c *coordinator) HandleTerminateSession(rq api.TerminateSessionRequest[com.Uid], w *Worker) { if session := w.router.GetUser(rq.Id); session != nil { - w.router.RemoveUser(session) + w.router.RemoveDisconnect(session) if room := session.GetSetRoom(nil); room != nil { room.CleanupUser(session) } @@ -134,7 +135,7 @@ func (c *coordinator) HandleTerminateSession(rq api.TerminateSessionRequest, w * } // HandleQuitGame handles cases when a user manually exits the game. -func (c *coordinator) HandleQuitGame(rq api.GameQuitRequest, w *Worker) { +func (c *coordinator) HandleQuitGame(rq api.GameQuitRequest[com.Uid], w *Worker) { if user := w.router.GetUser(rq.Id); user != nil { // we don't strictly need a room id form the request, // since users hold their room reference @@ -145,65 +146,82 @@ func (c *coordinator) HandleQuitGame(rq api.GameQuitRequest, w *Worker) { } } -func (c *coordinator) HandleSaveGame(rq api.SaveGameRequest, w *Worker) com.Out { - if room := roomy(rq, w); room != nil { - if err := room.SaveGame(); err != nil { - c.Log.Error().Err(err).Msg("cannot save game state") - return com.ErrPacket - } - return com.OkPacket +func (c *coordinator) HandleSaveGame(rq api.SaveGameRequest[com.Uid], w *Worker) api.Out { + room := w.router.GetRoom(rq.Rid) + if room == nil { + return api.ErrPacket } - return com.ErrPacket + if err := room.SaveGame(); err != nil { + c.log.Error().Err(err).Msg("cannot save game state") + return api.ErrPacket + } + return api.OkPacket } -func (c *coordinator) HandleLoadGame(rq api.LoadGameRequest, w *Worker) com.Out { - if room := roomy(rq, w); room != nil { - if err := room.LoadGame(); err != nil { - c.Log.Error().Err(err).Msg("cannot load game state") - return com.ErrPacket - } - return com.OkPacket +func (c *coordinator) HandleLoadGame(rq api.LoadGameRequest[com.Uid], w *Worker) api.Out { + room := w.router.GetRoom(rq.Rid) + if room == nil { + return api.ErrPacket } - return com.ErrPacket + if err := room.LoadGame(); err != nil { + c.log.Error().Err(err).Msg("cannot load game state") + return api.ErrPacket + } + return api.OkPacket } -func (c *coordinator) HandleChangePlayer(rq api.ChangePlayerRequest, w *Worker) com.Out { +func (c *coordinator) HandleChangePlayer(rq api.ChangePlayerRequest[com.Uid], w *Worker) api.Out { user := w.router.GetUser(rq.Id) if user == nil || w.router.GetRoom(rq.Rid) == nil { - return com.Out{Payload: -1} // semi-predicates + return api.Out{Payload: -1} // semi-predicates } user.SetPlayerIndex(rq.Index) w.log.Info().Msgf("Updated player index to: %d", rq.Index) - return com.Out{Payload: rq.Index} + return api.Out{Payload: rq.Index} } -func (c *coordinator) HandleToggleMultitap(rq api.ToggleMultitapRequest, w *Worker) com.Out { - if room := roomy(rq, w); room != nil { - room.ToggleMultitap() - return com.OkPacket - } - return com.ErrPacket -} - -func (c *coordinator) HandleRecordGame(rq api.RecordGameRequest, w *Worker) com.Out { - if !w.conf.Recording.Enabled { - return com.ErrPacket - } - if room := roomy(rq, w); room != nil { - room.(*RecordingRoom).ToggleRecording(rq.Active, rq.User) - return com.OkPacket - } - return com.ErrPacket -} - -func roomy(rq api.RoomInterface, w *Worker) GamingRoom { - rid := rq.GetRoom() - if rid == "" { - return nil - } - room := w.router.GetRoom(rid) +func (c *coordinator) HandleToggleMultitap(rq api.ToggleMultitapRequest[com.Uid], w *Worker) api.Out { + room := w.router.GetRoom(rq.Rid) if room == nil { - return nil + return api.ErrPacket } - return room + room.ToggleMultitap() + return api.OkPacket +} + +func (c *coordinator) HandleRecordGame(rq api.RecordGameRequest[com.Uid], w *Worker) api.Out { + if !w.conf.Recording.Enabled { + return api.ErrPacket + } + room := w.router.GetRoom(rq.Rid) + if room == nil { + return api.ErrPacket + } + room.(*RecordingRoom).ToggleRecording(rq.Active, rq.User) + return api.OkPacket +} + +// fromBase64Json decodes data from a URL-encoded Base64+JSON string. +func fromBase64Json(data string, obj any) error { + b, err := base64.URLEncoding.DecodeString(data) + if err != nil { + return err + } + err = json.Unmarshal(b, obj) + if err != nil { + return err + } + return nil +} + +// toBase64Json encodes data to a URL-encoded Base64+JSON string. +func toBase64Json(data any) (string, error) { + if data == nil { + return "", nil + } + b, err := json.Marshal(data) + if err != nil { + return "", err + } + return base64.URLEncoding.EncodeToString(b), nil } diff --git a/pkg/worker/room.go b/pkg/worker/room.go index 56755e8a..eba8227c 100644 --- a/pkg/worker/room.go +++ b/pkg/worker/room.go @@ -82,7 +82,7 @@ func (r *Room) GetId() string { return r.id } func (r *Room) GetLog() *logger.Logger { return r.log } func (r *Room) HasSave() bool { return os.Exists(r.emulator.GetHashPath()) } func (r *Room) HasUser(u *Session) bool { return r != nil && r.users.Has(u.id) } -func (r *Room) IsEmpty() bool { return r.users.IsEmpty() } +func (r *Room) IsEmpty() bool { return r.users.Len() == 0 } func (r *Room) LoadGame() error { return r.emulator.LoadGameState() } func (r *Room) SaveGame() error { return r.emulator.SaveGameState() } func (r *Room) StartEmulator() { go r.emulator.Start() } @@ -144,14 +144,14 @@ func (r *Room) PollUserInput(session *Session) { func (r *Room) AddUser(user *Session) { r.users.Add(user) user.SetRoom(r) - r.log.Debug().Str("user", string(user.Id())).Msg("User has joined the room") + r.log.Debug().Str("user", user.Id().String()).Msg("User has joined the room") } func (r *Room) CleanupUser(user *Session) { user.SetRoom(nil) if r.HasUser(user) { r.users.Remove(user) - r.log.Debug().Str("user", string(user.Id())).Msg("User has left the room") + r.log.Debug().Str("user", user.Id().String()).Msg("User has left the room") } if r.IsEmpty() { r.log.Debug().Msg("The room is empty") diff --git a/pkg/worker/router.go b/pkg/worker/router.go index 11c14ae5..d3878cff 100644 --- a/pkg/worker/router.go +++ b/pkg/worker/router.go @@ -2,7 +2,6 @@ package worker import ( "github.com/giongto35/cloud-game/v3/pkg/com" - "github.com/giongto35/cloud-game/v3/pkg/network" "github.com/giongto35/cloud-game/v3/pkg/network/webrtc" "github.com/pion/webrtc/v3/pkg/media" ) @@ -18,9 +17,9 @@ type Router struct { // Session represents WebRTC connection of the user. type Session struct { - id network.Uid + id com.Uid conn *webrtc.Peer - pi int + pi int // player index room GamingRoom // back reference } @@ -39,19 +38,18 @@ func (r *Router) GetRoom(id string) GamingRoom { } return nil } -func (r *Router) GetUser(uid network.Uid) *Session { sess, _ := r.users.Find(string(uid)); return sess } -func (r *Router) RemoveRoom() { r.room = nil } -func (r *Router) RemoveUser(user *Session) { r.users.Remove(user); user.Close() } +func (r *Router) GetUser(uid com.Uid) *Session { sess, _ := r.users.Find(uid); return sess } +func (r *Router) RemoveDisconnect(user *Session) { r.users.Remove(user); user.Disconnect() } -func NewSession(rtc *webrtc.Peer, id network.Uid) *Session { return &Session{id: id, conn: rtc} } +func NewSession(rtc *webrtc.Peer, id com.Uid) *Session { return &Session{id: id, conn: rtc} } -func (s *Session) Id() network.Uid { return s.id } -func (s *Session) GetSetRoom(v GamingRoom) GamingRoom { vv := s.room; s.room = v; return vv } +func (s *Session) Disconnect() { s.conn.Disconnect() } func (s *Session) GetPeerConn() *webrtc.Peer { return s.conn } func (s *Session) GetPlayerIndex() int { return s.pi } +func (s *Session) GetSetRoom(v GamingRoom) GamingRoom { vv := s.room; s.room = v; return vv } +func (s *Session) Id() com.Uid { return s.id } func (s *Session) IsConnected() bool { return s.conn.IsConnected() } -func (s *Session) SendVideo(sample *media.Sample) error { return s.conn.WriteVideo(sample) } func (s *Session) SendAudio(sample *media.Sample) error { return s.conn.WriteAudio(sample) } -func (s *Session) SetRoom(room GamingRoom) { s.room = room } +func (s *Session) SendVideo(sample *media.Sample) error { return s.conn.WriteVideo(sample) } func (s *Session) SetPlayerIndex(index int) { s.pi = index } -func (s *Session) Close() { s.conn.Disconnect() } +func (s *Session) SetRoom(room GamingRoom) { s.room = room } diff --git a/pkg/worker/worker.go b/pkg/worker/worker.go index 1ac216e3..46d25926 100644 --- a/pkg/worker/worker.go +++ b/pkg/worker/worker.go @@ -64,7 +64,7 @@ func (w *Worker) Run() { remoteAddr := w.conf.Worker.Network.CoordinatorAddress defer func() { if w.cord != nil { - w.cord.Close() + w.cord.Disconnect() } w.router.Close() w.log.Debug().Msgf("Service loop end") @@ -75,16 +75,15 @@ func (w *Worker) Run() { case <-w.done: return default: - conn, err := connect(remoteAddr, w.conf.Worker, w.address, w.log) + cord, err := newCoordinatorConnection(remoteAddr, w.conf.Worker, w.address, w.log) if err != nil { w.log.Error().Err(err).Msgf("no connection: %v. Retrying in %v", remoteAddr, retry) time.Sleep(retry) continue } - w.cord = conn - w.cord.Log.Info().Msgf("Connected to the coordinator %v", remoteAddr) - w.cord.HandleRequests(w) - <-w.cord.Done() + w.cord = cord + w.cord.log.Info().Msgf("Connected to the coordinator %v", remoteAddr) + <-w.cord.HandleRequests(w) w.router.Close() } }