From f9b3265158cae7b2c7ae72aca29c85f4eff2a5bc Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:22:36 +0000 Subject: [PATCH] all: replace tsaddr.SortPrefixes with netip.Prefix.Compare Replace the Tailscale-specific tsaddr.SortPrefixes function with the standard library's netip.Prefix.Compare via slices.SortFunc. The netip.Prefix.Compare function sorts prefixes lexicographically by IP address first, then by prefix length. This differs slightly from tsaddr.SortPrefixes, so test expectations are updated to match the new sort order. Before: tsaddr.SortPrefixes(prefixes) After: slices.SortFunc(prefixes, netip.Prefix.Compare) --- hscontrol/db/db.go | 3 +- hscontrol/grpcv1.go | 20 ++++++------- hscontrol/mapper/tail_test.go | 11 ++++---- hscontrol/policy/policy.go | 5 ++-- hscontrol/policy/policy_autoapprove_test.go | 25 ++++++++--------- .../policy/policy_route_approval_test.go | 25 +++++++++-------- hscontrol/routes/primary.go | 9 +++--- hscontrol/state/state.go | 28 ++++++++----------- hscontrol/types/node.go | 24 ++++------------ 9 files changed, 64 insertions(+), 86 deletions(-) diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index a1429aa6..05a4c7c8 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -24,7 +24,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" - "tailscale.com/net/tsaddr" "zgo.at/zcache/v2" ) @@ -168,7 +167,7 @@ AND auth_key_id NOT IN ( } for nodeID, routes := range nodeRoutes { - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) routes = slices.Compact(routes) data, err := json.Marshal(routes) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a35a73af..3605be60 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -4,6 +4,7 @@ package hscontrol import ( + "cmp" "context" "errors" "fmt" @@ -11,7 +12,6 @@ import ( "net/netip" "os" "slices" - "sort" "strings" "time" @@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers( response[index] = user.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.User) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListUsersResponse{Users: response}, nil @@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.PreAuthKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil @@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes( newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) @@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N response[index] = resp } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.Node) int { + return cmp.Compare(a.Id, b.Id) }) return response @@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.ApiKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListApiKeysResponse{ApiKeys: response}, nil diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 5b7030de..dc1dd1c0 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -13,7 +13,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: ptr.To(uint(0)), + UserID: new(uint(0)), User: &types.User{ Name: "mini", }, @@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - tsaddr.AllIPv4(), - netip.MustParsePrefix("192.168.0.0/24"), - netip.MustParsePrefix("100.64.0.1/32"), - tsaddr.AllIPv6(), + tsaddr.AllIPv4(), // 0.0.0.0/0 + netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4 + netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4 + tsaddr.AllIPv6(), // ::/0 (IPv6) }, PrimaryRoutes: []netip.Prefix{ netip.MustParsePrefix("192.168.0.0/24"), diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 677cb854..24d2865e 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -9,7 +9,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/samber/lo" - "tailscale.com/net/tsaddr" "tailscale.com/types/views" ) @@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove } // Sort and deduplicate - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { return route.IsValid() @@ -120,7 +119,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove // Sort the current approved for comparison sortedCurrent := make([]netip.Prefix, len(currentApproved)) copy(sortedCurrent, currentApproved) - tsaddr.SortPrefixes(sortedCurrent) + slices.SortFunc(sortedCurrent, netip.Prefix.Compare) // Only update if the routes actually changed if !slices.Equal(sortedCurrent, newApproved) { diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 61c69067..b7a758e6 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -3,6 +3,7 @@ package policy import ( "fmt" "net/netip" + "slices" "testing" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" @@ -10,9 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -32,10 +31,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test-node", - UserID: ptr.To(user1.ID), - User: ptr.To(user1), + UserID: new(user1.ID), + User: new(user1), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), Tags: []string{"tag:test"}, } @@ -44,10 +43,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "other-node", - UserID: ptr.To(user2.ID), - User: ptr.To(user2), + UserID: new(user2.ID), + User: new(user2), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + IPv4: new(netip.MustParseAddr("100.64.0.2")), } // Create a policy that auto-approves specific routes @@ -194,7 +193,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) // Sort for comparison since ApproveRoutesWithPolicy sorts the results - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) // Verify that all previously approved routes are still present @@ -304,10 +303,10 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -330,7 +329,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { if tt.wantApproved == nil { assert.Nil(t, gotApproved, "expected nil approved routes") } else { - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") } }) diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index 70aa6a21..0e974a1a 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -13,7 +13,6 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { @@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, announcedRoutes: []netip.Prefix{}, // No routes announced anymore nodeUser: "test", + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("192.168.0.0/24"), }, wantChanged: false, @@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, nodeUser: "test", nodeTags: []string{"tag:approved"}, + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved }, wantChanged: true, }, @@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: tt.nodeHostname, - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, Tags: tt.nodeTags, } @@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -343,13 +344,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: currentApproved, } diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 977dc7a9..72eb2a5b 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -4,7 +4,6 @@ import ( "fmt" "net/netip" "slices" - "sort" "strings" "sync" @@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { // this is important so the same node is chosen two times in a row // as the primary route. ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) // Create a map of prefixes to nodes that serve them so we // can determine the primary route for each prefix. @@ -236,7 +235,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } } - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) return routes } @@ -254,7 +253,7 @@ func (pr *PrimaryRoutes) stringLocked() string { fmt.Fprintln(&sb, "Available routes:") ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) @@ -294,7 +293,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes { // Populate available routes for nodeID, routes := range pr.routes { prefixes := routes.Slice() - tsaddr.SortPrefixes(prefixes) + slices.SortFunc(prefixes, netip.Prefix.Compare) debug.AvailableRoutes[nodeID] = prefixes } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d1401ef0..1004151e 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -25,10 +25,8 @@ import ( "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) { // On startup, all nodes should be marked as offline until they reconnect // This ensures we don't have stale online status from previous runs for _, node := range nodes { - node.IsOnline = ptr.To(false) + node.IsOnline = new(false) } users, err := db.ListUsers() if err != nil { @@ -468,10 +466,8 @@ func (s *State) Connect(id types.NodeID) []change.Change { // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. - // now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.IsOnline = ptr.To(true) - // n.LastSeen = ptr.To(now) + n.IsOnline = new(true) }) if !ok { return nil @@ -498,9 +494,9 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.LastSeen = ptr.To(now) + n.LastSeen = new(now) // NodeStore is the source of truth for all node state including online status. - n.IsOnline = ptr.To(false) + n.IsOnline = new(false) }) if !ok { @@ -790,7 +786,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // Preserve online status and NetInfo when refreshing from database existingNode, exists := s.nodeStore.GetNode(node.ID) if exists && existingNode.Valid() { - node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + node.IsOnline = new(existingNode.IsOnline().Get()) // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). @@ -1117,7 +1113,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro DiscoKey: params.DiscoKey, Hostinfo: params.Hostinfo, Endpoints: params.Endpoints, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), RegisterMethod: params.RegisterMethod, Expiry: params.Expiry, } @@ -1407,8 +1403,8 @@ func (s *State) HandleNodeFromAuthPath( node.Endpoints = regEntry.Node.Endpoints node.RegisterMethod = regEntry.Node.RegisterMethod - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the provided value or registration entry. @@ -1669,8 +1665,8 @@ func (s *State) HandleNodeFromPreAuthKey( // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the client request. @@ -2122,8 +2118,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool { newRoutes = []netip.Prefix{} } - tsaddr.SortPrefixes(oldRoutes) - tsaddr.SortPrefixes(newRoutes) + slices.SortFunc(oldRoutes, netip.Prefix.Compare) + slices.SortFunc(newRoutes, netip.Prefix.Compare) return !slices.Equal(oldRoutes, newRoutes) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 41cd9759..1a66341d 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -42,10 +42,6 @@ type ( NodeIDs []NodeID ) -func (n NodeIDs) Len() int { return len(n) } -func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } -func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] } - func (id NodeID) StableID() tailcfg.StableNodeID { return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) } @@ -197,13 +193,7 @@ func (node *Node) IPs() []netip.Addr { // HasIP reports if a node has a given IP address. func (node *Node) HasIP(i netip.Addr) bool { - for _, ip := range node.IPs() { - if ip.Compare(i) == 0 { - return true - } - } - - return false + return slices.Contains(node.IPs(), i) } // IsTagged reports if a device is tagged and therefore should not be treated @@ -355,13 +345,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { } func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool { - for _, node := range nodes { - if node.NodeKey == nodeKey { - return true - } - } - - return false + return slices.ContainsFunc(nodes, func(node *Node) bool { + return node.NodeKey == nodeKey + }) } func (node *Node) Proto() *v1.Node { @@ -1048,7 +1034,7 @@ func (nv NodeView) TailNode( primaryRoutes := primaryRouteFunc(nv.ID()) allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes()) - tsaddr.SortPrefixes(allowedIPs) + slices.SortFunc(allowedIPs, netip.Prefix.Compare) capMap := tailcfg.NodeCapMap{ tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},