headscale/hscontrol/mapper/batcher.go
Kristoffer Dalby 3b4b9a4436 hscontrol: fix tag updates not propagating to node self view
When SetNodeTags changed a node's tags, the node's self view wasn't
updated. The bug manifested as: the first SetNodeTags call updates
the server but the client's self view doesn't update until a second
call with the same tag.

Root cause: Three issues combined to prevent self-updates:

1. SetNodeTags returned PolicyChange which doesn't set OriginNode,
   so the mapper's self-update check failed.

2. The Change.Merge function didn't preserve OriginNode, so when
   changes were batched together, OriginNode was lost.

3. generateMapResponse checked OriginNode only in buildFromChange(),
   but PolicyChange uses RequiresRuntimePeerComputation which
   bypasses that code path entirely and calls policyChangeResponse()
   instead.

The fix addresses all three:
- state.go: Set OriginNode on the returned change
- change.go: Preserve OriginNode (and TargetNode) during merge
- batcher.go: Pass isSelfUpdate to policyChangeResponse so the
  origin node gets both self info AND packet filters
- mapper.go: Add includeSelf parameter to policyChangeResponse

Fixes #2978
2026-01-20 10:13:47 +01:00

178 lines
5.8 KiB
Go

package mapper
import (
"errors"
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/prometheus/client_golang/prometheus"
"github.com/prometheus/client_golang/prometheus/promauto"
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
)
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type",
}, []string{"response_type"})
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
// Batcher defines the common interface for all batcher implementations.
type Batcher interface {
Start()
Close()
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(r ...change.Change)
MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error)
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
}
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
return &LockFreeBatcher{
mapper: mapper,
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
}
}
// NewBatcherAndMapper creates a Batcher implementation.
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
m := newMapper(cfg, state)
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
m.batcher = b
return b
}
// nodeConnection interface for different connection implementations.
type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
// updateSentPeers updates the tracking of which peers have been sent to this node.
updateSentPeers(resp *tailcfg.MapResponse)
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
nodeID := nc.nodeID()
version := nc.version()
if r.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send
}
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
// Handle self-only responses
if r.IsSelfOnly() && r.TargetNode != nodeID {
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
}
// Check if this is a self-update (the changed node is the receiving node).
// When true, ensure the response includes the node's self info so it sees
// its own attribute changes (e.g., tags changed via admin API).
isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID
var (
mapResp *tailcfg.MapResponse
err error
)
// Track metric using categorized type, not free-form reason
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
// Check if this requires runtime peer visibility computation (e.g., policy changes)
if r.RequiresRuntimePeerComputation {
currentPeers := mapper.state.ListPeers(nodeID)
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
for _, peer := range currentPeers.All() {
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
}
removedPeers := nc.computePeerDiff(currentPeerIDs)
// Include self node when this is a self-update (e.g., node's own tags changed)
// so the node sees its updated self info along with new packet filters.
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate)
} else if isSelfUpdate {
// Non-policy self-update: just send the self node info
mapResp, err = mapper.selfMapResponse(nodeID, version)
} else {
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return errors.New("nodeConnection is nil")
}
nodeID := nc.nodeID()
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
data, err := generateMapResponse(nc, mapper, r)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some response types
return nil
}
// Send the map response
err = nc.send(data)
if err != nil {
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
// Update peer tracking after successful send
nc.updateSentPeers(data)
return nil
}
// workResult represents the result of processing a change.
type workResult struct {
mapResponse *tailcfg.MapResponse
err error
}
// work represents a unit of work to be processed by workers.
type work struct {
c change.Change
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}