headscale/hscontrol/state/node_store_test.go
Kristoffer Dalby b36438bf90 all: disable thelper linter and clean up nolint comments
Disable the thelper linter which triggers on inline test closures in
table-driven tests. These closures are intentionally not standalone
helpers and don't benefit from t.Helper().

Also remove explanatory comments from nolint directives throughout the
codebase as they add noise without providing significant value.
2026-01-21 15:56:57 +00:00

1339 lines
41 KiB
Go

package state
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime"
"sync"
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
)
// Test sentinel errors for concurrent operations.
var (
errTestUpdateNodeFailed = errors.New("UpdateNode failed")
errTestGetNodeFailed = errors.New("GetNode failed")
errTestPutNodeFailed = errors.New("PutNode failed")
)
func TestSnapshotFromNodes(t *testing.T) {
tests := []struct {
name string
setupFunc func() (map[types.NodeID]types.Node, PeersFunc)
validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot)
}{
{
name: "empty nodes",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := make(map[types.NodeID]types.Node)
peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
return make(map[types.NodeID][]types.NodeView)
}
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "single node",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers
assert.Len(t, snapshot.nodesByUser[1], 1)
assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID())
},
},
{
name: "multiple nodes same user",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 1, "user1", "node2"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Each node sees the other as peer (but not itself)
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "multiple nodes different users",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 1, "user1", "node3"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// Each node should have 2 peers (all others, but not itself)
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2
},
},
{
name: "odd-even peers filtering",
setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
2: createTestNode(2, 2, "user2", "node2"),
3: createTestNode(3, 3, "user3", "node3"),
4: createTestNode(4, 4, "user4", "node4"),
}
peersFunc := oddEvenPeersFunc
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 4)
assert.Len(t, snapshot.allNodes, 4)
assert.Len(t, snapshot.peersByNode, 4)
assert.Len(t, snapshot.nodesByUser, 4)
// Odd nodes should only see other odd nodes as peers
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// Even nodes should only see other even nodes as peers
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
nodes, peersFunc := tt.setupFunc()
snapshot := snapshotFromNodes(nodes, peersFunc)
tt.validate(t, nodes, snapshot)
})
}
}
// Helper functions
func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node {
now := time.Now()
machineKey := key.NewMachine()
nodeKey := key.NewNode()
discoKey := key.NewDisco()
ipv4 := netip.MustParseAddr("100.64.0.1")
ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1")
return types.Node{
ID: nodeID,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: new(userID),
User: &types.User{
Name: username,
DisplayName: username,
},
RegisterMethod: "test",
IPv4: &ipv4,
IPv6: &ipv6,
CreatedAt: now,
UpdatedAt: now,
}
}
// Peer functions
func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
nodeIsOdd := node.ID()%2 == 1
for _, n := range nodes {
if n.ID() == node.ID() {
continue
}
peerIsOdd := n.ID()%2 == 1
// Only add peer if both are odd or both are even
if nodeIsOdd == peerIsOdd {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
return ret
}
func TestNodeStoreOperations(t *testing.T) {
tests := []struct {
name string
setupFunc func(t *testing.T) *NodeStore
steps []testStep
}{
{
name: "create empty store and add single node",
setupFunc: func(t *testing.T) *NodeStore {
t.Helper()
return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "verify empty store",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
{
name: "add first node",
action: func(store *NodeStore) {
node := createTestNode(1, 1, "user1", "node1")
resultNode := store.PutNode(node)
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
assert.Equal(t, node.ID, resultNode.ID())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
require.Contains(t, snapshot.nodesByID, types.NodeID(1))
assert.Equal(t, node.ID, snapshot.nodesByID[1].ID)
assert.Empty(t, snapshot.peersByNode[1]) // no peers yet
assert.Len(t, snapshot.nodesByUser[1], 1)
},
},
},
},
{
name: "create store with initial node and add more",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
initialNodes := types.Nodes{&node1}
return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "verify initial state",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
assert.Len(t, snapshot.nodesByUser, 1)
assert.Empty(t, snapshot.peersByNode[1])
},
},
{
name: "add second node same user",
action: func(store *NodeStore) {
node2 := createTestNode(2, 1, "user1", "node2")
resultNode := store.PutNode(node2)
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
assert.Equal(t, types.NodeID(2), resultNode.ID())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 1)
// Now both nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
assert.Len(t, snapshot.nodesByUser[1], 2)
},
},
{
name: "add third node different user",
action: func(store *NodeStore) {
node3 := createTestNode(3, 2, "user2", "node3")
resultNode := store.PutNode(node3)
assert.True(t, resultNode.Valid(), "PutNode should return valid node")
assert.Equal(t, types.NodeID(3), resultNode.ID())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
// All nodes should see the other 2 as peers
assert.Len(t, snapshot.peersByNode[1], 2)
assert.Len(t, snapshot.peersByNode[2], 2)
assert.Len(t, snapshot.peersByNode[3], 2)
// User groupings
assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3
},
},
},
},
{
name: "test node deletion",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
node3 := createTestNode(3, 2, "user2", "node3")
initialNodes := types.Nodes{&node1, &node2, &node3}
return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "verify initial 3 nodes",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
assert.Len(t, snapshot.nodesByUser, 2)
},
},
{
name: "delete middle node",
action: func(store *NodeStore) {
store.DeleteNode(2)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
assert.Len(t, snapshot.nodesByUser, 2)
// Node 2 should be gone
assert.NotContains(t, snapshot.nodesByID, types.NodeID(2))
// Remaining nodes should see each other as peers
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
assert.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
// User groupings updated
assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1
assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3
},
},
{
name: "delete all remaining nodes",
action: func(store *NodeStore) {
store.DeleteNode(1)
store.DeleteNode(3)
snapshot := store.data.Load()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
assert.Empty(t, snapshot.nodesByUser)
},
},
},
},
{
name: "test node updates",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "verify initial hostnames",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
},
},
{
name: "update node hostname",
action: func(store *NodeStore) {
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "updated-node1"
n.GivenName = "updated-node1"
})
assert.True(t, ok, "UpdateNode should return true for existing node")
assert.True(t, resultNode.Valid(), "Result node should be valid")
assert.Equal(t, "updated-node1", resultNode.Hostname())
assert.Equal(t, "updated-node1", resultNode.GivenName())
snapshot := store.data.Load()
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged
// Peers should still work correctly
assert.Len(t, snapshot.peersByNode[1], 1)
assert.Len(t, snapshot.peersByNode[2], 1)
},
},
},
},
{
name: "test with odd-even peers filtering",
setupFunc: func(t *testing.T) *NodeStore {
return NewNodeStore(nil, oddEvenPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "add nodes with odd-even filtering",
action: func(store *NodeStore) {
// Add nodes in sequence
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
assert.True(t, n1.Valid())
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
assert.True(t, n2.Valid())
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
assert.True(t, n3.Valid())
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
assert.True(t, n4.Valid())
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 4)
// Verify odd-even peer relationships
require.Len(t, snapshot.peersByNode[1], 1)
assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID())
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[3], 1)
assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
{
name: "delete odd node and verify even nodes unaffected",
action: func(store *NodeStore) {
store.DeleteNode(1)
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
// Node 3 (odd) should now have no peers
assert.Empty(t, snapshot.peersByNode[3])
// Even nodes should still see each other
require.Len(t, snapshot.peersByNode[2], 1)
assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID())
require.Len(t, snapshot.peersByNode[4], 1)
assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID())
},
},
},
},
{
name: "test batch modifications return correct node state",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "verify initial state",
action: func(store *NodeStore) {
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 2)
assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname)
},
},
{
name: "concurrent updates should reflect all batch changes",
action: func(store *NodeStore) {
// Start multiple updates that will be batched together
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var (
resultNode1, resultNode2 types.NodeView
newNode3 types.NodeView
ok1, ok2 bool
)
// These should all be processed in the same batch
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "batch-updated-node1"
n.GivenName = "batch-given-1"
})
close(done1)
}()
go func() {
resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) {
n.Hostname = "batch-updated-node2"
n.GivenName = "batch-given-2"
})
close(done2)
}()
go func() {
node3 := createTestNode(3, 1, "user1", "node3")
newNode3 = store.PutNode(node3)
close(done3)
}()
// Wait for all operations to complete
<-done1
<-done2
<-done3
// Verify the returned nodes reflect the batch state
assert.True(t, ok1, "UpdateNode should succeed for node 1")
assert.True(t, ok2, "UpdateNode should succeed for node 2")
assert.True(t, resultNode1.Valid())
assert.True(t, resultNode2.Valid())
assert.True(t, newNode3.Valid())
// Check that returned nodes have the updated values
assert.Equal(t, "batch-updated-node1", resultNode1.Hostname())
assert.Equal(t, "batch-given-1", resultNode1.GivenName())
assert.Equal(t, "batch-updated-node2", resultNode2.Hostname())
assert.Equal(t, "batch-given-2", resultNode2.GivenName())
assert.Equal(t, "node3", newNode3.Hostname())
// Verify the snapshot also reflects all changes
snapshot := store.data.Load()
assert.Len(t, snapshot.nodesByID, 3)
assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname)
assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname)
assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname)
// Verify peer relationships are updated correctly with new node
assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3
assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3
assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2
},
},
{
name: "update non-existent node returns invalid view",
action: func(store *NodeStore) {
resultNode, ok := store.UpdateNode(999, func(n *types.Node) {
n.Hostname = "should-not-exist"
})
assert.False(t, ok, "UpdateNode should return false for non-existent node")
assert.False(t, resultNode.Valid(), "Result should be invalid NodeView")
},
},
{
name: "multiple updates to same node in batch all see final state",
action: func(store *NodeStore) {
// This test verifies that when multiple updates to the same node
// are batched together, each returned node reflects ALL changes
// in the batch, not just the individual update's changes.
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var (
resultNode1, resultNode2, resultNode3 types.NodeView
ok1, ok2, ok3 bool
)
// These updates all modify node 1 and should be batched together
// The final state should have all three modifications applied
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "multi-update-hostname"
})
close(done1)
}()
go func() {
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "multi-update-givenname"
})
close(done2)
}()
go func() {
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.Tags = []string{"tag1", "tag2"}
})
close(done3)
}()
// Wait for all operations to complete
<-done1
<-done2
<-done3
// All updates should succeed
assert.True(t, ok1, "First update should succeed")
assert.True(t, ok2, "Second update should succeed")
assert.True(t, ok3, "Third update should succeed")
// CRITICAL: Each returned node should reflect ALL changes from the batch
// not just the change from its specific update call
// resultNode1 (from hostname update) should also have the givenname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode1.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode1.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.Tags().AsSlice())
// resultNode2 (from givenname update) should also have the hostname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode2.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode2.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.Tags().AsSlice())
// resultNode3 (from tags update) should also have the hostname and givenname changes
assert.Equal(t, "multi-update-hostname", resultNode3.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode3.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.Tags().AsSlice())
// Verify the snapshot also has all changes
snapshot := store.data.Load()
finalNode := snapshot.nodesByID[1]
assert.Equal(t, "multi-update-hostname", finalNode.Hostname)
assert.Equal(t, "multi-update-givenname", finalNode.GivenName)
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.Tags)
},
},
},
},
{
name: "test UpdateNode result is immutable for database save",
setupFunc: func(t *testing.T) *NodeStore {
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 1, "user1", "node2")
initialNodes := types.Nodes{&node1, &node2}
return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
{
name: "verify returned node is complete and consistent",
action: func(store *NodeStore) {
// Update a node and verify the returned view is complete
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "db-save-hostname"
n.GivenName = "db-save-given"
n.Tags = []string{"db-tag1", "db-tag2"}
})
assert.True(t, ok, "UpdateNode should succeed")
assert.True(t, resultNode.Valid(), "Result should be valid")
// Verify the returned node has all expected values
assert.Equal(t, "db-save-hostname", resultNode.Hostname())
assert.Equal(t, "db-save-given", resultNode.GivenName())
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.Tags().AsSlice())
// Convert to struct as would be done for database save
nodePtr := resultNode.AsStruct()
assert.NotNil(t, nodePtr)
assert.Equal(t, "db-save-hostname", nodePtr.Hostname)
assert.Equal(t, "db-save-given", nodePtr.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.Tags)
// Verify the snapshot also reflects the same state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, "db-save-hostname", storedNode.Hostname)
assert.Equal(t, "db-save-given", storedNode.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.Tags)
},
},
{
name: "concurrent updates all return consistent final state for DB save",
action: func(store *NodeStore) {
// Multiple goroutines updating the same node
// All should receive the final batch state suitable for DB save
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var (
result1, result2, result3 types.NodeView
ok1, ok2, ok3 bool
)
// Start concurrent updates
go func() {
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "concurrent-db-hostname"
})
close(done1)
}()
go func() {
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "concurrent-db-given"
})
close(done2)
}()
go func() {
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.Tags = []string{"concurrent-tag"}
})
close(done3)
}()
// Wait for all to complete
<-done1
<-done2
<-done3
assert.True(t, ok1 && ok2 && ok3, "All updates should succeed")
// All results should be valid and suitable for database save
assert.True(t, result1.Valid())
assert.True(t, result2.Valid())
assert.True(t, result3.Valid())
// Convert each to struct as would be done for DB save
nodePtr1 := result1.AsStruct()
nodePtr2 := result2.AsStruct()
nodePtr3 := result3.AsStruct()
// All should have the complete final state
assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.Tags)
assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.Tags)
assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.Tags)
// Verify consistency with stored state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname)
assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName)
assert.Equal(t, nodePtr1.Tags, storedNode.Tags)
},
},
{
name: "verify returned node preserves all fields for DB save",
action: func(store *NodeStore) {
// Get initial state
snapshot := store.data.Load()
originalNode := snapshot.nodesByID[2]
originalIPv4 := originalNode.IPv4
originalIPv6 := originalNode.IPv6
originalCreatedAt := originalNode.CreatedAt
originalUser := originalNode.User
// Update only hostname
resultNode, ok := store.UpdateNode(2, func(n *types.Node) {
n.Hostname = "preserve-test-hostname"
})
assert.True(t, ok, "Update should succeed")
// Convert to struct for DB save
nodeForDB := resultNode.AsStruct()
// Verify all fields are preserved
assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname)
assert.Equal(t, originalIPv4, nodeForDB.IPv4)
assert.Equal(t, originalIPv6, nodeForDB.IPv6)
assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt)
assert.Equal(t, originalUser.Name, nodeForDB.User.Name)
assert.Equal(t, types.NodeID(2), nodeForDB.ID)
// These fields should be suitable for direct database save
assert.NotNil(t, nodeForDB.IPv4)
assert.NotNil(t, nodeForDB.IPv6)
assert.False(t, nodeForDB.CreatedAt.IsZero())
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := tt.setupFunc(t)
store.Start()
defer store.Stop()
for _, step := range tt.steps {
t.Run(step.name, func(t *testing.T) {
step.action(store)
})
}
})
}
}
type testStep struct {
name string
action func(store *NodeStore)
}
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
// Helper for concurrent test nodes.
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
machineKey := key.NewMachine()
nodeKey := key.NewNode()
return types.Node{
ID: id,
Hostname: hostname,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
UserID: new(uint(1)),
User: &types.User{
Name: "concurrent-test-user",
},
}
}
// --- Concurrency: concurrent PutNode operations ---.
func TestNodeStoreConcurrentPutNode(t *testing.T) {
const concurrentOps = 20
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, concurrentOps)
for i := range concurrentOps {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
//nolint:gosec
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
}
// --- Batching: concurrent ops fit in one batch ---.
func TestNodeStoreBatchingEfficiency(t *testing.T) {
const ops = 15
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, ops)
for i := range ops {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
//nolint:gosec
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
}
// --- Race conditions: many goroutines on same node ---.
func TestNodeStoreRaceConditions(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
nodeID := types.NodeID(1)
node := createConcurrentTestNode(nodeID, "race-node")
resultNode := store.PutNode(node)
require.True(t, resultNode.Valid())
const (
numGoroutines = 30
opsPerGoroutine = 10
)
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*opsPerGoroutine)
for i := range numGoroutines {
wg.Add(1)
go func(gid int) {
defer wg.Done()
for j := range opsPerGoroutine {
switch j % 3 {
case 0:
resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "race-updated"
})
if !resultNode.Valid() {
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestUpdateNodeFailed, gid, j)
}
case 1:
retrieved, found := store.GetNode(nodeID)
if !found || !retrieved.Valid() {
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestGetNodeFailed, gid, j)
}
case 2:
newNode := createConcurrentTestNode(nodeID, "race-put")
resultNode := store.PutNode(newNode)
if !resultNode.Valid() {
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestPutNodeFailed, gid, j)
}
}
}
}(i)
}
wg.Wait()
close(errors)
errorCount := 0
for err := range errors {
t.Error(err)
errorCount++
}
if errorCount > 0 {
t.Fatalf("Race condition test failed with %d errors", errorCount)
}
}
// --- Resource cleanup: goroutine leak detection ---.
func TestNodeStoreResourceCleanup(t *testing.T) {
// initialGoroutines := runtime.NumGoroutine()
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
// Wait for store to be ready
var afterStartGoroutines int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
afterStartGoroutines = runtime.NumGoroutine()
assert.Positive(c, afterStartGoroutines) // Just ensure we have a valid count
}, time.Second, 10*time.Millisecond, "store should be running")
const ops = 100
for i := range ops {
nodeID := types.NodeID(i + 1) //nolint:gosec
node := createConcurrentTestNode(nodeID, "cleanup-node")
resultNode := store.PutNode(node)
assert.True(t, resultNode.Valid())
store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "cleanup-updated"
})
retrieved, found := store.GetNode(nodeID)
assert.True(t, found && retrieved.Valid())
if i%10 == 9 {
store.DeleteNode(nodeID)
}
}
runtime.GC()
// Wait for goroutines to settle and check for leaks
assert.EventuallyWithT(t, func(c *assert.CollectT) {
finalGoroutines := runtime.NumGoroutine()
assert.LessOrEqual(c, finalGoroutines, afterStartGoroutines+2,
"Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines)
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
}
// --- Timeout/deadlock: operations complete within reasonable time ---.
func TestNodeStoreOperationTimeout(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
const ops = 30
var wg sync.WaitGroup
putResults := make([]error, ops)
updateResults := make([]error, ops)
// Launch all PutNode operations concurrently
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i) //nolint:gosec
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
node := createConcurrentTestNode(id, "timeout-node")
resultNode := store.PutNode(node)
endPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
if !resultNode.Valid() {
putResults[idx-1] = fmt.Errorf("%w for node %d", errTestPutNodeFailed, id)
}
}(i, nodeID)
}
wg.Wait()
// Launch all UpdateNode operations concurrently
wg = sync.WaitGroup{}
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i) //nolint:gosec
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
n.Hostname = "timeout-updated"
})
endUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
if !ok || !resultNode.Valid() {
updateResults[idx-1] = fmt.Errorf("%w for node %d", errTestUpdateNodeFailed, id)
}
}(i, nodeID)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
errorCount := 0
for _, err := range putResults {
if err != nil {
t.Error(err)
errorCount++
}
}
for _, err := range updateResults {
if err != nil {
t.Error(err)
errorCount++
}
}
if errorCount == 0 {
t.Log("All concurrent operations completed successfully within timeout")
} else {
t.Fatalf("Some concurrent operations failed: %d errors", errorCount)
}
case <-ctx.Done():
fmt.Println("[TestNodeStoreOperationTimeout] Timeout reached, test failed")
t.Fatal("Operations timed out - potential deadlock or resource issue")
}
}
// --- Edge case: update non-existent node ---.
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
for i := range 10 {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
nonExistentID := types.NodeID(999 + i) //nolint:gosec
updateCallCount := 0
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
updateCallCount++
n.Hostname = "should-never-be-called"
})
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) finished, valid=%v, ok=%v, updateCallCount=%d\n", nonExistentID, resultNode.Valid(), ok, updateCallCount)
assert.False(t, ok, "UpdateNode should return false for non-existent node")
assert.False(t, resultNode.Valid(), "UpdateNode should return invalid node for non-existent node")
assert.Equal(t, 0, updateCallCount, "UpdateFn should not be called for non-existent node")
store.Stop()
}
}
// --- Allocation benchmark ---.
func BenchmarkNodeStoreAllocations(b *testing.B) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
for i := 0; b.Loop(); i++ {
nodeID := types.NodeID(i + 1) //nolint:gosec
node := createConcurrentTestNode(nodeID, "bench-node")
store.PutNode(node)
store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "bench-updated"
})
store.GetNode(nodeID)
if i%10 == 9 {
store.DeleteNode(nodeID)
}
}
}
func TestNodeStoreAllocationStats(t *testing.T) {
res := testing.Benchmark(BenchmarkNodeStoreAllocations)
allocs := res.AllocsPerOp()
t.Logf("NodeStore allocations per op: %.2f", float64(allocs))
}
// TestRebuildPeerMapsWithChangedPeersFunc tests that RebuildPeerMaps correctly
// rebuilds the peer map when the peersFunc behavior changes.
// This simulates what happens when SetNodeTags changes node tags and the
// PolicyManager's matchers are updated, requiring the peer map to be rebuilt.
func TestRebuildPeerMapsWithChangedPeersFunc(t *testing.T) {
// Create a peersFunc that can be controlled via a channel
// Initially it returns all nodes as peers, then we change it to return no peers
allowPeers := true
// This simulates how PolicyManager.BuildPeerMap works - it reads state
// that can change between calls
dynamicPeersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
if allowPeers {
// Allow all peers
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
} else {
// Allow no peers
for _, node := range nodes {
ret[node.ID()] = []types.NodeView{}
}
}
return ret
}
// Create nodes
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 2, "user2", "node2")
initialNodes := types.Nodes{&node1, &node2}
// Create store with dynamic peersFunc
store := NewNodeStore(initialNodes, dynamicPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
// Initially, nodes should see each other as peers
snapshot := store.data.Load()
require.Len(t, snapshot.peersByNode[1], 1, "node1 should have 1 peer initially")
require.Len(t, snapshot.peersByNode[2], 1, "node2 should have 1 peer initially")
require.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
require.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
// Now "change the policy" by disabling peers
allowPeers = false
// Call RebuildPeerMaps to rebuild with the new behavior
store.RebuildPeerMaps()
// After rebuild, nodes should have no peers
snapshot = store.data.Load()
assert.Empty(t, snapshot.peersByNode[1], "node1 should have no peers after rebuild")
assert.Empty(t, snapshot.peersByNode[2], "node2 should have no peers after rebuild")
// Verify that ListPeers returns the correct result
peers1 := store.ListPeers(1)
peers2 := store.ListPeers(2)
assert.Equal(t, 0, peers1.Len(), "ListPeers for node1 should return empty")
assert.Equal(t, 0, peers2.Len(), "ListPeers for node2 should return empty")
// Now re-enable peers and rebuild again
allowPeers = true
store.RebuildPeerMaps()
// Nodes should see each other again
snapshot = store.data.Load()
require.Len(t, snapshot.peersByNode[1], 1, "node1 should have 1 peer after re-enabling")
require.Len(t, snapshot.peersByNode[2], 1, "node2 should have 1 peer after re-enabling")
peers1 = store.ListPeers(1)
peers2 = store.ListPeers(2)
assert.Equal(t, 1, peers1.Len(), "ListPeers for node1 should return 1")
assert.Equal(t, 1, peers2.Len(), "ListPeers for node2 should return 1")
}