diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index a1429aa6..05a4c7c8 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -24,7 +24,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" - "tailscale.com/net/tsaddr" "zgo.at/zcache/v2" ) @@ -168,7 +167,7 @@ AND auth_key_id NOT IN ( } for nodeID, routes := range nodeRoutes { - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) routes = slices.Compact(routes) data, err := json.Marshal(routes) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a35a73af..3605be60 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -4,6 +4,7 @@ package hscontrol import ( + "cmp" "context" "errors" "fmt" @@ -11,7 +12,6 @@ import ( "net/netip" "os" "slices" - "sort" "strings" "time" @@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers( response[index] = user.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.User) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListUsersResponse{Users: response}, nil @@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.PreAuthKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil @@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes( newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) @@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N response[index] = resp } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.Node) int { + return cmp.Compare(a.Id, b.Id) }) return response @@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.ApiKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListApiKeysResponse{ApiKeys: response}, nil diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 5b7030de..dc1dd1c0 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -13,7 +13,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: ptr.To(uint(0)), + UserID: new(uint(0)), User: &types.User{ Name: "mini", }, @@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - tsaddr.AllIPv4(), - netip.MustParsePrefix("192.168.0.0/24"), - netip.MustParsePrefix("100.64.0.1/32"), - tsaddr.AllIPv6(), + tsaddr.AllIPv4(), // 0.0.0.0/0 + netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4 + netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4 + tsaddr.AllIPv6(), // ::/0 (IPv6) }, PrimaryRoutes: []netip.Prefix{ netip.MustParsePrefix("192.168.0.0/24"), diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 677cb854..24d2865e 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -9,7 +9,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/samber/lo" - "tailscale.com/net/tsaddr" "tailscale.com/types/views" ) @@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove } // Sort and deduplicate - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { return route.IsValid() @@ -120,7 +119,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove // Sort the current approved for comparison sortedCurrent := make([]netip.Prefix, len(currentApproved)) copy(sortedCurrent, currentApproved) - tsaddr.SortPrefixes(sortedCurrent) + slices.SortFunc(sortedCurrent, netip.Prefix.Compare) // Only update if the routes actually changed if !slices.Equal(sortedCurrent, newApproved) { diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 61c69067..b7a758e6 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -3,6 +3,7 @@ package policy import ( "fmt" "net/netip" + "slices" "testing" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" @@ -10,9 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -32,10 +31,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test-node", - UserID: ptr.To(user1.ID), - User: ptr.To(user1), + UserID: new(user1.ID), + User: new(user1), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), Tags: []string{"tag:test"}, } @@ -44,10 +43,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "other-node", - UserID: ptr.To(user2.ID), - User: ptr.To(user2), + UserID: new(user2.ID), + User: new(user2), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + IPv4: new(netip.MustParseAddr("100.64.0.2")), } // Create a policy that auto-approves specific routes @@ -194,7 +193,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) // Sort for comparison since ApproveRoutesWithPolicy sorts the results - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) // Verify that all previously approved routes are still present @@ -304,10 +303,10 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -330,7 +329,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { if tt.wantApproved == nil { assert.Nil(t, gotApproved, "expected nil approved routes") } else { - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") } }) diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index 70aa6a21..0e974a1a 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -13,7 +13,6 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { @@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, announcedRoutes: []netip.Prefix{}, // No routes announced anymore nodeUser: "test", + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("192.168.0.0/24"), }, wantChanged: false, @@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, nodeUser: "test", nodeTags: []string{"tag:approved"}, + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved }, wantChanged: true, }, @@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: tt.nodeHostname, - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, Tags: tt.nodeTags, } @@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -343,13 +344,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: currentApproved, } diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 977dc7a9..72eb2a5b 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -4,7 +4,6 @@ import ( "fmt" "net/netip" "slices" - "sort" "strings" "sync" @@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { // this is important so the same node is chosen two times in a row // as the primary route. ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) // Create a map of prefixes to nodes that serve them so we // can determine the primary route for each prefix. @@ -236,7 +235,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } } - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) return routes } @@ -254,7 +253,7 @@ func (pr *PrimaryRoutes) stringLocked() string { fmt.Fprintln(&sb, "Available routes:") ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) @@ -294,7 +293,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes { // Populate available routes for nodeID, routes := range pr.routes { prefixes := routes.Slice() - tsaddr.SortPrefixes(prefixes) + slices.SortFunc(prefixes, netip.Prefix.Compare) debug.AvailableRoutes[nodeID] = prefixes } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d1401ef0..1004151e 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -25,10 +25,8 @@ import ( "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) { // On startup, all nodes should be marked as offline until they reconnect // This ensures we don't have stale online status from previous runs for _, node := range nodes { - node.IsOnline = ptr.To(false) + node.IsOnline = new(false) } users, err := db.ListUsers() if err != nil { @@ -468,10 +466,8 @@ func (s *State) Connect(id types.NodeID) []change.Change { // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. - // now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.IsOnline = ptr.To(true) - // n.LastSeen = ptr.To(now) + n.IsOnline = new(true) }) if !ok { return nil @@ -498,9 +494,9 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.LastSeen = ptr.To(now) + n.LastSeen = new(now) // NodeStore is the source of truth for all node state including online status. - n.IsOnline = ptr.To(false) + n.IsOnline = new(false) }) if !ok { @@ -790,7 +786,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // Preserve online status and NetInfo when refreshing from database existingNode, exists := s.nodeStore.GetNode(node.ID) if exists && existingNode.Valid() { - node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + node.IsOnline = new(existingNode.IsOnline().Get()) // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). @@ -1117,7 +1113,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro DiscoKey: params.DiscoKey, Hostinfo: params.Hostinfo, Endpoints: params.Endpoints, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), RegisterMethod: params.RegisterMethod, Expiry: params.Expiry, } @@ -1407,8 +1403,8 @@ func (s *State) HandleNodeFromAuthPath( node.Endpoints = regEntry.Node.Endpoints node.RegisterMethod = regEntry.Node.RegisterMethod - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the provided value or registration entry. @@ -1669,8 +1665,8 @@ func (s *State) HandleNodeFromPreAuthKey( // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the client request. @@ -2122,8 +2118,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool { newRoutes = []netip.Prefix{} } - tsaddr.SortPrefixes(oldRoutes) - tsaddr.SortPrefixes(newRoutes) + slices.SortFunc(oldRoutes, netip.Prefix.Compare) + slices.SortFunc(newRoutes, netip.Prefix.Compare) return !slices.Equal(oldRoutes, newRoutes) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 41cd9759..1a66341d 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -42,10 +42,6 @@ type ( NodeIDs []NodeID ) -func (n NodeIDs) Len() int { return len(n) } -func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } -func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] } - func (id NodeID) StableID() tailcfg.StableNodeID { return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) } @@ -197,13 +193,7 @@ func (node *Node) IPs() []netip.Addr { // HasIP reports if a node has a given IP address. func (node *Node) HasIP(i netip.Addr) bool { - for _, ip := range node.IPs() { - if ip.Compare(i) == 0 { - return true - } - } - - return false + return slices.Contains(node.IPs(), i) } // IsTagged reports if a device is tagged and therefore should not be treated @@ -355,13 +345,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { } func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool { - for _, node := range nodes { - if node.NodeKey == nodeKey { - return true - } - } - - return false + return slices.ContainsFunc(nodes, func(node *Node) bool { + return node.NodeKey == nodeKey + }) } func (node *Node) Proto() *v1.Node { @@ -1048,7 +1034,7 @@ func (nv NodeView) TailNode( primaryRoutes := primaryRouteFunc(nv.ID()) allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes()) - tsaddr.SortPrefixes(allowedIPs) + slices.SortFunc(allowedIPs, netip.Prefix.Compare) capMap := tailcfg.NodeCapMap{ tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},