mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-23 02:24:10 +00:00
Disable the thelper linter which triggers on inline test closures in table-driven tests. These closures are intentionally not standalone helpers and don't benefit from t.Helper(). Also remove explanatory comments from nolint directives throughout the codebase as they add noise without providing significant value.
188 lines
6.1 KiB
Go
188 lines
6.1 KiB
Go
package mapper
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"time"
|
|
|
|
"github.com/juanfont/headscale/hscontrol/state"
|
|
"github.com/juanfont/headscale/hscontrol/types"
|
|
"github.com/juanfont/headscale/hscontrol/types/change"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"github.com/prometheus/client_golang/prometheus/promauto"
|
|
"github.com/puzpuzpuz/xsync/v4"
|
|
"github.com/rs/zerolog/log"
|
|
"tailscale.com/tailcfg"
|
|
)
|
|
|
|
// Sentinel errors for batcher operations.
|
|
var (
|
|
ErrInvalidNodeID = errors.New("invalid nodeID")
|
|
ErrMapperNil = errors.New("mapper is nil")
|
|
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
|
)
|
|
|
|
// workChannelMultiplier is the multiplier for work channel capacity based on worker count.
|
|
// The size is arbitrary chosen, the sizing should be revisited.
|
|
const workChannelMultiplier = 200
|
|
|
|
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
|
Namespace: "headscale",
|
|
Name: "mapresponse_generated_total",
|
|
Help: "total count of mapresponses generated by response type",
|
|
}, []string{"response_type"})
|
|
|
|
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
|
|
|
|
// Batcher defines the common interface for all batcher implementations.
|
|
type Batcher interface {
|
|
Start()
|
|
Close()
|
|
AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error
|
|
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
|
|
IsConnected(id types.NodeID) bool
|
|
ConnectedMap() *xsync.Map[types.NodeID, bool]
|
|
AddWork(r ...change.Change)
|
|
MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error)
|
|
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
|
|
}
|
|
|
|
func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher {
|
|
return &LockFreeBatcher{
|
|
mapper: mapper,
|
|
workers: workers,
|
|
tick: time.NewTicker(batchTime),
|
|
|
|
workCh: make(chan work, workers*workChannelMultiplier),
|
|
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
|
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
|
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
|
|
}
|
|
}
|
|
|
|
// NewBatcherAndMapper creates a Batcher implementation.
|
|
func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher {
|
|
m := newMapper(cfg, state)
|
|
b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m)
|
|
m.batcher = b
|
|
|
|
return b
|
|
}
|
|
|
|
// nodeConnection interface for different connection implementations.
|
|
type nodeConnection interface {
|
|
nodeID() types.NodeID
|
|
version() tailcfg.CapabilityVersion
|
|
send(data *tailcfg.MapResponse) error
|
|
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
|
|
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
|
|
// updateSentPeers updates the tracking of which peers have been sent to this node.
|
|
updateSentPeers(resp *tailcfg.MapResponse)
|
|
}
|
|
|
|
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
|
|
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
|
|
nodeID := nc.nodeID()
|
|
version := nc.version()
|
|
|
|
if r.IsEmpty() {
|
|
return nil, nil //nolint:nilnil
|
|
}
|
|
|
|
if nodeID == 0 {
|
|
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
|
}
|
|
|
|
if mapper == nil {
|
|
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
|
}
|
|
|
|
// Handle self-only responses
|
|
if r.IsSelfOnly() && r.TargetNode != nodeID {
|
|
return nil, nil //nolint:nilnil
|
|
}
|
|
|
|
// Check if this is a self-update (the changed node is the receiving node).
|
|
// When true, ensure the response includes the node's self info so it sees
|
|
// its own attribute changes (e.g., tags changed via admin API).
|
|
isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID
|
|
|
|
var (
|
|
mapResp *tailcfg.MapResponse
|
|
err error
|
|
)
|
|
|
|
// Track metric using categorized type, not free-form reason
|
|
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
|
|
|
|
// Check if this requires runtime peer visibility computation (e.g., policy changes)
|
|
if r.RequiresRuntimePeerComputation {
|
|
currentPeers := mapper.state.ListPeers(nodeID)
|
|
|
|
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
|
|
for _, peer := range currentPeers.All() {
|
|
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
|
|
}
|
|
|
|
removedPeers := nc.computePeerDiff(currentPeerIDs)
|
|
// Include self node when this is a self-update (e.g., node's own tags changed)
|
|
// so the node sees its updated self info along with new packet filters.
|
|
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate)
|
|
} else if isSelfUpdate {
|
|
// Non-policy self-update: just send the self node info
|
|
mapResp, err = mapper.selfMapResponse(nodeID, version)
|
|
} else {
|
|
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
|
|
}
|
|
|
|
return mapResp, nil
|
|
}
|
|
|
|
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
|
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
|
if nc == nil {
|
|
return ErrNodeConnectionNil
|
|
}
|
|
|
|
nodeID := nc.nodeID()
|
|
|
|
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
|
|
|
|
data, err := generateMapResponse(nc, mapper, r)
|
|
if err != nil {
|
|
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
|
|
}
|
|
|
|
if data == nil {
|
|
// No data to send is valid for some response types
|
|
return nil
|
|
}
|
|
|
|
// Send the map response
|
|
err = nc.send(data)
|
|
if err != nil {
|
|
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
|
|
}
|
|
|
|
// Update peer tracking after successful send
|
|
nc.updateSentPeers(data)
|
|
|
|
return nil
|
|
}
|
|
|
|
// workResult represents the result of processing a change.
|
|
type workResult struct {
|
|
mapResponse *tailcfg.MapResponse
|
|
err error
|
|
}
|
|
|
|
// work represents a unit of work to be processed by workers.
|
|
type work struct {
|
|
c change.Change
|
|
nodeID types.NodeID
|
|
resultCh chan<- workResult // optional channel for synchronous operations
|
|
}
|