diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index fbeedd86..c90cdc32 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -1,8 +1,8 @@ package mapper import ( - "context" "crypto/rand" + "errors" "fmt" "sync" "sync/atomic" @@ -16,6 +16,8 @@ import ( "tailscale.com/types/ptr" ) +var errConnectionClosed = errors.New("connection channel already closed") + // LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. type LockFreeBatcher struct { tick *time.Ticker @@ -26,9 +28,9 @@ type LockFreeBatcher struct { connected *xsync.Map[types.NodeID, *time.Time] // Work queue channel - workCh chan work - ctx context.Context - cancel context.CancelFunc + workCh chan work + workChOnce sync.Once // Ensures workCh is only closed once + done chan struct{} // Batching state pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet] @@ -144,23 +146,20 @@ func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) { } func (b *LockFreeBatcher) Start() { - b.ctx, b.cancel = context.WithCancel(context.Background()) + b.done = make(chan struct{}) go b.doWork() } func (b *LockFreeBatcher) Close() { - if b.cancel != nil { - b.cancel() - b.cancel = nil + // Signal shutdown to all goroutines + if b.done != nil { + close(b.done) } - // Only close workCh once - select { - case <-b.workCh: - // Channel is already closed - default: + // Only close workCh once using sync.Once to prevent races + b.workChOnce.Do(func() { close(b.workCh) - } + }) // Close the underlying channels supplying the data to the clients. b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { @@ -186,8 +185,8 @@ func (b *LockFreeBatcher) doWork() { case <-cleanupTicker.C: // Clean up nodes that have been offline for too long b.cleanupOfflineNodes() - case <-b.ctx.Done(): - log.Info().Msg("batcher context done, stopping to feed workers") + case <-b.done: + log.Info().Msg("batcher done channel closed, stopping to feed workers") return } } @@ -235,7 +234,7 @@ func (b *LockFreeBatcher) worker(workerID int) { // Send result select { case w.resultCh <- result: - case <-b.ctx.Done(): + case <-b.done: return } @@ -258,8 +257,8 @@ func (b *LockFreeBatcher) worker(workerID int) { Msg("failed to apply change") } } - case <-b.ctx.Done(): - log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker") + case <-b.done: + log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker") return } } @@ -276,7 +275,7 @@ func (b *LockFreeBatcher) queueWork(w work) { select { case b.workCh <- w: // Successfully queued - case <-b.ctx.Done(): + case <-b.done: // Batcher is shutting down return } @@ -443,7 +442,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.Change select { case result := <-resultCh: return result.mapResponse, result.err - case <-b.ctx.Done(): + case <-b.done: return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) } } @@ -455,6 +454,7 @@ type connectionEntry struct { version tailcfg.CapabilityVersion created time.Time lastUsed atomic.Int64 // Unix timestamp of last successful send + closed atomic.Bool // Indicates if this connection has been closed } // multiChannelNodeConn manages multiple concurrent connections for a single node. @@ -488,6 +488,9 @@ func (mc *multiChannelNodeConn) close() { defer mc.mutex.Unlock() for _, conn := range mc.connections { + // Mark as closed before closing the channel to prevent + // send on closed channel panics from concurrent workers + conn.closed.Store(true) close(conn.c) } } @@ -620,6 +623,12 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { return nil } + // Check if the connection has been closed to prevent send on closed channel panic. + // This can happen during shutdown when Close() is called while workers are still processing. + if entry.closed.Load() { + return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed) + } + // Use a short timeout to detect stale connections where the client isn't reading the channel. // This is critical for detecting Docker containers that are forcefully terminated // but still have channels that appear open. diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 8cbcaa75..f43ea5a1 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -147,12 +147,12 @@ type node struct { n *types.Node ch chan *tailcfg.MapResponse - // Update tracking + // Update tracking (all accessed atomically for thread safety) updateCount int64 patchCount int64 fullCount int64 - maxPeersCount int - lastPeerCount int + maxPeersCount atomic.Int64 + lastPeerCount atomic.Int64 stop chan struct{} stopped chan struct{} } @@ -422,18 +422,32 @@ func (n *node) start() { // Track update types if info.IsFull { atomic.AddInt64(&n.fullCount, 1) - n.lastPeerCount = info.PeerCount - // Update max peers seen - if info.PeerCount > n.maxPeersCount { - n.maxPeersCount = info.PeerCount + n.lastPeerCount.Store(int64(info.PeerCount)) + // Update max peers seen using compare-and-swap for thread safety + for { + current := n.maxPeersCount.Load() + if int64(info.PeerCount) <= current { + break + } + + if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) { + break + } } } if info.IsPatch { atomic.AddInt64(&n.patchCount, 1) - // For patches, we track how many patch items - if info.PatchCount > n.maxPeersCount { - n.maxPeersCount = info.PatchCount + // For patches, we track how many patch items using compare-and-swap + for { + current := n.maxPeersCount.Load() + if int64(info.PatchCount) <= current { + break + } + + if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) { + break + } } } } @@ -465,8 +479,8 @@ func (n *node) cleanup() NodeStats { TotalUpdates: atomic.LoadInt64(&n.updateCount), PatchUpdates: atomic.LoadInt64(&n.patchCount), FullUpdates: atomic.LoadInt64(&n.fullCount), - MaxPeersSeen: n.maxPeersCount, - LastPeerCount: n.lastPeerCount, + MaxPeersSeen: int(n.maxPeersCount.Load()), + LastPeerCount: int(n.lastPeerCount.Load()), } } @@ -665,7 +679,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { connectedCount := 0 for i := range allNodes { node := &allNodes[i] - currentMaxPeers := node.maxPeersCount + + currentMaxPeers := int(node.maxPeersCount.Load()) if currentMaxPeers >= expectedPeers { connectedCount++ }