all: replace tsaddr.SortPrefixes with netip.Prefix.Compare

Replace the Tailscale-specific tsaddr.SortPrefixes function with
the standard library's netip.Prefix.Compare via slices.SortFunc.

The netip.Prefix.Compare function sorts prefixes lexicographically
by IP address first, then by prefix length. This differs slightly
from tsaddr.SortPrefixes, so test expectations are updated to
match the new sort order.

Before:
    tsaddr.SortPrefixes(prefixes)

After:
    slices.SortFunc(prefixes, netip.Prefix.Compare)
This commit is contained in:
Kristoffer Dalby 2026-01-20 14:22:36 +00:00
parent 9ab229675d
commit f9b3265158
9 changed files with 64 additions and 86 deletions

View file

@ -24,7 +24,6 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"tailscale.com/net/tsaddr"
"zgo.at/zcache/v2"
)
@ -168,7 +167,7 @@ AND auth_key_id NOT IN (
}
for nodeID, routes := range nodeRoutes {
tsaddr.SortPrefixes(routes)
slices.SortFunc(routes, netip.Prefix.Compare)
routes = slices.Compact(routes)
data, err := json.Marshal(routes)

View file

@ -4,6 +4,7 @@
package hscontrol
import (
"cmp"
"context"
"errors"
"fmt"
@ -11,7 +12,6 @@ import (
"net/netip"
"os"
"slices"
"sort"
"strings"
"time"
@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers(
response[index] = user.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.User) int {
return cmp.Compare(a.Id, b.Id)
})
return &v1.ListUsersResponse{Users: response}, nil
@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
response[index] = key.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.PreAuthKey) int {
return cmp.Compare(a.Id, b.Id)
})
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
newApproved = append(newApproved, prefix)
}
}
tsaddr.SortPrefixes(newApproved)
slices.SortFunc(newApproved, netip.Prefix.Compare)
newApproved = slices.Compact(newApproved)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N
response[index] = resp
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.Node) int {
return cmp.Compare(a.Id, b.Id)
})
return response
@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys(
response[index] = key.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.ApiKey) int {
return cmp.Compare(a.Id, b.Id)
})
return &v1.ListApiKeysResponse{ApiKeys: response}, nil

View file

@ -13,7 +13,6 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestTailNode(t *testing.T) {
@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) {
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: ptr.To(uint(0)),
UserID: new(uint(0)),
User: &types.User{
Name: "mini",
},
@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) {
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
tsaddr.AllIPv4(), // 0.0.0.0/0
netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4
netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4
tsaddr.AllIPv6(), // ::/0 (IPv6)
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),

View file

@ -9,7 +9,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/types/views"
)
@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
}
// Sort and deduplicate
tsaddr.SortPrefixes(newApproved)
slices.SortFunc(newApproved, netip.Prefix.Compare)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
@ -120,7 +119,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
// Sort the current approved for comparison
sortedCurrent := make([]netip.Prefix, len(currentApproved))
copy(sortedCurrent, currentApproved)
tsaddr.SortPrefixes(sortedCurrent)
slices.SortFunc(sortedCurrent, netip.Prefix.Compare)
// Only update if the routes actually changed
if !slices.Equal(sortedCurrent, newApproved) {

View file

@ -3,6 +3,7 @@ package policy
import (
"fmt"
"net/netip"
"slices"
"testing"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
@ -10,9 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
)
@ -32,10 +31,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test-node",
UserID: ptr.To(user1.ID),
User: ptr.To(user1),
UserID: new(user1.ID),
User: new(user1),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
Tags: []string{"tag:test"},
}
@ -44,10 +43,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "other-node",
UserID: ptr.To(user2.ID),
User: ptr.To(user2),
UserID: new(user2.ID),
User: new(user2),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
IPv4: new(netip.MustParseAddr("100.64.0.2")),
}
// Create a policy that auto-approves specific routes
@ -194,7 +193,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
tsaddr.SortPrefixes(tt.wantApproved)
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
// Verify that all previously approved routes are still present
@ -304,10 +303,10 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
@ -330,7 +329,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
if tt.wantApproved == nil {
assert.Nil(t, gotApproved, "expected nil approved routes")
} else {
tsaddr.SortPrefixes(tt.wantApproved)
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
}
})

View file

@ -13,7 +13,6 @@ import (
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
},
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
nodeUser: "test",
// Sorted by netip.Prefix.Compare: by IP address then by prefix length
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
},
nodeUser: "test",
nodeTags: []string{"tag:approved"},
// Sorted by netip.Prefix.Compare: by IP address then by prefix length
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
},
wantChanged: true,
},
@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname,
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
Tags: tt.nodeTags,
}
@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
@ -343,13 +344,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: currentApproved,
}

View file

@ -4,7 +4,6 @@ import (
"fmt"
"net/netip"
"slices"
"sort"
"strings"
"sync"
@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
// this is important so the same node is chosen two times in a row
// as the primary route.
ids := types.NodeIDs(xmaps.Keys(pr.routes))
sort.Sort(ids)
slices.Sort(ids)
// Create a map of prefixes to nodes that serve them so we
// can determine the primary route for each prefix.
@ -236,7 +235,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
}
}
tsaddr.SortPrefixes(routes)
slices.SortFunc(routes, netip.Prefix.Compare)
return routes
}
@ -254,7 +253,7 @@ func (pr *PrimaryRoutes) stringLocked() string {
fmt.Fprintln(&sb, "Available routes:")
ids := types.NodeIDs(xmaps.Keys(pr.routes))
sort.Sort(ids)
slices.Sort(ids)
for _, id := range ids {
prefixes := pr.routes[id]
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
@ -294,7 +293,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes {
// Populate available routes
for nodeID, routes := range pr.routes {
prefixes := routes.Slice()
tsaddr.SortPrefixes(prefixes)
slices.SortFunc(prefixes, netip.Prefix.Compare)
debug.AvailableRoutes[nodeID] = prefixes
}

View file

@ -25,10 +25,8 @@ import (
"github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
zcache "zgo.at/zcache/v2"
)
@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) {
// On startup, all nodes should be marked as offline until they reconnect
// This ensures we don't have stale online status from previous runs
for _, node := range nodes {
node.IsOnline = ptr.To(false)
node.IsOnline = new(false)
}
users, err := db.ListUsers()
if err != nil {
@ -468,10 +466,8 @@ func (s *State) Connect(id types.NodeID) []change.Change {
// CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification
// This ensures that when the NodeCameOnline change is distributed and processed by other nodes,
// the NodeStore already reflects the correct online status for full map generation.
// now := time.Now()
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
n.IsOnline = ptr.To(true)
// n.LastSeen = ptr.To(now)
n.IsOnline = new(true)
})
if !ok {
return nil
@ -498,9 +494,9 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
now := time.Now()
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
n.LastSeen = ptr.To(now)
n.LastSeen = new(now)
// NodeStore is the source of truth for all node state including online status.
n.IsOnline = ptr.To(false)
n.IsOnline = new(false)
})
if !ok {
@ -790,7 +786,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
// Preserve online status and NetInfo when refreshing from database
existingNode, exists := s.nodeStore.GetNode(node.ID)
if exists && existingNode.Valid() {
node.IsOnline = ptr.To(existingNode.IsOnline().Get())
node.IsOnline = new(existingNode.IsOnline().Get())
// TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics
// when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest).
@ -1117,7 +1113,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
DiscoKey: params.DiscoKey,
Hostinfo: params.Hostinfo,
Endpoints: params.Endpoints,
LastSeen: ptr.To(time.Now()),
LastSeen: new(time.Now()),
RegisterMethod: params.RegisterMethod,
Expiry: params.Expiry,
}
@ -1407,8 +1403,8 @@ func (s *State) HandleNodeFromAuthPath(
node.Endpoints = regEntry.Node.Endpoints
node.RegisterMethod = regEntry.Node.RegisterMethod
node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now())
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
// Tagged nodes keep their existing expiry (disabled).
// User-owned nodes update expiry from the provided value or registration entry.
@ -1669,8 +1665,8 @@ func (s *State) HandleNodeFromPreAuthKey(
// Only update AuthKey reference
node.AuthKey = pak
node.AuthKeyID = &pak.ID
node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now())
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
// Tagged nodes keep their existing expiry (disabled).
// User-owned nodes update expiry from the client request.
@ -2122,8 +2118,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool {
newRoutes = []netip.Prefix{}
}
tsaddr.SortPrefixes(oldRoutes)
tsaddr.SortPrefixes(newRoutes)
slices.SortFunc(oldRoutes, netip.Prefix.Compare)
slices.SortFunc(newRoutes, netip.Prefix.Compare)
return !slices.Equal(oldRoutes, newRoutes)
}

View file

@ -42,10 +42,6 @@ type (
NodeIDs []NodeID
)
func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (id NodeID) StableID() tailcfg.StableNodeID {
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
}
@ -197,13 +193,7 @@ func (node *Node) IPs() []netip.Addr {
// HasIP reports if a node has a given IP address.
func (node *Node) HasIP(i netip.Addr) bool {
for _, ip := range node.IPs() {
if ip.Compare(i) == 0 {
return true
}
}
return false
return slices.Contains(node.IPs(), i)
}
// IsTagged reports if a device is tagged and therefore should not be treated
@ -355,13 +345,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
}
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
for _, node := range nodes {
if node.NodeKey == nodeKey {
return true
}
}
return false
return slices.ContainsFunc(nodes, func(node *Node) bool {
return node.NodeKey == nodeKey
})
}
func (node *Node) Proto() *v1.Node {
@ -1048,7 +1034,7 @@ func (nv NodeView) TailNode(
primaryRoutes := primaryRouteFunc(nv.ID())
allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes())
tsaddr.SortPrefixes(allowedIPs)
slices.SortFunc(allowedIPs, netip.Prefix.Compare)
capMap := tailcfg.NodeCapMap{
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},