From 57d96193217fd04cfa925edb10c2904aee2b8551 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:21:36 +0000 Subject: [PATCH 01/30] go.mod: bump Go version to 1.26rc2 Update the Go version requirement to 1.26rc2 to enable new language features like the enhanced new() builtin and errors.AsType for type-safe error unwrapping. --- go.mod | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/go.mod b/go.mod index 905a27db..5616acda 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/juanfont/headscale -go 1.25 +go 1.26rc2 require ( github.com/arl/statsviz v0.7.2 From 9ab229675d7ead6a552a6b6e12d3e931ff69506e Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:22:27 +0000 Subject: [PATCH 02/30] all: use errors.AsType for type-safe error unwrapping Replace errors.As with the new errors.AsType generic function introduced in Go 1.26. This provides compile-time type safety and approximately 3x better performance by avoiding reflection. Before: var target *AppError if errors.As(err, &target) { // use target } After: if target, ok := errors.AsType[*AppError](err); ok { // use target } --- cmd/headscale/cli/serve.go | 3 +-- hscontrol/auth.go | 11 ++++------- hscontrol/handlers.go | 3 +-- hscontrol/noise.go | 3 +-- hscontrol/policy/v2/types.go | 26 ++++++++++++-------------- hscontrol/types/preauth_key_test.go | 3 +-- 6 files changed, 20 insertions(+), 29 deletions(-) diff --git a/cmd/headscale/cli/serve.go b/cmd/headscale/cli/serve.go index 8f05f851..f815f9f9 100644 --- a/cmd/headscale/cli/serve.go +++ b/cmd/headscale/cli/serve.go @@ -23,8 +23,7 @@ var serveCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { app, err := newHeadscaleServerWithConfig() if err != nil { - var squibbleErr squibble.ValidationError - if errors.As(err, &squibbleErr) { + if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok { fmt.Printf("SQLite schema failed to validate:\n") fmt.Println(squibbleErr.Diff) } diff --git a/hscontrol/auth.go b/hscontrol/auth.go index ac5968e3..aa7088d7 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -16,7 +16,6 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) type AuthProvider interface { @@ -113,8 +112,7 @@ func (h *Headscale) handleRegister( resp, err := h.handleRegisterWithAuthKey(req, machineKey) if err != nil { // Preserve HTTPError types so they can be handled properly by the HTTP layer - var httpErr HTTPError - if errors.As(err, &httpErr) { + if httpErr, ok := errors.AsType[HTTPError](err); ok { return nil, httpErr } @@ -316,7 +314,7 @@ func (h *Headscale) reqToNewRegisterResponse( MachineKey: machineKey, NodeKey: req.NodeKey, Hostinfo: hostinfo, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), }, ) @@ -344,8 +342,7 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } - var perr types.PAKError - if errors.As(err, &perr) { + if perr, ok := errors.AsType[types.PAKError](err); ok { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } @@ -443,7 +440,7 @@ func (h *Headscale) handleRegisterInteractive( MachineKey: machineKey, NodeKey: req.NodeKey, Hostinfo: hostinfo, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), }, ) diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index dc693dae..2aee3cb2 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -36,8 +36,7 @@ const ( // httpError logs an error and sends an HTTP error response with the given. func httpError(w http.ResponseWriter, err error) { - var herr HTTPError - if errors.As(err, &herr) { + if herr, ok := errors.AsType[HTTPError](err); ok { http.Error(w, herr.Msg, herr.Code) log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg) } else { diff --git a/hscontrol/noise.go b/hscontrol/noise.go index a667cd1f..f0e2fefa 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -256,8 +256,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer()) if err != nil { - var httpErr HTTPError - if errors.As(err, &httpErr) { + if httpErr, ok := errors.AsType[HTTPError](err); ok { resp = &tailcfg.RegisterResponse{ Error: httpErr.Msg, } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 75b16bc1..fbce8a2b 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -16,7 +16,6 @@ import ( "go4.org/netipx" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/multierr" "tailscale.com/util/slicesx" @@ -656,17 +655,17 @@ func parseAlias(vs string) (Alias, error) { case isWildcard(vs): return Wildcard, nil case isUser(vs): - return ptr.To(Username(vs)), nil + return new(Username(vs)), nil case isGroup(vs): - return ptr.To(Group(vs)), nil + return new(Group(vs)), nil case isTag(vs): - return ptr.To(Tag(vs)), nil + return new(Tag(vs)), nil case isAutoGroup(vs): - return ptr.To(AutoGroup(vs)), nil + return new(AutoGroup(vs)), nil } if isHost(vs) { - return ptr.To(Host(vs)), nil + return new(Host(vs)), nil } return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: @@ -829,11 +828,11 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) { func parseAutoApprover(s string) (AutoApprover, error) { switch { case isUser(s): - return ptr.To(Username(s)), nil + return new(Username(s)), nil case isGroup(s): - return ptr.To(Group(s)), nil + return new(Group(s)), nil case isTag(s): - return ptr.To(Tag(s)), nil + return new(Tag(s)), nil } return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: @@ -925,11 +924,11 @@ func (o Owners) MarshalJSON() ([]byte, error) { func parseOwner(s string) (Owner, error) { switch { case isUser(s): - return ptr.To(Username(s)), nil + return new(Username(s)), nil case isGroup(s): - return ptr.To(Group(s)), nil + return new(Group(s)), nil case isTag(s): - return ptr.To(Tag(s)), nil + return new(Tag(s)), nil } return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: @@ -2023,8 +2022,7 @@ func unmarshalPolicy(b []byte) (*Policy, error) { ast.Standardize() if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { - var serr *json.SemanticError - if errors.As(err, &serr) && serr.Err == json.ErrUnknownName { + if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName { ptr := serr.JSONPointer name := ptr.LastToken() return nil, fmt.Errorf("unknown field %q", name) diff --git a/hscontrol/types/preauth_key_test.go b/hscontrol/types/preauth_key_test.go index 4ab1c717..1b280149 100644 --- a/hscontrol/types/preauth_key_test.go +++ b/hscontrol/types/preauth_key_test.go @@ -110,8 +110,7 @@ func TestCanUsePreAuthKey(t *testing.T) { if err == nil { t.Errorf("expected error but got none") } else { - var httpErr PAKError - ok := errors.As(err, &httpErr) + httpErr, ok := errors.AsType[PAKError](err) if !ok { t.Errorf("expected HTTPError but got %T", err) } else { From f9b3265158cae7b2c7ae72aca29c85f4eff2a5bc Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:22:36 +0000 Subject: [PATCH 03/30] all: replace tsaddr.SortPrefixes with netip.Prefix.Compare Replace the Tailscale-specific tsaddr.SortPrefixes function with the standard library's netip.Prefix.Compare via slices.SortFunc. The netip.Prefix.Compare function sorts prefixes lexicographically by IP address first, then by prefix length. This differs slightly from tsaddr.SortPrefixes, so test expectations are updated to match the new sort order. Before: tsaddr.SortPrefixes(prefixes) After: slices.SortFunc(prefixes, netip.Prefix.Compare) --- hscontrol/db/db.go | 3 +- hscontrol/grpcv1.go | 20 ++++++------- hscontrol/mapper/tail_test.go | 11 ++++---- hscontrol/policy/policy.go | 5 ++-- hscontrol/policy/policy_autoapprove_test.go | 25 ++++++++--------- .../policy/policy_route_approval_test.go | 25 +++++++++-------- hscontrol/routes/primary.go | 9 +++--- hscontrol/state/state.go | 28 ++++++++----------- hscontrol/types/node.go | 24 ++++------------ 9 files changed, 64 insertions(+), 86 deletions(-) diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index a1429aa6..05a4c7c8 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -24,7 +24,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" - "tailscale.com/net/tsaddr" "zgo.at/zcache/v2" ) @@ -168,7 +167,7 @@ AND auth_key_id NOT IN ( } for nodeID, routes := range nodeRoutes { - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) routes = slices.Compact(routes) data, err := json.Marshal(routes) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a35a73af..3605be60 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -4,6 +4,7 @@ package hscontrol import ( + "cmp" "context" "errors" "fmt" @@ -11,7 +12,6 @@ import ( "net/netip" "os" "slices" - "sort" "strings" "time" @@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers( response[index] = user.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.User) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListUsersResponse{Users: response}, nil @@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.PreAuthKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil @@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes( newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) @@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N response[index] = resp } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.Node) int { + return cmp.Compare(a.Id, b.Id) }) return response @@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.ApiKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListApiKeysResponse{ApiKeys: response}, nil diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 5b7030de..dc1dd1c0 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -13,7 +13,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: ptr.To(uint(0)), + UserID: new(uint(0)), User: &types.User{ Name: "mini", }, @@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - tsaddr.AllIPv4(), - netip.MustParsePrefix("192.168.0.0/24"), - netip.MustParsePrefix("100.64.0.1/32"), - tsaddr.AllIPv6(), + tsaddr.AllIPv4(), // 0.0.0.0/0 + netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4 + netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4 + tsaddr.AllIPv6(), // ::/0 (IPv6) }, PrimaryRoutes: []netip.Prefix{ netip.MustParsePrefix("192.168.0.0/24"), diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 677cb854..24d2865e 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -9,7 +9,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/samber/lo" - "tailscale.com/net/tsaddr" "tailscale.com/types/views" ) @@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove } // Sort and deduplicate - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { return route.IsValid() @@ -120,7 +119,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove // Sort the current approved for comparison sortedCurrent := make([]netip.Prefix, len(currentApproved)) copy(sortedCurrent, currentApproved) - tsaddr.SortPrefixes(sortedCurrent) + slices.SortFunc(sortedCurrent, netip.Prefix.Compare) // Only update if the routes actually changed if !slices.Equal(sortedCurrent, newApproved) { diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 61c69067..b7a758e6 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -3,6 +3,7 @@ package policy import ( "fmt" "net/netip" + "slices" "testing" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" @@ -10,9 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -32,10 +31,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test-node", - UserID: ptr.To(user1.ID), - User: ptr.To(user1), + UserID: new(user1.ID), + User: new(user1), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), Tags: []string{"tag:test"}, } @@ -44,10 +43,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "other-node", - UserID: ptr.To(user2.ID), - User: ptr.To(user2), + UserID: new(user2.ID), + User: new(user2), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + IPv4: new(netip.MustParseAddr("100.64.0.2")), } // Create a policy that auto-approves specific routes @@ -194,7 +193,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) // Sort for comparison since ApproveRoutesWithPolicy sorts the results - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) // Verify that all previously approved routes are still present @@ -304,10 +303,10 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -330,7 +329,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { if tt.wantApproved == nil { assert.Nil(t, gotApproved, "expected nil approved routes") } else { - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") } }) diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index 70aa6a21..0e974a1a 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -13,7 +13,6 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { @@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, announcedRoutes: []netip.Prefix{}, // No routes announced anymore nodeUser: "test", + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("192.168.0.0/24"), }, wantChanged: false, @@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, nodeUser: "test", nodeTags: []string{"tag:approved"}, + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved }, wantChanged: true, }, @@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: tt.nodeHostname, - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, Tags: tt.nodeTags, } @@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -343,13 +344,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: currentApproved, } diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 977dc7a9..72eb2a5b 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -4,7 +4,6 @@ import ( "fmt" "net/netip" "slices" - "sort" "strings" "sync" @@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { // this is important so the same node is chosen two times in a row // as the primary route. ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) // Create a map of prefixes to nodes that serve them so we // can determine the primary route for each prefix. @@ -236,7 +235,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } } - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) return routes } @@ -254,7 +253,7 @@ func (pr *PrimaryRoutes) stringLocked() string { fmt.Fprintln(&sb, "Available routes:") ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) @@ -294,7 +293,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes { // Populate available routes for nodeID, routes := range pr.routes { prefixes := routes.Slice() - tsaddr.SortPrefixes(prefixes) + slices.SortFunc(prefixes, netip.Prefix.Compare) debug.AvailableRoutes[nodeID] = prefixes } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d1401ef0..1004151e 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -25,10 +25,8 @@ import ( "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) { // On startup, all nodes should be marked as offline until they reconnect // This ensures we don't have stale online status from previous runs for _, node := range nodes { - node.IsOnline = ptr.To(false) + node.IsOnline = new(false) } users, err := db.ListUsers() if err != nil { @@ -468,10 +466,8 @@ func (s *State) Connect(id types.NodeID) []change.Change { // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. - // now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.IsOnline = ptr.To(true) - // n.LastSeen = ptr.To(now) + n.IsOnline = new(true) }) if !ok { return nil @@ -498,9 +494,9 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.LastSeen = ptr.To(now) + n.LastSeen = new(now) // NodeStore is the source of truth for all node state including online status. - n.IsOnline = ptr.To(false) + n.IsOnline = new(false) }) if !ok { @@ -790,7 +786,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // Preserve online status and NetInfo when refreshing from database existingNode, exists := s.nodeStore.GetNode(node.ID) if exists && existingNode.Valid() { - node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + node.IsOnline = new(existingNode.IsOnline().Get()) // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). @@ -1117,7 +1113,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro DiscoKey: params.DiscoKey, Hostinfo: params.Hostinfo, Endpoints: params.Endpoints, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), RegisterMethod: params.RegisterMethod, Expiry: params.Expiry, } @@ -1407,8 +1403,8 @@ func (s *State) HandleNodeFromAuthPath( node.Endpoints = regEntry.Node.Endpoints node.RegisterMethod = regEntry.Node.RegisterMethod - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the provided value or registration entry. @@ -1669,8 +1665,8 @@ func (s *State) HandleNodeFromPreAuthKey( // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the client request. @@ -2122,8 +2118,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool { newRoutes = []netip.Prefix{} } - tsaddr.SortPrefixes(oldRoutes) - tsaddr.SortPrefixes(newRoutes) + slices.SortFunc(oldRoutes, netip.Prefix.Compare) + slices.SortFunc(newRoutes, netip.Prefix.Compare) return !slices.Equal(oldRoutes, newRoutes) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 41cd9759..1a66341d 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -42,10 +42,6 @@ type ( NodeIDs []NodeID ) -func (n NodeIDs) Len() int { return len(n) } -func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } -func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] } - func (id NodeID) StableID() tailcfg.StableNodeID { return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) } @@ -197,13 +193,7 @@ func (node *Node) IPs() []netip.Addr { // HasIP reports if a node has a given IP address. func (node *Node) HasIP(i netip.Addr) bool { - for _, ip := range node.IPs() { - if ip.Compare(i) == 0 { - return true - } - } - - return false + return slices.Contains(node.IPs(), i) } // IsTagged reports if a device is tagged and therefore should not be treated @@ -355,13 +345,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { } func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool { - for _, node := range nodes { - if node.NodeKey == nodeKey { - return true - } - } - - return false + return slices.ContainsFunc(nodes, func(node *Node) bool { + return node.NodeKey == nodeKey + }) } func (node *Node) Proto() *v1.Node { @@ -1048,7 +1034,7 @@ func (nv NodeView) TailNode( primaryRoutes := primaryRouteFunc(nv.ID()) allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes()) - tsaddr.SortPrefixes(allowedIPs) + slices.SortFunc(allowedIPs, netip.Prefix.Compare) capMap := tailcfg.NodeCapMap{ tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, From 094faf7a6aae19375a46a6044a1f240f3d319b2d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:22:46 +0000 Subject: [PATCH 04/30] all: modernize sorting with slices package Replace deprecated sort package functions with their modern slices package equivalents: - sort.Slice -> slices.SortFunc - sort.SliceStable -> slices.SortStableFunc - sort.Sort -> slices.Sort - sort.Strings -> slices.Sort Also removes the now-unused sort.Interface implementation (Len, Less, Swap methods) from types.NodeIDs since slices.Sort works directly with ordered types. --- cmd/hi/stats.go | 7 +-- hscontrol/db/node.go | 7 ++- hscontrol/db/preauth_keys_test.go | 5 +- hscontrol/mapper/builder.go | 7 +-- hscontrol/policy/v2/policy.go | 9 +--- hscontrol/types/change/change.go | 8 ++-- hscontrol/types/change/change_test.go | 4 +- integration/auth_oidc_test.go | 67 +++++++++++++-------------- integration/helpers.go | 21 ++++----- integration/hsic/hsic.go | 6 +-- integration/route_test.go | 27 +++++------ integration/tags_test.go | 41 ++++++++-------- 12 files changed, 98 insertions(+), 111 deletions(-) diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index b68215a6..c1bb9cfe 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -1,12 +1,13 @@ package main import ( + "cmp" "context" "encoding/json" "errors" "fmt" "log" - "sort" + "slices" "strings" "sync" "time" @@ -371,8 +372,8 @@ func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { } // Sort by container name for consistent output - sort.Slice(summaries, func(i, j int) bool { - return summaries[i].ContainerName < summaries[j].ContainerName + slices.SortFunc(summaries, func(a, b ContainerStatsSummary) int { + return cmp.Compare(a.ContainerName, b.ContainerName) }) return summaries diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index bf407bb4..3887350b 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -1,13 +1,13 @@ package db import ( + "cmp" "encoding/json" "errors" "fmt" "net/netip" "regexp" "slices" - "sort" "strconv" "strings" "sync" @@ -20,7 +20,6 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) const ( @@ -60,7 +59,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types return types.Nodes{}, err } - sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) + slices.SortFunc(nodes, func(a, b *types.Node) int { return cmp.Compare(a.ID, b.ID) }) return nodes, nil } @@ -668,7 +667,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) Hostname: nodeName, UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } err = hsdb.DB.Save(node).Error diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 7c5dcbd7..2f28d449 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -11,7 +11,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tailscale.com/types/ptr" ) func TestCreatePreAuthKey(t *testing.T) { @@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) { test: func(t *testing.T, db *HSDatabase) { t.Helper() - _, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil) + _, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil) assert.Error(t, err) }, }, @@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { Hostname: "testest", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(key.ID), + AuthKeyID: new(key.ID), } db.DB.Save(&node) diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index c666ff24..b6f0b534 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -1,9 +1,10 @@ package mapper import ( + "cmp" "errors" "net/netip" - "sort" + "slices" "time" "github.com/juanfont/headscale/hscontrol/policy" @@ -261,8 +262,8 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ( } // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID + slices.SortStableFunc(tailPeers, func(a, b *tailcfg.Node) int { + return cmp.Compare(a.ID, b.ID) }) return tailPeers, nil diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 54196e6b..042c2723 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -956,14 +956,7 @@ func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types. // It will return a Owners list where all the Tag types have been resolved to their underlying Owners. func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) { if visiting[tag] { - cycleStart := 0 - - for i, t := range chain { - if t == tag { - cycleStart = i - break - } - } + cycleStart := slices.Index(chain, tag) cycleTags := make([]string, len(chain[cycleStart:])) for i, t := range chain[cycleStart:] { diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index a76fb7c4..6913d7d9 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -333,7 +333,7 @@ func NodeOnline(nodeID types.NodeID) Change { PeerPatches: []*tailcfg.PeerChange{ { NodeID: nodeID.NodeID(), - Online: ptrTo(true), + Online: new(true), }, }, } @@ -346,7 +346,7 @@ func NodeOffline(nodeID types.NodeID) Change { PeerPatches: []*tailcfg.PeerChange{ { NodeID: nodeID.NodeID(), - Online: ptrTo(false), + Online: new(false), }, }, } @@ -366,8 +366,10 @@ func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change { } // ptrTo returns a pointer to the given value. +// +//go:fix inline func ptrTo[T any](v T) *T { - return &v + return new(v) } // High-level change constructors diff --git a/hscontrol/types/change/change_test.go b/hscontrol/types/change/change_test.go index 9f181dd6..dc2dd0af 100644 --- a/hscontrol/types/change/change_test.go +++ b/hscontrol/types/change/change_test.go @@ -16,8 +16,8 @@ func TestChange_FieldSync(t *testing.T) { typ := reflect.TypeFor[Change]() boolCount := 0 - for i := range typ.NumField() { - if typ.Field(i).Type.Kind() == reflect.Bool { + for field := range typ.Fields() { + if field.Type.Kind() == reflect.Bool { boolCount++ } } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 359dd456..c1d066f8 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1,15 +1,16 @@ package integration import ( + "cmp" "maps" "net/netip" "net/url" - "sort" + "slices" "strconv" "testing" "time" - "github.com/google/go-cmp/cmp" + gocmp "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" @@ -111,11 +112,11 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { t.Fatalf("unexpected users: %s", diff) } } @@ -388,11 +389,11 @@ func TestOIDC024UserCreation(t *testing.T) { listUsers, err := headscale.ListUsers() require.NoError(t, err) - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { t.Errorf("unexpected users: %s", diff) } }) @@ -517,11 +518,11 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("User validation failed after first login - unexpected users: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") @@ -599,11 +600,11 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff) } }, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login") @@ -763,11 +764,11 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)") @@ -935,13 +936,11 @@ func TestOIDCFollowUpUrl(t *testing.T) { }, } - sort.Slice( - listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }, - ) + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) + }) - if diff := cmp.Diff( + if diff := gocmp.Diff( wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), @@ -1046,13 +1045,11 @@ func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { }, } - sort.Slice( - listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }, - ) + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) + }) - if diff := cmp.Diff( + if diff := gocmp.Diff( wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), @@ -1155,11 +1152,11 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("User validation failed after first login - unexpected users: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") @@ -1249,11 +1246,11 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("Final user validation failed - user1 should persist after same-user relogin: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") diff --git a/integration/helpers.go b/integration/helpers.go index 7d40c8e6..5acf4729 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -26,7 +26,6 @@ import ( "golang.org/x/exp/maps" "golang.org/x/exp/slices" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) const ( @@ -839,32 +838,32 @@ func wildcard() policyv2.Alias { // usernamep returns a pointer to a Username as an Alias for policy v2 configurations. // Used in ACL rules to reference specific users in network access policies. func usernamep(name string) policyv2.Alias { - return ptr.To(policyv2.Username(name)) + return new(policyv2.Username(name)) } // hostp returns a pointer to a Host as an Alias for policy v2 configurations. // Used in ACL rules to reference specific hosts in network access policies. func hostp(name string) policyv2.Alias { - return ptr.To(policyv2.Host(name)) + return new(policyv2.Host(name)) } // groupp returns a pointer to a Group as an Alias for policy v2 configurations. // Used in ACL rules to reference user groups in network access policies. func groupp(name string) policyv2.Alias { - return ptr.To(policyv2.Group(name)) + return new(policyv2.Group(name)) } // tagp returns a pointer to a Tag as an Alias for policy v2 configurations. // Used in ACL rules to reference node tags in network access policies. func tagp(name string) policyv2.Alias { - return ptr.To(policyv2.Tag(name)) + return new(policyv2.Tag(name)) } // prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. // Converts CIDR notation to policy prefix format for network range specifications. func prefixp(cidr string) policyv2.Alias { prefix := netip.MustParsePrefix(cidr) - return ptr.To(policyv2.Prefix(prefix)) + return new(policyv2.Prefix(prefix)) } // aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges. @@ -880,31 +879,31 @@ func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.A // usernameOwner returns a Username as an Owner for use in TagOwners policies. // Specifies which users can assign and manage specific tags in ACL configurations. func usernameOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Username(name)) + return new(policyv2.Username(name)) } // groupOwner returns a Group as an Owner for use in TagOwners policies. // Specifies which groups can assign and manage specific tags in ACL configurations. func groupOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Group(name)) + return new(policyv2.Group(name)) } // usernameApprover returns a Username as an AutoApprover for subnet route policies. // Specifies which users can automatically approve subnet route advertisements. func usernameApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Username(name)) + return new(policyv2.Username(name)) } // groupApprover returns a Group as an AutoApprover for subnet route policies. // Specifies which groups can automatically approve subnet route advertisements. func groupApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Group(name)) + return new(policyv2.Group(name)) } // tagApprover returns a Tag as an AutoApprover for subnet route policies. // Specifies which tagged nodes can automatically approve subnet route advertisements. func tagApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Tag(name)) + return new(policyv2.Tag(name)) } // oidcMockUser creates a MockUser for OIDC authentication testing. diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 42bb8e93..202f2014 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -16,7 +16,7 @@ import ( "os" "path" "path/filepath" - "sort" + "slices" "strconv" "strings" "time" @@ -1232,8 +1232,8 @@ func (t *HeadscaleInContainer) ListNodes( } } - sort.Slice(ret, func(i, j int) bool { - return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 + slices.SortFunc(ret, func(a, b *v1.Node) int { + return cmp.Compare(a.GetId(), b.GetId()) }) return ret, nil diff --git a/integration/route_test.go b/integration/route_test.go index 0460b5ef..6d0a1be2 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -7,7 +7,6 @@ import ( "maps" "net/netip" "slices" - "sort" "strconv" "strings" "testing" @@ -287,11 +286,10 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("webservice: %s, %s", webip.String(), weburl) // Sort nodes by ID - sort.SliceStable(allClients, func(i, j int) bool { - statusI := allClients[i].MustStatus() - statusJ := allClients[j].MustStatus() - - return statusI.Self.ID < statusJ.Self.ID + slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { + statusA := a.MustStatus() + statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) // This is ok because the scenario makes users in order, so the three first @@ -1359,10 +1357,10 @@ func TestSubnetRouteACL(t *testing.T) { } // Sort nodes by ID - sort.SliceStable(allClients, func(i, j int) bool { - statusI := allClients[i].MustStatus() - statusJ := allClients[j].MustStatus() - return statusI.Self.ID < statusJ.Self.ID + slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { + statusA := a.MustStatus() + statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) subRouter1 := allClients[0] @@ -2403,11 +2401,10 @@ func TestAutoApproveMultiNetwork(t *testing.T) { t.Logf("webservice: %s, %s", webip.String(), weburl) // Sort nodes by ID - sort.SliceStable(allClients, func(i, j int) bool { - statusI := allClients[i].MustStatus() - statusJ := allClients[j].MustStatus() - - return statusI.Self.ID < statusJ.Self.ID + slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { + statusA := a.MustStatus() + statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) // This is ok because the scenario makes users in order, so the three first diff --git a/integration/tags_test.go b/integration/tags_test.go index 5dad36e5..91c771c4 100644 --- a/integration/tags_test.go +++ b/integration/tags_test.go @@ -1,7 +1,7 @@ package integration import ( - "sort" + "slices" "testing" "time" @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) const tagTestUser = "taguser" @@ -30,9 +29,9 @@ const tagTestUser = "taguser" func tagsTestPolicy() *policyv2.Policy { return &policyv2.Policy{ TagOwners: policyv2.TagOwners{ - "tag:valid-owned": policyv2.Owners{ptr.To(policyv2.Username(tagTestUser + "@"))}, - "tag:second": policyv2.Owners{ptr.To(policyv2.Username(tagTestUser + "@"))}, - "tag:valid-unowned": policyv2.Owners{ptr.To(policyv2.Username("other-user@"))}, + "tag:valid-owned": policyv2.Owners{new(policyv2.Username(tagTestUser + "@"))}, + "tag:second": policyv2.Owners{new(policyv2.Username(tagTestUser + "@"))}, + "tag:valid-unowned": policyv2.Owners{new(policyv2.Username("other-user@"))}, // Note: tag:nonexistent deliberately NOT defined }, ACLs: []policyv2.ACL{ @@ -51,11 +50,11 @@ func tagsEqual(actual, expected []string) bool { return false } - sortedActual := append([]string{}, actual...) - sortedExpected := append([]string{}, expected...) + sortedActual := slices.Clone(actual) + sortedExpected := slices.Clone(expected) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) for i := range sortedActual { if sortedActual[i] != sortedExpected[i] { @@ -69,11 +68,11 @@ func tagsEqual(actual, expected []string) bool { // assertNodeHasTagsWithCollect asserts that a node has exactly the expected tags (order-independent). func assertNodeHasTagsWithCollect(c *assert.CollectT, node *v1.Node, expectedTags []string) { actualTags := node.GetTags() - sortedActual := append([]string{}, actualTags...) - sortedExpected := append([]string{}, expectedTags...) + sortedActual := slices.Clone(actualTags) + sortedExpected := slices.Clone(expectedTags) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) assert.Equal(c, sortedExpected, sortedActual, "Node %s tags mismatch", node.GetName()) } @@ -102,11 +101,11 @@ func assertNodeSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient } } - sortedActual := append([]string{}, actualTagsSlice...) - sortedExpected := append([]string{}, expectedTags...) + sortedActual := slices.Clone(actualTagsSlice) + sortedExpected := slices.Clone(expectedTags) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) assert.Equal(c, sortedExpected, sortedActual, "Client %s self tags mismatch", client.Hostname()) } @@ -2507,11 +2506,11 @@ func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClie } } - sortedActual := append([]string{}, actualTagsSlice...) - sortedExpected := append([]string{}, expectedTags...) + sortedActual := slices.Clone(actualTagsSlice) + sortedExpected := slices.Clone(expectedTags) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) assert.Equal(c, sortedExpected, sortedActual, "Client %s netmap self tags mismatch", client.Hostname()) } From 3675b6550401c52e48235cbdf24e7d8f3da02a45 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:23:21 +0000 Subject: [PATCH 05/30] all: use new() builtin and slices utilities Modernize code to use Go 1.26 features: 1. Replace ptr.To with the new() builtin: Before: ptr.To(value) After: new(value) 2. Replace append clone pattern with slices.Clone: Before: copy := append([]T{}, slice...) After: copy := slices.Clone(slice) 3. Replace manual contains loops with slices.Contains: Before: for _, v := range slice { if v == target { return true } } After: slices.Contains(slice, target) The ptr.To function from tailscale.com/types/ptr is no longer needed as Go 1.26's enhanced new() builtin accepts a value argument and returns a pointer to a copy of that value. Note: Auto-generated files (types_clone.go) are not modified as they are generated by tailscale.com/cmd/cloner. --- hscontrol/db/ip_test.go | 11 +- hscontrol/db/node_test.go | 17 +- hscontrol/db/text_serialiser.go | 8 +- hscontrol/db/users_test.go | 3 +- hscontrol/mapper/batcher_lockfree.go | 3 +- hscontrol/mapper/batcher_test.go | 5 +- hscontrol/mapper/mapper_test.go | 3 +- hscontrol/policy/policy_test.go | 13 +- hscontrol/policy/policyutil/reduce_test.go | 47 ++- hscontrol/policy/route_approval_test.go | 17 +- hscontrol/policy/v2/filter_test.go | 115 ++++---- hscontrol/policy/v2/policy_test.go | 33 +-- hscontrol/policy/v2/types_test.go | 321 ++++++++++----------- hscontrol/state/ephemeral_test.go | 15 +- hscontrol/state/maprequest_test.go | 3 +- hscontrol/state/node_store_test.go | 5 +- hscontrol/types/node_tags_test.go | 13 +- hscontrol/types/users.go | 2 +- integration/acl_test.go | 17 +- integration/api_auth_test.go | 6 +- integration/auth_key_test.go | 3 +- integration/ssh_test.go | 5 +- 22 files changed, 324 insertions(+), 341 deletions(-) diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 7ba335e8..73895876 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/net/tsaddr" - "tailscale.com/types/ptr" ) var mpp = func(pref string) *netip.Prefix { @@ -488,8 +487,8 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) { alloc, err := NewIPAllocator( db, - ptr.To(tsaddr.CGNATRange()), - ptr.To(tsaddr.TailscaleULARange()), + new(tsaddr.CGNATRange()), + new(tsaddr.TailscaleULARange()), types.IPAllocationStrategySequential, ) if err != nil { @@ -497,17 +496,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) { } // Validate that we do not give out 100.100.100.100 - nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange())) + nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange())) require.NoError(t, err) assert.Equal(t, na("100.100.100.101"), *nextQuad100) // Validate that we do not give out fd7a:115c:a1e0::53 - nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange())) + nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange())) require.NoError(t, err) assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6) // Validate that we do not give out fd7a:115c:a1e0::53 - nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange())) + nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange())) t.Logf("chrome: %s", nextChrome.String()) require.NoError(t, err) assert.Equal(t, na("100.115.94.0"), *nextChrome) diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 7e00f9ca..e82cdb62 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -22,7 +22,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestGetNode(t *testing.T) { @@ -115,7 +114,7 @@ func TestExpireNode(t *testing.T) { Hostname: "testnode", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), Expiry: &time.Time{}, } db.DB.Save(node) @@ -159,7 +158,7 @@ func TestSetTags(t *testing.T) { Hostname: "testnode", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } trx := db.DB.Save(node) @@ -443,7 +442,7 @@ func TestAutoApproveRoutes(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), } err = adb.DB.Save(&node).Error @@ -460,7 +459,7 @@ func TestAutoApproveRoutes(t *testing.T) { RoutableIPs: tt.routes, }, Tags: []string{"tag:exit"}, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + IPv4: new(netip.MustParseAddr("100.64.0.2")), } err = adb.DB.Save(&nodeTagged).Error @@ -649,7 +648,7 @@ func TestListEphemeralNodes(t *testing.T) { Hostname: "test", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } nodeEph := types.Node{ @@ -659,7 +658,7 @@ func TestListEphemeralNodes(t *testing.T) { Hostname: "ephemeral", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pakEph.ID), + AuthKeyID: new(pakEph.ID), } err = db.DB.Save(&node).Error @@ -750,8 +749,8 @@ func TestNodeNaming(t *testing.T) { if err != nil { return err } - _, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil) - _, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil) + _, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil) + _, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil) return err }) require.NoError(t, err) diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 6172e7e0..46bd154f 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -17,7 +17,7 @@ func isTextUnmarshaler(rv reflect.Value) bool { } func maybeInstantiatePtr(rv reflect.Value) { - if rv.Kind() == reflect.Ptr && rv.IsNil() { + if rv.Kind() == reflect.Pointer && rv.IsNil() { np := reflect.New(rv.Type().Elem()) rv.Set(np) } @@ -36,7 +36,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect // If the field is a pointer, we need to dereference it to get the actual type // so we do not end with a second pointer. - if fieldValue.Elem().Kind() == reflect.Ptr { + if fieldValue.Elem().Kind() == reflect.Pointer { fieldValue = fieldValue.Elem() } @@ -65,7 +65,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect // If it is not a pointer, we need to assign the value to the // field. dstField := field.ReflectValueOf(ctx, dst) - if dstField.Kind() == reflect.Ptr { + if dstField.Kind() == reflect.Pointer { dstField.Set(fieldValue) } else { dstField.Set(fieldValue.Elem()) @@ -86,7 +86,7 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec // If the value is nil, we return nil, however, go nil values are not // always comparable, particularly when reflection is involved: // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 - if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) { + if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { return nil, nil } b, err := v.MarshalText() diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index a3fd49b3..bbb8e4d4 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" - "tailscale.com/types/ptr" ) func TestCreateAndDestroyUser(t *testing.T) { @@ -79,7 +78,7 @@ func TestDestroyUserErrors(t *testing.T) { Hostname: "testnode", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } trx := db.DB.Save(&node) require.NoError(t, trx.Error) diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index e00512b6..1d9c2c32 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -13,7 +13,6 @@ import ( "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) var errConnectionClosed = errors.New("connection channel already closed") @@ -136,7 +135,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo // No active connections - keep the node entry alive for rapid reconnections // The node will get a fresh full map when it reconnects log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection") - b.connected.Store(id, ptr.To(time.Now())) + b.connected.Store(id, new(time.Now())) return false } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 70d5e377..00053892 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -327,7 +328,7 @@ func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { // Return a copy to avoid race conditions return UpdateStats{ TotalUpdates: stats.TotalUpdates, - UpdateSizes: append([]int{}, stats.UpdateSizes...), + UpdateSizes: slices.Clone(stats.UpdateSizes), LastUpdate: stats.LastUpdate, } } @@ -344,7 +345,7 @@ func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats { for nodeID, stats := range ut.stats { result[nodeID] = UpdateStats{ TotalUpdates: stats.TotalUpdates, - UpdateSizes: append([]int{}, stats.UpdateSizes...), + UpdateSizes: slices.Clone(stats.UpdateSizes), LastUpdate: stats.LastUpdate, } } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 1bafd135..a503c08c 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -14,7 +14,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" - "tailscale.com/types/ptr" ) var iap = func(ipStr string) *netip.Addr { @@ -51,7 +50,7 @@ func TestDNSConfigMapResponse(t *testing.T) { mach := func(hostname, username string, userid uint) *types.Node { return &types.Node{ Hostname: hostname, - UserID: ptr.To(userid), + UserID: new(userid), User: &types.User{ Name: username, }, diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 87142dd9..eb3d85b6 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) var ap = func(ipStr string) *netip.Addr { @@ -1074,21 +1073,21 @@ func TestSSHPolicyRules(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: ap("100.64.0.1"), - UserID: ptr.To(uint(1)), - User: ptr.To(users[0]), + UserID: new(uint(1)), + User: new(users[0]), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: ap("100.64.0.2"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), } taggedClient := types.Node{ Hostname: "tagged-client", IPv4: ap("100.64.0.4"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), Tags: []string{"tag:client"}, } diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go index 35f5b472..bd975d23 100644 --- a/hscontrol/policy/policyutil/reduce_test.go +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -16,7 +16,6 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" "tailscale.com/util/must" ) @@ -144,13 +143,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: ptr.To(users[0]), + User: new(users[0]), }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: ptr.To(users[0]), + User: new(users[0]), }, }, want: []tailcfg.FilterRule{}, @@ -191,7 +190,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -202,7 +201,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -283,19 +282,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, // "internal" exit node &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -344,7 +343,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -353,12 +352,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -453,7 +452,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -462,12 +461,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -565,7 +564,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -574,12 +573,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -655,7 +654,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -664,12 +663,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -737,7 +736,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -747,7 +746,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -804,13 +803,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[3]), + User: new(users[3]), }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, }, diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 39b15cee..5aa5e28c 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" - "tailscale.com/types/ptr" ) func TestNodeCanApproveRoute(t *testing.T) { @@ -25,24 +24,24 @@ func TestNodeCanApproveRoute(t *testing.T) { ID: 1, Hostname: "user1-device", IPv4: ap("100.64.0.1"), - UserID: ptr.To(uint(1)), - User: ptr.To(users[0]), + UserID: new(uint(1)), + User: new(users[0]), } exitNode := types.Node{ ID: 2, Hostname: "user2-device", IPv4: ap("100.64.0.2"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), } taggedNode := types.Node{ ID: 3, Hostname: "tagged-server", IPv4: ap("100.64.0.3"), - UserID: ptr.To(uint(3)), - User: ptr.To(users[2]), + UserID: new(uint(3)), + User: new(users[2]), Tags: []string{"tag:router"}, } @@ -50,8 +49,8 @@ func TestNodeCanApproveRoute(t *testing.T) { ID: 4, Hostname: "multi-tag-node", IPv4: ap("100.64.0.4"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), Tags: []string{"tag:router", "tag:server"}, } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 46f544c9..d798b5f7 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) // aliasWithPorts creates an AliasWithPorts structure from an alias and ports. @@ -410,14 +409,14 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: createAddr("100.64.0.1"), - UserID: ptr.To(users[0].ID), - User: ptr.To(users[0]), + UserID: new(users[0].ID), + User: new(users[0]), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: createAddr("100.64.0.2"), - UserID: ptr.To(users[1].ID), - User: ptr.To(users[1]), + UserID: new(users[1].ID), + User: new(users[1]), } nodes := types.Nodes{&nodeUser1, &nodeUser2} @@ -622,14 +621,14 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: createAddr("100.64.0.1"), - UserID: ptr.To(users[0].ID), - User: ptr.To(users[0]), + UserID: new(users[0].ID), + User: new(users[0]), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: createAddr("100.64.0.2"), - UserID: ptr.To(users[1].ID), - User: ptr.To(users[1]), + UserID: new(users[1].ID), + User: new(users[1]), } nodes := types.Nodes{&nodeUser1, &nodeUser2} @@ -683,15 +682,15 @@ func TestSSHIntegrationReproduction(t *testing.T) { node1 := &types.Node{ Hostname: "user1-node", IPv4: createAddr("100.64.0.1"), - UserID: ptr.To(users[0].ID), - User: ptr.To(users[0]), + UserID: new(users[0].ID), + User: new(users[0]), } node2 := &types.Node{ Hostname: "user2-node", IPv4: createAddr("100.64.0.2"), - UserID: ptr.To(users[1].ID), - User: ptr.To(users[1]), + UserID: new(users[1].ID), + User: new(users[1]), } nodes := types.Nodes{node1, node2} @@ -806,19 +805,19 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { nodes := types.Nodes{ { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.1"), }, { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.2"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.3"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.4"), }, // Tagged device for user1 @@ -938,11 +937,11 @@ func TestTagUserMutualExclusivity(t *testing.T) { nodes := types.Nodes{ // User-owned nodes { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.1"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.2"), }, // Tagged nodes @@ -960,8 +959,8 @@ func TestTagUserMutualExclusivity(t *testing.T) { policy := &Policy{ TagOwners: TagOwners{ - Tag("tag:server"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:server"): Owners{new(Username("user1@"))}, + Tag("tag:database"): Owners{new(Username("user2@"))}, }, ACLs: []ACL{ // Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes @@ -1056,11 +1055,11 @@ func TestAutogroupTagged(t *testing.T) { nodes := types.Nodes{ // User-owned nodes (not tagged) { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.1"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.2"), }, // Tagged nodes @@ -1083,10 +1082,10 @@ func TestAutogroupTagged(t *testing.T) { policy := &Policy{ TagOwners: TagOwners{ - Tag("tag:server"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, - Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:server"): Owners{new(Username("user1@"))}, + Tag("tag:database"): Owners{new(Username("user2@"))}, + Tag("tag:web"): Owners{new(Username("user1@"))}, + Tag("tag:prod"): Owners{new(Username("user1@"))}, }, ACLs: []ACL{ // Rule: autogroup:tagged can reach user-owned nodes @@ -1206,10 +1205,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, } policy := &Policy{ @@ -1273,11 +1272,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, - {User: ptr.To(users[2]), IPv4: ap("100.64.0.5")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[2]), IPv4: ap("100.64.0.5")}, } policy := &Policy{ @@ -1342,13 +1341,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { nodes := types.Nodes{ // User1's nodes - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"}, + {User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"}, + {User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"}, // User2's nodes - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"}, + {User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"}, + {User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"}, // Tagged node for user1 (should be excluded) - {User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}}, + {User: new(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}}, } policy := &Policy{ @@ -1420,10 +1419,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, } policy := &Policy{ @@ -1474,11 +1473,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, - {User: ptr.To(users[2]), IPv4: ap("100.64.0.5")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[2]), IPv4: ap("100.64.0.5")}, } policy := &Policy{ @@ -1531,10 +1530,10 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}}, + {User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"}, + {User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"}, + {User: new(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}}, + {User: new(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}}, } policy := &Policy{ @@ -1591,10 +1590,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}}, + {User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"}, + {User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"}, + {User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"}, + {User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}}, } policy := &Policy{ diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 26b0d141..80c08eed 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -11,7 +11,6 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { @@ -20,8 +19,8 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) Hostname: name, IPv4: ap(ipv4), IPv6: ap(ipv6), - User: ptr.To(user), - UserID: ptr.To(user.ID), + User: new(user), + UserID: new(user.ID), Hostinfo: hostinfo, } } @@ -457,8 +456,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -468,8 +467,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-2-router", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Tags: []string{"tag:node-router"}, Hostinfo: &tailcfg.Hostinfo{}, } @@ -537,8 +536,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -547,8 +546,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { Hostname: "test-2-device", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -647,8 +646,8 @@ func TestTagPropagationToPeerMap(t *testing.T) { Hostname: "user1-node", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Tags: []string{"tag:web", "tag:internal"}, } @@ -658,8 +657,8 @@ func TestTagPropagationToPeerMap(t *testing.T) { Hostname: "user2-node", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), } initialNodes := types.Nodes{user1Node, user2Node} @@ -686,8 +685,8 @@ func TestTagPropagationToPeerMap(t *testing.T) { Hostname: "user1-node", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Tags: []string{"tag:internal"}, // tag:web removed! } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 664f76b7..8f4f7a85 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -19,7 +19,6 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) // TestUnmarshalPolicy tests the unmarshalling of JSON into Policy objects and the marshalling @@ -53,11 +52,11 @@ func TestMarshalJSON(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - ptr.To(Username("user@example.com")), + new(Username("user@example.com")), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Username("other@example.com")), + Alias: new(Username("other@example.com")), Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, @@ -253,11 +252,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - ptr.To(Username("testuser@headscale.net")), + new(Username("testuser@headscale.net")), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Username("otheruser@headscale.net")), + Alias: new(Username("otheruser@headscale.net")), Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, @@ -546,7 +545,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(AutoGroup("autogroup:internet")), + Alias: new(AutoGroup("autogroup:internet")), Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, @@ -682,7 +681,7 @@ func TestUnmarshalPolicy(t *testing.T) { `, want: &Policy{ TagOwners: TagOwners{ - Tag("tag:web"): Owners{ptr.To(Username("admin@example.com"))}, + Tag("tag:web"): Owners{new(Username("admin@example.com"))}, }, SSHs: []SSH{ { @@ -691,7 +690,7 @@ func TestUnmarshalPolicy(t *testing.T) { tp("tag:web"), }, Destinations: SSHDstAliases{ - ptr.To(Username("admin@example.com")), + new(Username("admin@example.com")), }, Users: []SSHUser{ SSHUser("*"), @@ -733,7 +732,7 @@ func TestUnmarshalPolicy(t *testing.T) { gp("group:admins"), }, Destinations: SSHDstAliases{ - ptr.To(Username("admin@example.com")), + new(Username("admin@example.com")), }, Users: []SSHUser{ SSHUser("root"), @@ -1154,7 +1153,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(AutoGroup("autogroup:internet")), + Alias: new(AutoGroup("autogroup:internet")), Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, @@ -1491,7 +1490,7 @@ func TestUnmarshalPolicy(t *testing.T) { want: &Policy{ TagOwners: TagOwners{ Tag("tag:bigbrother"): {}, - Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:smallbrother"): {new(Tag("tag:bigbrother"))}, }, ACLs: []ACL{ { @@ -1502,7 +1501,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Tag("tag:smallbrother")), + Alias: new(Tag("tag:smallbrother")), Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}}, }, }, @@ -1583,14 +1582,14 @@ func TestUnmarshalPolicy(t *testing.T) { } } -func gp(s string) *Group { return ptr.To(Group(s)) } -func up(s string) *Username { return ptr.To(Username(s)) } -func hp(s string) *Host { return ptr.To(Host(s)) } -func tp(s string) *Tag { return ptr.To(Tag(s)) } -func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) } +func gp(s string) *Group { return new(Group(s)) } +func up(s string) *Username { return new(Username(s)) } +func hp(s string) *Host { return new(Host(s)) } +func tp(s string) *Tag { return new(Tag(s)) } +func agp(s string) *AutoGroup { return new(AutoGroup(s)) } func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } -func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } -func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) } +func ap(addr string) *netip.Addr { return new(netip.MustParseAddr(addr)) } +func pp(pref string) *Prefix { return new(Prefix(mp(pref))) } func p(pref string) Prefix { return Prefix(mp(pref)) } func TestResolvePolicy(t *testing.T) { @@ -1636,31 +1635,31 @@ func TestResolvePolicy(t *testing.T) { }, { name: "username", - toResolve: ptr.To(Username("testuser@")), + toResolve: new(Username("testuser@")), nodes: types.Nodes{ // Not matching other user { - User: ptr.To(notme), + User: new(notme), IPv4: ap("100.100.101.1"), }, // Not matching forced tags { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.2"), }, // not matching because it's tagged (tags copied from AuthKey) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"alsotagged"}, IPv4: ap("100.100.101.3"), }, { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.103"), }, { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.104"), }, }, @@ -1668,31 +1667,31 @@ func TestResolvePolicy(t *testing.T) { }, { name: "group", - toResolve: ptr.To(Group("group:testgroup")), + toResolve: new(Group("group:testgroup")), nodes: types.Nodes{ // Not matching other user { - User: ptr.To(notme), + User: new(notme), IPv4: ap("100.100.101.4"), }, // Not matching forced tags { - User: ptr.To(groupuser), + User: new(groupuser), Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.5"), }, // not matching because it's tagged (tags copied from AuthKey) { - User: ptr.To(groupuser), + User: new(groupuser), Tags: []string{"tag:alsotagged"}, IPv4: ap("100.100.101.6"), }, { - User: ptr.To(groupuser), + User: new(groupuser), IPv4: ap("100.100.101.203"), }, { - User: ptr.To(groupuser), + User: new(groupuser), IPv4: ap("100.100.101.204"), }, }, @@ -1710,7 +1709,7 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: ptr.To(notme), + User: new(notme), IPv4: ap("100.100.101.9"), }, // Not matching forced tags @@ -1746,7 +1745,7 @@ func TestResolvePolicy(t *testing.T) { pol: &Policy{ TagOwners: TagOwners{ Tag("tag:bigbrother"): {}, - Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:smallbrother"): {new(Tag("tag:bigbrother"))}, }, }, nodes: types.Nodes{ @@ -1769,7 +1768,7 @@ func TestResolvePolicy(t *testing.T) { pol: &Policy{ TagOwners: TagOwners{ Tag("tag:bigbrother"): {}, - Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:smallbrother"): {new(Tag("tag:bigbrother"))}, }, }, nodes: types.Nodes{ @@ -1804,14 +1803,14 @@ func TestResolvePolicy(t *testing.T) { }, { name: "multiple-groups", - toResolve: ptr.To(Group("group:testgroup")), + toResolve: new(Group("group:testgroup")), nodes: types.Nodes{ { - User: ptr.To(groupuser1), + User: new(groupuser1), IPv4: ap("100.100.101.203"), }, { - User: ptr.To(groupuser2), + User: new(groupuser2), IPv4: ap("100.100.101.204"), }, }, @@ -1829,10 +1828,10 @@ func TestResolvePolicy(t *testing.T) { }, { name: "invalid-username", - toResolve: ptr.To(Username("invaliduser@")), + toResolve: new(Username("invaliduser@")), nodes: types.Nodes{ { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.103"), }, }, @@ -1860,47 +1859,47 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-member-comprehensive", - toResolve: ptr.To(AutoGroup(AutoGroupMember)), + toResolve: new(AutoGroup(AutoGroupMember)), nodes: types.Nodes{ // Node with no tags (should be included - is a member) { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.1"), }, // Node with single tag (should be excluded - tagged nodes are not members) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with multiple tags, all defined in policy (should be excluded) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.3"), }, // Node with tag not defined in policy (should be excluded - still tagged) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:undefined"}, IPv4: ap("100.100.101.4"), }, // Node with mixed tags - some defined, some not (should be excluded) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:undefined"}, IPv4: ap("100.100.101.5"), }, // Another untagged node from different user (should be included) { - User: ptr.To(testuser2), + User: new(testuser2), IPv4: ap("100.100.101.6"), }, }, pol: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, - Tag("tag:other"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:test"): Owners{new(Username("testuser@"))}, + Tag("tag:other"): Owners{new(Username("testuser@"))}, }, }, want: []netip.Prefix{ @@ -1910,54 +1909,54 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-tagged", - toResolve: ptr.To(AutoGroup(AutoGroupTagged)), + toResolve: new(AutoGroup(AutoGroupTagged)), nodes: types.Nodes{ // Node with no tags (should be excluded - not tagged) { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.1"), }, // Node with single tag defined in policy (should be included) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with multiple tags, all defined in policy (should be included) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.3"), }, // Node with tag not defined in policy (should be included - still tagged) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:undefined"}, IPv4: ap("100.100.101.4"), }, // Node with mixed tags - some defined, some not (should be included) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:undefined"}, IPv4: ap("100.100.101.5"), }, // Another untagged node from different user (should be excluded) { - User: ptr.To(testuser2), + User: new(testuser2), IPv4: ap("100.100.101.6"), }, // Tagged node from different user (should be included) { - User: ptr.To(testuser2), + User: new(testuser2), Tags: []string{"tag:server"}, IPv4: ap("100.100.101.7"), }, }, pol: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, - Tag("tag:other"): Owners{ptr.To(Username("testuser@"))}, - Tag("tag:server"): Owners{ptr.To(Username("testuser2@"))}, + Tag("tag:test"): Owners{new(Username("testuser@"))}, + Tag("tag:other"): Owners{new(Username("testuser@"))}, + Tag("tag:server"): Owners{new(Username("testuser2@"))}, }, }, want: []netip.Prefix{ @@ -1968,37 +1967,37 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-self", - toResolve: ptr.To(AutoGroupSelf), + toResolve: new(AutoGroupSelf), nodes: types.Nodes{ { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.1"), }, { - User: ptr.To(testuser2), + User: new(testuser2), IPv4: ap("100.100.101.2"), }, { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.3"), }, { - User: ptr.To(testuser2), + User: new(testuser2), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.4"), }, }, pol: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:test"): Owners{new(Username("testuser@"))}, }, }, wantErr: "autogroup:self requires per-node resolution", }, { name: "autogroup-invalid", - toResolve: ptr.To(AutoGroup("autogroup:invalid")), + toResolve: new(AutoGroup("autogroup:invalid")), wantErr: "unknown autogroup", }, } @@ -2076,7 +2075,7 @@ func TestResolveAutoApprovers(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, }, }, }, @@ -2091,8 +2090,8 @@ func TestResolveAutoApprovers(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, + mp("10.0.1.0/24"): {new(Username("user2@"))}, }, }, }, @@ -2107,7 +2106,7 @@ func TestResolveAutoApprovers(t *testing.T) { name: "exit-node", policy: &Policy{ AutoApprovers: AutoApproverPolicy{ - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, want: map[netip.Prefix]*netipx.IPSet{}, @@ -2122,7 +2121,7 @@ func TestResolveAutoApprovers(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, }, }, }, @@ -2137,20 +2136,20 @@ func TestResolveAutoApprovers(t *testing.T) { policy: &Policy{ TagOwners: TagOwners{ "tag:testtag": Owners{ - ptr.To(Username("user1@")), - ptr.To(Username("user2@")), + new(Username("user1@")), + new(Username("user2@")), }, "tag:exittest": Owners{ - ptr.To(Group("group:exitgroup")), + new(Group("group:exitgroup")), }, }, Groups: Groups{ "group:exitgroup": Usernames{"user2@"}, }, AutoApprovers: AutoApproverPolicy{ - ExitNode: AutoApprovers{ptr.To(Tag("tag:exittest"))}, + ExitNode: AutoApprovers{new(Tag("tag:exittest"))}, Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.1.0/24"): {ptr.To(Tag("tag:testtag"))}, + mp("10.0.1.0/24"): {new(Tag("tag:testtag"))}, }, }, }, @@ -2168,10 +2167,10 @@ func TestResolveAutoApprovers(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {new(Username("user3@"))}, }, - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, want: map[netip.Prefix]*netipx.IPSet{ @@ -2388,7 +2387,7 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, }, }, }, @@ -2401,8 +2400,8 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, + mp("10.0.1.0/24"): {new(Username("user2@"))}, }, }, }, @@ -2414,7 +2413,7 @@ func TestNodeCanApproveRoute(t *testing.T) { name: "exit-node-approval", policy: &Policy{ AutoApprovers: AutoApproverPolicy{ - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2429,7 +2428,7 @@ func TestNodeCanApproveRoute(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, }, }, }, @@ -2445,10 +2444,10 @@ func TestNodeCanApproveRoute(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {new(Username("user3@"))}, }, - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2460,7 +2459,7 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user2@"))}, + mp("10.0.0.0/24"): {new(Username("user2@"))}, }, }, }, @@ -2518,7 +2517,7 @@ func TestResolveTagOwners(t *testing.T) { name: "single-tag-owner", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2530,7 +2529,7 @@ func TestResolveTagOwners(t *testing.T) { name: "multiple-tag-owners", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user1@")), new(Username("user2@"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2545,7 +2544,7 @@ func TestResolveTagOwners(t *testing.T) { "group:testgroup": Usernames{"user1@", "user2@"}, }, TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + Tag("tag:test"): Owners{new(Group("group:testgroup"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2557,8 +2556,8 @@ func TestResolveTagOwners(t *testing.T) { name: "tag-owns-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:bigbrother"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:bigbrother"): Owners{new(Username("user1@"))}, + Tag("tag:smallbrother"): Owners{new(Tag("tag:bigbrother"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2619,7 +2618,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "single-tag-owner", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2630,7 +2629,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "multiple-tag-owners", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user1@")), new(Username("user2@"))}, }, }, node: nodes[1], @@ -2644,7 +2643,7 @@ func TestNodeCanHaveTag(t *testing.T) { "group:testgroup": Usernames{"user1@", "user2@"}, }, TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + Tag("tag:test"): Owners{new(Group("group:testgroup"))}, }, }, node: nodes[1], @@ -2658,7 +2657,7 @@ func TestNodeCanHaveTag(t *testing.T) { "group:testgroup": Usernames{"invalid"}, }, TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + Tag("tag:test"): Owners{new(Group("group:testgroup"))}, }, }, node: nodes[0], @@ -2670,7 +2669,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-cannot-have-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user2@"))}, }, }, node: nodes[0], @@ -2681,7 +2680,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-with-unauthorized-tag-different-user", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:prod"): Owners{new(Username("user1@"))}, }, }, node: nodes[2], // user3's node @@ -2692,8 +2691,8 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-with-multiple-tags-one-unauthorized", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:web"): Owners{new(Username("user1@"))}, + Tag("tag:database"): Owners{new(Username("user2@"))}, }, }, node: nodes[0], // user1's node @@ -2713,7 +2712,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "tag-not-in-tagowners", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:prod"): Owners{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2726,13 +2725,13 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-without-ip-user-owns-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[0], - UserID: ptr.To(users[0].ID), + UserID: new(users[0].ID), }, tag: "tag:test", want: true, // Should succeed via user-based fallback @@ -2741,13 +2740,13 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-without-ip-user-does-not-own-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user2@"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[0], // user1, but tag owned by user2 - UserID: ptr.To(users[0].ID), + UserID: new(users[0].ID), }, tag: "tag:test", want: false, // user1 does not own tag:test @@ -2759,13 +2758,13 @@ func TestNodeCanHaveTag(t *testing.T) { "group:admins": Usernames{"user1@", "user2@"}, }, TagOwners: TagOwners{ - Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))}, + Tag("tag:admin"): Owners{new(Group("group:admins"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[1], // user2 is in group:admins - UserID: ptr.To(users[1].ID), + UserID: new(users[1].ID), }, tag: "tag:admin", want: true, // Should succeed via group membership @@ -2777,13 +2776,13 @@ func TestNodeCanHaveTag(t *testing.T) { "group:admins": Usernames{"user1@"}, }, TagOwners: TagOwners{ - Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))}, + Tag("tag:admin"): Owners{new(Group("group:admins"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[1], // user2 is NOT in group:admins - UserID: ptr.To(users[1].ID), + UserID: new(users[1].ID), }, tag: "tag:admin", want: false, // user2 is not in group:admins @@ -2792,7 +2791,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-without-ip-no-user", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, node: &types.Node{ @@ -2809,14 +2808,14 @@ func TestNodeCanHaveTag(t *testing.T) { }, TagOwners: TagOwners{ Tag("tag:server"): Owners{ - ptr.To(Username("user1@")), - ptr.To(Group("group:ops")), + new(Username("user1@")), + new(Group("group:ops")), }, }, }, node: &types.Node{ User: &users[0], // user1 directly owns the tag - UserID: ptr.To(users[0].ID), + UserID: new(users[0].ID), }, tag: "tag:server", want: true, @@ -2829,14 +2828,14 @@ func TestNodeCanHaveTag(t *testing.T) { }, TagOwners: TagOwners{ Tag("tag:server"): Owners{ - ptr.To(Username("user1@")), - ptr.To(Group("group:ops")), + new(Username("user1@")), + new(Group("group:ops")), }, }, }, node: &types.Node{ User: &users[2], // user3 is in group:ops - UserID: ptr.To(users[2].ID), + UserID: new(users[2].ID), }, tag: "tag:server", want: true, @@ -2881,14 +2880,14 @@ func TestUserMatchesOwner(t *testing.T) { name: "username-match", policy: &Policy{}, user: users[0], - owner: ptr.To(Username("user1@")), + owner: new(Username("user1@")), want: true, }, { name: "username-no-match", policy: &Policy{}, user: users[0], - owner: ptr.To(Username("user2@")), + owner: new(Username("user2@")), want: false, }, { @@ -2899,7 +2898,7 @@ func TestUserMatchesOwner(t *testing.T) { }, }, user: users[1], // user2 is in group:admins - owner: ptr.To(Group("group:admins")), + owner: new(Group("group:admins")), want: true, }, { @@ -2910,7 +2909,7 @@ func TestUserMatchesOwner(t *testing.T) { }, }, user: users[1], // user2 is NOT in group:admins - owner: ptr.To(Group("group:admins")), + owner: new(Group("group:admins")), want: false, }, { @@ -2919,7 +2918,7 @@ func TestUserMatchesOwner(t *testing.T) { Groups: Groups{}, }, user: users[0], - owner: ptr.To(Group("group:undefined")), + owner: new(Group("group:undefined")), want: false, }, { @@ -3261,20 +3260,20 @@ func TestFlattenTagOwners(t *testing.T) { { name: "tag-owns-tag", input: TagOwners{ - Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:bigbrother"): Owners{new(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{new(Tag("tag:bigbrother"))}, }, want: TagOwners{ - Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:smallbrother"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:bigbrother"): Owners{new(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{new(Group("group:user1"))}, }, wantErr: "", }, { name: "circular-reference", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, - Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + Tag("tag:a"): Owners{new(Tag("tag:b"))}, + Tag("tag:b"): Owners{new(Tag("tag:a"))}, }, want: nil, wantErr: "circular reference detected: tag:a -> tag:b", @@ -3282,83 +3281,83 @@ func TestFlattenTagOwners(t *testing.T) { { name: "mixed-owners", input: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, - Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:x"): Owners{new(Username("user1@")), new(Tag("tag:y"))}, + Tag("tag:y"): Owners{new(Username("user2@"))}, }, want: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, - Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:x"): Owners{new(Username("user1@")), new(Username("user2@"))}, + Tag("tag:y"): Owners{new(Username("user2@"))}, }, wantErr: "", }, { name: "mixed-dupe-owners", input: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, - Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:x"): Owners{new(Username("user1@")), new(Tag("tag:y"))}, + Tag("tag:y"): Owners{new(Username("user1@"))}, }, want: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:x"): Owners{new(Username("user1@"))}, + Tag("tag:y"): Owners{new(Username("user1@"))}, }, wantErr: "", }, { name: "no-tag-owners", input: TagOwners{ - Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:solo"): Owners{new(Username("user1@"))}, }, want: TagOwners{ - Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:solo"): Owners{new(Username("user1@"))}, }, wantErr: "", }, { name: "tag-long-owner-chain", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, - Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, - Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, - Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, - Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, - Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + Tag("tag:a"): Owners{new(Group("group:user1"))}, + Tag("tag:b"): Owners{new(Tag("tag:a"))}, + Tag("tag:c"): Owners{new(Tag("tag:b"))}, + Tag("tag:d"): Owners{new(Tag("tag:c"))}, + Tag("tag:e"): Owners{new(Tag("tag:d"))}, + Tag("tag:f"): Owners{new(Tag("tag:e"))}, + Tag("tag:g"): Owners{new(Tag("tag:f"))}, }, want: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:b"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:c"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:d"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:e"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:f"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:g"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:a"): Owners{new(Group("group:user1"))}, + Tag("tag:b"): Owners{new(Group("group:user1"))}, + Tag("tag:c"): Owners{new(Group("group:user1"))}, + Tag("tag:d"): Owners{new(Group("group:user1"))}, + Tag("tag:e"): Owners{new(Group("group:user1"))}, + Tag("tag:f"): Owners{new(Group("group:user1"))}, + Tag("tag:g"): Owners{new(Group("group:user1"))}, }, wantErr: "", }, { name: "tag-long-circular-chain", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:g"))}, - Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, - Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, - Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, - Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, - Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, - Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + Tag("tag:a"): Owners{new(Tag("tag:g"))}, + Tag("tag:b"): Owners{new(Tag("tag:a"))}, + Tag("tag:c"): Owners{new(Tag("tag:b"))}, + Tag("tag:d"): Owners{new(Tag("tag:c"))}, + Tag("tag:e"): Owners{new(Tag("tag:d"))}, + Tag("tag:f"): Owners{new(Tag("tag:e"))}, + Tag("tag:g"): Owners{new(Tag("tag:f"))}, }, wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g", }, { name: "undefined-tag-reference", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:nonexistent"))}, + Tag("tag:a"): Owners{new(Tag("tag:nonexistent"))}, }, wantErr: `tag "tag:a" references undefined tag "tag:nonexistent"`, }, { name: "tag-with-empty-owners-is-valid", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:a"): Owners{new(Tag("tag:b"))}, Tag("tag:b"): Owners{}, // empty owners but exists }, want: TagOwners{ diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go index 632af13c..9f713b3d 100644 --- a/hscontrol/state/ephemeral_test.go +++ b/hscontrol/state/ephemeral_test.go @@ -8,7 +8,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tailscale.com/types/ptr" ) // TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode @@ -50,7 +49,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) go func() { updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) done <- true }() @@ -106,7 +105,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { // Start UpdateNode in goroutine - it will queue and wait for batch go func() { node, ok := store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) resultChan <- struct { node types.NodeView @@ -156,7 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { // Simulate UpdateNode being called updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) require.True(t, ok, "UpdateNode should succeed") require.True(t, updatedNode.Valid(), "UpdateNode should return valid node") @@ -221,7 +220,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) go func() { updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) done <- true }() @@ -294,7 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { go func() { node, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) endpoint := netip.MustParseAddrPort("10.0.0.1:41641") n.Endpoints = []netip.AddrPort{endpoint} }) @@ -363,7 +362,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { go func() { updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) updateDone <- struct { node types.NodeView @@ -417,7 +416,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { // UpdateNode returns a node updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) require.True(t, ok, "UpdateNode should succeed") require.True(t, updatedNode.Valid(), "updated node should be valid") diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 99f781d4..0fa81318 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -9,7 +9,6 @@ import ( "github.com/stretchr/testify/require" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -149,7 +148,7 @@ func createTestNodeSimple(id types.NodeID) *types.Node { node := &types.Node{ ID: id, Hostname: "test-node", - UserID: ptr.To(uint(id)), + UserID: new(uint(id)), User: &user, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 3d6184ba..745850cc 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestSnapshotFromNodes(t *testing.T) { @@ -174,7 +173,7 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) DiscoKey: discoKey.Public(), Hostname: hostname, GivenName: hostname, - UserID: ptr.To(userID), + UserID: new(userID), User: &types.User{ Name: username, DisplayName: username, @@ -856,7 +855,7 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { Hostname: hostname, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), - UserID: ptr.To(uint(1)), + UserID: new(uint(1)), User: &types.User{ Name: "concurrent-test-user", }, diff --git a/hscontrol/types/node_tags_test.go b/hscontrol/types/node_tags_test.go index 72598b3c..97e01b2a 100644 --- a/hscontrol/types/node_tags_test.go +++ b/hscontrol/types/node_tags_test.go @@ -6,7 +6,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gorm.io/gorm" - "tailscale.com/types/ptr" ) // TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged. @@ -69,7 +68,7 @@ func TestNodeIsTagged(t *testing.T) { { name: "node with user and no tags - not tagged", node: Node{ - UserID: ptr.To(uint(42)), + UserID: new(uint(42)), Tags: []string{}, }, want: false, @@ -112,7 +111,7 @@ func TestNodeViewIsTagged(t *testing.T) { { name: "user-owned node", node: Node{ - UserID: ptr.To(uint(1)), + UserID: new(uint(1)), }, want: false, }, @@ -223,7 +222,7 @@ func TestNodeTagsImmutableAfterRegistration(t *testing.T) { // Test that a user-owned node is not tagged userNode := Node{ ID: 2, - UserID: ptr.To(uint(42)), + UserID: new(uint(42)), Tags: []string{}, RegisterMethod: util.RegisterMethodOIDC, } @@ -243,7 +242,7 @@ func TestNodeOwnershipModel(t *testing.T) { name: "tagged node has tags, UserID is informational", node: Node{ ID: 1, - UserID: ptr.To(uint(5)), // "created by" user 5 + UserID: new(uint(5)), // "created by" user 5 Tags: []string{"tag:server"}, }, wantIsTagged: true, @@ -253,7 +252,7 @@ func TestNodeOwnershipModel(t *testing.T) { name: "user-owned node has no tags", node: Node{ ID: 2, - UserID: ptr.To(uint(5)), + UserID: new(uint(5)), Tags: []string{}, }, wantIsTagged: false, @@ -265,7 +264,7 @@ func TestNodeOwnershipModel(t *testing.T) { name: "node with only authkey tags - not tagged (tags should be copied)", node: Node{ ID: 3, - UserID: ptr.To(uint(5)), // "created by" user 5 + UserID: new(uint(5)), // "created by" user 5 AuthKey: &PreAuthKey{ Tags: []string{"tag:database"}, }, diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index ec40492b..27aff519 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -93,7 +93,7 @@ func (u *User) StringID() string { } // TypedID returns a pointer to the user's ID as a UserID type. -// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)). +// This is a convenience method to avoid ugly casting like new(types.UserID(user.ID)). func (u *User) TypedID() *UserID { uid := UserID(u.ID) return &uid diff --git a/integration/acl_test.go b/integration/acl_test.go index c746f900..7a33240b 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) var veryLargeDestination = []policyv2.AliasWithPorts{ @@ -1284,9 +1283,9 @@ func TestACLAutogroupMember(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupMember)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupMember), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupMember), tailcfg.PortRangeAny), }, }, }, @@ -1372,9 +1371,9 @@ func TestACLAutogroupTagged(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupTagged)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupTagged)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), }, }, }, @@ -1657,9 +1656,9 @@ func TestACLAutogroupSelf(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupMember)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), }, }, { @@ -1956,9 +1955,9 @@ func TestACLPolicyPropagationOverTime(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupMember)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), }, }, }, diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index df5f2455..223e4c8b 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -356,13 +356,13 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { lines := strings.Split(curlOutput, "\n") var httpCode string - var responseBody string + var responseBody strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBody.WriteString(line) } } @@ -372,7 +372,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Should contain user data var response v1.ListUsersResponse - err = protojson.Unmarshal([]byte(responseBody), &response) + err = protojson.Unmarshal([]byte(responseBody.String()), &response) assert.NoError(t, err, "Response should be valid protobuf JSON") users := response.GetUsers() assert.Len(t, users, 2, "Should have 2 users") diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index ba6a195b..9cf352bb 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -17,7 +17,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { @@ -608,7 +607,7 @@ func TestAuthKeyLogoutAndReloginRoutesPreserved(t *testing.T) { }, AutoApprovers: policyv2.AutoApproverPolicy{ Routes: map[netip.Prefix]policyv2.AutoApprovers{ - netip.MustParsePrefix(advertiseRoute): {ptr.To(policyv2.Username(user + "@test.no"))}, + netip.MustParsePrefix(advertiseRoute): {new(policyv2.Username(user + "@test.no"))}, }, }, }, diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 1ca291c0..04365eae 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) func isSSHNoAccessStdError(stderr string) bool { @@ -482,10 +481,10 @@ func TestSSHAutogroupSelf(t *testing.T) { { Action: "accept", Sources: policyv2.SSHSrcAliases{ - ptr.To(policyv2.AutoGroupMember), + new(policyv2.AutoGroupMember), }, Destinations: policyv2.SSHDstAliases{ - ptr.To(policyv2.AutoGroupSelf), + new(policyv2.AutoGroupSelf), }, Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, From ad7669a2d40d6631311e6498f014546dd78d4d6f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:37:24 +0000 Subject: [PATCH 06/30] all: apply golangci-lint auto-fixes Apply auto-fixes from golangci-lint for the following linters: - wsl_v5: whitespace formatting and blank line adjustments - godot: add periods to comment sentences - nlreturn: add newlines before return statements - perfsprint: optimize fmt.Sprintf to more efficient alternatives Also add missing imports (errors, encoding/hex) where auto-fix added new code patterns that require them. --- cmd/headscale/cli/mockoidc.go | 2 + cmd/headscale/cli/policy.go | 3 + cmd/headscale/cli/root.go | 2 + cmd/headscale/cli/users.go | 1 + cmd/hi/cleanup.go | 5 + cmd/hi/docker.go | 34 ++++- cmd/hi/doctor.go | 4 + cmd/hi/main.go | 2 + cmd/hi/run.go | 7 +- cmd/hi/stats.go | 21 +++- cmd/mapresponses/main.go | 4 +- hscontrol/app.go | 25 +++- hscontrol/auth.go | 7 +- hscontrol/auth_test.go | 90 +++++++++++-- hscontrol/db/db.go | 5 + hscontrol/db/db_test.go | 3 + .../db/ephemeral_garbage_collector_test.go | 79 +++++++++--- hscontrol/db/ip_test.go | 1 + hscontrol/db/node.go | 6 + hscontrol/db/node_test.go | 7 ++ hscontrol/db/sqliteconfig/config.go | 5 + hscontrol/db/sqliteconfig/config_test.go | 2 + hscontrol/db/sqliteconfig/integration_test.go | 3 + hscontrol/db/text_serialiser.go | 3 + hscontrol/db/users.go | 2 + hscontrol/debug.go | 30 +++++ hscontrol/derp/derp.go | 3 + hscontrol/derp/derp_test.go | 2 + hscontrol/derp/server/derp_server.go | 11 +- hscontrol/dns/extrarecords.go | 6 + hscontrol/handlers.go | 2 + hscontrol/mapper/batcher_lockfree.go | 26 +++- hscontrol/mapper/batcher_test.go | 24 ++++ hscontrol/mapper/builder.go | 6 +- hscontrol/mapper/mapper.go | 6 +- hscontrol/mapper/mapper_test.go | 7 ++ hscontrol/noise.go | 3 + hscontrol/oidc.go | 14 +++ hscontrol/policy/matcher/matcher.go | 3 + hscontrol/policy/pm.go | 8 +- hscontrol/policy/policy.go | 1 + hscontrol/policy/policy_autoapprove_test.go | 7 +- hscontrol/policy/policy_test.go | 22 +++- hscontrol/policy/policyutil/reduce.go | 1 + hscontrol/policy/policyutil/reduce_test.go | 8 +- hscontrol/policy/route_approval_test.go | 2 + hscontrol/policy/v2/filter.go | 27 +++- hscontrol/policy/v2/filter_test.go | 33 +++-- hscontrol/policy/v2/policy.go | 28 ++++- hscontrol/policy/v2/policy_test.go | 8 +- hscontrol/policy/v2/types.go | 97 +++++++++++--- hscontrol/policy/v2/types_test.go | 13 +- hscontrol/policy/v2/utils.go | 3 + hscontrol/policy/v2/utils_test.go | 4 + hscontrol/poll.go | 4 + hscontrol/routes/primary.go | 9 ++ hscontrol/routes/primary_test.go | 19 ++- hscontrol/state/debug.go | 10 ++ hscontrol/state/ephemeral_test.go | 30 ++++- hscontrol/state/maprequest.go | 1 + hscontrol/state/maprequest_test.go | 2 +- hscontrol/state/node_store.go | 25 ++++ hscontrol/state/node_store_test.go | 119 +++++++++++++++--- hscontrol/tailsql.go | 1 + hscontrol/types/common.go | 1 + hscontrol/types/config.go | 3 + hscontrol/types/config_test.go | 4 + hscontrol/types/node.go | 5 +- hscontrol/types/preauth_key.go | 1 + hscontrol/types/users.go | 9 ++ hscontrol/types/users_test.go | 4 + hscontrol/util/dns_test.go | 2 + hscontrol/util/prompt.go | 2 + hscontrol/util/prompt_test.go | 8 ++ hscontrol/util/string.go | 2 + hscontrol/util/util.go | 22 +++- hscontrol/util/util_test.go | 15 +++ integration/api_auth_test.go | 63 +++++++--- integration/auth_key_test.go | 38 +++++- integration/auth_oidc_test.go | 39 ++++++ integration/auth_web_flow_test.go | 14 +++ integration/derp_verify_endpoint_test.go | 3 + integration/dockertestutil/config.go | 1 + integration/dockertestutil/execute.go | 2 + integration/dockertestutil/logs.go | 1 + integration/dockertestutil/network.go | 3 + integration/dsic/dsic.go | 8 ++ integration/helpers.go | 63 ++++++++-- integration/hsic/hsic.go | 17 ++- integration/integrationutil/util.go | 3 + integration/route_test.go | 101 ++++++++++++++- integration/scenario.go | 33 ++++- integration/scenario_test.go | 2 + 93 files changed, 1262 insertions(+), 155 deletions(-) diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 9969f7c6..af28ce9f 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -73,6 +73,7 @@ func mockOIDC() error { } var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) if err != nil { return fmt.Errorf("unmarshalling users: %w", err) @@ -137,6 +138,7 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Info().Msgf("Request: %+v", r) h.ServeHTTP(w, r) + if r.Response != nil { log.Info().Msgf("Response: %+v", r.Response) } diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 2aaebcfa..f3921a64 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -29,13 +29,16 @@ func init() { if err := setPolicy.MarkFlagRequired("file"); err != nil { log.Fatal().Err(err).Msg("") } + setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") policyCmd.AddCommand(setPolicy) checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") + if err := checkPolicy.MarkFlagRequired("file"); err != nil { log.Fatal().Err(err).Msg("") } + policyCmd.AddCommand(checkPolicy) } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index d7cdabb6..d67c2df8 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -80,6 +80,7 @@ func initConfig() { Repository: "headscale", TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }), } + res, err := latest.Check(githubTag, versionInfo.Version) if err == nil && res.Outdated { //nolint @@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool { return true } } + return false } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 9a816c78..6e4bdd02 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -23,6 +23,7 @@ func usernameAndIDFlag(cmd *cobra.Command) { // If both are empty, it will exit the program with an error. func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { username, _ := cmd.Flags().GetString("name") + identifier, _ := cmd.Flags().GetInt64("identifier") if username == "" && identifier < 0 { err := errors.New("--name or --identifier flag is required") diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 7c5b5214..e0268fd8 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -69,8 +69,10 @@ func killTestContainers(ctx context.Context) error { } removed := 0 + for _, cont := range containers { shouldRemove := false + for _, name := range cont.Names { if strings.Contains(name, "headscale-test-suite") || strings.Contains(name, "hs-") || @@ -259,8 +261,10 @@ func cleanOldImages(ctx context.Context) error { } removed := 0 + for _, img := range images { shouldRemove := false + for _, tag := range img.RepoTags { if strings.Contains(tag, "hs-") || strings.Contains(tag, "headscale-integration") || @@ -302,6 +306,7 @@ func cleanCacheVolume(ctx context.Context) error { defer cli.Close() volumeName := "hs-integration-go-cache" + err = cli.VolumeRemove(ctx, volumeName, true) if err != nil { if errdefs.IsNotFound(err) { diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index a6b94b25..3ad70173 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -60,6 +60,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if config.Verbose { log.Printf("Running pre-test cleanup...") } + if err := cleanupBeforeTest(ctx); err != nil && config.Verbose { log.Printf("Warning: pre-test cleanup failed: %v", err) } @@ -95,13 +96,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection for container resource monitoring (if enabled) var statsCollector *StatsCollector + if config.Stats { var err error + statsCollector, err = NewStatsCollector() if err != nil { if config.Verbose { log.Printf("Warning: failed to create stats collector: %v", err) } + statsCollector = nil } @@ -140,6 +144,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if len(violations) > 0 { log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") log.Printf("=================================") + for _, violation := range violations { log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) @@ -347,6 +352,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC maxWaitTime := 10 * time.Second checkInterval := 500 * time.Millisecond timeout := time.After(maxWaitTime) + ticker := time.NewTicker(checkInterval) defer ticker.Stop() @@ -356,6 +362,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction") } + return nil case <-ticker.C: allFinalized := true @@ -366,12 +373,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err) } + continue } // Check if container is in a final state if !isContainerFinalized(inspect.State) { allFinalized = false + if verbose { log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) } @@ -384,6 +393,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("All test containers finalized, ready for artifact extraction") } + return nil } } @@ -403,10 +413,12 @@ func findProjectRoot(startPath string) string { if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { return current } + parent := filepath.Dir(current) if parent == current { return startPath } + current = parent } } @@ -416,6 +428,7 @@ func boolToInt(b bool) int { if b { return 1 } + return 0 } @@ -435,6 +448,7 @@ func createDockerClient() (*client.Client, error) { } var clientOpts []client.Opt + clientOpts = append(clientOpts, client.WithAPIVersionNegotiation()) if contextInfo != nil { @@ -444,6 +458,7 @@ func createDockerClient() (*client.Client, error) { if runConfig.Verbose { log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) } + clientOpts = append(clientOpts, client.WithHost(host)) } } @@ -460,6 +475,7 @@ func createDockerClient() (*client.Client, error) { // getCurrentDockerContext retrieves the current Docker context information. func getCurrentDockerContext() (*DockerContext, error) { cmd := exec.Command("docker", "context", "inspect") + output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("failed to get docker context: %w", err) @@ -491,6 +507,7 @@ func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageNa if client.IsErrNotFound(err) { return false, nil } + return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err) } @@ -509,6 +526,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if verbose { log.Printf("Image %s is available locally", imageName) } + return nil } @@ -533,6 +551,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if err != nil { return fmt.Errorf("failed to read pull output: %w", err) } + log.Printf("Image %s pulled successfully", imageName) } @@ -547,9 +566,11 @@ func listControlFiles(logsDir string) { return } - var logFiles []string - var dataFiles []string - var dataDirs []string + var ( + logFiles []string + dataFiles []string + dataDirs []string + ) for _, entry := range entries { name := entry.Name() @@ -578,6 +599,7 @@ func listControlFiles(logsDir string) { if len(logFiles) > 0 { log.Printf("Headscale logs:") + for _, file := range logFiles { log.Printf(" %s", file) } @@ -585,9 +607,11 @@ func listControlFiles(logsDir string) { if len(dataFiles) > 0 || len(dataDirs) > 0 { log.Printf("Headscale data:") + for _, file := range dataFiles { log.Printf(" %s", file) } + for _, dir := range dataDirs { log.Printf(" %s/", dir) } @@ -612,6 +636,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose) extractedCount := 0 + for _, cont := range currentTestContainers { // Extract container logs and tar files if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { @@ -622,6 +647,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi if verbose { log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12]) } + extractedCount++ } } @@ -645,11 +671,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // Find the test container to get its run ID label var runID string + for _, cont := range containers { if cont.ID == testContainerID { if cont.Labels != nil { runID = cont.Labels["hi.run-id"] } + break } } diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 8af6051f..8ebda159 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -266,6 +266,7 @@ func checkGoInstallation() DoctorResult { } cmd := exec.Command("go", "version") + output, err := cmd.Output() if err != nil { return DoctorResult{ @@ -287,6 +288,7 @@ func checkGoInstallation() DoctorResult { // checkGitRepository verifies we're in a git repository. func checkGitRepository() DoctorResult { cmd := exec.Command("git", "rev-parse", "--git-dir") + err := cmd.Run() if err != nil { return DoctorResult{ @@ -316,6 +318,7 @@ func checkRequiredFiles() DoctorResult { } var missingFiles []string + for _, file := range requiredFiles { cmd := exec.Command("test", "-e", file) if err := cmd.Run(); err != nil { @@ -350,6 +353,7 @@ func displayDoctorResults(results []DoctorResult) { for _, result := range results { var icon string + switch result.Status { case "PASS": icon = "✅" diff --git a/cmd/hi/main.go b/cmd/hi/main.go index baecc6f3..0c9adc30 100644 --- a/cmd/hi/main.go +++ b/cmd/hi/main.go @@ -82,9 +82,11 @@ func cleanAll(ctx context.Context) error { if err := killTestContainers(ctx); err != nil { return err } + if err := pruneDockerNetworks(ctx); err != nil { return err } + if err := cleanOldImages(ctx); err != nil { return err } diff --git a/cmd/hi/run.go b/cmd/hi/run.go index 1694399d..e6c52634 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -48,6 +48,7 @@ func runIntegrationTest(env *command.Env) error { if runConfig.Verbose { log.Printf("Running pre-flight system checks...") } + if err := runDoctorCheck(env.Context()); err != nil { return fmt.Errorf("pre-flight checks failed: %w", err) } @@ -94,8 +95,10 @@ func detectGoVersion() string { // splitLines splits a string into lines without using strings.Split. func splitLines(s string) []string { - var lines []string - var current string + var ( + lines []string + current string + ) for _, char := range s { if char == '\n' { diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index c1bb9cfe..1c17df84 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -71,10 +71,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver // Start monitoring existing containers sc.wg.Add(1) + go sc.monitorExistingContainers(ctx, runID, verbose) // Start Docker events monitoring for new containers sc.wg.Add(1) + go sc.monitorDockerEvents(ctx, runID, verbose) if verbose { @@ -88,10 +90,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver func (sc *StatsCollector) StopCollection() { // Check if already stopped without holding lock sc.mutex.RLock() + if !sc.collectionStarted { sc.mutex.RUnlock() return } + sc.mutex.RUnlock() // Signal stop to all goroutines @@ -115,6 +119,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s if verbose { log.Printf("Failed to list existing containers: %v", err) } + return } @@ -168,6 +173,7 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, if verbose { log.Printf("Error in Docker events stream: %v", err) } + return } } @@ -214,6 +220,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI } sc.wg.Add(1) + go sc.collectStatsForContainer(ctx, containerID, verbose) } @@ -227,11 +234,13 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe if verbose { log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) } + return } defer statsResponse.Body.Close() decoder := json.NewDecoder(statsResponse.Body) + var prevStats *container.Stats for { @@ -247,6 +256,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe if err.Error() != "EOF" && verbose { log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) } + return } @@ -262,8 +272,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe // Store the sample (skip first sample since CPU calculation needs previous stats) if prevStats != nil { // Get container stats reference without holding the main mutex - var containerStats *ContainerStats - var exists bool + var ( + containerStats *ContainerStats + exists bool + ) sc.mutex.RLock() containerStats, exists = sc.containers[containerID] @@ -332,10 +344,12 @@ type StatsSummary struct { func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { // Take snapshot of container references without holding main lock long sc.mutex.RLock() + containerRefs := make([]*ContainerStats, 0, len(sc.containers)) for _, containerStats := range sc.containers { containerRefs = append(containerRefs, containerStats) } + sc.mutex.RUnlock() summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) @@ -393,9 +407,11 @@ func calculateStatsSummary(values []float64) StatsSummary { if value < min { min = value } + if value > max { max = value } + sum += value } @@ -435,6 +451,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo } summaries := sc.GetSummary() + var violations []MemoryViolation for _, summary := range summaries { diff --git a/cmd/mapresponses/main.go b/cmd/mapresponses/main.go index 5d7ad07d..af35bc48 100644 --- a/cmd/mapresponses/main.go +++ b/cmd/mapresponses/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" @@ -40,7 +41,7 @@ func main() { // runIntegrationTest executes the integration test workflow. func runOnline(env *command.Env) error { if mapConfig.Directory == "" { - return fmt.Errorf("directory is required") + return errors.New("directory is required") } resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) @@ -57,5 +58,6 @@ func runOnline(env *command.Env) error { os.Stderr.Write(out) os.Stderr.Write([]byte("\n")) + return nil } diff --git a/hscontrol/app.go b/hscontrol/app.go index aa011503..8ce1066f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -142,6 +142,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if !ok { log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") + return } @@ -157,10 +158,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.ephemeralGC = ephemeralGC var authProvider AuthProvider + authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + oidcProvider, err := NewAuthProviderOIDC( ctx, &app, @@ -177,6 +180,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { authProvider = oidcProvider } } + app.authProvider = authProvider if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS @@ -251,9 +255,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { lastExpiryCheck := time.Unix(0, 0) derpTickerChan := make(<-chan time.Time) + if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 { derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) defer derpTicker.Stop() + derpTickerChan = derpTicker.C } @@ -271,8 +277,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { return case <-expireTicker.C: - var expiredNodeChanges []change.Change - var changed bool + var ( + expiredNodeChanges []change.Change + changed bool + ) lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) @@ -287,11 +295,13 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { case <-derpTickerChan: log.Info().Msg("Fetching DERPMap updates") + derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { derpMap, err := derp.GetDERPMap(h.cfg.DERP) if err != nil { return nil, err } + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { region, _ := h.DERPServer.GenerateRegion() derpMap.Regions[region.RegionID] = ®ion @@ -303,6 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { log.Error().Err(err).Msg("failed to build new DERPMap, retrying later") continue } + h.state.SetDERPMap(derpMap) h.Change(change.DERPMap()) @@ -311,6 +322,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { if !ok { continue } + h.cfg.TailcfgDNSConfig.ExtraRecords = records h.Change(change.ExtraRecords()) @@ -390,6 +402,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writeUnauthorized := func(statusCode int) { writer.WriteHeader(statusCode) + if _, err := writer.Write([]byte("Unauthorized")); err != nil { log.Error().Err(err).Msg("writing HTTP response failed") } @@ -486,6 +499,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { var err error + capver.CanOldCodeBeCleanedUp() if profilingEnabled { @@ -512,6 +526,7 @@ func (h *Headscale) Serve() error { Msg("Clients with a lower minimum version will be rejected") h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() defer h.mapBatcher.Close() @@ -545,6 +560,7 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() + ephmNodes := h.state.ListEphemeralNodes() for _, node := range ephmNodes.All() { h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) @@ -555,7 +571,9 @@ func (h *Headscale) Serve() error { if err != nil { return fmt.Errorf("setting up extrarecord manager: %w", err) } + h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() + go h.extraRecordMan.Run() defer h.extraRecordMan.Close() } @@ -564,6 +582,7 @@ func (h *Headscale) Serve() error { // records updates scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) defer scheduleCancel() + go h.scheduledTasks(scheduleCtx) if zl.GlobalLevel() == zl.TraceLevel { @@ -751,7 +770,6 @@ func (h *Headscale) Serve() error { log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)") } - var tailsqlContext context.Context if tailsqlEnabled { if h.cfg.Database.Type != types.DatabaseSqlite { @@ -863,6 +881,7 @@ func (h *Headscale) Serve() error { // Close state connections info("closing state and database") + err = h.state.Close() if err != nil { log.Error().Err(err).Msg("failed to close state") diff --git a/hscontrol/auth.go b/hscontrol/auth.go index aa7088d7..c5fa91c2 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -51,6 +51,7 @@ func (h *Headscale) handleRegister( if err != nil { return nil, fmt.Errorf("handling logout: %w", err) } + if resp != nil { return resp, nil } @@ -131,7 +132,7 @@ func (h *Headscale) handleRegister( } // handleLogout checks if the [tailcfg.RegisterRequest] is a -// logout attempt from a node. If the node is not attempting to +// logout attempt from a node. If the node is not attempting to. func (h *Headscale) handleLogout( node types.NodeView, req tailcfg.RegisterRequest, @@ -158,6 +159,7 @@ func (h *Headscale) handleLogout( Interface("reg.req", req). Bool("unexpected", true). Msg("Node key expired, forcing re-authentication") + return &tailcfg.RegisterResponse{ NodeKeyExpired: true, MachineAuthorized: false, @@ -277,6 +279,7 @@ func (h *Headscale) waitForFollowup( // registration is expired in the cache, instruct the client to try a new registration return h.reqToNewRegisterResponse(req, machineKey) } + return nodeToRegisterResponse(node.View()), nil } } @@ -342,6 +345,7 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } + if perr, ok := errors.AsType[types.PAKError](err); ok { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } @@ -432,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive( Str("generated.hostname", hostname). Msg("Received registration request with empty hostname, generated default") } + hostinfo.Hostname = hostname nodeToRegister := types.NewRegisterNode( diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 1677642f..8a012ff6 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -2,6 +2,7 @@ package hscontrol import ( "context" + "errors" "fmt" "net/url" "strings" @@ -16,14 +17,14 @@ import ( "tailscale.com/types/key" ) -// Interactive step type constants +// Interactive step type constants. const ( stepTypeInitialRequest = "initial_request" stepTypeAuthCompletion = "auth_completion" stepTypeFollowupRequest = "followup_request" ) -// interactiveStep defines a step in the interactive authentication workflow +// interactiveStep defines a step in the interactive authentication workflow. type interactiveStep struct { stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest expectAuthURL bool @@ -75,6 +76,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -129,6 +131,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -163,6 +166,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both nodes exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.True(t, found2) assert.Equal(t, "reusable-node-1", node1.Hostname()) @@ -196,6 +200,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -227,6 +232,7 @@ func TestAuthenticationFlows(t *testing.T) { // First node should exist, second should not _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) _, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.False(t, found2) }, @@ -272,6 +278,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -391,6 +398,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -400,8 +408,10 @@ func TestAuthenticationFlows(t *testing.T) { // Wait for node to be available in NodeStore with debug info var attemptCount int + require.EventuallyWithT(t, func(c *assert.CollectT) { attemptCount++ + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) if assert.True(c, found, "node should be available in NodeStore") { t.Logf("Node found in NodeStore after %d attempts", attemptCount) @@ -451,6 +461,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -500,6 +511,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -549,25 +561,31 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err } // Wait for node to be available in NodeStore - var node types.NodeView - var found bool + var ( + node types.NodeView + found bool + ) + require.EventuallyWithT(t, func(c *assert.CollectT) { node, found = app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(c, found, "node should be available in NodeStore") }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + if !found { - return "", fmt.Errorf("node not found after setup") + return "", errors.New("node not found after setup") } // Expire the node expiredTime := time.Now().Add(-1 * time.Hour) _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + return "", err }, request: func(_ string) tailcfg.RegisterRequest { @@ -610,6 +628,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -673,6 +692,7 @@ func TestAuthenticationFlows(t *testing.T) { // and handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") + node := app.state.CreateNodeForTest(user, "followup-success-node") registered <- node }() @@ -782,6 +802,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -821,6 +842,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -865,6 +887,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -898,6 +921,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -922,6 +946,7 @@ func TestAuthenticationFlows(t *testing.T) { node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found) assert.Equal(t, "tagged-pak-node", node.Hostname()) + if node.AuthKey().Valid() { assert.NotEmpty(t, node.AuthKey().Tags()) } @@ -1031,6 +1056,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1047,6 +1073,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(newAuthKey string) tailcfg.RegisterRequest { @@ -1099,6 +1126,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1161,6 +1189,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1177,6 +1206,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pakRotation.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1226,6 +1256,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1265,6 +1296,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1429,6 +1461,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify custom hostinfo was preserved through interactive workflow node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be found after interactive registration") + if found { assert.Equal(t, "custom-interactive-node", node.Hostname()) assert.Equal(t, "linux", node.Hostinfo().OS()) @@ -1455,6 +1488,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1520,6 +1554,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify registration ID was properly generated and used node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered after interactive workflow") + if found { assert.Equal(t, "registration-id-test-node", node.Hostname()) assert.Equal(t, "test-os", node.Hostinfo().OS()) @@ -1535,6 +1570,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1577,6 +1613,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1632,6 +1669,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1648,6 +1686,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(user2AuthKey string) tailcfg.RegisterRequest { @@ -1712,6 +1751,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -1838,6 +1878,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1932,6 +1973,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -2097,6 +2139,7 @@ func TestAuthenticationFlows(t *testing.T) { // Collect results - at least one should succeed successCount := 0 + for range numConcurrent { select { case err := <-results: @@ -2217,6 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) { // Should handle nil hostinfo gracefully node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered despite nil hostinfo") + if found { // Should have some default hostname or handle nil gracefully hostname := node.Hostname() @@ -2315,12 +2359,14 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL assert.Contains(t, authURL2, "/register/") // Both should have different registration IDs regID1, err1 := extractRegistrationIDFromAuthURL(authURL1) regID2, err2 := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err1) require.NoError(t, err2) assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") @@ -2328,6 +2374,7 @@ func TestAuthenticationFlows(t *testing.T) { // Both cache entries should exist simultaneously _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2371,6 +2418,7 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL regID2, err := extractRegistrationIDFromAuthURL(authURL2) require.NoError(t, err) @@ -2378,6 +2426,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both exist _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2403,6 +2452,7 @@ func TestAuthenticationFlows(t *testing.T) { errorChan <- err return } + responseChan <- resp }() @@ -2430,6 +2480,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the node was created with the second registration's data node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered") + if found { assert.Equal(t, "pending-node-2", node.Hostname()) assert.Equal(t, "second-registration-user", node.User().Name()) @@ -2463,8 +2514,10 @@ func TestAuthenticationFlows(t *testing.T) { // Set up context with timeout for followup tests ctx := context.Background() + if req.Followup != "" { var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() } @@ -2516,7 +2569,7 @@ func TestAuthenticationFlows(t *testing.T) { } } -// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow +// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow. func runInteractiveWorkflowTest(t *testing.T, tt struct { name string setupFunc func(*testing.T, *Headscale) (string, error) @@ -2597,6 +2650,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { errorChan <- err return } + responseChan <- resp }() @@ -2650,24 +2704,27 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if responseToValidate == nil { responseToValidate = initialResp } + tt.validate(t, responseToValidate, app) } } -// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL +// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" + idx := strings.LastIndex(authURL, registerPrefix) if idx == -1 { return "", fmt.Errorf("invalid AuthURL format: %s", authURL) } idStr := authURL[idx+len(registerPrefix):] + return types.RegistrationIDFromString(idStr) } -// validateCompleteRegistrationResponse performs comprehensive validation of a registration response +// validateCompleteRegistrationResponse performs comprehensive validation of a registration response. func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { // Basic response validation require.NotNil(t, resp, "response should not be nil") @@ -2681,7 +2738,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe // Additional validation can be added here as needed } -// Simple test to validate basic node creation and lookup +// Simple test to validate basic node creation and lookup. func TestNodeStoreLookup(t *testing.T) { app := createTestApp(t) @@ -2713,8 +2770,10 @@ func TestNodeStoreLookup(t *testing.T) { // Wait for node to be available in NodeStore var node types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + node, found = app.state.GetNodeByNodeKey(nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") @@ -2783,8 +2842,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Get the node ID var registeredNode types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available") @@ -2796,6 +2857,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify initial state: user1 has 2 nodes, user2 has 2 nodes user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID)) user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID)) + require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially") require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially") @@ -2876,6 +2938,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify new nodes were created for user1 with the same machine keys t.Logf("Verifying new nodes created for user1 from user2's machine keys...") + for i := 2; i < 4; i++ { node := nodes[i] // Should be able to find a node with user1 and this machine key (the new one) @@ -2899,7 +2962,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Expected behavior: // - User1's original node should STILL EXIST (expired) // - User2 should get a NEW node created (NOT transfer) -// - Both nodes share the same machine key (same physical device) +// - Both nodes share the same machine key (same physical device). func TestWebFlowReauthDifferentUser(t *testing.T) { machineKey := key.NewMachine() nodeKey1 := key.NewNode() @@ -3043,6 +3106,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Count nodes per user user1Nodes := 0 user2Nodes := 0 + for i := 0; i < allNodesSlice.Len(); i++ { n := allNodesSlice.At(i) if n.UserID().Get() == user1.ID { @@ -3060,7 +3124,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { }) } -// Helper function to create test app +// Helper function to create test app. func createTestApp(t *testing.T) *Headscale { t.Helper() @@ -3147,6 +3211,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { } t.Log("Step 1: Initial registration with pre-auth key") + initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public()) require.NoError(t, err, "initial registration should succeed") require.NotNil(t, initialResp) @@ -3172,6 +3237,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // - System reboots // The Tailscale client persists the pre-auth key in its state and sends it on every registration t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") + restartReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ AuthKey: pakNew.Key, // Same key, now marked as Used=true @@ -3189,9 +3255,11 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // This is the assertion that currently FAILS in v0.27.0 assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + if err != nil { t.Logf("Error received (this is the bug): %v", err) t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially") + return // Stop here to show the bug clearly } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 05a4c7c8..988675b9 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -155,6 +155,7 @@ AND auth_key_id NOT IN ( nodeRoutes := map[uint64][]netip.Prefix{} var routes []types.Route + err = tx.Find(&routes).Error if err != nil { return fmt.Errorf("fetching routes: %w", err) @@ -255,9 +256,11 @@ AND auth_key_id NOT IN ( // Check if routes table exists and drop it (should have been migrated already) var routesExists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists) if err == nil && routesExists { log.Info().Msg("Dropping leftover routes table") + if err := tx.Exec("DROP TABLE routes").Error; err != nil { return fmt.Errorf("dropping routes table: %w", err) } @@ -280,6 +283,7 @@ AND auth_key_id NOT IN ( for _, table := range tablesToRename { // Check if table exists before renaming var exists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists) if err != nil { return fmt.Errorf("checking if table %s exists: %w", table, err) @@ -761,6 +765,7 @@ AND auth_key_id NOT IN ( // or else it blocks... sqlConn.SetMaxIdleConns(maxIdleConns) + sqlConn.SetMaxOpenConns(maxOpenConns) defer sqlConn.SetMaxIdleConns(1) defer sqlConn.SetMaxOpenConns(1) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3cd0d14e..47a527b9 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -44,6 +44,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { // Verify api_keys data preservation var apiKeyCount int + err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error require.NoError(t, err) assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema") @@ -186,6 +187,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { func requireConstraintFailed(t *testing.T, err error) { t.Helper() require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) } @@ -401,6 +403,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { // skip already-applied migrations and only run new ones. func TestSQLiteAllTestdataMigrations(t *testing.T) { t.Parallel() + schemas, err := os.ReadDir("testdata/sqlite") require.NoError(t, err) diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go index d118b7fd..2ad50885 100644 --- a/hscontrol/db/ephemeral_garbage_collector_test.go +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Basic deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex - var deletionWg sync.WaitGroup + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + deletionWg sync.WaitGroup + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionWg.Done() } @@ -43,10 +47,13 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { go gc.Start() // Schedule several nodes for deletion with short expiry - const expiry = fifty - const numNodes = 100 + const ( + expiry = fifty + numNodes = 100 + ) // Set up wait group for expected deletions + deletionWg.Add(numNodes) for i := 1; i <= numNodes; i++ { @@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Start GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() - const shortExpiry = fifty - const longExpiry = 1 * time.Hour + const ( + shortExpiry = fifty + longExpiry = 1 * time.Hour + ) nodeID := types.NodeID(1) @@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionNotifier <- nodeID } // Start the GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() nodeID := types.NodeID(1) + const expiry = fifty // Schedule node for deletion @@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + nodeDeleted := make(chan struct{}) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() close(nodeDeleted) // Signal that deletion happened } @@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Use a WaitGroup to ensure the GC has started var startWg sync.WaitGroup startWg.Add(1) + go func() { startWg.Done() // Signal that the goroutine has started gc.Start() }() + startWg.Wait() // Wait for the GC to start // Close GC right away @@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Check no node was deleted deleteMutex.Lock() + nodesDeleted := len(deletedIDs) + deleteMutex.Unlock() assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") @@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() } @@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { go gc.Start() // Number of concurrent scheduling goroutines - const numSchedulers = 10 - const nodesPerScheduler = 50 + const ( + numSchedulers = 10 + nodesPerScheduler = 50 + ) const closeAfterNodes = 25 // Close GC after this many nodes per scheduler diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 73895876..7827e002 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -483,6 +483,7 @@ func TestBackfillIPAddresses(t *testing.T) { func TestIPAllocatorNextNoReservedIPs(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) + defer db.Close() alloc, err := NewIPAllocator( diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 3887350b..7c818a75 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -206,6 +206,7 @@ func SetTags( slices.Sort(tags) tags = slices.Compact(tags) + b, err := json.Marshal(tags) if err != nil { return err @@ -378,6 +379,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n if ipv4 == nil { ipv4 = oldNode.IPv4 } + if ipv6 == nil { ipv6 = oldNode.IPv6 } @@ -406,6 +408,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n node.IPv6 = ipv6 var err error + node.Hostname, err = util.NormaliseHostname(node.Hostname) if err != nil { newHostname := util.InvalidString() @@ -693,9 +696,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname . } var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { var err error + registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) + return err }) if err != nil { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e82cdb62..3696aa2e 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -497,6 +497,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes1) == 0 { expectedRoutes1 = nil } + if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -508,6 +509,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes2) == 0 { expectedRoutes2 = nil } + if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -745,12 +747,15 @@ func TestNodeNaming(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) if err != nil { return err } + _, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil) _, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil) + return err }) require.NoError(t, err) @@ -999,6 +1004,7 @@ func TestListPeers(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err @@ -1084,6 +1090,7 @@ func TestListNodes(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go index d27977a4..23cb4b50 100644 --- a/hscontrol/db/sqliteconfig/config.go +++ b/hscontrol/db/sqliteconfig/config.go @@ -372,18 +372,23 @@ func (c *Config) ToURL() (string, error) { if c.BusyTimeout > 0 { pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout)) } + if c.JournalMode != "" { pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode)) } + if c.AutoVacuum != "" { pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum)) } + if c.WALAutocheckpoint >= 0 { pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint)) } + if c.Synchronous != "" { pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous)) } + if c.ForeignKeys { pragmas = append(pragmas, "foreign_keys=ON") } diff --git a/hscontrol/db/sqliteconfig/config_test.go b/hscontrol/db/sqliteconfig/config_test.go index 66955bb9..7829d9e9 100644 --- a/hscontrol/db/sqliteconfig/config_test.go +++ b/hscontrol/db/sqliteconfig/config_test.go @@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) { t.Errorf("Config.ToURL() error = %v", err) return } + if got != tt.want { t.Errorf("Config.ToURL() = %q, want %q", got, tt.want) } @@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) { Path: "", BusyTimeout: -1, } + _, err := config.ToURL() if err == nil { t.Error("Config.ToURL() with invalid config should return error") diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index bb54ea1e..b411daeb 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -109,7 +109,9 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { for pragma, expectedValue := range tt.expected { t.Run("pragma_"+pragma, func(t *testing.T) { var actualValue any + query := "PRAGMA " + pragma + err := db.QueryRow(query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) @@ -249,6 +251,7 @@ func TestJournalModeValidation(t *testing.T) { defer db.Close() var actualMode string + err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 46bd154f..102c0e9c 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -42,6 +42,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect if dbValue != nil { var bytes []byte + switch v := dbValue.(type) { case []byte: bytes = v @@ -55,6 +56,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect maybeInstantiatePtr(fieldValue) f := fieldValue.MethodByName("UnmarshalText") args := []reflect.Value{reflect.ValueOf(bytes)} + ret := f.Call(args) if !ret[0].IsNil() { return decodingError(field.Name, ret[0].Interface().(error)) @@ -89,6 +91,7 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { return nil, nil } + b, err := v.MarshalText() if err != nil { return nil, err diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 6aff9ed1..650dbd49 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -88,10 +88,12 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user") // not exist or if another User exists with the new name. func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } + if err = util.ValidateHostname(newName); err != nil { return err } diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 629b7be1..4fdcac11 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -25,17 +25,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { overview := h.state.DebugOverviewJSON() + overviewJSON, err := json.MarshalIndent(overview, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(overviewJSON) } else { // Default to text/plain for backward compatibility overview := h.state.DebugOverview() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(overview)) @@ -45,11 +48,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Configuration endpoint debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config := h.state.DebugConfig() + configJSON, err := json.MarshalIndent(config, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(configJSON) @@ -70,6 +75,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { } else { w.Header().Set("Content-Type", "text/plain") } + w.WriteHeader(http.StatusOK) w.Write([]byte(policy)) })) @@ -81,11 +87,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + filterJSON, err := json.MarshalIndent(filter, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(filterJSON) @@ -94,11 +102,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { // SSH policies endpoint debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sshPolicies := h.state.DebugSSHPolicies() + sshJSON, err := json.MarshalIndent(sshPolicies, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(sshJSON) @@ -112,17 +122,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { derpInfo := h.state.DebugDERPJSON() + derpJSON, err := json.MarshalIndent(derpInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(derpJSON) } else { // Default to text/plain for backward compatibility derpInfo := h.state.DebugDERPMap() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(derpInfo)) @@ -137,17 +150,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { nodeStoreNodes := h.state.DebugNodeStoreJSON() + nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(nodeStoreJSON) } else { // Default to text/plain for backward compatibility nodeStoreInfo := h.state.DebugNodeStore() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(nodeStoreInfo)) @@ -157,11 +173,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Registration cache endpoint debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cacheInfo := h.state.DebugRegistrationCache() + cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(cacheJSON) @@ -175,17 +193,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { routes := h.state.DebugRoutes() + routesJSON, err := json.MarshalIndent(routes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(routesJSON) } else { // Default to text/plain for backward compatibility routes := h.state.DebugRoutesString() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(routes)) @@ -200,17 +221,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { policyManagerInfo := h.state.DebugPolicyManagerJSON() + policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(policyManagerJSON) } else { // Default to text/plain for backward compatibility policyManagerInfo := h.state.DebugPolicyManager() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(policyManagerInfo)) @@ -227,6 +251,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { if res == nil { w.WriteHeader(http.StatusOK) w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + return } @@ -235,6 +260,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(resJSON) @@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string { activeConnections: info.ActiveConnections, }) totalNodes++ + if info.Connected { connectedCount++ } @@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string { activeConnections: 0, }) totalNodes++ + if connected { connectedCount++ } + return true }) } @@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { ActiveConnections: 0, } info.TotalNodes++ + return true }) } diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 42d74abe..f3807e21 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) { for id := range dm.Regions { ids = append(ids, id) } + slices.Sort(ids) for _, id := range ids { @@ -164,12 +165,14 @@ func derpRandom() *rand.Rand { rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) derpRandomInst = rnd }) + return derpRandomInst } func resetDerpRandomForTesting() { derpRandomMu.Lock() defer derpRandomMu.Unlock() + derpRandomOnce = sync.Once{} derpRandomInst = nil } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go index 91d605a6..445c1044 100644 --- a/hscontrol/derp/derp_test.go +++ b/hscontrol/derp/derp_test.go @@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { viper.Set("dns.base_domain", tt.baseDomain) + defer viper.Reset() + resetDerpRandomForTesting() testMap := tt.derpMap.View().AsStruct() diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index c736da28..bf292d03 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -74,9 +74,11 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { if err != nil { return tailcfg.DERPRegion{}, err } - var host string - var port int - var portStr string + var ( + host string + port int + portStr string + ) // Extract hostname and port from URL host, portStr, err = net.SplitHostPort(serverURL.Host) @@ -205,6 +207,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques return } defer websocketConn.Close(websocket.StatusInternalError, "closing") + if websocketConn.Subprotocol() != "derp" { websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") @@ -309,6 +312,7 @@ func DERPBootstrapDNSHandler( resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) defer cancel() var resolver net.Resolver + for _, region := range derpMap.Regions().All() { for _, node := range region.Nodes().All() { // we don't care if we override some nodes addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) @@ -320,6 +324,7 @@ func DERPBootstrapDNSHandler( continue } + dnsEntries[node.HostName()] = addrs } } diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 82b3078b..5d16c675 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -85,12 +85,15 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher event channel closing") return } + switch event.Op { case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") + if event.Name != e.path { continue } + e.updateRecords() // If a file is removed or renamed, fsnotify will loose track of it @@ -123,6 +126,7 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher error channel closing") return } + log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) } } @@ -165,6 +169,7 @@ func (e *ExtraRecordsMan) updateRecords() { e.hashes[e.path] = newHash log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) + e.updateCh <- e.records.Slice() } @@ -183,6 +188,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error } var records []tailcfg.DNSRecord + err = json.Unmarshal(b, &records) if err != nil { return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 2aee3cb2..7ec26994 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -181,6 +181,7 @@ func (h *Headscale) HealthHandler( json.NewEncoder(writer).Encode(res) } + err := h.state.PingDB(req.Context()) if err != nil { respond(err) @@ -217,6 +218,7 @@ func (h *Headscale) VersionHandler( writer.WriteHeader(http.StatusOK) versionInfo := types.GetVersionInfo() + err := json.NewEncoder(writer).Encode(versionInfo) if err != nil { log.Error(). diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 1d9c2c32..918b7049 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -2,6 +2,7 @@ package mapper import ( "crypto/rand" + "encoding/hex" "errors" "fmt" "sync" @@ -77,6 +78,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse if err != nil { log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) } @@ -86,10 +88,11 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse case c <- initialMap: // Success case <-time.After(5 * time.Second): - log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") + log.Error().Uint64("node.id", id.Uint64()).Err(errors.New("timeout")).Msg("Initial map send timeout") log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). Msg("Initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to send initial map to node %d: timeout", id) } @@ -129,6 +132,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo log.Debug().Caller().Uint64("node.id", id.Uint64()). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("Node connection removed but keeping online because other connections remain") + return true // Node still has active connections } @@ -211,10 +215,12 @@ func (b *LockFreeBatcher) worker(workerID int) { // This is used for synchronous map generation. if w.resultCh != nil { var result workResult + if nc, exists := b.nodes.Load(w.nodeID); exists { var err error result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + result.err = err if result.err != nil { b.workErrors.Add(1) @@ -397,6 +403,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() { } } } + return true }) @@ -449,6 +456,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { if nodeConn.hasActiveConnections() { ret.Store(id, true) } + return true }) @@ -464,6 +472,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret.Store(id, false) } } + return true }) @@ -518,7 +527,8 @@ type multiChannelNodeConn struct { func generateConnectionID() string { bytes := make([]byte, 8) rand.Read(bytes) - return fmt.Sprintf("%x", bytes) + + return hex.EncodeToString(bytes) } // newMultiChannelNodeConn creates a new multi-channel node connection. @@ -545,11 +555,14 @@ func (mc *multiChannelNodeConn) close() { // addConnection adds a new connection. func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { mutexWaitStart := time.Now() + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") mc.mutex.Lock() + mutexWaitDur := time.Since(mutexWaitStart) + defer mc.mutex.Unlock() mc.connections = append(mc.connections, entry) @@ -571,9 +584,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)). Int("remaining_connections", len(mc.connections)). Msg("Successfully removed connection") + return true } } + return false } @@ -607,6 +622,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // This is not an error - the node will receive a full map when it reconnects log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + return nil // Return success instead of error } @@ -615,7 +631,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: broadcasting to all connections") var lastErr error + successCount := 0 + var failedConnections []int // Track failed connections for removal // Send to all connections @@ -626,6 +644,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { if err := conn.send(data); err != nil { lastErr = err + failedConnections = append(failedConnections, i) log.Warn().Err(err). Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). @@ -633,6 +652,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: connection send failed") } else { successCount++ + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). Str("conn.id", conn.id).Int("connection_index", i). Msg("send: successfully sent to connection") @@ -797,6 +817,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { Connected: connected, ActiveConnections: activeConnCount, } + return true }) @@ -811,6 +832,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { ActiveConnections: 0, } } + return true }) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 00053892..595fb252 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -677,6 +677,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { connectedCount := 0 + for i := range allNodes { node := &allNodes[i] @@ -694,6 +695,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { }, 5*time.Minute, 5*time.Second, "waiting for full connectivity") t.Logf("✅ All nodes achieved full connectivity!") + totalTime := time.Since(startTime) // Disconnect all nodes @@ -1309,6 +1311,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { for range i % 3 { runtime.Gosched() // Introduce timing variability } + batcher.RemoveNode(testNode.n.ID, ch) // Yield to allow workers to process and close channels @@ -1392,6 +1395,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, reason := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1449,7 +1453,9 @@ func TestBatcherConcurrentClients(t *testing.T) { ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) @@ -1463,6 +1469,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, _ := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1495,6 +1502,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for range i % 5 { runtime.Gosched() // Introduce timing variability } + churningChannelsMutex.Lock() ch, exists := churningChannels[nodeID] @@ -1879,6 +1887,7 @@ func XTestBatcherScalability(t *testing.T) { channel, tailcfg.CapabilityVersion(100), ) + connectedNodesMutex.Lock() connectedNodes[nodeID] = true @@ -2287,6 +2296,7 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 1: Connect all nodes initially t.Logf("Phase 1: Connecting all nodes...") + for i, node := range allNodes { err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) if err != nil { @@ -2303,6 +2313,7 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) t.Logf("Phase 2: Rapid disconnect all nodes...") + for i, node := range allNodes { removed := batcher.RemoveNode(node.n.ID, node.ch) t.Logf("Node %d RemoveNode result: %t", i, removed) @@ -2310,9 +2321,11 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) t.Logf("Phase 3: Rapid reconnect with new channels...") + newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) for i, node := range allNodes { newChannels[i] = make(chan *tailcfg.MapResponse, 10) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) if err != nil { t.Errorf("Failed to reconnect node %d: %v", i, err) @@ -2343,11 +2356,13 @@ func TestBatcherRapidReconnection(t *testing.T) { if infoMap, ok := info.(map[string]any); ok { if connected, ok := infoMap["connected"].(bool); ok && !connected { disconnectedCount++ + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) } } } else { disconnectedCount++ + t.Logf("Node %d missing from debug info entirely", i) } @@ -2382,6 +2397,7 @@ func TestBatcherRapidReconnection(t *testing.T) { case update := <-newChannels[i]: if update != nil { receivedCount++ + t.Logf("Node %d received update successfully", i) } case <-timeout: @@ -2414,6 +2430,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 1: Connect first node with initial connection t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add node1: %v", err) @@ -2433,7 +2450,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 2: Add second connection for node1 (multi-connection scenario) t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add second connection for node1: %v", err) @@ -2444,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 3: Add third connection for node1 t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add third connection for node1: %v", err) @@ -2455,6 +2476,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 4: Verify debug status shows correct connection count t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { Debug() map[types.NodeID]any }); ok { @@ -2462,6 +2484,7 @@ func TestBatcherMultiConnection(t *testing.T) { if info, exists := debugInfo[node1.n.ID]; exists { t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { if activeConnections, ok := infoMap["active_connections"].(int); ok { if activeConnections != 3 { @@ -2470,6 +2493,7 @@ func TestBatcherMultiConnection(t *testing.T) { t.Logf("SUCCESS: Node1 correctly shows 3 active connections") } } + if connected, ok := infoMap["connected"].(bool); ok && !connected { t.Errorf("Node1 should show as connected with 3 active connections") } diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index b6f0b534..df0693e3 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -37,6 +37,7 @@ const ( // NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() + return &MapResponseBuilder{ resp: &tailcfg.MapResponse{ KeepAlive: false, @@ -124,6 +125,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, } + return b } @@ -281,16 +283,18 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } + b.resp.PeersRemoved = tailscaleIDs return b } -// Build finalizes the response and returns marshaled bytes +// Build finalizes the response and returns marshaled bytes. func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { if len(b.errs) > 0 { return nil, multierr.New(b.errs...) } + if debugDumpMapResponsePath != "" { writeDebugMapResponse(b.resp, b.debugType, b.nodeID) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 616d470f..843729c7 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -60,7 +60,6 @@ func newMapper( state *state.State, ) *mapper { // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) - return &mapper{ state: state, cfg: cfg, @@ -80,6 +79,7 @@ func generateUserProfiles( userID := user.Model().ID userMap[userID] = &user ids = append(ids, userID) + for _, peer := range peers.All() { peerUser := peer.Owner() peerUserID := peerUser.Model().ID @@ -90,6 +90,7 @@ func generateUserProfiles( slices.Sort(ids) ids = slices.Compact(ids) var profiles []tailcfg.UserProfile + for _, id := range ids { if userMap[id] != nil { profiles = append(profiles, userMap[id].TailscaleUserProfile()) @@ -306,6 +307,7 @@ func writeDebugMapResponse( perms := fs.FileMode(debugMapResponsePerm) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) + err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -319,6 +321,7 @@ func writeDebugMapResponse( ) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) if err != nil { panic(err) @@ -375,6 +378,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe } var resp tailcfg.MapResponse + err = json.Unmarshal(body, &resp) if err != nil { log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name()) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index a503c08c..4852ce04 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -98,6 +98,7 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { if m.polMan == nil { return tailcfg.FilterAllowAll, nil } + return m.polMan.Filter() } @@ -105,6 +106,7 @@ func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { if m.polMan == nil { return nil, nil } + return m.polMan.SSHPolicy(node) } @@ -112,6 +114,7 @@ func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { if m.polMan == nil { return false } + return m.polMan.NodeCanHaveTag(node, tag) } @@ -119,6 +122,7 @@ func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { if m.primary == nil { return nil } + return m.primary.PrimaryRoutes(nodeID) } @@ -126,6 +130,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ if len(peerIDs) > 0 { // Filter peers by the provided IDs var filtered types.Nodes + for _, peer := range m.peers { if slices.Contains(peerIDs, peer.ID) { filtered = append(filtered, peer) @@ -136,6 +141,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ } // Return all peers except the node itself var filtered types.Nodes + for _, peer := range m.peers { if peer.ID != nodeID { filtered = append(filtered, peer) @@ -149,6 +155,7 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { if len(nodeIDs) > 0 { // Filter nodes by the provided IDs var filtered types.Nodes + for _, node := range m.nodes { if slices.Contains(nodeIDs, node.ID) { filtered = append(filtered, node) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index f0e2fefa..869fe3f3 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -243,10 +243,12 @@ func (ns *noiseServer) NoiseRegistrationHandler( registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { var resp *tailcfg.RegisterResponse + body, err := io.ReadAll(req.Body) if err != nil { return &tailcfg.RegisterRequest{}, regErr(err) } + var regReq tailcfg.RegisterRequest if err := json.Unmarshal(body, ®Req); err != nil { return ®Req, regErr(err) @@ -260,6 +262,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( resp = &tailcfg.RegisterResponse{ Error: httpErr.Msg, } + return ®Req, resp } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 7013b8ed..836e8763 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -163,6 +163,7 @@ func (a *AuthProviderOIDC) RegisterHandler( for k, v := range a.cfg.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info @@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } stateCookieName := getCookieName("state", state) + cookieState, err := req.Cookie(stateCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) @@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( httpError(writer, err) return } + if idToken.Nonce == "" { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) return } nonceCookieName := getCookieName("nonce", idToken.Nonce) + nonce, err := req.Cookie(nonceCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } + if idToken.Nonce != nonce.Value { httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return @@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Fetch user information (email, groups, name, etc) from the userinfo endpoint // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo var userinfo *oidc.UserInfo + userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token)) if err != nil { util.LogErr(err, "could not get userinfo; only using claims from id token") @@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified) claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username) claims.Name = cmp.Or(userinfo2.Name, claims.Name) + claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL) if userinfo2.Groups != nil { claims.Groups = userinfo2.Groups @@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Msgf("could not create or update user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) + _, werr := writer.Write([]byte("Could not create or update user")) if werr != nil { log.Error(). @@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Register the node if it does not exist. if registrationId != nil { verb := "Reauthenticated" + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { @@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } + httpError(writer, err) + return } @@ -324,6 +335,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) + if _, err := writer.Write(content.Bytes()); err != nil { util.LogErr(err, "Failed to write HTTP response") } @@ -370,6 +382,7 @@ func (a *AuthProviderOIDC) getOauth2Token( if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } + if regInfo.Verifier != nil { exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} } @@ -516,6 +529,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( newUser bool c change.Change ) + user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index afc3cf68..0c84bae0 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -21,10 +21,13 @@ func (m Match) DebugString() string { sb.WriteString("Match:\n") sb.WriteString(" Sources:\n") + for _, prefix := range m.srcs.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } + sb.WriteString(" Destinations:\n") + for _, prefix := range m.dests.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index f4db88a4..ee112609 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -38,8 +38,11 @@ type PolicyManager interface { // NewPolicyManager returns a new policy manager. func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { - var polMan PolicyManager - var err error + var ( + polMan PolicyManager + err error + ) + polMan, err = policyv2.NewPolicyManager(pol, users, nodes) if err != nil { return nil, err @@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ if err != nil { return nil, err } + polMans = append(polMans, pm) } diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 24d2865e..42942f61 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -125,6 +125,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove if !slices.Equal(sortedCurrent, newApproved) { // Log what changed var added, kept []netip.Prefix + for _, route := range newApproved { if !slices.Contains(sortedCurrent, route) { added = append(added, route) diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index b7a758e6..21c2a66e 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -312,8 +312,11 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { nodes := types.Nodes{&node} // Create policy manager or use nil if specified - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + if tt.name != "nil_policy_manager" { pm, err = pmf(users, nodes.ViewSlice()) assert.NoError(t, err) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index eb3d85b6..ee4818aa 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -32,6 +32,7 @@ func TestReduceNodes(t *testing.T) { rules []tailcfg.FilterRule node *types.Node } + tests := []struct { name string args args @@ -782,9 +783,11 @@ func TestReduceNodes(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -1031,8 +1034,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(nil, tt.nodes.ViewSlice()) require.NoError(t, err) @@ -1050,9 +1056,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -1405,13 +1413,17 @@ func TestSSHPolicyRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) if tt.expectErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errorMessage) + return } @@ -1434,6 +1446,7 @@ func TestReduceRoutes(t *testing.T) { routes []netip.Prefix rules []tailcfg.FilterRule } + tests := []struct { name string args args @@ -2055,6 +2068,7 @@ func TestReduceRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( tt.args.node.View(), tt.args.routes, diff --git a/hscontrol/policy/policyutil/reduce.go b/hscontrol/policy/policyutil/reduce.go index e4549c10..6d95a297 100644 --- a/hscontrol/policy/policyutil/reduce.go +++ b/hscontrol/policy/policyutil/reduce.go @@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf for _, rule := range rules { // record if the rule is actually relevant for the given node. var dests []tailcfg.NetPortRange + DEST_LOOP: for _, dest := range rule.DstPorts { expanded, err := util.ParseIPSet(dest.IP, nil) diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go index bd975d23..0b674981 100644 --- a/hscontrol/policy/policyutil/reduce_test.go +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -823,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm policy.PolicyManager - var err error + var ( + pm policy.PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) require.NoError(t, err) + got, _ := pm.Filter() t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) got = policyutil.ReduceFilterRules(tt.node.View(), got) diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 5aa5e28c..3d070a25 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -829,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if tt.name == "empty policy" { // We expect this one to have a valid but empty policy require.NoError(t, err) + if err != nil { return } @@ -843,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if diff := cmp.Diff(tt.canApprove, result); diff != "" { t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) } + assert.Equal(t, tt.canApprove, result, "Unexpected route approval result") }) } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 78c6ebc5..3f72cdda 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -45,6 +45,7 @@ func (pol *Policy) compileFilterRules( protocols, _ := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { ips, err := dest.Resolve(pol, users, nodes) if err != nil { @@ -127,8 +128,10 @@ func (pol *Policy) compileACLWithAutogroupSelf( node types.NodeView, nodes views.Slice[types.NodeView], ) ([]*tailcfg.FilterRule, error) { - var autogroupSelfDests []AliasWithPorts - var otherDests []AliasWithPorts + var ( + autogroupSelfDests []AliasWithPorts + otherDests []AliasWithPorts + ) for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -139,13 +142,14 @@ func (pol *Policy) compileACLWithAutogroupSelf( } protocols, _ := acl.Protocol.parseProtocol() + var rules []*tailcfg.FilterRule var resolvedSrcIPs []*netipx.IPSet for _, src := range acl.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - return nil, fmt.Errorf("autogroup:self cannot be used in sources") + return nil, errors.New("autogroup:self cannot be used in sources") } ips, err := src.Resolve(pol, users, nodes) @@ -167,6 +171,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(autogroupSelfDests) > 0 { // Pre-filter to same-user untagged devices once - reuse for both sources and destinations sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { sameUserNodes = append(sameUserNodes, n) @@ -176,6 +181,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(sameUserNodes) > 0 { // Filter sources to only same-user untagged devices var srcIPs netipx.IPSetBuilder + for _, ips := range resolvedSrcIPs { for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set @@ -192,6 +198,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if srcSet != nil && len(srcSet.Prefixes()) > 0 { var destPorts []tailcfg.NetPortRange + for _, dest := range autogroupSelfDests { for _, n := range sameUserNodes { for _, port := range dest.Ports { @@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy( // Separate destinations into autogroup:self and others // This is needed because autogroup:self requires filtering sources to same-user only, // while other destinations should use all resolved sources - var autogroupSelfDests []Alias - var otherDests []Alias + var ( + autogroupSelfDests []Alias + otherDests []Alias + ) for _, dst := range rule.Destinations { if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy( } var action tailcfg.SSHAction + switch rule.Action { case SSHActionAccept: action = sshAction(true, 0) @@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy( // by default, we do not allow root unless explicitly stated userMap["root"] = "" } + if rule.Users.ContainsRoot() { userMap["root"] = "root" } + for _, u := range rule.Users.NormalUsers() { userMap[u.String()] = u.String() } @@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy( if len(autogroupSelfDests) > 0 && !node.IsTagged() { // Build destination set for autogroup:self (same-user untagged devices only) var dest netipx.IPSetBuilder + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { n.AppendToIPSet(&dest) @@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy( // Filter sources to only same-user untagged devices // Pre-filter to same-user untagged devices for efficiency sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { sameUserNodes = append(sameUserNodes, n) @@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy( } var filteredSrcIPs netipx.IPSetBuilder + for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { @@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy( if len(otherDests) > 0 { // Build destination set for other destinations var dest netipx.IPSetBuilder + for _, dst := range otherDests { ips, err := dst.Resolve(pol, users, nodes) if err != nil { log.Trace().Caller().Err(err).Msgf("resolving destination ips") continue } + if ips != nil { dest.AddSet(ips) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index d798b5f7..663e3d6b 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -589,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { if sshPolicy == nil { return // Expected empty result } + assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match") + return } @@ -670,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } // TestSSHIntegrationReproduction reproduces the exact scenario from the integration test -// TestSSHOneUserToAll that was failing with empty sshUsers +// TestSSHOneUserToAll that was failing with empty sshUsers. func TestSSHIntegrationReproduction(t *testing.T) { // Create users matching the integration test users := types.Users{ @@ -735,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { } // TestSSHJSONSerialization verifies that the SSH policy can be properly serialized -// to JSON and that the sshUsers field is not empty +// to JSON and that the sshUsers field is not empty. func TestSSHJSONSerialization(t *testing.T) { users := types.Users{ {Name: "user1", Model: gorm.Model{ID: 1}}, @@ -775,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) { // Parse back to verify structure var parsed tailcfg.SSHPolicy + err = json.Unmarshal(jsonData, &parsed) require.NoError(t, err) @@ -859,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if len(rules) != 1 { t.Fatalf("expected 1 rule, got %d", len(rules)) } @@ -875,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { found := false addr := netip.MustParseAddr(expectedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -892,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"} for _, excludedIP := range excludedSourceIPs { addr := netip.MustParseAddr(excludedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -1325,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) { assert.Empty(t, rules3, "user3 should have no rules") } -// Helper function to create IP addresses for testing +// Helper function to create IP addresses for testing. func createAddr(ip string) *netip.Addr { addr, _ := netip.ParseAddr(ip) return &addr } // TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly -// with autogroup:self in destinations +// with autogroup:self in destinations. func TestSSHWithAutogroupSelfInDestination(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1380,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // Test for user2's first node @@ -1398,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule2.Principals { principalIPs2[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2) // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy3 != nil { assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self") } @@ -1411,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user // is in the source and autogroup:self in destination, only that user's devices -// can SSH (and only if they match the target user) +// can SSH (and only if they match the target user). func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1453,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1") } } -// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations +// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations. func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1511,19 +1522,21 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)") } } // TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices -// are excluded from both sources and destinations when autogroup:self is used +// are excluded from both sources and destinations when autogroup:self is used. func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1568,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs, "should only include untagged devices") @@ -1575,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self") } @@ -1623,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Verify autogroup:self rule has filtered sources (only same-user devices) selfRule := sshPolicy1.Rules[0] require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices") + selfPrincipals := make([]string, len(selfRule.Principals)) for i, p := range selfRule.Principals { selfPrincipals[i] = p.NodeIP } + require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals, "autogroup:self rule should only include same-user untagged devices") @@ -1638,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") routerRule := sshPolicyRouter.Rules[0] + routerPrincipals := make([]string, len(routerRule.Principals)) for i, p := range routerRule.Principals { routerPrincipals[i] = p.NodeIP } + require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 042c2723..8c07e6cc 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Filter: filter, Policy: pm.pol, }) + filterChanged := filterHash != pm.filterHash if filterChanged { log.Debug(). @@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("filter.rules.new", len(filter)). Msg("Policy filter hash changed") } + pm.filter = filter + pm.filterHash = filterHash if filterChanged { pm.matchers = matcher.MatchesFromFilterRules(pm.filter) @@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } tagOwnerMapHash := deephash.Hash(&tagMap) + tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash if tagOwnerChanged { log.Debug(). @@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("tagOwners.new", len(tagMap)). Msg("Tag owner hash changed") } + pm.tagOwnerMap = tagMap pm.tagOwnerMapHash = tagOwnerMapHash @@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } autoApproveMapHash := deephash.Hash(&autoMap) + autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash if autoApproveChanged { log.Debug(). @@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("autoApprovers.new", len(autoMap)). Msg("Auto-approvers hash changed") } + pm.autoApproveMap = autoMap pm.autoApproveMapHash = autoApproveMapHash exitSetHash := deephash.Hash(&exitSet) + exitSetChanged := exitSetHash != pm.exitSetHash if exitSetChanged { log.Debug(). @@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Str("exitSet.hash.new", exitSetHash.String()[:8]). Msg("Exit node set hash changed") } + pm.exitSet = exitSet pm.exitSetHash = exitSetHash @@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { if !needsUpdate { log.Trace(). Msg("Policy evaluation detected no changes - all hashes match") + return false, nil } @@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } + pm.sshPolicyMap[node.ID()] = sshPol return sshPol, nil @@ -318,6 +329,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ if err != nil || len(filter) == 0 { continue } + nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter) } @@ -398,6 +410,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil reducedFilter := policyutil.ReduceFilterRules(node, pm.filter) pm.filterRulesMap[node.ID()] = reducedFilter + return reducedFilter, nil } @@ -442,7 +455,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul // This is different from FilterForNode which returns REDUCED rules for packet filtering. // // For global policies: returns the global matchers (same for all nodes) -// For autogroup:self: returns node-specific matchers from unreduced compiled rules +// For autogroup:self: returns node-specific matchers from unreduced compiled rules. func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { if pm == nil { return nil, nil @@ -474,6 +487,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() + pm.users = users // Clear SSH policy map when users change to force SSH policy recomputation @@ -685,6 +699,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr if pm.exitSet == nil { return false } + if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) { return true } @@ -748,8 +763,10 @@ func (pm *PolicyManager) DebugString() string { } fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap)) + for prefix, approveAddrs := range pm.autoApproveMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range approveAddrs.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } @@ -758,14 +775,17 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap)) + for prefix, tagOwners := range pm.tagOwnerMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range tagOwners.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } } sb.WriteString("\n\n") + if pm.filter != nil { filter, err := json.MarshalIndent(pm.filter, "", " ") if err == nil { @@ -778,6 +798,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Matchers:\n") sb.WriteString("an internal structure used to filter nodes and routes\n") + for _, match := range pm.matchers { sb.WriteString(match.DebugString()) sb.WriteString("\n") @@ -785,6 +806,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Nodes:\n") + for _, node := range pm.nodes.All() { sb.WriteString(node.String()) sb.WriteString("\n") @@ -841,6 +863,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // Check if IPs changed (simple check - could be more sophisticated) oldIPs := oldNode.IPs() + newIPs := newNode.IPs() if len(oldIPs) != len(newIPs) { affectedUsers[newNode.User().ID()] = struct{}{} @@ -862,6 +885,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S for nodeID := range pm.filterRulesMap { // Find the user for this cached node var nodeUserID uint + found := false // Check in new nodes first @@ -869,6 +893,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if node.ID() == nodeID { nodeUserID = node.User().ID() found = true + break } } @@ -879,6 +904,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if node.ID() == nodeID { nodeUserID = node.User().ID() found = true + break } } diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 80c08eed..4477e8b1 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -56,6 +56,7 @@ func TestPolicyManager(t *testing.T) { if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff( tt.wantMatchers, matchers, @@ -176,13 +177,16 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { t.Run(tt.name, func(t *testing.T) { for i, n := range tt.newNodes { found := false + for _, origNode := range initialNodes { if n.Hostname == origNode.Hostname { n.ID = origNode.ID found = true + break } } + if !found { n.ID = types.NodeID(len(initialNodes) + i + 1) } @@ -369,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) { // TestAutogroupSelfReducedVsUnreducedRules verifies that: // 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships -// 2. FilterForNode returns reduced compiled rules for packet filters +// 2. FilterForNode returns reduced compiled rules for packet filters. func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"} user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"} @@ -409,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations // For node1, destinations should only be node1's IPs node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"} + for _, rule := range filterNode1 { for _, dst := range rule.DstPorts { require.Contains(t, node1IPs, dst.IP, @@ -418,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // For node2, destinations should only be node2's IPs node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"} + for _, rule := range filterNode2 { for _, dst := range rule.DstPorts { require.Contains(t, node2IPs, dst.IP, diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index fbce8a2b..3fe5a0d4 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -21,7 +21,7 @@ import ( "tailscale.com/util/slicesx" ) -// Global JSON options for consistent parsing across all struct unmarshaling +// Global JSON options for consistent parsing across all struct unmarshaling. var policyJSONOpts = []json.Options{ json.DefaultOptionsV2(), json.MatchCaseInsensitiveNames(true), @@ -58,6 +58,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { } var alias string + switch v := a.Alias.(type) { case *Username: alias = string(*v) @@ -89,6 +90,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { // Otherwise, format as "alias:ports" var ports []string + for _, port := range a.Ports { if port.First == port.Last { ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) @@ -123,6 +125,7 @@ func (u Username) Validate() error { if isUser(string(u)) { return nil } + return fmt.Errorf("Username has to contain @, got: %q", u) } @@ -194,8 +197,10 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { } func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) user, err := u.resolveUser(users) if err != nil { @@ -228,6 +233,7 @@ func (g Group) Validate() error { if isGroup(string(g)) { return nil } + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } @@ -268,8 +274,10 @@ func (g Group) MarshalJSON() ([]byte, error) { } func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, user := range p.Groups[g] { uips, err := user.Resolve(nil, users, nodes) @@ -290,6 +298,7 @@ func (t Tag) Validate() error { if isTag(string(t)) { return nil } + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } @@ -339,6 +348,7 @@ func (h Host) Validate() error { if isHost(string(h)) { return nil } + return fmt.Errorf("Hostname %q is invalid", h) } @@ -352,13 +362,16 @@ func (h *Host) UnmarshalJSON(b []byte) error { } func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) pref, ok := p.Hosts[h] if !ok { return nil, fmt.Errorf("unable to resolve host: %q", h) } + err := pref.Validate() if err != nil { errs = append(errs, err) @@ -376,6 +389,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView if err != nil { errs = append(errs, err) } + for _, node := range nodes.All() { if node.InIPSet(ipsTemp) { node.AppendToIPSet(&ips) @@ -391,6 +405,7 @@ func (p Prefix) Validate() error { if netip.Prefix(p).IsValid() { return nil } + return fmt.Errorf("Prefix %q is invalid", p) } @@ -404,6 +419,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + addrPref, err := addr.Prefix(addr.BitLen()) if err != nil { return err @@ -418,6 +434,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + *p = Prefix(pref) return nil @@ -428,6 +445,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { if err != nil { return err } + if err := p.Validate(); err != nil { return err } @@ -441,8 +459,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // // See [Policy], [types.Users], and [types.Nodes] for more details. func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) ips.AddPrefix(netip.Prefix(p)) // If the IP is a single host, look for a node to ensure we add all the IPs of @@ -587,8 +607,10 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { switch vs := v.(type) { case string: - var portsPart string - var err error + var ( + portsPart string + err error + ) if strings.Contains(vs, ":") { vs, portsPart, err = splitDestinationAndPort(vs) @@ -600,6 +622,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Ports = ports } else { return errors.New(`hostport must contain a colon (":")`) @@ -609,6 +632,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } + if err := ve.Validate(); err != nil { return err } @@ -646,6 +670,7 @@ func isHost(str string) bool { func parseAlias(vs string) (Alias, error) { var pref Prefix + err := pref.parseString(vs) if err == nil { return &pref, nil @@ -690,6 +715,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Alias = ptr return nil @@ -699,6 +725,7 @@ type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -744,8 +771,10 @@ func (a Aliases) MarshalJSON() ([]byte, error) { } func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, alias := range a { aips, err := alias.Resolve(p, users, nodes) @@ -770,6 +799,7 @@ func unmarshalPointer[T any]( parseFunc func(string) (T, error), ) (T, error) { var s string + err := json.Unmarshal(b, &s) if err != nil { var t T @@ -789,6 +819,7 @@ type AutoApprovers []AutoApprover func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { var autoApprovers []AutoApproverEnc + err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...) if err != nil { return err @@ -854,6 +885,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.AutoApprover = ptr return nil @@ -876,6 +908,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Owner = ptr return nil @@ -885,6 +918,7 @@ type Owners []Owner func (o *Owners) UnmarshalJSON(b []byte) error { var owners []OwnerEnc + err := json.Unmarshal(b, &owners, policyJSONOpts...) if err != nil { return err @@ -979,11 +1013,13 @@ func (g *Groups) UnmarshalJSON(b []byte) error { // Then validate each field can be converted to []string rawGroups := make(map[string][]string) + for key, value := range rawMap { switch v := value.(type) { case []any: // Convert []interface{} to []string var stringSlice []string + for _, item := range v { if str, ok := item.(string); ok { stringSlice = append(stringSlice, str) @@ -991,6 +1027,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) } } + rawGroups[key] = stringSlice case string: return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) @@ -1000,6 +1037,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { } *g = make(Groups) + for key, value := range rawGroups { group := Group(key) // Group name already validated above @@ -1014,6 +1052,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { return err } + usernames = append(usernames, username) } @@ -1033,6 +1072,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { } *h = make(Hosts) + for key, value := range rawHosts { host := Host(key) if err := host.Validate(); err != nil { @@ -1076,6 +1116,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { } rawTagOwners := make(map[string][]string) + for tag, owners := range to { tagStr := string(tag) ownerStrs := make([]string, len(owners)) @@ -1152,6 +1193,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if p == nil { return nil, nil, nil } + var err error routes := make(map[netip.Prefix]*netipx.IPSetBuilder) @@ -1160,6 +1202,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if _, ok := routes[prefix]; !ok { routes[prefix] = new(netipx.IPSetBuilder) } + for _, autoApprover := range autoApprovers { aa, ok := autoApprover.(Alias) if !ok { @@ -1173,6 +1216,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } var exitNodeSetBuilder netipx.IPSetBuilder + if len(p.AutoApprovers.ExitNode) > 0 { for _, autoApprover := range p.AutoApprovers.ExitNode { aa, ok := autoApprover.(Alias) @@ -1187,11 +1231,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } ret := make(map[netip.Prefix]*netipx.IPSet) + for prefix, builder := range routes { ipSet, err := builder.IPSet() if err != nil { return nil, nil, err } + ret[prefix] = ipSet } @@ -1235,6 +1281,7 @@ func (a *Action) UnmarshalJSON(b []byte) error { default: return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) } + return nil } @@ -1259,6 +1306,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { default: return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) } + return nil } @@ -1399,7 +1447,7 @@ func (p Protocol) validate() error { return nil case ProtocolWildcard: // Wildcard "*" is not allowed - Tailscale rejects it - return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") default: // Try to parse as a numeric protocol number str := string(p) @@ -1427,7 +1475,7 @@ func (p Protocol) MarshalJSON() ([]byte, error) { return json.Marshal(string(p)) } -// Protocol constants matching the IANA numbers +// Protocol constants matching the IANA numbers. const ( protocolICMP = 1 // Internet Control Message protocolIGMP = 2 // Internet Group Management @@ -1464,6 +1512,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Remove any fields that start with '#' filtered := make(map[string]any) + for key, value := range raw { if !strings.HasPrefix(key, "#") { filtered[key] = value @@ -1478,6 +1527,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Create a type alias to avoid infinite recursion type aclAlias ACL + var temp aclAlias // Unmarshal into the temporary struct using the v2 JSON options @@ -1487,6 +1537,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Copy the result back to the original struct *a = ACL(temp) + return nil } @@ -1733,6 +1784,7 @@ func (p *Policy) validate() error { } } } + for _, dst := range ssh.Destinations { switch dst := dst.(type) { case *AutoGroup: @@ -1846,6 +1898,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { for i, username := range usernames { users[i] = string(username) } + raw[string(group)] = users } @@ -1854,6 +1907,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1877,6 +1931,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1960,8 +2015,10 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { } func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, alias := range a { aips, err := alias.Resolve(p, users, nodes) @@ -2015,18 +2072,22 @@ func unmarshalPolicy(b []byte) (*Policy, error) { } var policy Policy + ast, err := hujson.Parse(b) if err != nil { return nil, fmt.Errorf("parsing HuJSON: %w", err) } ast.Standardize() + if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName { ptr := serr.JSONPointer name := ptr.LastToken() + return nil, fmt.Errorf("unknown field %q", name) } + return nil, fmt.Errorf("parsing policy from bytes: %w", err) } @@ -2073,6 +2134,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true @@ -2087,6 +2149,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range ssh.Destinations { if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 8f4f7a85..79d005a3 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -81,6 +81,7 @@ func TestMarshalJSON(t *testing.T) { // Unmarshal back to verify round trip var roundTripped Policy + err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) @@ -2020,6 +2021,7 @@ func TestResolvePolicy(t *testing.T) { } var prefs []netip.Prefix + if ips != nil { if p := ips.Prefixes(); len(p) > 0 { prefs = p @@ -2191,9 +2193,11 @@ func TestResolveAutoApprovers(t *testing.T) { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff) } + if tt.wantAllIPRoutes != nil { if gotAllIPRoutes == nil { t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil") @@ -2340,6 +2344,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet { for _, p := range prefixes { builder.AddPrefix(mp(p)) } + ipSet, _ := builder.IPSet() return ipSet @@ -2349,6 +2354,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool { if x == nil || y == nil { return x == y } + return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...) } @@ -2577,6 +2583,7 @@ func TestResolveTagOwners(t *testing.T) { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff) } @@ -2852,6 +2859,7 @@ func TestNodeCanHaveTag(t *testing.T) { require.ErrorContains(t, err, tt.wantErr) return } + require.NoError(t, err) got := pm.NodeCanHaveTag(tt.node.View(), tt.tag) @@ -3112,6 +3120,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) if tt.wantErr { @@ -3163,6 +3172,7 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { // Unmarshal back var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) require.NoError(t, err) @@ -3241,12 +3251,13 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { assert.Contains(t, err.Error(), `invalid action "deny"`) } -// Helper function to parse aliases for testing +// Helper function to parse aliases for testing. func mustParseAlias(s string) Alias { alias, err := parseAlias(s) if err != nil { panic(err) } + return alias } diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index a4367775..80de52bc 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -18,9 +18,11 @@ func splitDestinationAndPort(input string) (string, string, error) { if lastColonIndex == -1 { return "", "", errors.New("input must contain a colon character separating destination and port") } + if lastColonIndex == 0 { return "", "", errors.New("input cannot start with a colon character") } + if lastColonIndex == len(input)-1 { return "", "", errors.New("input cannot end with a colon character") } @@ -45,6 +47,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { for part := range parts { if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") + rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { return e == "" }) diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index 2084b22f..a845e7a9 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -58,9 +58,11 @@ func TestParsePort(t *testing.T) { if err != nil && err.Error() != test.err { t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) } + if err == nil && test.err != "" { t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) } + if result != test.expected { t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) } @@ -92,9 +94,11 @@ func TestParsePortRange(t *testing.T) { if err != nil && err.Error() != test.err { t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) } + if err == nil && test.err != "" { t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) } + if diff := cmp.Diff(result, test.expected); diff != "" { t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 02275751..d3c9f1ef 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() { // This is not my favourite solution, but it kind of works in our eventually consistent world. ticker := time.NewTicker(time.Second) defer ticker.Stop() + disconnected := true // Wait up to 10 seconds for the node to reconnect. // 10 seconds was arbitrary chosen as a reasonable time to reconnect. @@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() { disconnected = false break } + <-ticker.C } @@ -215,8 +217,10 @@ func (m *mapSession) serveLongPoll() { if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { m.errf(err, "failed to add node to batcher") log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session") + return } + log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher") m.h.Change(mapReqChange) diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 72eb2a5b..e3708a13 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -107,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Msg("Current primary no longer available") } } + if len(nodes) >= 1 { pr.primaries[prefix] = nodes[0] changed = true + log.Debug(). Caller(). Str("prefix", prefix.String()). @@ -126,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Str("prefix", prefix.String()). Msg("Cleaning up primary route that no longer has available nodes") delete(pr.primaries, prefix) + changed = true } } @@ -161,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) // If no routes are being set, remove the node from the routes map. if len(prefixes) == 0 { wasPresent := false + if _, ok := pr.routes[node]; ok { delete(pr.routes, node) + wasPresent = true + log.Debug(). Caller(). Uint64("node.id", node.Uint64()). Msg("Removed node from primary routes (no prefixes)") } + changed := pr.updatePrimaryLocked() log.Debug(). Caller(). @@ -254,12 +261,14 @@ func (pr *PrimaryRoutes) stringLocked() string { ids := types.NodeIDs(xmaps.Keys(pr.routes)) slices.Sort(ids) + for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) } fmt.Fprintln(&sb, "\n\nCurrent primary routes:") + for route, nodeID := range pr.primaries { fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID) } diff --git a/hscontrol/routes/primary_test.go b/hscontrol/routes/primary_test.go index 7a9767b2..b03c8f81 100644 --- a/hscontrol/routes/primary_test.go +++ b/hscontrol/routes/primary_test.go @@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) pr.SetRoutes(2, mp("192.168.2.0/24")) pr.SetRoutes(1) // Deregister by setting no routes + return pr.SetRoutes(1, mp("192.168.3.0/24")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) { { name: "multiple-nodes-register-same-route", operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - pr.SetRoutes(2, mp("192.168.1.0/24")) // true + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true + return pr.SetRoutes(3, mp("192.168.1.0/24")) // false }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) // false pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary - return pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(1) // true, 2 primary }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ 2: { @@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) + return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { var wg sync.WaitGroup wg.Add(2) + var change1, change2 bool + go func() { defer wg.Done() + change1 = pr.SetRoutes(1, mp("192.168.1.0/24")) }() go func() { defer wg.Done() + change2 = pr.SetRoutes(2, mp("192.168.2.0/24")) }() + wg.Wait() return change1 || change2 @@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pr := New() + change := tt.operations(pr) if change != tt.expectedChange { t.Errorf("change = %v, want %v", change, tt.expectedChange) } + comps := append(util.Comparers, cmpopts.EquateEmpty()) if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { t.Errorf("routes mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { t.Errorf("primaries mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 3ed1d79f..9cad1c04 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -77,6 +77,7 @@ func (s *State) DebugOverview() string { ephemeralCount := 0 now := time.Now() + for _, node := range allNodes.All() { if node.Valid() { userName := node.Owner().Name() @@ -103,17 +104,21 @@ func (s *State) DebugOverview() string { // User statistics sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users))) + for userName, nodeCount := range userNodeCounts { sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount)) } + sb.WriteString("\n") // Policy information sb.WriteString("Policy:\n") sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode)) + if s.cfg.Policy.Mode == types.PolicyModeFile { sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path)) } + sb.WriteString("\n") // DERP information @@ -123,6 +128,7 @@ func (s *State) DebugOverview() string { } else { sb.WriteString("DERP: not configured\n") } + sb.WriteString("\n") // Route information @@ -130,6 +136,7 @@ func (s *State) DebugOverview() string { if s.primaryRoutes.String() == "" { routeCount = 0 } + sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount)) sb.WriteString("\n") @@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string { for _, node := range region.Nodes { sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n", node.Name, node.HostName, node.DERPPort)) + if node.STUNPort != 0 { sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort)) } } + sb.WriteString("\n") } @@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { if s.primaryRoutes.String() == "" { routeCount = 0 } + info.PrimaryRoutes = routeCount return info diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go index 9f713b3d..5c755687 100644 --- a/hscontrol/state/ephemeral_test.go +++ b/hscontrol/state/ephemeral_test.go @@ -20,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // Create NodeStore store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -43,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // 6. If DELETE came after UPDATE, the returned node should be invalid done := make(chan bool, 2) - var updatedNode types.NodeView - var updateOk bool + + var ( + updatedNode types.NodeView + updateOk bool + ) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { n.LastSeen = new(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(node.ID) + done <- true }() @@ -90,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -147,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { node := createTestNode(3, 1, "test-user", "test-node-3") store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -203,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -213,8 +223,11 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 1. UpdateNode (from UpdateNodeFromMapRequest during polling) // 2. DeleteNode (from handleLogout when client sends logout request) - var updatedNode types.NodeView - var updateOk bool + var ( + updatedNode types.NodeView + updateOk bool + ) + done := make(chan bool, 2) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) @@ -222,12 +235,14 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { n.LastSeen = new(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(ephemeralNode.ID) + done <- true }() @@ -266,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 5. UpdateNode and DeleteNode batch together // 6. UpdateNode returns a valid node (from before delete in batch) // 7. persistNodeToDB is called with the stale valid node -// 8. Node gets re-inserted into database instead of staying deleted +// 8. Node gets re-inserted into database instead of staying deleted. func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -278,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -348,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -398,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // 3. UpdateNode and DeleteNode batch together // 4. UpdateNode returns a valid node (from before delete in batch) // 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node -// 6. persistNodeToDB must detect the node is deleted and refuse to persist +// 6. persistNodeToDB must detect the node is deleted and refuse to persist. func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -408,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { } store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go index e7dfc11c..d8cddaa1 100644 --- a/hscontrol/state/maprequest.go +++ b/hscontrol/state/maprequest.go @@ -29,6 +29,7 @@ func netInfoFromMapRequest( Uint64("node.id", nodeID.Uint64()). Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP). Msg("using NetInfo from previous Hostinfo in MapRequest") + return currentHostinfo.NetInfo } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 0fa81318..a7d50a07 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -136,7 +136,7 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { }) } -// Simple helper function for tests +// Simple helper function for tests. func createTestNodeSimple(id types.NodeID) *types.Node { user := types.User{ Name: "test-user", diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 6327b46b..5d8d6e85 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc for _, n := range allNodes { nodes[n.ID] = *n } + snap := snapshotFromNodes(nodes, peersFunc) store := &NodeStore{ @@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("put").Inc() return resultNode @@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("update").Inc() // Return the node and whether it exists (is valid) @@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() @@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() { if len(batch) != 0 { s.applyBatch(batch) } + return } + batch = append(batch, w) if len(batch) >= s.batchSize { s.applyBatch(batch) @@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) { w.updateFn(&n) nodes[w.nodeID] = n } + if w.nodeResult != nil { nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) } @@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) { nodeView := node.View() for _, w := range workItems { w.nodeResult <- nodeView + close(w.nodeResult) } } else { // Node was deleted or doesn't exist for _, w := range workItems { w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) } } @@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S peersByNode: func() map[types.NodeID][]types.NodeView { peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) defer peersTimer.ObserveDuration() + return peersFunc(allNodes) }(), nodesByUser: make(map[types.UserID][]types.NodeView), @@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S if newSnap.nodesByMachineKey[n.MachineKey] == nil { newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView) } + newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView } @@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string { // User distribution (shows internal UserID tracking, not display owner) sb.WriteString("Nodes by Internal User ID:\n") + for userID, nodes := range snapshot.nodesByUser { if len(nodes) > 0 { userName := "unknown" taggedCount := 0 + if len(nodes) > 0 && nodes[0].Valid() { userName = nodes[0].User().Name() // Count tagged nodes (which have UserID set but are owned by "tagged-devices") @@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string { } } } + sb.WriteString("\n") // Peer relationships summary sb.WriteString("Peer Relationships:\n") + totalPeers := 0 + for nodeID, peers := range snapshot.peersByNode { peerCount := len(peers) + totalPeers += peerCount if node, exists := snapshot.nodesByID[nodeID]; exists { sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n", nodeID, node.Hostname, peerCount)) } } + if len(snapshot.peersByNode) > 0 { avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode)) sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers)) } + sb.WriteString("\n") // Node key index @@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() { } s.writeQueue <- w + <-result } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 745850cc..23068b97 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -44,6 +44,7 @@ func TestSnapshotFromNodes(t *testing.T) { nodes := map[types.NodeID]types.Node{ 1: createTestNode(1, 1, "user1", "node1"), } + return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { @@ -192,11 +193,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + for _, n := range nodes { if n.ID() != node.ID() { peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -207,6 +210,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 for _, n := range nodes { @@ -221,6 +225,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -454,10 +459,13 @@ func TestNodeStoreOperations(t *testing.T) { // Add nodes in sequence n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) assert.True(t, n4.Valid()) @@ -525,16 +533,20 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2 types.NodeView - var newNode3 types.NodeView - var ok1, ok2 bool + var ( + resultNode1, resultNode2 types.NodeView + newNode3 types.NodeView + ok1, ok2 bool + ) // These should all be processed in the same batch + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "batch-updated-node1" n.GivenName = "batch-given-1" }) + close(done1) }() @@ -543,12 +555,14 @@ func TestNodeStoreOperations(t *testing.T) { n.Hostname = "batch-updated-node2" n.GivenName = "batch-given-2" }) + close(done2) }() go func() { node3 := createTestNode(3, 1, "user1", "node3") newNode3 = store.PutNode(node3) + close(done3) }() @@ -601,20 +615,23 @@ func TestNodeStoreOperations(t *testing.T) { // This test verifies that when multiple updates to the same node // are batched together, each returned node reflects ALL changes // in the batch, not just the individual update's changes. - done1 := make(chan struct{}) done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2, resultNode3 types.NodeView - var ok1, ok2, ok3 bool + var ( + resultNode1, resultNode2, resultNode3 types.NodeView + ok1, ok2, ok3 bool + ) // These updates all modify node 1 and should be batched together // The final state should have all three modifications applied + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "multi-update-hostname" }) + close(done1) }() @@ -622,6 +639,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "multi-update-givenname" }) + close(done2) }() @@ -629,6 +647,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"tag1", "tag2"} }) + close(done3) }() @@ -722,14 +741,18 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var result1, result2, result3 types.NodeView - var ok1, ok2, ok3 bool + var ( + result1, result2, result3 types.NodeView + ok1, ok2, ok3 bool + ) // Start concurrent updates + go func() { result1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "concurrent-db-hostname" }) + close(done1) }() @@ -737,6 +760,7 @@ func TestNodeStoreOperations(t *testing.T) { result2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "concurrent-db-given" }) + close(done2) }() @@ -744,6 +768,7 @@ func TestNodeStoreOperations(t *testing.T) { result3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"concurrent-tag"} }) + close(done3) }() @@ -827,6 +852,7 @@ func TestNodeStoreOperations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := tt.setupFunc(t) + store.Start() defer store.Stop() @@ -846,10 +872,11 @@ type testStep struct { // --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- -// Helper for concurrent test nodes +// Helper for concurrent test nodes. func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { machineKey := key.NewMachine() nodeKey := key.NewNode() + return types.Node{ ID: id, Hostname: hostname, @@ -862,72 +889,90 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { } } -// --- Concurrency: concurrent PutNode operations --- +// --- Concurrency: concurrent PutNode operations ---. func TestNodeStoreConcurrentPutNode(t *testing.T) { const concurrentOps = 20 store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, concurrentOps) for i := range concurrentOps { wg.Add(1) + go func(nodeID int) { defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") } -// --- Batching: concurrent ops fit in one batch --- +// --- Batching: concurrent ops fit in one batch ---. func TestNodeStoreBatchingEfficiency(t *testing.T) { const batchSize = 10 + const ops = 15 // more than batchSize store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, ops) for i := range ops { wg.Add(1) + go func(nodeID int) { defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") } -// --- Race conditions: many goroutines on same node --- +// --- Race conditions: many goroutines on same node ---. func TestNodeStoreRaceConditions(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -936,13 +981,18 @@ func TestNodeStoreRaceConditions(t *testing.T) { resultNode := store.PutNode(node) require.True(t, resultNode.Valid()) - const numGoroutines = 30 - const opsPerGoroutine = 10 + const ( + numGoroutines = 30 + opsPerGoroutine = 10 + ) + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) for i := range numGoroutines { wg.Add(1) + go func(gid int) { defer wg.Done() @@ -962,6 +1012,7 @@ func TestNodeStoreRaceConditions(t *testing.T) { } case 2: newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) if !resultNode.Valid() { errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) @@ -970,23 +1021,28 @@ func TestNodeStoreRaceConditions(t *testing.T) { } }(i) } + wg.Wait() close(errors) errorCount := 0 + for err := range errors { t.Error(err) + errorCount++ } + if errorCount > 0 { t.Fatalf("Race condition test failed with %d errors", errorCount) } } -// --- Resource cleanup: goroutine leak detection --- +// --- Resource cleanup: goroutine leak detection ---. func TestNodeStoreResourceCleanup(t *testing.T) { // initialGoroutines := runtime.NumGoroutine() store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1009,10 +1065,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }) retrieved, found := store.GetNode(nodeID) assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { store.DeleteNode(nodeID) } } + runtime.GC() // Wait for goroutines to settle and check for leaks @@ -1023,9 +1081,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }, time.Second, 10*time.Millisecond, "goroutines should not leak") } -// --- Timeout/deadlock: operations complete within reasonable time --- +// --- Timeout/deadlock: operations complete within reasonable time ---. func TestNodeStoreOperationTimeout(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1033,36 +1092,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) { defer cancel() const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) updateResults := make([]error, ops) // Launch all PutNode operations concurrently for i := 1; i <= ops; i++ { nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) node := createConcurrentTestNode(id, "timeout-node") resultNode := store.PutNode(node) endPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) } }(i, nodeID) } + wg.Wait() // Launch all UpdateNode operations concurrently wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) resultNode, ok := store.UpdateNode(id, func(n *types.Node) { @@ -1070,31 +1140,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) { }) endUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) } }(i, nodeID) } + done := make(chan struct{}) + go func() { wg.Wait() close(done) }() + select { case <-done: errorCount := 0 + for _, err := range putResults { if err != nil { t.Error(err) + errorCount++ } } + for _, err := range updateResults { if err != nil { t.Error(err) + errorCount++ } } + if errorCount == 0 { t.Log("All concurrent operations completed successfully within timeout") } else { @@ -1106,13 +1185,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) { } } -// --- Edge case: update non-existent node --- +// --- Edge case: update non-existent node ---. func TestNodeStoreUpdateNonExistentNode(t *testing.T) { for i := range 10 { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() + nonExistentID := types.NodeID(999 + i) updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { updateCallCount++ @@ -1126,9 +1207,10 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) { } } -// --- Allocation benchmark --- +// --- Allocation benchmark ---. func BenchmarkNodeStoreAllocations(b *testing.B) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1140,6 +1222,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) { n.Hostname = "bench-updated" }) store.GetNode(nodeID) + if i%10 == 9 { store.DeleteNode(nodeID) } diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 1a949173..efce647d 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -93,6 +93,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s mux := tsql.NewMux() tsweb.Debugger(mux) go http.Serve(lst, mux) + logf("TailSQL started") <-ctx.Done() logf("TailSQL shutting down...") diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index f4814519..be3756a0 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -177,6 +177,7 @@ func RegistrationIDFromString(str string) (RegistrationID, error) { if len(str) != RegistrationIDLength { return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) } + return RegistrationID(str), nil } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4068d72e..fffe166d 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -301,6 +301,7 @@ func validatePKCEMethod(method string) error { if method != PKCEMethodPlain && method != PKCEMethodS256 { return errInvalidPKCEMethod } + return nil } @@ -1082,6 +1083,7 @@ func LoadServerConfig() (*Config, error) { if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 { return workers } + return DefaultBatcherWorkers() }(), RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), @@ -1117,6 +1119,7 @@ func isSafeServerURL(serverURL, baseDomain string) error { } s := len(serverDomainParts) + b := len(baseDomainParts) for i := range baseDomainParts { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 6b9fc2ef..13a3a418 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -363,6 +363,7 @@ noise: // Populate a custom config file configFilePath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) @@ -398,10 +399,12 @@ server_url: http://127.0.0.1:8080 tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: TLS-ALPN-01 `) + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) } + err = LoadConfig(tmpDir, false) require.NoError(t, err) } @@ -463,6 +466,7 @@ func TestSafeServerURL(t *testing.T) { return } + assert.NoError(t, err) }) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 1a66341d..5140bc44 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -156,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool { // Strip invalid DNS characters for givenName comparison normalised := strings.ToLower(node.Hostname) normalised = invalidDNSRegex.ReplaceAllString(normalised, "") + return node.GivenName == normalised } @@ -464,7 +465,7 @@ func (node *Node) IsSubnetRouter() bool { return len(node.SubnetRoutes()) > 0 } -// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes +// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes. func (node *Node) AllApprovedRoutes() []netip.Prefix { return append(node.SubnetRoutes(), node.ExitRoutes()...) } @@ -579,6 +580,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { Str("rejected_hostname", hostInfo.Hostname). Err(err). Msg("Rejecting invalid hostname update from hostinfo") + return } @@ -670,6 +672,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node { func (nodes Nodes) DebugString() string { var sb strings.Builder sb.WriteString("Nodes:\n") + for _, node := range nodes { sb.WriteString(node.DebugString()) sb.WriteString("\n") diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 2ce02f02..3b3e59e2 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error { if pak.Expiration != nil { return *pak.Expiration } + return time.Time{} }()). Time("now", time.Now()). diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 27aff519..dbcf4f44 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -40,9 +40,11 @@ var TaggedDevices = User{ func (u Users) String() string { var sb strings.Builder sb.WriteString("[ ") + for _, user := range u { fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) } + sb.WriteString(" ]") return sb.String() @@ -89,6 +91,7 @@ func (u *User) StringID() string { if u == nil { return "" } + return strconv.FormatUint(uint64(u.ID), 10) } @@ -203,6 +206,7 @@ type FlexibleBoolean bool func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { var val any + err := json.Unmarshal(data, &val) if err != nil { return fmt.Errorf("could not unmarshal data: %w", err) @@ -216,6 +220,7 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { if err != nil { return fmt.Errorf("could not parse %s as boolean: %w", v, err) } + *bit = FlexibleBoolean(pv) default: @@ -253,9 +258,11 @@ func (c *OIDCClaims) Identifier() string { if c.Iss == "" && c.Sub == "" { return "" } + if c.Iss == "" { return CleanIdentifier(c.Sub) } + if c.Sub == "" { return CleanIdentifier(c.Iss) } @@ -340,6 +347,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, trimmed) } } + if len(cleanParts) == 0 { return "" } @@ -382,6 +390,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) { if claims.Iss == "" && !strings.HasPrefix(identifier, "/") { identifier = "/" + identifier } + u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true} u.DisplayName = claims.Name u.ProfilePicURL = claims.ProfilePictureURL diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index 15386553..acd88434 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -70,6 +70,7 @@ func TestUnmarshallOIDCClaims(t *testing.T) { t.Errorf("UnmarshallOIDCClaims() error = %v", err) return } + if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff) } @@ -190,6 +191,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) { } result := claims.Identifier() assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("Identifier() mismatch (-want +got):\n%s", diff) } @@ -282,6 +284,7 @@ func TestCleanIdentifier(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := CleanIdentifier(tt.identifier) assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff) } @@ -487,6 +490,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { var user User user.FromClaim(&got, tt.emailVerifiedRequired) + if diff := cmp.Diff(user, tt.want); diff != "" { t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index b492e4d6..4f9a338f 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) { t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if !tt.wantErr && got != tt.want { t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want) } @@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) { t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr && tt.errorContains != "" { if err == nil || !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains) diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go index 098f1979..7d9cdbdf 100644 --- a/hscontrol/util/prompt.go +++ b/hscontrol/util/prompt.go @@ -15,10 +15,12 @@ func YesNo(msg string) bool { var resp string fmt.Scanln(&resp) + resp = strings.ToLower(resp) switch resp { case "y", "yes", "sure": return true } + return false } diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go index d726ec60..fbed2ff8 100644 --- a/hscontrol/util/prompt_test.go +++ b/hscontrol/util/prompt_test.go @@ -86,6 +86,7 @@ func TestYesNo(t *testing.T) { // Write test input go func() { defer w.Close() + w.WriteString(tt.input) }() @@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check the result @@ -108,6 +110,7 @@ func TestYesNo(t *testing.T) { stderrR.Close() expectedPrompt := "Test question [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -130,6 +133,7 @@ func TestYesNoPromptMessage(t *testing.T) { // Write test input go func() { defer w.Close() + w.WriteString("n\n") }() @@ -140,6 +144,7 @@ func TestYesNoPromptMessage(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check that the custom message was included in the prompt @@ -148,6 +153,7 @@ func TestYesNoPromptMessage(t *testing.T) { stderrR.Close() expectedPrompt := customMessage + " [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -186,6 +192,7 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Write test input go func() { defer w.Close() + w.WriteString(tc.input) }() @@ -195,6 +202,7 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Drain stderr diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index d1d7ece7..0a37ec87 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -33,6 +33,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) { b, err := GenerateRandomBytes(n) uenc := base64.RawURLEncoding.EncodeToString(b) + return uenc[:n], err } @@ -99,6 +100,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string { DstIPs: %v } `, rule.SrcIPs, rule.DstPorts)) + if index < len(rules)-1 { sb.WriteString(", ") } diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 4d828d02..53189656 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -30,6 +30,7 @@ func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { // It returns an error if not exactly one URL is found. func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { lines := strings.Split(output, "\n") + var urlStr string for _, line := range lines { @@ -38,6 +39,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { if urlStr != "" { return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) } + urlStr = line } } @@ -94,6 +96,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`) + headerMatches := headerRegex.FindStringSubmatch(lines[0]) if len(headerMatches) < 2 { return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) @@ -105,6 +108,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if ipStr == "" { ipStr = headerMatches[3] } + ip, err := netip.ParseAddr(ipStr) if err != nil { return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) @@ -144,13 +148,17 @@ func ParseTraceroute(output string) (Traceroute, error) { } remainder := strings.TrimSpace(matches[2]) - var hopHostname string - var hopIP netip.Addr - var latencies []time.Duration + + var ( + hopHostname string + hopIP netip.Addr + latencies []time.Duration + ) // Check for Windows tracert format which has latencies before hostname // Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]" latencyFirst := false + if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") { // Check if latencies appear before any hostname/IP firstSpace := strings.Index(remainder, " ") @@ -171,12 +179,14 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Extract and remove the latency from the beginning latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues duration := time.Duration(ms * float64(time.Millisecond)) latencies = append(latencies, duration.Round(time.Microsecond)) } + remainder = strings.TrimSpace(remainder[latMatch[1]:]) } } @@ -205,6 +215,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if ip, err := netip.ParseAddr(parts[0]); err == nil { hopIP = ip } + remainder = strings.TrimSpace(strings.Join(parts[1:], " ")) } } @@ -216,6 +227,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if len(match) > 1 { // Remove '<' prefix if present (e.g., "<1 ms") latStr := strings.TrimPrefix(match[1], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues @@ -280,11 +292,13 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri if key == "" { return "unknown-node" } + keyPrefix := key if len(key) > 8 { keyPrefix = key[:8] } - return fmt.Sprintf("node-%s", keyPrefix) + + return "node-" + keyPrefix } lowercased := strings.ToLower(hostinfo.Hostname) diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 33f27b7a..a064a852 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -180,6 +180,7 @@ Success.`, if err != nil { t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr) } + if gotURL.String() != tt.wantURL { t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL) } @@ -1066,6 +1067,7 @@ func TestEnsureHostname(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { @@ -1103,9 +1105,11 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "test-node" { t.Errorf("hostname = %v, want test-node", hi.Hostname) } + if hi.OS != "linux" { t.Errorf("OS = %v, want linux", hi.OS) } @@ -1147,6 +1151,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "node-nkey1234" { t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname) } @@ -1162,6 +1167,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1179,6 +1185,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1200,18 +1207,23 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "test" { t.Errorf("hostname = %v, want test", hi.Hostname) } + if hi.OS != "windows" { t.Errorf("OS = %v, want windows", hi.OS) } + if hi.OSVersion != "10.0.19044" { t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion) } + if hi.DeviceModel != "test-device" { t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel) } + if hi.BackendLogID != "log123" { t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID) } @@ -1229,6 +1241,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if len(hi.Hostname) != 63 { t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) } @@ -1239,6 +1252,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { @@ -1265,6 +1279,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { for i, hostname := range testCases { t.Run(cmp.Diff("", ""), func(t *testing.T) { hostinfo := &tailcfg.Hostinfo{Hostname: hostname} + result := EnsureHostname(hostinfo, "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 223e4c8b..825f3d17 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -35,6 +35,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -46,6 +47,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Create an API key using the CLI var validAPIKey string + assert.EventuallyWithT(t, func(ct *assert.CollectT) { apiKeyOutput, err := headscale.Execute( []string{ @@ -63,7 +65,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Get the API endpoint endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" // Create HTTP client client := &http.Client{ @@ -81,6 +83,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -99,6 +102,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should NOT contain user data after "Unauthorized" // This is the security bypass - if users array is present, auth was bypassed var jsonCheck map[string]any + jsonErr := json.Unmarshal(body, &jsonCheck) // If we can unmarshal JSON and it contains "users", that's the bypass @@ -132,6 +136,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -165,6 +170,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -193,10 +199,11 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Expected: Should return 200 with user data (this is the authorized case) req, err := http.NewRequest("GET", apiURL, nil) require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey)) + req.Header.Set("Authorization", "Bearer "+validAPIKey) resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -208,16 +215,19 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should be able to parse as protobuf JSON var response v1.ListUsersResponse + err = protojson.Unmarshal(body, &response) assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key") // Should contain our test users users := response.GetUsers() assert.Len(t, users, 3, "Should have 3 users") + userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "user1") assert.Contains(t, userNames, "user2") assert.Contains(t, userNames, "user3") @@ -234,6 +244,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -254,10 +265,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" t.Run("Curl_NoAuth", func(t *testing.T) { // Execute curl from inside the headscale container without auth @@ -274,16 +286,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Parse the output lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + var ( + httpCode string + responseBody string + ) + + var responseBodySb295 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb295.WriteString(line) } } + responseBody += responseBodySb295.String() // Should return 401 assert.Equal(t, "401", httpCode, @@ -320,16 +337,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + var ( + httpCode string + responseBody string + ) + + var responseBodySb344 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb344.WriteString(line) } } + responseBody += responseBodySb344.String() assert.Equal(t, "401", httpCode) assert.Contains(t, responseBody, "Unauthorized") @@ -346,7 +368,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { "curl", "-s", "-H", - fmt.Sprintf("Authorization: Bearer %s", validAPIKey), + "Authorization: Bearer " + validAPIKey, "-w", "\nHTTP_CODE:%{http_code}", apiURL, @@ -355,8 +377,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody strings.Builder + + var ( + httpCode string + responseBody strings.Builder + ) for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { @@ -372,8 +397,10 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Should contain user data var response v1.ListUsersResponse + err = protojson.Unmarshal([]byte(responseBody.String()), &response) assert.NoError(t, err, "Response should be valid protobuf JSON") + users := response.GetUsers() assert.Len(t, users, 2, "Should have 2 users") }) @@ -391,6 +418,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -420,11 +448,12 @@ func TestGRPCAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) // Get the gRPC endpoint // For gRPC, we need to use the hostname and port 50443 - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" t.Run("gRPC_NoAPIKey", func(t *testing.T) { // Test 1: Try to use CLI without API key (should fail) @@ -487,6 +516,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) assert.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") @@ -495,6 +525,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "grpcuser1") assert.Contains(t, userNames, "grpcuser2") }) @@ -513,6 +544,7 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -540,9 +572,10 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" // Create a config file for testing configWithoutKey := fmt.Sprintf(` @@ -643,6 +676,7 @@ cli: // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) assert.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") @@ -651,6 +685,7 @@ cli: for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "cliuser1") assert.Contains(t, userNames, "cliuser2") }) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 9cf352bb..47c55a37 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -30,6 +30,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -68,18 +69,24 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -110,6 +117,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -147,6 +155,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after relogin") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -200,6 +209,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, nodeCountBeforeLogout) @@ -254,10 +264,14 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -300,9 +314,11 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -322,15 +338,18 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created // for the new user. The original nodes remain with the original user. var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() @@ -351,6 +370,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -376,11 +396,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } @@ -394,10 +416,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index c1d066f8..18c5c3a9 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -149,6 +149,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -176,6 +177,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { syncCompleteTime := time.Now() err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) + loginDuration := time.Since(syncCompleteTime) t.Logf("Login and sync completed in %v", loginDuration) @@ -207,6 +209,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check each client's status individually to provide better diagnostics expiredCount := 0 + for _, client := range allClients { status, err := client.Status() if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) { @@ -356,6 +359,7 @@ func TestOIDC024UserCreation(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -413,6 +417,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -470,6 +475,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { oidcMockUser("user1", true), }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -508,6 +514,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -528,9 +535,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) @@ -538,14 +548,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -583,6 +598,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users after user2 login") assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -638,10 +654,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user2's node should be active after user switch var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { if node.GetUser().GetId() == 2 { // user2 activeUser2NodeID = types.NodeID(node.GetId()) t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -655,6 +673,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user2 node is online if node, exists := nodeStore[activeUser2NodeID]; exists { assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User2 node should be online after login") } @@ -747,6 +766,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -816,10 +836,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user1's node should be active after relogin var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { if node.GetUser().GetId() == 1 { // user1 activeUser1NodeID = types.NodeID(node.GetId()) t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -833,6 +855,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after relogin") } @@ -907,6 +930,7 @@ func TestOIDCFollowUpUrl(t *testing.T) { time.Sleep(2 * time.Minute) var newUrl *url.URL + assert.EventuallyWithT(t, func(c *assert.CollectT) { st, err := ts.Status() assert.NoError(c, err) @@ -1103,6 +1127,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { oidcMockUser("user1", true), // Relogin with same user }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -1142,6 +1167,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1162,9 +1188,12 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var initialNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + initialNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes)) @@ -1172,14 +1201,19 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -1236,6 +1270,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1256,6 +1291,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") var finalNodes []*v1.Node + t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { finalNodes, err = headscale.ListNodes() @@ -1279,6 +1315,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Security validation: user1's node should be active after relogin activeUser1NodeID := types.NodeID(finalNodes[0].GetId()) + t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat)) require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() @@ -1287,6 +1324,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin") } @@ -1356,6 +1394,7 @@ func TestOIDCExpiryAfterRestart(t *testing.T) { // Verify initial expiry is set var initialExpiry time.Time + assert.EventuallyWithT(t, func(ct *assert.CollectT) { nodes, err := headscale.ListNodes() assert.NoError(ct, err) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 5dd546f3..a102b493 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -67,6 +67,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -106,13 +107,16 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -152,6 +156,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web flow logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -226,6 +231,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -256,13 +262,16 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after initial web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -313,9 +322,11 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { t.Logf("all clients logged back in as user1") var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -333,15 +344,18 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Validate that user2's old nodes still exist in database (but are expired/offline) // When CLI registration creates new nodes for user1, user2's old nodes remain var user2Nodes []*v1.Node + t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1") t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 60260bb1..d2aec30f 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -25,6 +25,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Generate random hostname for the headscale instance hash, err := util.GenerateRandomStringDNSSafe(6) require.NoError(t, err) + testName := "derpverify" hostname := fmt.Sprintf("hs-%s-%s", testName, hash) @@ -40,6 +41,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -107,6 +109,7 @@ func DERPVerify( if err := c.Connect(t.Context()); err != nil { result = fmt.Errorf("client Connect: %w", err) } + if m, err := c.Recv(); err != nil { result = fmt.Errorf("client first Recv: %w", err) } else if v, ok := m.(derp.ServerInfoMessage); !ok { diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index c0c57a3e..88b2712c 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -34,6 +34,7 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { if opts.Labels == nil { opts.Labels = make(map[string]string) } + opts.Labels["hi.run-id"] = runID opts.Labels["hi.test-type"] = testType } diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index b09e0d40..4a172471 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -41,6 +41,7 @@ type buffer struct { func (b *buffer) Write(p []byte) (n int, err error) { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.Write(p) } @@ -49,6 +50,7 @@ func (b *buffer) Write(p []byte) (n int, err error) { func (b *buffer) String() string { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.String() } diff --git a/integration/dockertestutil/logs.go b/integration/dockertestutil/logs.go index 7d104e43..d5911ca7 100644 --- a/integration/dockertestutil/logs.go +++ b/integration/dockertestutil/logs.go @@ -47,6 +47,7 @@ func SaveLog( } var stdout, stderr bytes.Buffer + err = WriteLog(pool, resource, &stdout, &stderr) if err != nil { return "", "", err diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 42483247..d07841f1 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -18,6 +18,7 @@ func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Ne if err != nil { return nil, fmt.Errorf("looking up network names: %w", err) } + if len(networks) == 0 { if _, err := pool.CreateNetwork(name); err == nil { // Create does not give us an updated version of the resource, so we need to @@ -90,6 +91,7 @@ func RandomFreeHostPort() (int, error) { // CleanUnreferencedNetworks removes networks that are not referenced by any containers. func CleanUnreferencedNetworks(pool *dockertest.Pool) error { filter := "name=hs-" + networks, err := pool.NetworksByName(filter) if err != nil { return fmt.Errorf("getting networks by filter %q: %w", filter, err) @@ -122,6 +124,7 @@ func CleanImagesInCI(pool *dockertest.Pool) error { } removedCount := 0 + for _, image := range images { // Only remove dangling (untagged) images to avoid forcing rebuilds // Dangling images have no RepoTags or only have ":" diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index d8a77575..344d93f7 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -159,10 +159,12 @@ func New( } else { hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) } + tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname) if err != nil { return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err) } + dsic := &DERPServerInContainer{ version: version, hostname: hostname, @@ -185,6 +187,7 @@ func New( fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort) fmt.Fprintf(&cmdArgs, " --stun=true") fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort) + if dsic.withVerifyClientURL != "" { fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL) } @@ -214,11 +217,13 @@ func New( } var container *dockertest.Resource + buildOptions := &dockertest.BuildOptions{ Dockerfile: "Dockerfile.derper", ContextDir: dockerContextPath, BuildArgs: []docker.BuildArg{}, } + switch version { case "head": buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ @@ -249,6 +254,7 @@ func New( err, ) } + log.Printf("Created %s container\n", hostname) dsic.container = container @@ -259,12 +265,14 @@ func New( return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } } + if len(dsic.tlsCert) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert) if err != nil { return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } } + if len(dsic.tlsKey) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey) if err != nil { diff --git a/integration/helpers.go b/integration/helpers.go index 5acf4729..4a00342c 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -3,6 +3,7 @@ package integration import ( "bufio" "bytes" + "errors" "fmt" "io" "net/netip" @@ -47,7 +48,7 @@ const ( TimestampFormatRunID = "20060102-150405" ) -// NodeSystemStatus represents the status of a node across different systems +// NodeSystemStatus represents the status of a node across different systems. type NodeSystemStatus struct { Batcher bool BatcherConnCount int @@ -104,7 +105,7 @@ func requireNoErrLogout(t *testing.T, err error) { require.NoError(t, err, "failed to log out tailscale nodes") } -// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes. func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { t.Helper() @@ -113,8 +114,10 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod status := client.MustStatus() nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) } + return expectedNodes } @@ -148,15 +151,17 @@ func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNode } // requireAllClientsOnline validates that all nodes are online/offline across all headscale systems -// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems. func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() startTime := time.Now() + stateStr := "offline" if expectedOnline { stateStr = "online" } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) if expectedOnline { @@ -171,15 +176,17 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) } -// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state. func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { // Get batcher state debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } @@ -187,6 +194,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get map responses mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -194,6 +202,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -264,6 +273,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if id == nodeID { continue // Skip self-references } + expectedPeerMaps++ if online, exists := peerMap[nodeID]; exists && online { @@ -278,6 +288,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer } } } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") // Update status with map response data @@ -301,10 +312,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Verify all systems show nodes in expected state and report failures allMatch := true + var failureReport strings.Builder ids := types.NodeIDs(maps.Keys(nodeStatus)) slices.Sort(ids) + for _, nodeID := range ids { status := nodeStatus[nodeID] systemsMatch := (status.Batcher == expectedOnline) && @@ -313,10 +326,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if !systemsMatch { allMatch = false + stateStr := "offline" if expectedOnline { stateStr = "online" } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) @@ -331,6 +346,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer t.Logf("Previous report:\n%s", prevReport) t.Logf("Current report:\n%s", failureReport.String()) t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() } @@ -344,11 +360,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if expectedOnline { stateStr = "online" } + assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) }, timeout, 2*time.Second, message) } -// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components. func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { t.Helper() @@ -357,18 +374,22 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } allBatcherOffline := true + for _, nodeID := range expectedNodes { nodeIDStr := fmt.Sprintf("%d", nodeID) if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) } } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") }, 15*time.Second, 1*time.Second, "batcher disconnection validation") @@ -377,20 +398,24 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } allNodeStoreOffline := true + for _, nodeID := range expectedNodes { if node, exists := nodeStore[nodeID]; exists { isOnline := node.IsOnline != nil && *node.IsOnline if isOnline { allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) } } } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") }, 20*time.Second, 1*time.Second, "nodestore offline validation") @@ -399,6 +424,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -411,6 +437,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec for nodeID := range onlineMap { if slices.Contains(expectedNodes, nodeID) { allMapResponsesOffline = false + assert.False(c, true, "Node %d should not appear in map responses", nodeID) } } @@ -421,13 +448,16 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec if id == nodeID { continue // Skip self-references } + if online, exists := peerMap[nodeID]; exists && online { allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) } } } } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") }, 60*time.Second, 2*time.Second, "map response propagation validation") @@ -447,6 +477,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -461,12 +492,14 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe for _, nodeID := range expectedNodes { node, exists := nodeStore[nodeID] assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { continue } // Validate that the node has Hostinfo assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -474,6 +507,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has NetInfo assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -524,6 +558,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { // Returns the total number of successful ping operations. func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { t.Helper() + success := 0 for _, client := range clients { @@ -545,6 +580,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts // for validating NAT traversal and relay functionality. Returns success count. func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { t.Helper() + success := 0 for _, client := range clients { @@ -602,9 +638,12 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) { for _, client := range clients { wg.Add(1) + c := client // Avoid loop pointer + go func() { defer wg.Done() + assertValidStatus(t, c) assertValidNetcheck(t, c) assertValidNetmap(t, c) @@ -635,6 +674,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { assert.NoError(c, err, "getting netmap for %q", client.Hostname()) assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) } @@ -653,6 +693,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) @@ -681,6 +722,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { // and network map presence. This test is not suitable for ACL/partial connection tests. func assertValidStatus(t *testing.T, client TailscaleClient) { t.Helper() + status, err := client.Status(true) if err != nil { t.Fatalf("getting status for %q: %s", client.Hostname(), err) @@ -738,6 +780,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) { // which is essential for NAT traversal and connectivity in restricted networks. func assertValidNetcheck(t *testing.T, client TailscaleClient) { t.Helper() + report, err := client.Netcheck() if err != nil { t.Fatalf("getting status for %q: %s", client.Hostname(), err) @@ -792,6 +835,7 @@ func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { t.Helper() buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) if err != nil { t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) @@ -815,6 +859,7 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) scanner := bufio.NewScanner(in) { const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) scanner.Buffer(buff, len(buff)) scanner.Split(bufio.ScanLines) @@ -941,17 +986,20 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) { for _, client := range updated { isOriginal := false + for _, origClient := range original { if client.Hostname() == origClient.Hostname() { isOriginal = true break } } + if !isOriginal { return client, nil } } - return nil, fmt.Errorf("no new client found") + + return nil, errors.New("no new client found") } // AddAndLoginClient adds a new tailscale client to a user and logs it in. @@ -959,7 +1007,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) // 1. Creating a new node // 2. Finding the new node in the client list // 3. Getting the user to create a preauth key -// 4. Logging in the new node +// 4. Logging in the new node. func (s *Scenario) AddAndLoginClient( t *testing.T, username string, @@ -1037,5 +1085,6 @@ func (s *Scenario) MustAddAndLoginClient( client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...) require.NoError(t, err) + return client } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 202f2014..a08ee7af 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -725,12 +725,14 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { // Find the top-level directory to strip var topLevelDir string + firstPass := tar.NewReader(bytes.NewReader(tarData)) for { header, err := firstPass.Next() if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -747,6 +749,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -794,6 +797,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { outFile.Close() return fmt.Errorf("failed to copy file contents: %w", err) } + outFile.Close() // Set file permissions @@ -844,10 +848,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if the database file exists and has a schema dbPath := "/tmp/integration_test_db.sqlite3" + fileInfo, err := t.Execute([]string{"ls", "-la", dbPath}) if err != nil { return fmt.Errorf("database file does not exist at %s: %w", dbPath, err) } + log.Printf("Database file info: %s", fileInfo) // Check if the database has any tables (schema) @@ -872,6 +878,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -886,6 +893,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Extract the first regular file we find if header.Typeflag == tar.TypeReg { dbPath := path.Join(savePath, t.hostname+".db") + outFile, err := os.Create(dbPath) if err != nil { return fmt.Errorf("failed to create database file: %w", err) @@ -893,6 +901,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { written, err := io.Copy(outFile, tarReader) outFile.Close() + if err != nil { return fmt.Errorf("failed to copy database file: %w", err) } @@ -1059,6 +1068,7 @@ func (t *HeadscaleInContainer) CreateUser( } var u v1.User + err = json.Unmarshal([]byte(result), &u) if err != nil { return nil, fmt.Errorf("failed to unmarshal user: %w", err) @@ -1195,6 +1205,7 @@ func (t *HeadscaleInContainer) ListNodes( users ...string, ) ([]*v1.Node, error) { var ret []*v1.Node + execUnmarshal := func(command []string) error { result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -1206,6 +1217,7 @@ func (t *HeadscaleInContainer) ListNodes( } var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) if err != nil { return fmt.Errorf("failed to unmarshal nodes: %w", err) @@ -1245,7 +1257,7 @@ func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error { "nodes", "delete", "--identifier", - fmt.Sprintf("%d", nodeID), + strconv.FormatUint(nodeID, 10), "--output", "json", "--force", @@ -1309,6 +1321,7 @@ func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { } var users []*v1.User + err = json.Unmarshal([]byte(result), &users) if err != nil { return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) @@ -1439,6 +1452,7 @@ func (h *HeadscaleInContainer) PID() (int, error) { if pidInt == 1 { continue } + pids = append(pids, pidInt) } @@ -1494,6 +1508,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( } var node *v1.Node + err = json.Unmarshal([]byte(result), &node) if err != nil { return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err) diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 4ddc7ae9..5604af32 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -28,6 +28,7 @@ func PeerSyncTimeout() time.Duration { if util.IsCI() { return 120 * time.Second } + return 60 * time.Second } @@ -205,6 +206,7 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type res := make(map[types.NodeID]map[types.NodeID]bool) for nid, mrs := range all { res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { for _, peer := range mr.Peers { if peer.Online != nil { @@ -225,5 +227,6 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type } } } + return res } diff --git a/integration/route_test.go b/integration/route_test.go index 6d0a1be2..b6fc8d85 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -48,6 +48,7 @@ func TestEnablingRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -90,6 +91,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -126,6 +128,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route approvals to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -148,9 +151,11 @@ func TestEnablingRoutes(t *testing.T) { assert.NotNil(c, peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) } } @@ -171,6 +176,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -270,6 +276,7 @@ func TestHASubnetRouterFailover(t *testing.T) { prefp, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) + pref := *prefp t.Logf("usernet1 prefix: %s", pref.String()) @@ -289,6 +296,7 @@ func TestHASubnetRouterFailover(t *testing.T) { slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { statusA := a.MustStatus() statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) @@ -308,6 +316,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String()) t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String()) t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time") + for _, client := range allClients[:3] { command := []string{ "tailscale", @@ -323,6 +332,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Wait for route configuration changes after advertising routes var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -362,10 +372,12 @@ func TestHASubnetRouterFailover(t *testing.T) { checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) { if t.Failed() { t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat)) + status, err := client.Status() if err == nil { printCurrentRouteMap(t, xmaps.Values(status.Peer)...) } + t.FailNow() } } @@ -384,6 +396,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String()) t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes") t.Logf(" Expected: Client can access webservice through router 1 only") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -454,10 +467,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1") @@ -481,6 +496,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)") t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter2.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -492,6 +508,7 @@ func TestHASubnetRouterFailover(t *testing.T) { nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 6) + if len(nodes) >= 3 { requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) @@ -567,10 +584,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode") @@ -596,6 +615,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY") t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -670,12 +690,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") var expectedIP netip.Addr + for _, ip := range ips { if ip.Is4() { expectedIP = ip break } } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") assertTracerouteViaIPWithCollect(c, tr, expectedIP) @@ -752,10 +774,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover") @@ -823,10 +847,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover") @@ -851,6 +877,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)") t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)") t.Logf(" Expected: HA is restored with 2 routers available") + err = subRouter1.Up() require.NoError(t, err) @@ -900,10 +927,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery") @@ -930,6 +959,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname()) t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname()) t.Logf(" Expected: Full HA restored with all 3 routers online") + err = subRouter2.Up() require.NoError(t, err) @@ -980,10 +1010,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery") @@ -1065,10 +1097,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable") @@ -1151,10 +1185,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable") @@ -1180,6 +1216,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname()) t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname()) t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY") + r1Node := MustFindNode(subRouter1.Hostname(), nodes) _, err = headscale.ApproveRoutes( r1Node.GetId(), @@ -1235,10 +1272,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable") @@ -1264,6 +1303,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname()) t.Logf(" Expected: Routers 1 & 3 are both STANDBY") t.Logf(" Expected: Full HA restored with all 3 routers available") + r3Node := MustFindNode(subRouter3.Hostname(), nodes) _, err = headscale.ApproveRoutes( r3Node.GetId(), @@ -1313,6 +1353,7 @@ func TestSubnetRouteACL(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1360,6 +1401,7 @@ func TestSubnetRouteACL(t *testing.T) { slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { statusA := a.MustStatus() statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) @@ -1389,15 +1431,20 @@ func TestSubnetRouteACL(t *testing.T) { // Wait for route advertisements to propagate to the server var nodes []*v1.Node + require.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) // Find the node that should have the route by checking node IDs - var routeNode *v1.Node - var otherNode *v1.Node + var ( + routeNode *v1.Node + otherNode *v1.Node + ) + for _, node := range nodes { nodeIDStr := strconv.FormatUint(node.GetId(), 10) if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { @@ -1460,6 +1507,7 @@ func TestSubnetRouteACL(t *testing.T) { srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + if srs1PeerStatus == nil { return } @@ -1570,6 +1618,7 @@ func TestEnablingExitRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) @@ -1591,8 +1640,10 @@ func TestEnablingExitRoutes(t *testing.T) { requireNoErrSync(t, err) var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1650,6 +1701,7 @@ func TestEnablingExitRoutes(t *testing.T) { peerStatus := status.Peer[peerKey] assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) @@ -1680,6 +1732,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1710,10 +1763,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1730,6 +1785,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1760,6 +1816,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1777,6 +1834,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) } }, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") @@ -1803,10 +1861,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := user2c.Traceroute(webip) assert.NoError(c, err) + ip, err := user1c.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for user1c") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router") } @@ -1827,6 +1887,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1854,10 +1915,12 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1874,6 +1937,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1956,6 +2020,7 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { return node } } + panic("node not found") } @@ -2239,10 +2304,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } scenario, err := NewScenario(tt.spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) var nodes []*v1.Node + opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -2298,6 +2365,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the Docker network route to the auto-approvers // Keep existing auto-approvers (like bigRoute) in place var approvers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): approvers = append(approvers, tagApprover(tt.approver)) @@ -2366,6 +2434,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } else { pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) } + require.NoError(t, err) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) @@ -2404,6 +2473,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { statusA := a.MustStatus() statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) @@ -2456,11 +2526,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) routerPeerFound := false + for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] if peerStatus.ID == routerUsernet1ID.StableID() { routerPeerFound = true + t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", peerStatus.HostName, peerStatus.ID, @@ -2468,9 +2540,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2507,10 +2581,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router") @@ -2547,9 +2623,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2569,10 +2647,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change") @@ -2606,6 +2686,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the route back to the auto approver in the policy, the route should // now become available again. var newApprovers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): newApprovers = append(newApprovers, tagApprover(tt.approver)) @@ -2639,9 +2720,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2661,10 +2744,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval") @@ -2700,11 +2785,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "2" { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2742,9 +2829,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2782,6 +2871,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "3" { requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) @@ -2816,10 +2906,12 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int { func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) { t.Logf("== Current routing map ==") slices.SortFunc(routers, SortPeerStatus) + for _, router := range routers { got := filterNonRoutes(router) t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID) t.Logf(" AllowedIPs: %v", got) + if router.PrimaryRoutes != nil { t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice()) } @@ -2832,6 +2924,7 @@ func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix { if tsaddr.IsExitRoute(p) { return true } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) }) } @@ -2883,6 +2976,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -3023,6 +3117,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { // List nodes and verify the router has 3 available routes var err error + nodes, err := headscale.NodesByUser() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -3058,10 +3153,12 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := nodeClient.Traceroute(webip) assert.NoError(c, err) + ip, err := routerClient.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerClient") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router") } diff --git a/integration/scenario.go b/integration/scenario.go index 35fee73e..0108a1de 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -191,9 +191,11 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { } var userToNetwork map[string]*dockertest.Network + if spec.Networks != nil || len(spec.Networks) != 0 { for name, users := range s.spec.Networks { networkName := testHashPrefix + "-" + name + network, err := s.AddNetwork(networkName) if err != nil { return nil, err @@ -203,6 +205,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if n2, ok := userToNetwork[user]; ok { return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) } + mak.Set(&userToNetwork, user, network) } } @@ -219,6 +222,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if err != nil { return nil, err } + mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc)) } } @@ -230,6 +234,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if spec.OIDCAccessTTL != 0 { ttl = spec.OIDCAccessTTL } + err = s.runMockOIDC(ttl, spec.OIDCUsers) if err != nil { return nil, err @@ -268,6 +273,7 @@ func (s *Scenario) Networks() []*dockertest.Network { if len(s.networks) == 0 { panic("Scenario.Networks called with empty network list") } + return xmaps.Values(s.networks) } @@ -337,6 +343,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { for userName, user := range s.users { for _, client := range user.Clients { log.Printf("removing client %s in user %s", client.Hostname(), userName) + stdoutPath, stderrPath, err := client.Shutdown() if err != nil { log.Printf("failed to tear down client: %s", err) @@ -353,6 +360,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { } } } + s.mu.Unlock() for _, derp := range s.derpServers { @@ -373,6 +381,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { if s.mockOIDC.r != nil { s.mockOIDC.r.Close() + if err := s.mockOIDC.r.Close(); err != nil { log.Printf("failed to tear down oidc server: %s", err) } @@ -552,6 +561,7 @@ func (s *Scenario) CreateTailscaleNode( s.mu.Lock() defer s.mu.Unlock() + opts = append(opts, tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), @@ -591,6 +601,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) error { if user, ok := s.users[userStr]; ok { var versions []string + for i := range count { version := requestedVersion if requestedVersion == "all" { @@ -749,10 +760,12 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat for _, client := range user.Clients { c := client expectedCount := expectedPeers + user.syncWaitGroup.Go(func() error { return c.WaitForPeers(expectedCount, timeout, retryInterval) }) } + if err := user.syncWaitGroup.Wait(); err != nil { allErrors = append(allErrors, err) } @@ -871,6 +884,7 @@ func (s *Scenario) createHeadscaleEnvWithTags( } else { key, err = s.CreatePreAuthKey(u.GetId(), true, false) } + if err != nil { return err } @@ -887,9 +901,11 @@ func (s *Scenario) createHeadscaleEnvWithTags( func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { tsc := client + user.joinWaitGroup.Go(func() error { loginURL, err := tsc.LoginWithURL(loginServer) if err != nil { @@ -945,6 +961,7 @@ func newDebugJar() (*debugJar, error) { if err != nil { return nil, err } + return &debugJar{ inner: jar, store: make(map[string]map[string]map[string]*http.Cookie), @@ -961,20 +978,25 @@ func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) { if c == nil || c.Name == "" { continue } + domain := c.Domain if domain == "" { domain = u.Hostname() } + path := c.Path if path == "" { path = "/" } + if _, ok := j.store[domain]; !ok { j.store[domain] = make(map[string]map[string]*http.Cookie) } + if _, ok := j.store[domain][path]; !ok { j.store[domain][path] = make(map[string]*http.Cookie) } + j.store[domain][path][c.Name] = copyCookie(c) } } @@ -989,8 +1011,10 @@ func (j *debugJar) Dump(w io.Writer) { for domain, paths := range j.store { fmt.Fprintf(w, "Domain: %s\n", domain) + for path, byName := range paths { fmt.Fprintf(w, " Path: %s\n", path) + for _, c := range byName { fmt.Fprintf( w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n", @@ -1054,7 +1078,9 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f } log.Printf("%s logging in with url: %s", hostname, loginURL.String()) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) if err != nil { return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err) @@ -1066,6 +1092,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f return http.ErrUseLastResponse } } + defer func() { hc.CheckRedirect = originalRedirect }() @@ -1080,6 +1107,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if err != nil { return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err) } + body := string(bodyBytes) var redirectURL *url.URL @@ -1126,6 +1154,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { if len(keySep) != 2 { return errParseAuthPage } + key := keySep[1] key = strings.SplitN(key, " ", 2)[0] log.Printf("registering node %s", key) @@ -1154,6 +1183,7 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error noTls := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint } + resp, err := noTls.RoundTrip(req) if err != nil { return nil, err @@ -1361,6 +1391,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse if err != nil { log.Fatalf("could not find an open port: %s", err) } + portNotation := fmt.Sprintf("%d/tcp", port) hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) @@ -1421,6 +1452,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse ipAddr := s.mockOIDC.r.GetIPInNetwork(network) log.Println("Waiting for headscale mock oidc to be ready for tests") + hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) if err := s.pool.Retry(func() error { @@ -1468,7 +1500,6 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { // log.Fatalf("could not find an open port: %s", err) // } // portNotation := fmt.Sprintf("%d/tcp", port) - hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := "hs-webservice-" + hash diff --git a/integration/scenario_test.go b/integration/scenario_test.go index 1e2a151a..71998fca 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -35,6 +35,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -83,6 +84,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) From 15d0efbf9d392d8d096584834472623cc88fdbed Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:42:02 +0000 Subject: [PATCH 07/30] all: update deprecated Docker and xsync APIs MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update deprecated Docker SDK types and functions: - errdefs.IsNotFound → cerrdefs.IsNotFound - errdefs.IsConflict → cerrdefs.IsConflict - cli.ImageInspectWithRaw → cli.ImageInspect - client.IsErrNotFound → cerrdefs.IsNotFound - event.ID → event.Actor.ID - types.Container → container.Summary - container.Stats → container.StatsResponse - xsync.MapOf/NewMapOf → xsync.Map/NewMap These updates align with the Docker SDK v28+ and xsync v4 API changes. --- cmd/hi/cleanup.go | 6 +++--- cmd/hi/docker.go | 5 +++-- cmd/hi/stats.go | 15 +++++++-------- integration/scenario.go | 4 ++-- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index e0268fd8..8dc57b4b 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -10,11 +10,11 @@ import ( "time" "github.com/cenkalti/backoff/v5" + cerrdefs "github.com/containerd/errdefs" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/image" "github.com/docker/docker/client" - "github.com/docker/docker/errdefs" ) // cleanupBeforeTest performs cleanup operations before running tests. @@ -309,9 +309,9 @@ func cleanCacheVolume(ctx context.Context) error { err = cli.VolumeRemove(ctx, volumeName, true) if err != nil { - if errdefs.IsNotFound(err) { + if cerrdefs.IsNotFound(err) { fmt.Printf("Go module cache volume not found: %s\n", volumeName) - } else if errdefs.IsConflict(err) { + } else if cerrdefs.IsConflict(err) { fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName) } else { fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 3ad70173..fbc2dba6 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -14,6 +14,7 @@ import ( "strings" "time" + cerrdefs "github.com/containerd/errdefs" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" @@ -502,9 +503,9 @@ func getDockerSocketPath() string { // checkImageAvailableLocally checks if the specified Docker image is available locally. func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) { - _, _, err := cli.ImageInspectWithRaw(ctx, imageName) + _, err := cli.ImageInspect(ctx, imageName) if err != nil { - if client.IsErrNotFound(err) { + if cerrdefs.IsNotFound(err) { return false, nil } diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index 1c17df84..e80ee8d1 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -12,7 +12,6 @@ import ( "sync" "time" - "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" @@ -153,13 +152,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, case event := <-events: if event.Type == "container" && event.Action == "start" { // Get container details - containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) + containerInfo, err := sc.client.ContainerInspect(ctx, event.Actor.ID) if err != nil { continue } - // Convert to types.Container format for consistency - cont := types.Container{ + // Convert to container.Summary format for consistency + cont := container.Summary{ ID: containerInfo.ID, Names: []string{containerInfo.Name}, Labels: containerInfo.Config.Labels, @@ -180,7 +179,7 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, } // shouldMonitorContainer determines if a container should be monitored. -func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { +func (sc *StatsCollector) shouldMonitorContainer(cont container.Summary, runID string) bool { // Check if it has the correct run ID label if cont.Labels == nil || cont.Labels["hi.run-id"] != runID { return false @@ -241,7 +240,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe decoder := json.NewDecoder(statsResponse.Body) - var prevStats *container.Stats + var prevStats *container.StatsResponse for { select { @@ -250,7 +249,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe case <-ctx.Done(): return default: - var stats container.Stats + var stats container.StatsResponse if err := decoder.Decode(&stats); err != nil { // EOF is expected when container stops or stream ends if err.Error() != "EOF" && verbose { @@ -299,7 +298,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe } // calculateCPUPercent calculates CPU usage percentage from Docker stats. -func calculateCPUPercent(prevStats, stats *container.Stats) float64 { +func calculateCPUPercent(prevStats, stats *container.StatsResponse) float64 { // CPU calculation based on Docker's implementation cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage) systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage) diff --git a/integration/scenario.go b/integration/scenario.go index 0108a1de..e209f6ea 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -96,7 +96,7 @@ type User struct { type Scenario struct { // TODO(kradalby): support multiple headcales for later, currently only // use one. - controlServers *xsync.MapOf[string, ControlServer] + controlServers *xsync.Map[string, ControlServer] derpServers []*dsic.DERPServerInContainer users map[string]*User @@ -180,7 +180,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { testHashPrefix := "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) s := &Scenario{ - controlServers: xsync.NewMapOf[string, ControlServer](), + controlServers: xsync.NewMap[string, ControlServer](), users: make(map[string]*User), pool: pool, From 52fc725cf1c5409bf6f1a356637e438ef700eb93 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:52:28 +0000 Subject: [PATCH 08/30] all: check unchecked error returns Fix errcheck and errchkjson lint issues by properly handling or explicitly discarding error return values. Changes include: - cmd/headscale/cli: Check MarkFlagRequired and AddMiddleware errors - hscontrol/db: Check VACUUM exec result - hscontrol/debug: Check Write errors in HTTP handlers - hscontrol/handlers: Add error logging for JSON encoding - hscontrol/mapper: Check AddNode errors in tests - hscontrol/platform_config: Check Write error - hscontrol/util/prompt: Check Scanln and Write errors - integration: Check json.Marshal and cleanup function errors --- cmd/headscale/cli/mockoidc.go | 2 +- cmd/headscale/cli/nodes.go | 4 +-- cmd/headscale/cli/users.go | 2 +- hscontrol/db/db.go | 2 +- hscontrol/debug.go | 38 ++++++++++++++-------------- hscontrol/handlers.go | 10 +++++--- hscontrol/mapper/batcher_lockfree.go | 2 +- hscontrol/mapper/batcher_test.go | 24 +++++++++--------- hscontrol/platform_config.go | 6 ++--- hscontrol/util/prompt.go | 2 +- hscontrol/util/prompt_test.go | 12 ++++----- integration/auth_web_flow_test.go | 5 +++- integration/dns_test.go | 15 +++++++---- integration/scenario.go | 10 ++++---- 14 files changed, 73 insertions(+), 61 deletions(-) diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index af28ce9f..c80c2a28 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -134,7 +134,7 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser ErrorQueue: &mockoidc.ErrorQueue{}, } - mock.AddMiddleware(func(h http.Handler) http.Handler { + _ = mock.AddMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Info().Msgf("Request: %+v", r) h.ServeHTTP(w, r) diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 882460dd..219f69c9 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -72,12 +72,12 @@ func init() { nodeCmd.AddCommand(deleteNodeCmd) tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - tagCmd.MarkFlagRequired("identifier") + _ = tagCmd.MarkFlagRequired("identifier") tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") nodeCmd.AddCommand(tagCmd) approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - approveRoutesCmd.MarkFlagRequired("identifier") + _ = approveRoutesCmd.MarkFlagRequired("identifier") approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) nodeCmd.AddCommand(approveRoutesCmd) diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 6e4bdd02..9f0954c6 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -51,7 +51,7 @@ func init() { userCmd.AddCommand(renameUserCmd) usernameAndIDFlag(renameUserCmd) renameUserCmd.Flags().StringP("new-name", "r", "", "New username") - renameNodeCmd.MarkFlagRequired("new-name") + _ = renameUserCmd.MarkFlagRequired("new-name") } var errMissingParameter = errors.New("missing parameters") diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 988675b9..1ef767ce 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -1035,7 +1035,7 @@ func (hsdb *HSDatabase) Close() error { } if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { - db.Exec("VACUUM") + _, _ = db.Exec("VACUUM") } return db.Close() diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 4fdcac11..93200b95 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -34,14 +34,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(overviewJSON) + _, _ = w.Write(overviewJSON) } else { // Default to text/plain for backward compatibility overview := h.state.DebugOverview() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(overview)) + _, _ = w.Write([]byte(overview)) } })) @@ -57,7 +57,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(configJSON) + _, _ = w.Write(configJSON) })) // Policy endpoint @@ -77,7 +77,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { } w.WriteHeader(http.StatusOK) - w.Write([]byte(policy)) + _, _ = w.Write([]byte(policy)) })) // Filter rules endpoint @@ -96,7 +96,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(filterJSON) + _, _ = w.Write(filterJSON) })) // SSH policies endpoint @@ -111,7 +111,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(sshJSON) + _, _ = w.Write(sshJSON) })) // DERP map endpoint @@ -131,14 +131,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(derpJSON) + _, _ = w.Write(derpJSON) } else { // Default to text/plain for backward compatibility derpInfo := h.state.DebugDERPMap() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(derpInfo)) + _, _ = w.Write([]byte(derpInfo)) } })) @@ -159,14 +159,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(nodeStoreJSON) + _, _ = w.Write(nodeStoreJSON) } else { // Default to text/plain for backward compatibility nodeStoreInfo := h.state.DebugNodeStore() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(nodeStoreInfo)) + _, _ = w.Write([]byte(nodeStoreInfo)) } })) @@ -182,7 +182,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(cacheJSON) + _, _ = w.Write(cacheJSON) })) // Routes endpoint @@ -202,14 +202,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(routesJSON) + _, _ = w.Write(routesJSON) } else { // Default to text/plain for backward compatibility routes := h.state.DebugRoutesString() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(routes)) + _, _ = w.Write([]byte(routes)) } })) @@ -230,14 +230,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(policyManagerJSON) + _, _ = w.Write(policyManagerJSON) } else { // Default to text/plain for backward compatibility policyManagerInfo := h.state.DebugPolicyManager() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(policyManagerInfo)) + _, _ = w.Write([]byte(policyManagerInfo)) } })) @@ -250,7 +250,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { if res == nil { w.WriteHeader(http.StatusOK) - w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + _, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) return } @@ -263,7 +263,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(resJSON) + _, _ = w.Write(resJSON) })) // Batcher endpoint @@ -283,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(batcherJSON) + _, _ = w.Write(batcherJSON) } else { // Default to text/plain for backward compatibility batcherInfo := h.debugBatcher() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(batcherInfo)) + _, _ = w.Write([]byte(batcherInfo)) } })) diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 7ec26994..ef214536 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -154,7 +154,9 @@ func (h *Headscale) KeyHandler( } writer.Header().Set("Content-Type", "application/json") - json.NewEncoder(writer).Encode(resp) + if err := json.NewEncoder(writer).Encode(resp); err != nil { + log.Error().Err(err).Msg("failed to encode key response") + } return } @@ -179,7 +181,9 @@ func (h *Headscale) HealthHandler( res.Status = "fail" } - json.NewEncoder(writer).Encode(res) + if err := json.NewEncoder(writer).Encode(res); err != nil { + log.Error().Err(err).Msg("failed to encode health response") + } } err := h.state.PingDB(req.Context()) @@ -268,7 +272,7 @@ func (a *AuthProviderWeb) RegisterHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) + _, _ = writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) } func FaviconHandler(writer http.ResponseWriter, req *http.Request) { diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 918b7049..3ff3406b 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -526,7 +526,7 @@ type multiChannelNodeConn struct { // generateConnectionID generates a unique connection identifier. func generateConnectionID() string { bytes := make([]byte, 8) - rand.Read(bytes) + _, _ = rand.Read(bytes) return hex.EncodeToString(bytes) } diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 595fb252..0ff03404 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -549,7 +549,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) // Wait for connection to be established assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -658,7 +658,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { for i := range allNodes { node := &allNodes[i] - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullUpdate()) @@ -827,7 +827,7 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := testData.Nodes[1] // Test AddNode with real node ID - batcher.AddNode(tn.n.ID, tn.ch, 100) + _ = batcher.AddNode(tn.n.ID, tn.ch, 100) if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") @@ -848,7 +848,7 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) // Add the second node and verify update message - batcher.AddNode(tn2.n.ID, tn2.ch, 100) + _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. @@ -1053,7 +1053,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1157,7 +1157,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, 1) wg.Go(func() { - batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) }) // Add real work during connection chaos @@ -1170,7 +1170,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Go(func() { runtime.Gosched() // Yield to introduce timing variability - batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }) // Remove second connection @@ -1261,7 +1261,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) batcher.AddWork(change.DERPMap()) // Consumer goroutine to validate data and detect channel issues @@ -1384,7 +1384,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for _, node := range stableNodes { ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1458,7 +1458,7 @@ func TestBatcherConcurrentClients(t *testing.T) { churningChannelsMutex.Unlock() - batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { @@ -1771,7 +1771,7 @@ func XTestBatcherScalability(t *testing.T) { for i := range testNodes { node := &testNodes[i] - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true @@ -2149,7 +2149,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time and wait for each to be connected for i, node := range allNodes { - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Wait for node to be connected diff --git a/hscontrol/platform_config.go b/hscontrol/platform_config.go index 23c4d25d..c8cc3fd4 100644 --- a/hscontrol/platform_config.go +++ b/hscontrol/platform_config.go @@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage( ) { writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) + _, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) } // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. @@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage( ) { writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) + _, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) } func (h *Headscale) ApplePlatformConfig( @@ -98,7 +98,7 @@ func (h *Headscale) ApplePlatformConfig( writer.Header(). Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write(content.Bytes()) + _, _ = writer.Write(content.Bytes()) } type AppleMobileConfig struct { diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go index 7d9cdbdf..410f6c2e 100644 --- a/hscontrol/util/prompt.go +++ b/hscontrol/util/prompt.go @@ -14,7 +14,7 @@ func YesNo(msg string) bool { fmt.Fprint(os.Stderr, msg+" [y/n] ") var resp string - fmt.Scanln(&resp) + _, _ = fmt.Scanln(&resp) resp = strings.ToLower(resp) switch resp { diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go index fbed2ff8..c6fcb702 100644 --- a/hscontrol/util/prompt_test.go +++ b/hscontrol/util/prompt_test.go @@ -87,7 +87,7 @@ func TestYesNo(t *testing.T) { go func() { defer w.Close() - w.WriteString(tt.input) + _, _ = w.WriteString(tt.input) }() // Call the function @@ -106,7 +106,7 @@ func TestYesNo(t *testing.T) { // Check that the prompt was written to stderr var stderrBuf bytes.Buffer - io.Copy(&stderrBuf, stderrR) + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() expectedPrompt := "Test question [y/n] " @@ -134,7 +134,7 @@ func TestYesNoPromptMessage(t *testing.T) { go func() { defer w.Close() - w.WriteString("n\n") + _, _ = w.WriteString("n\n") }() // Call the function with a custom message @@ -149,7 +149,7 @@ func TestYesNoPromptMessage(t *testing.T) { // Check that the custom message was included in the prompt var stderrBuf bytes.Buffer - io.Copy(&stderrBuf, stderrR) + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() expectedPrompt := customMessage + " [y/n] " @@ -193,7 +193,7 @@ func TestYesNoCaseInsensitive(t *testing.T) { go func() { defer w.Close() - w.WriteString(tc.input) + _, _ = w.WriteString(tc.input) }() // Call the function @@ -206,7 +206,7 @@ func TestYesNoCaseInsensitive(t *testing.T) { stderrW.Close() // Drain stderr - io.Copy(io.Discard, stderrR) + _, _ = io.Copy(io.Discard, stderrR) stderrR.Close() if result != tc.expected { diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index a102b493..256d7e4d 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -308,7 +308,10 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Register all clients as user1 (this is where cross-user registration happens) // This simulates: headscale nodes register --user user1 --key - scenario.runHeadscaleRegister("user1", body) + err = scenario.runHeadscaleRegister("user1", body) + if err != nil { + t.Fatalf("failed to register client %s: %s", client.Hostname(), err) + } } // Wait for all clients to reach running state diff --git a/integration/dns_test.go b/integration/dns_test.go index e937a421..3432eb9b 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -93,7 +93,8 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { Value: "6.6.6.6", }, } - b, _ := json.Marshal(extraRecords) + b, err := json.Marshal(extraRecords) + require.NoError(t, err) err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithPackages("python3", "curl", "bind-tools"), @@ -133,13 +134,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { require.NoError(t, err) // Write the file directly into place from the docker API. - b0, _ := json.Marshal([]tailcfg.DNSRecord{ + b0, err := json.Marshal([]tailcfg.DNSRecord{ { Name: "docker.myvpn.example.com", Type: "A", Value: "2.2.2.2", }, }) + require.NoError(t, err) err = hs.WriteFile(erPath, b0) require.NoError(t, err) @@ -155,7 +157,8 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { Type: "A", Value: "7.7.7.7", }) - b2, _ := json.Marshal(extraRecords) + b2, err := json.Marshal(extraRecords) + require.NoError(t, err) err = hs.WriteFile(erPath+"2", b2) require.NoError(t, err) @@ -169,13 +172,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Write a new file and copy it to the path to ensure the reload // works when a file is copied into place. - b3, _ := json.Marshal([]tailcfg.DNSRecord{ + b3, err := json.Marshal([]tailcfg.DNSRecord{ { Name: "copy.myvpn.example.com", Type: "A", Value: "8.8.8.8", }, }) + require.NoError(t, err) err = hs.WriteFile(erPath+"3", b3) require.NoError(t, err) @@ -187,13 +191,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { } // Write in place to ensure pipe like behaviour works - b4, _ := json.Marshal([]tailcfg.DNSRecord{ + b4, err := json.Marshal([]tailcfg.DNSRecord{ { Name: "docker.myvpn.example.com", Type: "A", Value: "9.9.9.9", }, }) + require.NoError(t, err) command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath} _, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")}) require.NoError(t, err) diff --git a/integration/scenario.go b/integration/scenario.go index e209f6ea..0b388c0a 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -169,8 +169,8 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { // Opportunity to clean up unreferenced networks. // This might be a no op, but it is worth a try as we sometime // dont clean up nicely after ourselves. - dockertestutil.CleanUnreferencedNetworks(pool) - dockertestutil.CleanImagesInCI(pool) + _ = dockertestutil.CleanUnreferencedNetworks(pool) + _ = dockertestutil.CleanImagesInCI(pool) if spec.MaxWait == 0 { pool.MaxWait = dockertestMaxWait() @@ -314,8 +314,8 @@ func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { } func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { - defer dockertestutil.CleanUnreferencedNetworks(s.pool) - defer dockertestutil.CleanImagesInCI(s.pool) + defer func() { _ = dockertestutil.CleanUnreferencedNetworks(s.pool) }() + defer func() { _ = dockertestutil.CleanImagesInCI(s.pool) }() s.controlServers.Range(func(_ string, control ControlServer) bool { stdoutPath, stderrPath, err := control.Shutdown() @@ -920,7 +920,7 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { // If the URL is not a OIDC URL, then we need to // run the register command to fully log in the client. if !strings.Contains(loginURL.String(), "/oidc/") { - s.runHeadscaleRegister(userStr, body) + _ = s.runHeadscaleRegister(userStr, body) } return nil From 956bcb368042102530119670fdbf41dba7885c08 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:03:04 +0000 Subject: [PATCH 09/30] all: remove unused code and simplify function signatures Remove unused functions, constants, types, and parameters: - Remove unused const reservedResponseHeaderSize from batcher_test.go - Remove unused func getStats from batcher_test.go - Remove unused func fullMapResponse from mapper.go - Remove unused type mockState and methods from mapper_test.go - Remove unused func createTestNodeSimple from maprequest_test.go - Remove unused const batchSize from node_store_test.go - Remove unused func ptrTo from change.go - Remove unused funcs assertClientsState, assertValidNetmap, assertValidStatus, assertValidNetcheck, groupOwner from helpers.go Simplify function signatures by removing unused parameters/returns: - nodeRoutesToPtables: remove always-nil error return - parseUpdateAndAnalyze: remove always-nil error return - parseProtocol: remove unused bool return value - requireAllClientsOfflineStaged: mark unused params as _ - requireAllClientsNetInfoAndDERP: hardcode constant timeout - pingAllHelper: remove never-used opts variadic parameter - node helper in policy_test.go: mark unused hostinfo as _ - drainChannelTimeout: mark unused name parameter as _ --- cmd/headscale/cli/nodes.go | 11 +- hscontrol/mapper/batcher_test.go | 83 +++++-------- hscontrol/mapper/mapper.go | 23 ---- hscontrol/mapper/mapper_test.go | 91 --------------- hscontrol/policy/v2/filter.go | 4 +- hscontrol/policy/v2/policy_test.go | 3 +- hscontrol/policy/v2/types.go | 37 +++--- hscontrol/state/maprequest_test.go | 24 ---- hscontrol/state/node_store_test.go | 4 +- hscontrol/types/change/change.go | 7 -- integration/auth_key_test.go | 10 +- integration/helpers.go | 182 ++--------------------------- 12 files changed, 67 insertions(+), 412 deletions(-) diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 219f69c9..01f20ad0 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -233,10 +233,7 @@ var listNodeRoutesCmd = &cobra.Command{ return } - tableData, err := nodeRoutesToPtables(nodes) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - } + tableData := nodeRoutesToPtables(nodes) err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { @@ -601,9 +598,7 @@ func nodesToPtables( return tableData, nil } -func nodeRoutesToPtables( - nodes []*v1.Node, -) (pterm.TableData, error) { +func nodeRoutesToPtables(nodes []*v1.Node) pterm.TableData { tableHeader := []string{ "ID", "Hostname", @@ -627,7 +622,7 @@ func nodeRoutesToPtables( ) } - return tableData, nil + return tableData } var tagCmd = &cobra.Command{ diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 0ff03404..1dac3705 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -127,11 +127,9 @@ const ( // Channel configuration. NORMAL_BUFFER_SIZE = 50 - SMALL_BUFFER_SIZE = 3 - TINY_BUFFER_SIZE = 1 // For maximum contention - LARGE_BUFFER_SIZE = 200 - - reservedResponseHeaderSize = 4 + SMALL_BUFFER_SIZE = 3 + TINY_BUFFER_SIZE = 1 // For maximum contention + LARGE_BUFFER_SIZE = 200 ) // TestData contains all test entities created for a test scenario. @@ -319,23 +317,6 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) { stats.LastUpdate = time.Now() } -// getStats returns a copy of the statistics for a node. -func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { - ut.mu.RLock() - defer ut.mu.RUnlock() - - if stats, exists := ut.stats[nodeID]; exists { - // Return a copy to avoid race conditions - return UpdateStats{ - TotalUpdates: stats.TotalUpdates, - UpdateSizes: slices.Clone(stats.UpdateSizes), - LastUpdate: stats.LastUpdate, - } - } - - return UpdateStats{} -} - // getAllStats returns a copy of all statistics. func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats { ut.mu.RLock() @@ -387,16 +368,14 @@ type UpdateInfo struct { } // parseUpdateAndAnalyze parses an update and returns detailed information. -func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { - info := UpdateInfo{ +func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo { + return UpdateInfo{ PeerCount: len(resp.Peers), PatchCount: len(resp.PeersChangedPatch), IsFull: len(resp.Peers) > 0, IsPatch: len(resp.PeersChangedPatch) > 0, IsDERP: resp.DERPMap != nil, } - - return info, nil } // start begins consuming updates from the node's channel and tracking stats. @@ -418,36 +397,36 @@ func (n *node) start() { atomic.AddInt64(&n.updateCount, 1) // Parse update and track detailed stats - if info, err := parseUpdateAndAnalyze(data); err == nil { - // Track update types - if info.IsFull { - atomic.AddInt64(&n.fullCount, 1) - n.lastPeerCount.Store(int64(info.PeerCount)) - // Update max peers seen using compare-and-swap for thread safety - for { - current := n.maxPeersCount.Load() - if int64(info.PeerCount) <= current { - break - } + info := parseUpdateAndAnalyze(data) - if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) { - break - } + // Track update types + if info.IsFull { + atomic.AddInt64(&n.fullCount, 1) + n.lastPeerCount.Store(int64(info.PeerCount)) + // Update max peers seen using compare-and-swap for thread safety + for { + current := n.maxPeersCount.Load() + if int64(info.PeerCount) <= current { + break + } + + if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) { + break } } + } - if info.IsPatch { - atomic.AddInt64(&n.patchCount, 1) - // For patches, we track how many patch items using compare-and-swap - for { - current := n.maxPeersCount.Load() - if int64(info.PatchCount) <= current { - break - } + if info.IsPatch { + atomic.AddInt64(&n.patchCount, 1) + // For patches, we track how many patch items using compare-and-swap + for { + current := n.maxPeersCount.Load() + if int64(info.PatchCount) <= current { + break + } - if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) { - break - } + if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) { + break } } } @@ -914,7 +893,7 @@ func TestBatcherBasicOperations(t *testing.T) { } } -func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { +func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, _ string, timeout time.Duration) { count := 0 timer := time.NewTimer(timeout) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 843729c7..329c9b58 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -139,29 +139,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { } } -// fullMapResponse returns a MapResponse for the given node. -func (m *mapper) fullMapResponse( - nodeID types.NodeID, - capVer tailcfg.CapabilityVersion, -) (*tailcfg.MapResponse, error) { - peers := m.state.ListPeers(nodeID) - - return m.NewMapResponseBuilder(nodeID). - WithDebugType(fullResponseDebug). - WithCapabilityVersion(capVer). - WithSelfNode(). - WithDERPMap(). - WithDomain(). - WithCollectServicesDisabled(). - WithDebugConfig(). - WithSSHPolicy(). - WithDNSConfig(). - WithUserProfiles(peers). - WithPacketFilters(). - WithPeers(peers). - Build() -} - func (m *mapper) selfMapResponse( nodeID types.NodeID, capVer tailcfg.CapabilityVersion, diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 4852ce04..ae2900d9 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -3,14 +3,10 @@ package mapper import ( "fmt" "net/netip" - "slices" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" @@ -81,93 +77,6 @@ func TestDNSConfigMapResponse(t *testing.T) { } } -// mockState is a mock implementation that provides the required methods. -type mockState struct { - polMan policy.PolicyManager - derpMap *tailcfg.DERPMap - primary *routes.PrimaryRoutes - nodes types.Nodes - peers types.Nodes -} - -func (m *mockState) DERPMap() *tailcfg.DERPMap { - return m.derpMap -} - -func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { - if m.polMan == nil { - return tailcfg.FilterAllowAll, nil - } - - return m.polMan.Filter() -} - -func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - if m.polMan == nil { - return nil, nil - } - - return m.polMan.SSHPolicy(node) -} - -func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { - if m.polMan == nil { - return false - } - - return m.polMan.NodeCanHaveTag(node, tag) -} - -func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { - if m.primary == nil { - return nil - } - - return m.primary.PrimaryRoutes(nodeID) -} - -func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - if len(peerIDs) > 0 { - // Filter peers by the provided IDs - var filtered types.Nodes - - for _, peer := range m.peers { - if slices.Contains(peerIDs, peer.ID) { - filtered = append(filtered, peer) - } - } - - return filtered, nil - } - // Return all peers except the node itself - var filtered types.Nodes - - for _, peer := range m.peers { - if peer.ID != nodeID { - filtered = append(filtered, peer) - } - } - - return filtered, nil -} - -func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { - if len(nodeIDs) > 0 { - // Filter nodes by the provided IDs - var filtered types.Nodes - - for _, node := range m.nodes { - if slices.Contains(nodeIDs, node.ID) { - filtered = append(filtered, node) - } - } - - return filtered, nil - } - - return m.nodes, nil -} - func Test_fullMapResponse(t *testing.T) { t.Skip("Test needs to be refactored for new state-based architecture") // TODO: Refactor this test to work with the new state-based mapper diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 3f72cdda..958902a2 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -42,7 +42,7 @@ func (pol *Policy) compileFilterRules( continue } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange @@ -141,7 +141,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( } } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() var rules []*tailcfg.FilterRule diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 4477e8b1..3e3c70f5 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -13,7 +13,7 @@ import ( "tailscale.com/tailcfg" ) -func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { +func node(name, ipv4, ipv6 string, user types.User, _ *tailcfg.Hostinfo) *types.Node { return &types.Node{ ID: 0, Hostname: name, @@ -21,7 +21,6 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) IPv6: ap(ipv6), User: new(user), UserID: new(user.ID), - Hostinfo: hostinfo, } } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 3fe5a0d4..c99f1156 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1377,49 +1377,44 @@ func (p Protocol) Description() string { } } -// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. -func (p Protocol) parseProtocol() ([]int, bool) { +func (p Protocol) parseProtocol() []int { switch p { case "": // Empty protocol applies to TCP and UDP traffic only - return []int{protocolTCP, protocolUDP}, false + return []int{protocolTCP, protocolUDP} case ProtocolWildcard: // Wildcard protocol - defensive handling (should not reach here due to validation) - return nil, false + return nil case ProtocolIGMP: - return []int{protocolIGMP}, true + return []int{protocolIGMP} case ProtocolIPv4, ProtocolIPInIP: - return []int{protocolIPv4}, true + return []int{protocolIPv4} case ProtocolTCP: - return []int{protocolTCP}, false + return []int{protocolTCP} case ProtocolEGP: - return []int{protocolEGP}, true + return []int{protocolEGP} case ProtocolIGP: - return []int{protocolIGP}, true + return []int{protocolIGP} case ProtocolUDP: - return []int{protocolUDP}, false + return []int{protocolUDP} case ProtocolGRE: - return []int{protocolGRE}, true + return []int{protocolGRE} case ProtocolESP: - return []int{protocolESP}, true + return []int{protocolESP} case ProtocolAH: - return []int{protocolAH}, true + return []int{protocolAH} case ProtocolSCTP: - return []int{protocolSCTP}, false + return []int{protocolSCTP} case ProtocolICMP: - return []int{protocolICMP, protocolIPv6ICMP}, true + return []int{protocolICMP, protocolIPv6ICMP} default: // Try to parse as a numeric protocol number // This should not fail since validation happened during unmarshaling protocolNumber, _ := strconv.Atoi(string(p)) - // Determine if wildcard is needed based on protocol number - needsWildcard := protocolNumber != protocolTCP && - protocolNumber != protocolUDP && - protocolNumber != protocolSCTP - - return []int{protocolNumber}, needsWildcard + return []int{protocolNumber} } } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index a7d50a07..ce6804e4 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -1,14 +1,12 @@ package state import ( - "net/netip" "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/key" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -136,25 +134,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { }) } -// Simple helper function for tests. -func createTestNodeSimple(id types.NodeID) *types.Node { - user := types.User{ - Name: "test-user", - } - - machineKey := key.NewMachine() - nodeKey := key.NewNode() - - node := &types.Node{ - ID: id, - Hostname: "test-node", - UserID: new(uint(id)), - User: &user, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPv4: &netip.Addr{}, - IPv6: &netip.Addr{}, - } - - return node -} diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 23068b97..b90956aa 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -930,9 +930,7 @@ func TestNodeStoreConcurrentPutNode(t *testing.T) { // --- Batching: concurrent ops fit in one batch ---. func TestNodeStoreBatchingEfficiency(t *testing.T) { - const batchSize = 10 - - const ops = 15 // more than batchSize + const ops = 15 store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index 6913d7d9..37a63e80 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -365,13 +365,6 @@ func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change { } } -// ptrTo returns a pointer to the given value. -// -//go:fix inline -func ptrTo[T any](v T) *T { - return new(v) -} - // High-level change constructors // NodeAdded returns a Change for when a node is added or updated. diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 47c55a37..0bced1ed 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -64,7 +64,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second) // Validate that all nodes have NetInfo and DERP servers before logout - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout") // assertClientsState(t, allClients) @@ -172,7 +172,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireNoErrSync(t, err) // Validate that all nodes have NetInfo and DERP servers after reconnection - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection") err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) @@ -262,7 +262,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // Validate initial connection state requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") var ( listNodes []*v1.Node @@ -332,7 +332,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // Validate connection state after relogin as user1 requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin") // Validate that user2 still has their original nodes after user1's re-authentication // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created @@ -414,7 +414,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // Validate initial connection state requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") var ( listNodes []*v1.Node diff --git a/integration/helpers.go b/integration/helpers.go index 4a00342c..59b87cff 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -9,7 +9,6 @@ import ( "net/netip" "strconv" "strings" - "sync" "testing" "time" @@ -128,7 +127,7 @@ func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNo t.Helper() requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") } // validateLogoutComplete performs comprehensive validation after client logout. @@ -147,7 +146,7 @@ func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNode t.Helper() requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin") } // requireAllClientsOnline validates that all nodes are online/offline across all headscale systems @@ -366,7 +365,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer } // requireAllClientsOfflineStaged validates offline state with staged timeouts for different components. -func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { +func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, _ string, _ time.Duration) { t.Helper() // Stage 1: Verify batcher disconnection (should be immediate) @@ -467,9 +466,11 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec // requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database // and a valid DERP server based on the NetInfo. This function follows the pattern of // requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. -func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { +func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string) { t.Helper() + const timeout = 3 * time.Minute + startTime := time.Now() t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message) @@ -556,14 +557,14 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { // pingAllHelper performs ping tests between all clients and addresses, returning success count. // This is used to validate network connectivity in integration tests. // Returns the total number of successful ping operations. -func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { +func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { t.Helper() success := 0 for _, client := range clients { for _, addr := range addrs { - err := client.Ping(addr, opts...) + err := client.Ping(addr) if err != nil { t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) } else { @@ -628,167 +629,6 @@ func isSelfClient(client TailscaleClient, addr string) bool { return false } -// assertClientsState validates the status and netmap of a list of clients for general connectivity. -// Runs parallel validation of status, netcheck, and netmap for all clients to ensure -// they have proper network configuration for all-to-all connectivity tests. -func assertClientsState(t *testing.T, clients []TailscaleClient) { - t.Helper() - - var wg sync.WaitGroup - - for _, client := range clients { - wg.Add(1) - - c := client // Avoid loop pointer - - go func() { - defer wg.Done() - - assertValidStatus(t, c) - assertValidNetcheck(t, c) - assertValidNetmap(t, c) - }() - } - - t.Logf("waiting for client state checks to finish") - wg.Wait() -} - -// assertValidNetmap validates that a client's netmap has all required fields for proper operation. -// Checks self node and all peers for essential networking data including hostinfo, addresses, -// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56. -// This test is not suitable for ACL/partial connection tests. -func assertValidNetmap(t *testing.T, client TailscaleClient) { - t.Helper() - - if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { - t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) - - return - } - - t.Logf("Checking netmap of %q", client.Hostname()) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - netmap, err := client.Netmap() - assert.NoError(c, err, "getting netmap for %q", client.Hostname()) - - assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) - - if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { - assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) - } - - assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) - assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) - - assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) - - assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) - assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) - assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) - - for _, peer := range netmap.Peers { - assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) - assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) - - assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) - - if hi := peer.Hostinfo(); hi.Valid() { - assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) - - // Netinfo is not always set - // assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) - if ni := hi.NetInfo(); ni.Valid() { - assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) - } - } - - assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) - - assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) - - assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) - } - }, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname()) -} - -// assertValidStatus validates that a client's status has all required fields for proper operation. -// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints, -// and network map presence. This test is not suitable for ACL/partial connection tests. -func assertValidStatus(t *testing.T, client TailscaleClient) { - t.Helper() - - status, err := client.Status(true) - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) - - assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) - - // This seem to not appear until version 1.56 - if status.Self.AllowedIPs != nil { - assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) - } - - assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) - - assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) - - assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) - - // This isn't really relevant for Self as it won't be in its own socket/wireguard. - // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) - // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) - - for _, peer := range status.Peer { - assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) - - assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) - - // This seem to not appear until version 1.56 - if peer.AllowedIPs != nil { - assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) - } - - // Addrs does not seem to appear in the status from peers. - // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) - assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) - - // TODO(kradalby): InEngine is only true when a proper tunnel is set up, - // there might be some interesting stuff to test here in the future. - // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) - } -} - -// assertValidNetcheck validates that a client has a proper DERP relay configured. -// Ensures the client has discovered and selected a DERP server for relay functionality, -// which is essential for NAT traversal and connectivity in restricted networks. -func assertValidNetcheck(t *testing.T, client TailscaleClient) { - t.Helper() - - report, err := client.Netcheck() - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) -} - // assertCommandOutputContains executes a command with exponential backoff retry until the output // contains the expected string or timeout is reached (10 seconds). // This implements eventual consistency patterns and should be used instead of time.Sleep @@ -927,12 +767,6 @@ func usernameOwner(name string) policyv2.Owner { return new(policyv2.Username(name)) } -// groupOwner returns a Group as an Owner for use in TagOwners policies. -// Specifies which groups can assign and manage specific tags in ACL configurations. -func groupOwner(name string) policyv2.Owner { - return new(policyv2.Group(name)) -} - // usernameApprover returns a Username as an AutoApprover for subnet route policies. // Specifies which users can automatically approve subnet route advertisements. func usernameApprover(name string) policyv2.AutoApprover { From 676273ee9dea4ef87e20e4f71fa1f3f39a2ef7b6 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:08:44 +0000 Subject: [PATCH 10/30] all: fix staticcheck issues - Apply De Morgan's law to simplify condition in types.go - Simplify return statement in batcher_test.go - Fix potential nil pointer dereference by using t.Fatal in util_test.go - Remove unused variable assignment in auth_oidc_test.go - Remove duplicate tarReader assignment in hsic.go - Add empty line after embedded field in batcher_test.go and types.go Note: Many staticcheck SA4006 warnings are false positives due to Go 1.26's new `new(value)` syntax which creates a pointer to a value. The staticcheck tool hasn't been updated to understand this syntax. --- hscontrol/mapper/batcher_test.go | 8 ++------ hscontrol/policy/v2/types.go | 3 ++- hscontrol/util/util_test.go | 8 ++++---- integration/auth_oidc_test.go | 2 +- integration/hsic/hsic.go | 4 +--- 5 files changed, 10 insertions(+), 15 deletions(-) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 1dac3705..0fe1266d 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -36,6 +36,7 @@ type batcherTestCase struct { // that would normally be sent by poll.go in production. type testBatcherWrapper struct { Batcher + state *state.State } @@ -81,12 +82,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe } // Finally remove from the real batcher - removed := t.Batcher.RemoveNode(id, c) - if !removed { - return false - } - - return true + return t.Batcher.RemoveNode(id, c) } // wrapBatcherForTest wraps a batcher with test-specific behavior. diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index c99f1156..41e7e0d9 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -596,6 +596,7 @@ type Alias interface { type AliasWithPorts struct { Alias + Ports []tailcfg.PortRange } @@ -2107,7 +2108,7 @@ func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWi for _, dst := range destinations { for _, portRange := range dst.Ports { // Check if it's not a wildcard port (0-65535) - if !(portRange.First == 0 && portRange.Last == 65535) { + if portRange.First != 0 || portRange.Last != 65535 { return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) } } diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index a064a852..ec72250e 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1103,7 +1103,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "test-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } if hi.Hostname != "test-node" { @@ -1149,7 +1149,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "node-nkey1234", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } if hi.Hostname != "node-nkey1234" { @@ -1165,7 +1165,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "unknown-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } if hi.Hostname != "unknown-node" { @@ -1183,7 +1183,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "unknown-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } if hi.Hostname != "unknown-node" { diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 18c5c3a9..bdd5bce2 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1052,7 +1052,7 @@ func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { require.NotEqual(t, redirect1.String(), redirect2.String()) // complete auth with the first opened "browser tab" - _, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) + _, _, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) require.NoError(t, err) listUsers, err = headscale.ListUsers() diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index a08ee7af..c4eaeea8 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -721,8 +721,6 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { return fmt.Errorf("failed to create directory %s: %w", targetDir, err) } - tarReader := tar.NewReader(bytes.NewReader(tarData)) - // Find the top-level directory to strip var topLevelDir string @@ -743,7 +741,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { } } - tarReader = tar.NewReader(bytes.NewReader(tarData)) + tarReader := tar.NewReader(bytes.NewReader(tarData)) for { header, err := tarReader.Next() if err == io.EOF { From 3843036d13947db7f8b91a03620304d780b2f4a3 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:13:52 +0000 Subject: [PATCH 11/30] all: use context-aware methods for exec, database, and HTTP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace direct calls with context-aware versions: - exec.Command → exec.CommandContext - db.Exec → db.ExecContext - db.Ping → db.PingContext - db.QueryRow → db.QueryRowContext - http.NewRequest → http.NewRequestWithContext - net.LookupIP → net.DefaultResolver.LookupIPAddr --- cmd/hi/docker.go | 2 +- cmd/hi/doctor.go | 6 +++--- hscontrol/db/db.go | 2 +- hscontrol/db/db_test.go | 5 +++-- hscontrol/db/sqliteconfig/integration_test.go | 17 +++++++++-------- hscontrol/derp/server/derp_server.go | 6 +++--- integration/api_auth_test.go | 9 +++++---- 7 files changed, 25 insertions(+), 22 deletions(-) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index fbc2dba6..81f1d729 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -475,7 +475,7 @@ func createDockerClient() (*client.Client, error) { // getCurrentDockerContext retrieves the current Docker context information. func getCurrentDockerContext() (*DockerContext, error) { - cmd := exec.Command("docker", "context", "inspect") + cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect") output, err := cmd.Output() if err != nil { diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 8ebda159..2bfc41fd 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -265,7 +265,7 @@ func checkGoInstallation() DoctorResult { } } - cmd := exec.Command("go", "version") + cmd := exec.CommandContext(context.Background(), "go", "version") output, err := cmd.Output() if err != nil { @@ -287,7 +287,7 @@ func checkGoInstallation() DoctorResult { // checkGitRepository verifies we're in a git repository. func checkGitRepository() DoctorResult { - cmd := exec.Command("git", "rev-parse", "--git-dir") + cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir") err := cmd.Run() if err != nil { @@ -320,7 +320,7 @@ func checkRequiredFiles() DoctorResult { var missingFiles []string for _, file := range requiredFiles { - cmd := exec.Command("test", "-e", file) + cmd := exec.CommandContext(context.Background(), "test", "-e", file) if err := cmd.Run(); err != nil { missingFiles = append(missingFiles, file) } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 1ef767ce..ff9379c1 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -1035,7 +1035,7 @@ func (hsdb *HSDatabase) Close() error { } if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { - _, _ = db.Exec("VACUUM") + _, _ = db.ExecContext(context.Background(), "VACUUM") } return db.Close() diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 47a527b9..f93b9ef8 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "os" "os/exec" @@ -177,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { return err } - _, err = db.Exec(string(schemaContent)) + _, err = db.ExecContext(context.Background(), string(schemaContent)) return err } @@ -322,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { } // Construct the pg_restore command - cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) + cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) // Set the output streams cmd.Stdout = os.Stdout diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index b411daeb..fa39f958 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -1,6 +1,7 @@ package sqliteconfig import ( + "context" "database/sql" "path/filepath" "strings" @@ -101,7 +102,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { defer db.Close() // Test connection - if err := db.Ping(); err != nil { + if err := db.PingContext(context.Background()); err != nil { t.Fatalf("Failed to ping database: %v", err) } @@ -112,7 +113,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { query := "PRAGMA " + pragma - err := db.QueryRow(query).Scan(&actualValue) + err := db.QueryRowContext(context.Background(), query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) } @@ -180,23 +181,23 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { ); ` - if _, err := db.Exec(schema); err != nil { + if _, err := db.ExecContext(context.Background(), schema); err != nil { t.Fatalf("Failed to create schema: %v", err) } // Insert parent record - if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { + if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { t.Fatalf("Failed to insert parent: %v", err) } // Test 1: Valid foreign key should work - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") + _, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") if err != nil { t.Fatalf("Valid foreign key insert failed: %v", err) } // Test 2: Invalid foreign key should fail - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") + _, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") if err == nil { t.Error("Expected foreign key constraint violation, but insert succeeded") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -206,7 +207,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } // Test 3: Deleting referenced parent should fail - _, err = db.Exec("DELETE FROM parent WHERE id = 1") + _, err = db.ExecContext(context.Background(), "DELETE FROM parent WHERE id = 1") if err == nil { t.Error("Expected foreign key constraint violation when deleting referenced parent") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -252,7 +253,7 @@ func TestJournalModeValidation(t *testing.T) { var actualMode string - err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) + err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) } diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index bf292d03..562061e2 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -99,12 +99,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { // If debug flag is set, resolve hostname to IP address if debugUseDERPIP { - ips, err := net.LookupIP(host) + addrs, err := net.DefaultResolver.LookupIPAddr(context.Background(), host) if err != nil { log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host) - } else if len(ips) > 0 { + } else if len(addrs) > 0 { // Use the first IP address - ipStr := ips[0].String() + ipStr := addrs[0].IP.String() log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr) host = ipStr } diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 825f3d17..ed4a1f4d 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -1,6 +1,7 @@ package integration import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -78,7 +79,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_NoAuthHeader", func(t *testing.T) { // Test 1: Request without any Authorization header // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) resp, err := client.Do(req) @@ -130,7 +131,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) { // Test 2: Request with invalid Authorization header (missing "Bearer " prefix) // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "InvalidToken") @@ -164,7 +165,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Test 3: Request with Bearer prefix but invalid token // Expected: Should return 401 with ONLY "Unauthorized" text, no user data // Note: Both malformed and properly formatted invalid tokens should return 401 - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer invalid-token-12345") @@ -197,7 +198,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_ValidAPIKey", func(t *testing.T) { // Test 4: Request with valid API key // Expected: Should return 200 with user data (this is the authorized case) - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer "+validAPIKey) From 144c79aedf42126fa4398302b9c64efb5770a8cb Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:17:28 +0000 Subject: [PATCH 12/30] hscontrol/mapper: fix copylocks govet warnings Change Nodes field in TestData from []node to []*node to avoid copying sync/atomic.Int64 values (which contain noCopy sentinel). This fixes all govet copylocks warnings in batcher_test.go. --- hscontrol/mapper/batcher_test.go | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 0fe1266d..4f950d15 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -132,7 +132,7 @@ const ( type TestData struct { Database *db.HSDatabase Users []*types.User - Nodes []node + Nodes []*node State *state.State Config *types.Config Batcher Batcher @@ -218,11 +218,11 @@ func setupBatcherWithTestData( // Create test users and nodes in the database users := database.CreateUsersForTest(userCount, "testuser") - allNodes := make([]node, 0, userCount*nodesPerUser) + allNodes := make([]*node, 0, userCount*nodesPerUser) for _, user := range users { dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node") for i := range dbNodes { - allNodes = append(allNodes, node{ + allNodes = append(allNodes, &node{ n: dbNodes[i], ch: make(chan *tailcfg.MapResponse, bufferSize), }) @@ -516,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { defer cleanup() batcher := testData.Batcher - testNode := &testData.Nodes[0] + testNode := testData.Nodes[0] t.Logf("Testing enhanced tracking with node ID %d", testNode.n.ID) @@ -632,7 +632,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity @@ -654,7 +654,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { connectedCount := 0 for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] currentMaxPeers := int(node.maxPeersCount.Load()) if currentMaxPeers >= expectedPeers { @@ -675,7 +675,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { // Disconnect all nodes for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] batcher.RemoveNode(node.n.ID, node.ch) } @@ -696,7 +696,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { nodeDetails := make([]string, 0, min(10, len(allNodes))) for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] stats := node.cleanup() totalUpdates += stats.TotalUpdates @@ -1745,7 +1745,7 @@ func XTestBatcherScalability(t *testing.T) { var connectedNodesMutex sync.RWMutex for i := range testNodes { - node := &testNodes[i] + node := testNodes[i] _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) connectedNodesMutex.Lock() @@ -1976,7 +1976,7 @@ func XTestBatcherScalability(t *testing.T) { // Now disconnect all nodes from batcher to stop new updates for i := range testNodes { - node := &testNodes[i] + node := testNodes[i] batcher.RemoveNode(node.n.ID, node.ch) } @@ -1995,7 +1995,7 @@ func XTestBatcherScalability(t *testing.T) { nodeStatsReport := make([]string, 0, len(testNodes)) for i := range testNodes { - node := &testNodes[i] + node := testNodes[i] stats := node.cleanup() totalUpdates += stats.TotalUpdates totalPatches += stats.PatchUpdates @@ -2651,9 +2651,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { batcher := testData.Batcher st := testData.State - node1 := &testData.Nodes[0] - node2 := &testData.Nodes[1] - node3 := &testData.Nodes[2] + node1 := testData.Nodes[0] + node2 := testData.Nodes[1] + node3 := testData.Nodes[2] t.Logf("Testing issue #2924: Node1=%d, Node2=%d, Node3=%d", node1.n.ID, node2.n.ID, node3.n.ID) From 25fdad39495f029555f0867e00c516185018a3c8 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:35:44 +0000 Subject: [PATCH 13/30] hscontrol/policy/v2: define sentinel errors Add comprehensive sentinel errors for all error conditions in the policy engine and use consistent error wrapping patterns with fmt.Errorf("%w: ...). Update test expectations to match the new error message formats. --- hscontrol/policy/v2/filter.go | 5 +- hscontrol/policy/v2/types.go | 199 ++++++++++++++++++------------ hscontrol/policy/v2/types_test.go | 87 ++++++------- hscontrol/policy/v2/utils.go | 28 +++-- hscontrol/policy/v2/utils_test.go | 62 ++++------ 5 files changed, 213 insertions(+), 168 deletions(-) diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 958902a2..ced8531c 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -1,7 +1,6 @@ package v2 import ( - "errors" "fmt" "slices" "time" @@ -14,8 +13,6 @@ import ( "tailscale.com/types/views" ) -var ErrInvalidAction = errors.New("invalid action") - // compileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *Policy) compileFilterRules( @@ -149,7 +146,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( for _, src := range acl.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - return nil, errors.New("autogroup:self cannot be used in sources") + return nil, ErrAutogroupSelfInSource } ips, err := src.Resolve(pol, users, nodes) diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 41e7e0d9..4ff5dd1a 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -36,6 +36,73 @@ var ErrCircularReference = errors.New("circular reference detected") var ErrUndefinedTagReference = errors.New("references undefined tag") +// Sentinel errors for type/alias validation. +var ( + ErrUnknownAliasType = errors.New("unknown alias type") + ErrUnknownOwnerType = errors.New("unknown owner type") + ErrUnknownAutoApproverType = errors.New("unknown auto approver type") + ErrInvalidAlias = errors.New("invalid alias") + ErrInvalidAutoApprover = errors.New("invalid auto approver") + ErrInvalidOwner = errors.New("invalid owner") +) + +// Sentinel errors for format validation. +var ( + ErrUsernameMissingAt = errors.New("username must contain @") + ErrGroupMissingPrefix = errors.New("group must start with 'group:'") + ErrTagMissingPrefix = errors.New("tag must start with 'tag:'") + ErrInvalidHostname = errors.New("invalid hostname") + ErrInvalidPrefix = errors.New("invalid prefix") + ErrInvalidAutoGroup = errors.New("invalid autogroup") + ErrInvalidAction = errors.New("invalid action") + ErrInvalidSSHAction = errors.New("invalid SSH action") + ErrInvalidProtocol = errors.New("invalid protocol") + ErrProtocolOutOfRange = errors.New("protocol number out of range") + ErrLeadingZeroProtocol = errors.New("leading zero not permitted in protocol number") + ErrHostportMissingColon = errors.New("hostport must contain a colon") + ErrUnsupportedType = errors.New("unsupported type") +) + +// Sentinel errors for resolution/lookup failures. +var ( + ErrUserNotFound = errors.New("user not found") + ErrMultipleUsersFound = errors.New("multiple users found") + ErrHostNotResolved = errors.New("unable to resolve host") + ErrGroupNotDefined = errors.New("group not defined in policy") + ErrTagNotDefined = errors.New("tag not defined in policy") + ErrHostNotDefined = errors.New("host not defined in policy") + ErrInvalidIPAddress = errors.New("invalid IP address") + ErrNestedGroups = errors.New("nested groups not allowed") + ErrInvalidGroupMember = errors.New("invalid group member type") + ErrGroupValueNotArray = errors.New("group value must be an array") + ErrAutoApproverNotAlias = errors.New("auto approver is not an alias") +) + +// Sentinel errors for autogroup context validation. +var ( + ErrAutogroupInternetInSource = errors.New("autogroup:internet can only be used in ACL destinations") + ErrAutogroupSelfInSource = errors.New("autogroup:self can only be used in ACL destinations") + ErrAutogroupNotSupportedSource = errors.New("autogroup not supported for source") + ErrAutogroupNotSupportedDest = errors.New("autogroup not supported for destination") + ErrAutogroupNotSupportedSSH = errors.New("autogroup not supported for SSH") + ErrAutogroupNotSupported = errors.New("autogroup not supported in headscale") + ErrAliasNotSupportedSSH = errors.New("alias type not supported for SSH") +) + +// Sentinel errors for SSH aliases. +var ( + ErrAliasNotSupportedSSHSrc = errors.New("alias type not supported for SSH source") + ErrAliasNotSupportedSSHDst = errors.New("alias type not supported for SSH destination") + ErrUnknownSSHSrcAliasType = errors.New("unknown SSH source alias type") + ErrUnknownSSHDstAliasType = errors.New("unknown SSH destination alias type") +) + +// Sentinel errors for policy parsing. +var ( + ErrUnknownField = errors.New("unknown field in policy") + ErrProtocolNoSpecificPorts = errors.New("protocol does not support specific ports") +) + type Asterix int func (a Asterix) Validate() error { @@ -75,7 +142,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { case Asterix: alias = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } // If no ports are specified @@ -126,7 +193,7 @@ func (u Username) Validate() error { return nil } - return fmt.Errorf("Username has to contain @, got: %q", u) + return fmt.Errorf("%w: got %q", ErrUsernameMissingAt, u) } func (u *Username) String() string { @@ -186,11 +253,11 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { } if len(potentialUsers) == 0 { - return types.User{}, fmt.Errorf("user with token %q not found", u.String()) + return types.User{}, fmt.Errorf("%w: token %q", ErrUserNotFound, u.String()) } if len(potentialUsers) > 1 { - return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String()) + return types.User{}, fmt.Errorf("%w: token %q found %s", ErrMultipleUsersFound, u.String(), potentialUsers.String()) } return potentialUsers[0], nil @@ -234,7 +301,7 @@ func (g Group) Validate() error { return nil } - return fmt.Errorf(`Group has to start with "group:", got: %q`, g) + return fmt.Errorf("%w: got %q", ErrGroupMissingPrefix, g) } func (g *Group) UnmarshalJSON(b []byte) error { @@ -299,7 +366,7 @@ func (t Tag) Validate() error { return nil } - return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) + return fmt.Errorf("%w: got %q", ErrTagMissingPrefix, t) } func (t *Tag) UnmarshalJSON(b []byte) error { @@ -349,7 +416,7 @@ func (h Host) Validate() error { return nil } - return fmt.Errorf("Hostname %q is invalid", h) + return fmt.Errorf("%w: %q", ErrInvalidHostname, h) } func (h *Host) UnmarshalJSON(b []byte) error { @@ -369,7 +436,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView pref, ok := p.Hosts[h] if !ok { - return nil, fmt.Errorf("unable to resolve host: %q", h) + return nil, fmt.Errorf("%w: %q", ErrHostNotResolved, h) } err := pref.Validate() @@ -406,7 +473,7 @@ func (p Prefix) Validate() error { return nil } - return fmt.Errorf("Prefix %q is invalid", p) + return fmt.Errorf("%w: %q", ErrInvalidPrefix, p) } func (p Prefix) String() string { @@ -510,7 +577,7 @@ func (ag AutoGroup) Validate() error { return nil } - return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) + return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutoGroup, ag, autogroups) } func (ag *AutoGroup) UnmarshalJSON(b []byte) error { @@ -570,7 +637,7 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type return nil, ErrAutogroupSelfRequiresPerNodeResolution default: - return nil, fmt.Errorf("unknown autogroup %q", ag) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoGroup, ag) } } @@ -626,7 +693,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { ve.Ports = ports } else { - return errors.New(`hostport must contain a colon (":")`) + return ErrHostportMissingColon } ve.Alias, err = parseAlias(vs) @@ -639,7 +706,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { } default: - return fmt.Errorf("type %T not supported", vs) + return fmt.Errorf("%w: %T", ErrUnsupportedType, vs) } return nil @@ -694,15 +761,7 @@ func parseAlias(vs string) (Alias, error) { return new(Host(vs)), nil } - return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: -- wildcard (*) -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") -- autogroup (starting with "autogroup:") -- host - -Please check the format and try again.`, vs) + return nil, fmt.Errorf("%w: %q", ErrInvalidAlias, vs) } // AliasEnc is used to deserialize a Alias. @@ -764,7 +823,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } } @@ -850,7 +909,7 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) { case *Group: approvers[i] = string(*v) default: - return nil, fmt.Errorf("unknown auto approver type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAutoApproverType, v) } } @@ -867,12 +926,7 @@ func parseAutoApprover(s string) (AutoApprover, error) { return new(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoApprover, s) } // AutoApproverEnc is used to deserialize a AutoApprover. @@ -949,7 +1003,7 @@ func (o Owners) MarshalJSON() ([]byte, error) { case *Tag: owners[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -966,12 +1020,7 @@ func parseOwner(s string) (Owner, error) { return new(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidOwner, s) } type Usernames []Username @@ -990,7 +1039,7 @@ func (g Groups) Contains(group *Group) error { } } - return fmt.Errorf(`Group %q is not defined in the Policy, please define or remove the reference to it`, group) + return fmt.Errorf("%w: %q", ErrGroupNotDefined, group) } // UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure @@ -1025,15 +1074,15 @@ func (g *Groups) UnmarshalJSON(b []byte) error { if str, ok := item.(string); ok { stringSlice = append(stringSlice, str) } else { - return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) + return fmt.Errorf("%w: group %q got %T", ErrInvalidGroupMember, key, item) } } rawGroups[key] = stringSlice case string: - return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) + return fmt.Errorf("%w: group %q got string %q", ErrGroupValueNotArray, key, v) default: - return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v) + return fmt.Errorf("%w: group %q got %T", ErrGroupValueNotArray, key, v) } } @@ -1048,7 +1097,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { username := Username(u) if err := username.Validate(); err != nil { if isGroup(u) { - return fmt.Errorf("Nested groups are not allowed, found %q inside %q", u, group) + return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group) } return err @@ -1082,7 +1131,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { var prefix Prefix if err := prefix.parseString(value); err != nil { - return fmt.Errorf(`Hostname "%s" contains an invalid IP address: "%s"`, key, value) + return fmt.Errorf("%w: hostname %q value %q", ErrInvalidIPAddress, key, value) } (*h)[host] = prefix @@ -1131,7 +1180,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { case *Tag: ownerStrs[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -1155,7 +1204,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error { } } - return fmt.Errorf(`Tag %q is not defined in the Policy, please define or remove the reference to it`, tagOwner) + return fmt.Errorf("%w: %q", ErrTagNotDefined, tagOwner) } type AutoApproverPolicy struct { @@ -1208,7 +1257,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1223,7 +1272,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1280,7 +1329,7 @@ func (a *Action) UnmarshalJSON(b []byte) error { case "accept": *a = ActionAccept default: - return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + return fmt.Errorf("%w: %q, must be %q", ErrInvalidAction, str, ActionAccept) } return nil @@ -1305,7 +1354,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { case "check": *a = SSHActionCheck default: - return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + return fmt.Errorf("%w: %q, must be one of: accept, check", ErrInvalidSSHAction, str) } return nil @@ -1443,23 +1492,23 @@ func (p Protocol) validate() error { return nil case ProtocolWildcard: // Wildcard "*" is not allowed - Tailscale rejects it - return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + return fmt.Errorf("%w: use protocol number 0-255 or protocol name", ErrInvalidProtocol) default: // Try to parse as a numeric protocol number str := string(p) // Check for leading zeros (not allowed by Tailscale) if str == "0" || (len(str) > 1 && str[0] == '0') { - return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + return fmt.Errorf("%w: %q", ErrLeadingZeroProtocol, str) } protocolNumber, err := strconv.Atoi(str) if err != nil { - return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocol, p) } if protocolNumber < 0 || protocolNumber > 255 { - return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + return fmt.Errorf("%w: %d", ErrProtocolOutOfRange, protocolNumber) } return nil @@ -1577,7 +1626,7 @@ func validateAutogroupSupported(ag *AutoGroup) error { } if slices.Contains(autogroupNotSupported, *ag) { - return fmt.Errorf("autogroup %q is not supported in headscale", *ag) + return fmt.Errorf("%w: %q", ErrAutogroupNotSupported, *ag) } return nil @@ -1589,15 +1638,15 @@ func validateAutogroupForSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + return ErrAutogroupInternetInSource } if src.Is(AutoGroupSelf) { - return errors.New(`"autogroup:self" used in source, it can only be used in ACL destinations`) + return ErrAutogroupSelfInSource } if !slices.Contains(autogroupForSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSource, *src, autogroupForSrc) } return nil @@ -1609,7 +1658,7 @@ func validateAutogroupForDst(dst *AutoGroup) error { } if !slices.Contains(autogroupForDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedDest, *dst, autogroupForDst) } return nil @@ -1621,11 +1670,11 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + return fmt.Errorf("%w: autogroup:internet in SSH source", ErrAutogroupNotSupportedSSH) } if !slices.Contains(autogroupForSSHSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc) + return fmt.Errorf("%w: %q for SSH sources, can be %v", ErrAutogroupNotSupportedSSH, *src, autogroupForSSHSrc) } return nil @@ -1637,11 +1686,11 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error { } if dst.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + return fmt.Errorf("%w: autogroup:internet in SSH destination", ErrAutogroupNotSupportedSSH) } if !slices.Contains(autogroupForSSHDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst) + return fmt.Errorf("%w: %q for SSH destinations, can be %v", ErrAutogroupNotSupportedSSH, *dst, autogroupForSSHDst) } return nil @@ -1653,7 +1702,7 @@ func validateAutogroupForSSHUser(user *AutoGroup) error { } if !slices.Contains(autogroupForSSHUser, *user) { - return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser) + return fmt.Errorf("%w: %q for SSH user, can be %v", ErrAutogroupNotSupportedSSH, *user, autogroupForSSHUser) } return nil @@ -1678,7 +1727,7 @@ func (p *Policy) validate() error { case *Host: h := src if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: ag := src @@ -1710,7 +1759,7 @@ func (p *Policy) validate() error { case *Host: h := dst.Alias.(*Host) if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: ag := dst.Alias.(*AutoGroup) @@ -1915,10 +1964,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { case *Username, *Group, *Tag, *AutoGroup: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH source", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHSrc, alias.Alias) } } @@ -1946,10 +1992,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { Asterix: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH destination", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHDst, alias.Alias) } } @@ -1976,7 +2019,7 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH destination alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHDstAliasType, v) } } @@ -2003,7 +2046,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH source alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHSrcAliasType, v) } } @@ -2077,11 +2120,11 @@ func unmarshalPolicy(b []byte) (*Policy, error) { ast.Standardize() if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { - if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName { + if serr, ok := errors.AsType[*json.SemanticError](err); ok && errors.Is(serr.Err, json.ErrUnknownName) { ptr := serr.JSONPointer name := ptr.LastToken() - return nil, fmt.Errorf("unknown field %q", name) + return nil, fmt.Errorf("%w: %q", ErrUnknownField, name) } return nil, fmt.Errorf("parsing policy from bytes: %w", err) @@ -2109,7 +2152,7 @@ func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWi for _, portRange := range dst.Ports { // Check if it's not a wildcard port (0-65535) if portRange.First != 0 || portRange.Last != 65535 { - return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) + return fmt.Errorf("%w: %q only allows \"*\"", ErrProtocolNoSpecificPorts, protocol) } } } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 79d005a3..b5e5a210 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -366,7 +366,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: "alias v2.Asterix is not supported for SSH source", + wantErr: "alias type not supported for SSH source: v2.Asterix", }, { name: "invalid-username", @@ -380,7 +380,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Username has to contain @, got: "invalid"`, + wantErr: `username must contain @: got "invalid"`, }, { name: "invalid-group", @@ -393,7 +393,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group has to start with "group:", got: "grou:example"`, + wantErr: `group must start with 'group:': got "grou:example"`, }, { name: "group-in-group", @@ -408,7 +408,7 @@ func TestUnmarshalPolicy(t *testing.T) { } `, // wantErr: `Username has to contain @, got: "group:inner"`, - wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`, + wantErr: `nested groups not allowed: found "group:inner" inside "group:example"`, }, { name: "invalid-addr", @@ -419,7 +419,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Hostname "derp" contains an invalid IP address: "10.0"`, + wantErr: `invalid IP address: hostname "derp" value "10.0"`, }, { name: "invalid-prefix", @@ -430,7 +430,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`, + wantErr: `invalid IP address: hostname "derp" value "10.0/42"`, }, // TODO(kradalby): Figure out why this doesn't work. // { @@ -459,7 +459,7 @@ func TestUnmarshalPolicy(t *testing.T) { ], } `, - wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, + wantErr: `invalid autogroup: got "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, }, { name: "undefined-hostname-errors-2490", @@ -478,7 +478,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `host not defined in policy: "user1" - please define or remove the reference`, }, { name: "defined-hostname-does-not-err-2490", @@ -571,7 +571,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`, + wantErr: `autogroup:internet can only be used in ACL destinations`, }, { name: "autogroup:internet-in-ssh-src-not-allowed", @@ -590,7 +590,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`, + wantErr: `autogroup not supported for SSH: autogroup:internet in SSH source`, }, { name: "autogroup:internet-in-ssh-dst-not-allowed", @@ -609,7 +609,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`, + wantErr: `autogroup not supported for SSH: autogroup:internet in SSH destination`, }, { name: "ssh-basic", @@ -760,7 +760,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-dst", @@ -779,7 +779,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-ssh-src", @@ -798,7 +798,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-tagOwner", @@ -809,7 +809,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-route", @@ -822,7 +822,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-exitnode", @@ -833,7 +833,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "tag-must-be-defined-acl-src", @@ -852,7 +852,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-dst", @@ -871,7 +871,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-src", @@ -890,7 +890,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-dst", @@ -912,7 +912,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-route", @@ -925,7 +925,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-exitnode", @@ -936,7 +936,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "missing-dst-port-is-err", @@ -955,7 +955,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `hostport must contain a colon (":")`, + wantErr: `hostport must contain a colon`, }, { name: "dst-port-zero-is-err", @@ -985,7 +985,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "rules"`, + wantErr: `unknown field in policy: "rules"`, }, { name: "disallow-unsupported-fields-nested", @@ -1008,7 +1008,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`, }, { name: "invalid-group-datatype", @@ -1020,7 +1020,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`, + wantErr: `group value must be an array: group "group:invalid" got string "should fail"`, }, { name: "invalid-group-name-and-datatype-fails-on-name-first", @@ -1032,7 +1032,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`, }, { name: "disallow-unsupported-fields-hosts-level", @@ -1044,7 +1044,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`, + wantErr: `invalid IP address: hostname "INVALID_HOST_FIELD" value "should fail"`, }, { name: "disallow-unsupported-fields-tagowners-level", @@ -1056,7 +1056,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`, + wantErr: `tag must start with 'tag:': got "INVALID_TAG_FIELD"`, }, { name: "disallow-unsupported-fields-acls-level", @@ -1073,7 +1073,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_ACL_FIELD"`, + wantErr: `unknown field in policy: "INVALID_ACL_FIELD"`, }, { name: "disallow-unsupported-fields-ssh-level", @@ -1090,7 +1090,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_SSH_FIELD"`, + wantErr: `unknown field in policy: "INVALID_SSH_FIELD"`, }, { name: "disallow-unsupported-fields-policy-level", @@ -1107,7 +1107,7 @@ func TestUnmarshalPolicy(t *testing.T) { "INVALID_POLICY_FIELD": "should fail at policy level" } `, - wantErr: `unknown field "INVALID_POLICY_FIELD"`, + wantErr: `unknown field in policy: "INVALID_POLICY_FIELD"`, }, { name: "disallow-unsupported-fields-autoapprovers-level", @@ -1122,7 +1122,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`, + wantErr: `unknown field in policy: "INVALID_AUTO_APPROVER_FIELD"`, }, // headscale-admin uses # in some field names to add metadata, so we will ignore // those to ensure it doesnt break. @@ -1181,7 +1181,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "proto"`, + wantErr: `unknown field in policy: "proto"`, }, { name: "protocol-wildcard-not-allowed", @@ -1197,7 +1197,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`, + wantErr: `invalid protocol: use protocol number 0-255 or protocol name`, }, { name: "protocol-case-insensitive-uppercase", @@ -1277,7 +1277,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `leading 0 not permitted in protocol number "0"`, + wantErr: `leading zero not permitted in protocol number: "0"`, }, { name: "protocol-empty-applies-to-tcp-udp-only", @@ -1324,7 +1324,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "icmp" only allows "*"`, }, { name: "protocol-icmp-with-wildcard-port-allowed", @@ -1372,7 +1372,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "gre" only allows "*"`, }, { name: "protocol-tcp-with-specific-port-allowed", @@ -1836,7 +1836,7 @@ func TestResolvePolicy(t *testing.T) { IPv4: ap("100.100.101.103"), }, }, - wantErr: `user with token "invaliduser@" not found`, + wantErr: `user not found: token "invaliduser@"`, }, { name: "invalid-tag", @@ -1999,7 +1999,7 @@ func TestResolvePolicy(t *testing.T) { { name: "autogroup-invalid", toResolve: new(AutoGroup("autogroup:invalid")), - wantErr: "unknown autogroup", + wantErr: "invalid autogroup", }, } @@ -2670,7 +2670,7 @@ func TestNodeCanHaveTag(t *testing.T) { node: nodes[0], tag: "tag:test", want: false, - wantErr: "Username has to contain @", + wantErr: "username must contain @", }, { name: "node-cannot-have-tag", @@ -3248,7 +3248,8 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { _, err := unmarshalPolicy([]byte(policyJSON)) require.Error(t, err) - assert.Contains(t, err.Error(), `invalid action "deny"`) + assert.Contains(t, err.Error(), `invalid action`) + assert.Contains(t, err.Error(), `deny`) } // Helper function to parse aliases for testing. diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index 80de52bc..3fb0d38b 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -9,6 +9,18 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for port and destination parsing. +var ( + ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port") + ErrInputStartsWithColon = errors.New("input cannot start with a colon character") + ErrInputEndsWithColon = errors.New("input cannot end with a colon character") + ErrInvalidPortRange = errors.New("invalid port range format") + ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port") + ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard") + ErrInvalidPortNumber = errors.New("invalid port number") + ErrPortOutOfRange = errors.New("port number out of range") +) + // splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid. func splitDestinationAndPort(input string) (string, string, error) { // Find the last occurrence of the colon character @@ -16,15 +28,15 @@ func splitDestinationAndPort(input string) (string, string, error) { // Check if the colon character is present and not at the beginning or end of the string if lastColonIndex == -1 { - return "", "", errors.New("input must contain a colon character separating destination and port") + return "", "", ErrInputMissingColon } if lastColonIndex == 0 { - return "", "", errors.New("input cannot start with a colon character") + return "", "", ErrInputStartsWithColon } if lastColonIndex == len(input)-1 { - return "", "", errors.New("input cannot end with a colon character") + return "", "", ErrInputEndsWithColon } // Split the string into destination and port based on the last colon @@ -52,7 +64,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { return e == "" }) if len(rangeParts) != 2 { - return nil, errors.New("invalid port range format") + return nil, ErrInvalidPortRange } first, err := parsePort(rangeParts[0]) @@ -66,7 +78,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if first > last { - return nil, errors.New("invalid port range: first port is greater than last port") + return nil, ErrPortRangeInverted } portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last}) @@ -77,7 +89,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if port < 1 { - return nil, errors.New("first port must be >0, or use '*' for wildcard") + return nil, ErrPortMustBePositive } portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port}) @@ -91,11 +103,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { func parsePort(portStr string) (uint16, error) { port, err := strconv.Atoi(portStr) if err != nil { - return 0, errors.New("invalid port number") + return 0, ErrInvalidPortNumber } if port < 0 || port > 65535 { - return 0, errors.New("port number out of range") + return 0, ErrPortOutOfRange } return uint16(port), nil diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index a845e7a9..2ce95537 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -24,14 +24,14 @@ func TestParseDestinationAndPort(t *testing.T) { {"tag:api-server:443", "tag:api-server", "443", nil}, {"example-host-1:*", "example-host-1", "*", nil}, {"hostname:80-90", "hostname", "80-90", nil}, - {"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")}, - {":invalid", "", "", errors.New("input cannot start with a colon character")}, - {"invalid:", "", "", errors.New("input cannot end with a colon character")}, + {"invalidinput", "", "", ErrInputMissingColon}, + {":invalid", "", "", ErrInputStartsWithColon}, + {"invalid:", "", "", ErrInputEndsWithColon}, } for _, testCase := range testCases { dst, port, err := splitDestinationAndPort(testCase.input) - if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) { + if dst != testCase.expectedDst || port != testCase.expectedPort || !errors.Is(err, testCase.expectedErr) { t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)", testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr) } @@ -42,27 +42,23 @@ func TestParsePort(t *testing.T) { tests := []struct { input string expected uint16 - err string + err error }{ - {"80", 80, ""}, - {"0", 0, ""}, - {"65535", 65535, ""}, - {"-1", 0, "port number out of range"}, - {"65536", 0, "port number out of range"}, - {"abc", 0, "invalid port number"}, - {"", 0, "invalid port number"}, + {"80", 80, nil}, + {"0", 0, nil}, + {"65535", 65535, nil}, + {"-1", 0, ErrPortOutOfRange}, + {"65536", 0, ErrPortOutOfRange}, + {"abc", 0, ErrInvalidPortNumber}, + {"", 0, ErrInvalidPortNumber}, } for _, test := range tests { result, err := parsePort(test.input) - if err != nil && err.Error() != test.err { + if !errors.Is(err, test.err) { t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) } - if err == nil && test.err != "" { - t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) - } - if result != test.expected { t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) } @@ -73,32 +69,28 @@ func TestParsePortRange(t *testing.T) { tests := []struct { input string expected []tailcfg.PortRange - err string + err error }{ - {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""}, - {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""}, - {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""}, - {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""}, - {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, - {"80-", nil, "invalid port range format"}, - {"-90", nil, "invalid port range format"}, - {"80-90,", nil, "invalid port number"}, - {"80,90-", nil, "invalid port range format"}, - {"80-90,abc", nil, "invalid port number"}, - {"80-90,65536", nil, "port number out of range"}, - {"80-90,90-80", nil, "invalid port range: first port is greater than last port"}, + {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, nil}, + {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, nil}, + {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, nil}, + {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, nil}, + {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, nil}, + {"80-", nil, ErrInvalidPortRange}, + {"-90", nil, ErrInvalidPortRange}, + {"80-90,", nil, ErrInvalidPortNumber}, + {"80,90-", nil, ErrInvalidPortRange}, + {"80-90,abc", nil, ErrInvalidPortNumber}, + {"80-90,65536", nil, ErrPortOutOfRange}, + {"80-90,90-80", nil, ErrPortRangeInverted}, } for _, test := range tests { result, err := parsePortRange(test.input) - if err != nil && err.Error() != test.err { + if !errors.Is(err, test.err) { t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) } - if err == nil && test.err != "" { - t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) - } - if diff := cmp.Diff(result, test.expected); diff != "" { t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) } From 7cbd3d8d91e9da228205d51789f8761236df83c5 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:47:41 +0000 Subject: [PATCH 14/30] all: define sentinel errors for err113 compliance Add sentinel errors across multiple packages to satisfy the err113 linter: - cmd/headscale/cli: Add ErrNameOrIDRequired, ErrMultipleUsersFoundUseID, errMockOidcUsersNotDefined - cmd/hi: Add ErrMemoryLimitExceeded, ErrStatsCollectionAlreadyStarted - hscontrol/db: Add ErrNameNotUnique, ErrTextUnmarshalFailed, ErrUnsupportedType, ErrTextMarshalerOnly, ErrTooManyWhereArgs, ErrMultipleUsers, fix pak.ID usage - hscontrol/dns: Add ErrPathIsDirectory - hscontrol/noise: Add ErrUnsupportedClientVersion - hscontrol/tailsql: Add ErrNoCertDomains - hscontrol/types: Add ErrInvalidRegIDLength, errNoPrefixConfigured, errInvalidAllocationStrategy, ErrCannotParseBool - hscontrol/util: Add username/hostname validation errors and traceroute parsing errors --- cmd/headscale/cli/mockoidc.go | 4 ++-- cmd/headscale/cli/users.go | 15 ++++++++----- cmd/hi/docker.go | 3 ++- cmd/hi/stats.go | 5 ++++- hscontrol/db/node.go | 5 +++-- hscontrol/db/text_serialiser.go | 14 +++++++++--- hscontrol/db/users.go | 6 ++++-- hscontrol/dns/extrarecords.go | 6 +++++- hscontrol/noise.go | 5 ++++- hscontrol/tailsql.go | 5 ++++- hscontrol/types/common.go | 7 ++++-- hscontrol/types/config.go | 15 +++++++------ hscontrol/types/node.go | 2 +- hscontrol/types/users.go | 6 +++++- hscontrol/util/dns.go | 38 +++++++++++++++++++-------------- hscontrol/util/util.go | 23 +++++++++++++++----- hscontrol/util/util_test.go | 5 ++--- 17 files changed, 111 insertions(+), 53 deletions(-) diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index c80c2a28..d1374ec5 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -2,7 +2,6 @@ package cli import ( "encoding/json" - "errors" "fmt" "net" "net/http" @@ -19,6 +18,7 @@ const ( errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined") errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined") errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined") + errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined") refreshTTL = 60 * time.Minute ) @@ -69,7 +69,7 @@ func mockOIDC() error { userStr := os.Getenv("MOCKOIDC_USERS") if userStr == "" { - return errors.New("MOCKOIDC_USERS not defined") + return errMockOidcUsersNotDefined } var users []mockoidc.MockUser diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 9f0954c6..084548a9 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -14,6 +14,12 @@ import ( "google.golang.org/grpc/status" ) +// Sentinel errors for CLI commands. +var ( + ErrNameOrIDRequired = errors.New("--name or --identifier flag is required") + ErrMultipleUsersFoundUseID = errors.New("unable to determine user, query returned multiple users, use ID") +) + func usernameAndIDFlag(cmd *cobra.Command) { cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)") cmd.Flags().StringP("name", "n", "", "Username") @@ -26,10 +32,9 @@ func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { identifier, _ := cmd.Flags().GetInt64("identifier") if username == "" && identifier < 0 { - err := errors.New("--name or --identifier flag is required") ErrorOutput( - err, - "Cannot rename user: "+status.Convert(err).Message(), + ErrNameOrIDRequired, + "Cannot rename user: "+status.Convert(ErrNameOrIDRequired).Message(), "", ) } @@ -149,7 +154,7 @@ var destroyUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + err := ErrMultipleUsersFoundUseID ErrorOutput( err, "Error: "+status.Convert(err).Message(), @@ -277,7 +282,7 @@ var renameUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + err := ErrMultipleUsersFoundUseID ErrorOutput( err, "Error: "+status.Convert(err).Message(), diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 81f1d729..698e9d54 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -27,6 +27,7 @@ var ( ErrTestFailed = errors.New("test failed") ErrUnexpectedContainerWait = errors.New("unexpected end of container wait") ErrNoDockerContext = errors.New("no docker context found") + ErrMemoryLimitExceeded = errors.New("container exceeded memory limits") ) // runTestContainer executes integration tests in a Docker container. @@ -151,7 +152,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) } - return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) + return fmt.Errorf("%w: %d container(s)", ErrMemoryLimitExceeded, len(violations)) } } diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index e80ee8d1..00a6cc4f 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -18,6 +18,9 @@ import ( "github.com/docker/docker/client" ) +// Sentinel errors for stats collection. +var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started") + // ContainerStats represents statistics for a single container. type ContainerStats struct { ContainerID string @@ -63,7 +66,7 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver defer sc.mutex.Unlock() if sc.collectionStarted { - return errors.New("stats collection already started") + return ErrStatsCollectionAlreadyStarted } sc.collectionStarted = true diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 7c818a75..04f9a621 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -36,6 +36,7 @@ var ( "node not found in registration cache", ) ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") + ErrNameNotUnique = errors.New("name is not unique") ) // ListPeers returns peers of node, regardless of any Policy or if the node is expired. @@ -288,7 +289,7 @@ func RenameNode(tx *gorm.DB, } if count > 0 { - return errors.New("name is not unique") + return ErrNameNotUnique } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -670,7 +671,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) Hostname: nodeName, UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: new(pak.ID), + AuthKeyID: &pak.ID, } err = hsdb.DB.Save(node).Error diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 102c0e9c..b1d294ea 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -3,12 +3,20 @@ package db import ( "context" "encoding" + "errors" "fmt" "reflect" "gorm.io/gorm/schema" ) +// Sentinel errors for text serialisation. +var ( + ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value") + ErrUnsupportedType = errors.New("unsupported type") + ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported") +) + // Got from https://github.com/xdg-go/strum/blob/main/types.go var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() @@ -49,7 +57,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect case string: bytes = []byte(v) default: - return fmt.Errorf("failed to unmarshal text value: %#v", dbValue) + return fmt.Errorf("%w: %#v", ErrTextUnmarshalFailed, dbValue) } if isTextUnmarshaler(fieldValue) { @@ -75,7 +83,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect return nil } else { - return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) + return fmt.Errorf("%w: %T", ErrUnsupportedType, fieldValue.Interface()) } } @@ -99,6 +107,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec return string(b), nil default: - return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) + return nil, fmt.Errorf("%w, got %T", ErrTextMarshalerOnly, v) } } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 650dbd49..be073999 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -15,6 +15,8 @@ var ( ErrUserExists = errors.New("user already exists") ErrUserNotFound = errors.New("user not found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrTooManyWhereArgs = errors.New("expect 0 or 1 where User structs") + ErrMultipleUsers = errors.New("expected exactly one user") ) func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { @@ -153,7 +155,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { // ListUsers gets all the existing users. func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { if len(where) > 1 { - return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + return nil, fmt.Errorf("%w, got %d", ErrTooManyWhereArgs, len(where)) } var user *types.User @@ -182,7 +184,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { } if len(users) != 1 { - return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + return nil, fmt.Errorf("%w, found %d", ErrMultipleUsers, len(users)) } return &users[0], nil diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 5d16c675..7cd88abe 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "os" "sync" @@ -15,6 +16,9 @@ import ( "tailscale.com/util/set" ) +// Sentinel errors for extra records. +var ErrPathIsDirectory = errors.New("path is a directory, only file is supported") + type ExtraRecordsMan struct { mu sync.RWMutex records set.Set[tailcfg.DNSRecord] @@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) { } if fi.IsDir() { - return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) + return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path) } records, hash, err := readExtraRecordsFromPath(path) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 869fe3f3..d8e83154 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -31,6 +31,9 @@ const ( earlyPayloadMagic = "\xff\xff\xffTS" ) +// Sentinel errors for noise server. +var ErrUnsupportedClientVersion = errors.New("unsupported client version") + type noiseServer struct { headscale *Headscale @@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler( } func unsupportedClientError(version tailcfg.CapabilityVersion) error { - return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version) + return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version) } func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index efce647d..82cf9d58 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -13,6 +13,9 @@ import ( "tailscale.com/types/logger" ) +// Sentinel errors for tailsql service. +var ErrNoCertDomains = errors.New("no cert domains available for HTTPS") + func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error { opts := tailsql.Options{ Hostname: "tailsql-headscale", @@ -71,7 +74,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. certDomains := tsNode.CertDomains() if len(certDomains) == 0 { - return errors.New("no cert domains available for HTTPS") + return ErrNoCertDomains } base := "https://" + certDomains[0] go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index be3756a0..e0f4fcdd 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -20,7 +20,10 @@ const ( DatabaseSqlite = "sqlite3" ) -var ErrCannotParsePrefix = errors.New("cannot parse prefix") +var ( + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidRegIDLength = errors.New("registration ID has invalid length") +) type StateUpdateType int @@ -175,7 +178,7 @@ func MustRegistrationID() RegistrationID { func RegistrationIDFromString(str string) (RegistrationID, error) { if len(str) != RegistrationIDLength { - return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) + return "", fmt.Errorf("%w: expected %d characters", ErrInvalidRegIDLength, RegistrationIDLength) } return RegistrationID(str), nil diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index fffe166d..e947e104 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -33,10 +33,12 @@ const ( ) var ( - errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") - errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") - errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") - errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + errInvalidAllocationStrategy = errors.New("invalid prefixes.allocation strategy") ) type IPAllocationStrategy string @@ -929,7 +931,7 @@ func LoadServerConfig() (*Config, error) { } if prefix4 == nil && prefix6 == nil { - return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + return nil, errNoPrefixConfigured } allocStr := viper.GetString("prefixes.allocation") @@ -941,7 +943,8 @@ func LoadServerConfig() (*Config, error) { alloc = IPAllocationStrategyRandom default: return nil, fmt.Errorf( - "config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", + "%w: %q, allowed options: %s, %s", + errInvalidAllocationStrategy, allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom, diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 5140bc44..ea96284c 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -844,7 +844,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer // GetFQDN returns the fully qualified domain name for the node. func (nv NodeView) GetFQDN(baseDomain string) (string, error) { if !nv.Valid() { - return "", errors.New("failed to create valid FQDN: node view is invalid") + return "", fmt.Errorf("failed to create valid FQDN: %w", ErrInvalidNodeView) } return nv.ж.GetFQDN(baseDomain) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index dbcf4f44..c724c909 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -4,6 +4,7 @@ import ( "cmp" "database/sql" "encoding/json" + "errors" "fmt" "net/mail" "net/url" @@ -18,6 +19,9 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for user types. +var ErrCannotParseBool = errors.New("could not parse value as boolean") + type UserID uint64 type Users []User @@ -224,7 +228,7 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { *bit = FlexibleBoolean(pv) default: - return fmt.Errorf("could not parse %v as boolean", v) + return fmt.Errorf("%w: %v", ErrCannotParseBool, v) } return nil diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index dcd58528..bc48f592 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -26,6 +26,21 @@ var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var ErrInvalidHostName = errors.New("invalid hostname") +// Sentinel errors for username validation. +var ( + ErrUsernameTooShort = errors.New("username must be at least 2 characters long") + ErrUsernameMustStartLetter = errors.New("username must start with a letter") + ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'") + ErrUsernameInvalidChar = errors.New("username contains invalid character") +) + +// Sentinel errors for hostname validation. +var ( + ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters") + ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen") + ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot") +) + // ValidateUsername checks if a username is valid. // It must be at least 2 characters long, start with a letter, and contain // only letters, numbers, hyphens, dots, and underscores. @@ -34,12 +49,12 @@ var ErrInvalidHostName = errors.New("invalid hostname") func ValidateUsername(username string) error { // Ensure the username meets the minimum length requirement if len(username) < 2 { - return errors.New("username must be at least 2 characters long") + return ErrUsernameTooShort } // Ensure the username starts with a letter if !unicode.IsLetter(rune(username[0])) { - return errors.New("username must start with a letter") + return ErrUsernameMustStartLetter } atCount := 0 @@ -55,10 +70,10 @@ func ValidateUsername(username string) error { case char == '@': atCount++ if atCount > 1 { - return errors.New("username cannot contain more than one '@'") + return ErrUsernameTooManyAt } default: - return fmt.Errorf("username contains invalid character: '%c'", char) + return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char) } } @@ -70,10 +85,7 @@ func ValidateUsername(username string) error { // The hostname must already be lowercase and contain only valid characters. func ValidateHostname(name string) error { if len(name) < 2 { - return fmt.Errorf( - "hostname %q is too short, must be at least 2 characters", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameTooShort, name) } if len(name) > LabelHostnameLength { return fmt.Errorf( @@ -90,17 +102,11 @@ func ValidateHostname(name string) error { } if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") { - return fmt.Errorf( - "hostname %q cannot start or end with a hyphen", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameHyphenEnds, name) } if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { - return fmt.Errorf( - "hostname %q cannot start or end with a dot", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameDotEnds, name) } if invalidDNSRegex.MatchString(name) { diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 53189656..b4ca0c51 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -16,6 +16,19 @@ import ( "tailscale.com/util/cmpver" ) +// Sentinel errors for URL parsing. +var ( + ErrMultipleURLsFound = errors.New("multiple URLs found") + ErrNoURLFound = errors.New("no URL found") +) + +// Sentinel errors for traceroute parsing. +var ( + ErrTracerouteEmpty = errors.New("empty traceroute output") + ErrTracerouteHeader = errors.New("parsing traceroute header") + ErrTracerouteNotReached = errors.New("traceroute did not reach target") +) + func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { if cmpver.Compare(minimum, toCheck) <= 0 || toCheck == "unstable" || @@ -37,7 +50,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { line = strings.TrimSpace(line) if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") { if urlStr != "" { - return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) + return nil, fmt.Errorf("%w: %s and %s", ErrMultipleURLsFound, urlStr, line) } urlStr = line @@ -45,7 +58,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { } if urlStr == "" { - return nil, errors.New("no URL found") + return nil, ErrNoURLFound } loginURL, err := url.Parse(urlStr) @@ -91,7 +104,7 @@ type Traceroute struct { func ParseTraceroute(output string) (Traceroute, error) { lines := strings.Split(strings.TrimSpace(output), "\n") if len(lines) < 1 { - return Traceroute{}, errors.New("empty traceroute output") + return Traceroute{}, ErrTracerouteEmpty } // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) @@ -99,7 +112,7 @@ func ParseTraceroute(output string) (Traceroute, error) { headerMatches := headerRegex.FindStringSubmatch(lines[0]) if len(headerMatches) < 2 { - return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) + return Traceroute{}, fmt.Errorf("%w: %s", ErrTracerouteHeader, lines[0]) } hostname := headerMatches[1] @@ -255,7 +268,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // If we didn't reach the target, it's unsuccessful if !result.Success { - result.Err = errors.New("traceroute did not reach target") + result.Err = ErrTracerouteNotReached } return result, nil diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index ec72250e..22788bff 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1,7 +1,6 @@ package util import ( - "errors" "net/netip" "strings" "testing" @@ -322,7 +321,7 @@ func TestParseTraceroute(t *testing.T) { }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteNotReached, }, wantErr: false, }, @@ -490,7 +489,7 @@ over a maximum of 30 hops: }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteNotReached, }, wantErr: false, }, From 1a997b649405edb2cdaa21326503a0cc3c5b7018 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:50:57 +0000 Subject: [PATCH 15/30] hscontrol/mapper: define sentinel errors for err113 compliance Add sentinel errors to mapper package for batcher operations: - ErrInvalidNodeID, ErrMapperNil, ErrNodeConnectionNil in batcher.go - ErrInitialMapTimeout, ErrNodeNotFound, ErrBatcherShutdown, ErrConnectionTimeout in batcher_lockfree.go - Update builder.go to use ErrNodeNotFound from same package --- hscontrol/mapper/batcher.go | 13 ++++++++++--- hscontrol/mapper/batcher_lockfree.go | 19 +++++++++++++------ hscontrol/mapper/builder.go | 13 ++++++------- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 0a1e30d0..c1349f75 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -15,6 +15,13 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for batcher operations. +var ( + ErrInvalidNodeID = errors.New("invalid nodeID") + ErrMapperNil = errors.New("mapper is nil") + ErrNodeConnectionNil = errors.New("nodeConnection is nil") +) + var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", @@ -80,11 +87,11 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t } if nodeID == 0 { - return nil, fmt.Errorf("invalid nodeID: %d", nodeID) + return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID) } if mapper == nil { - return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) + return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID) } // Handle self-only responses @@ -135,7 +142,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { - return errors.New("nodeConnection is nil") + return ErrNodeConnectionNil } nodeID := nc.nodeID() diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 3ff3406b..988f0b35 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -16,7 +16,14 @@ import ( "tailscale.com/tailcfg" ) -var errConnectionClosed = errors.New("connection channel already closed") +// Sentinel errors for lock-free batcher operations. +var ( + errConnectionClosed = errors.New("connection channel already closed") + ErrInitialMapTimeout = errors.New("failed to send initial map: timeout") + ErrNodeNotFound = errors.New("node not found") + ErrBatcherShutdown = errors.New("batcher shutting down") + ErrConnectionTimeout = errors.New("connection timeout sending to channel (likely stale connection)") +) // LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. type LockFreeBatcher struct { @@ -88,12 +95,12 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse case c <- initialMap: // Success case <-time.After(5 * time.Second): - log.Error().Uint64("node.id", id.Uint64()).Err(errors.New("timeout")).Msg("Initial map send timeout") + log.Error().Uint64("node.id", id.Uint64()).Err(ErrInitialMapTimeout).Msg("Initial map send timeout") log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). Msg("Initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) - return fmt.Errorf("failed to send initial map to node %d: timeout", id) + return fmt.Errorf("%w for node %d", ErrInitialMapTimeout, id) } // Update connection status @@ -234,7 +241,7 @@ func (b *LockFreeBatcher) worker(workerID int) { nc.updateSentPeers(result.mapResponse) } } else { - result.err = fmt.Errorf("node %d not found", w.nodeID) + result.err = fmt.Errorf("%w: %d", ErrNodeNotFound, w.nodeID) b.workErrors.Add(1) log.Error().Err(result.err). @@ -492,7 +499,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang case result := <-resultCh: return result.mapResponse, result.err case <-b.done: - return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) + return nil, fmt.Errorf("%w: generating map response for node %d", ErrBatcherShutdown, id) } } @@ -707,7 +714,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { case <-time.After(50 * time.Millisecond): // Connection is likely stale - client isn't reading from channel // This catches the case where Docker containers are killed but channels remain open - return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id) + return fmt.Errorf("%w: connection %s", ErrConnectionTimeout, entry.id) } } diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index df0693e3..cd1d9a9d 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -2,7 +2,6 @@ package mapper import ( "cmp" - "errors" "net/netip" "slices" "time" @@ -71,7 +70,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { nv, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -133,7 +132,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -152,7 +151,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -165,7 +164,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -178,7 +177,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -232,7 +231,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - return nil, errors.New("node not found") + return nil, ErrNodeNotFound } // Get unreduced matchers for peer relationship determination. From 7b4c49a91fe472df8c960c2b2a000487357892ec Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 15:52:19 +0000 Subject: [PATCH 16/30] hscontrol/state: fix err113 issues using sentinel errors Update state.go to use ErrNodeNotFound sentinel error for node disconnect error, and debug.go to use ErrUnsupportedPolicyMode from state.go. --- hscontrol/state/debug.go | 2 +- hscontrol/state/state.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 9cad1c04..abb34eb0 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -245,7 +245,7 @@ func (s *State) DebugPolicy() (string, error) { return string(pol), nil default: - return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode) + return "", fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, s.cfg.Policy.Mode) } } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 1004151e..bb929faa 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -500,7 +500,7 @@ func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { }) if !ok { - return nil, fmt.Errorf("node not found: %d", id) + return nil, fmt.Errorf("%w: %d", ErrNodeNotFound, id) } log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected") From f969745db498e0127c332b6e3e9d9c13ac398caa Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:06:19 +0000 Subject: [PATCH 17/30] integration: define sentinel errors for err113 compliance Add sentinel errors across integration test infrastructure and test files to comply with err113 linter requirements. This replaces inline dynamic errors with wrapped static sentinel errors. Files updated: - integration/tsic/tsic.go: Add errNoNetworkSet, errLogoutFailed, errNoIPsReturned, errNoIPv4AddressFound, errBackendStateTimeout, errPeerWaitTimeout, errPeerNotOnline, errPeerNoHostname, errPeerNoDERP, errFileEmpty, errTailscaleVersionRequired - integration/scenario.go: Add errUserAlreadyInNetwork, errNoNetworkNamed, errNoIPAMConfig, errHTTPClientNil, errLoginURLNil, errUnexpectedStatusCode, errNetworkDoesNotExist - integration/helpers.go: Add errExpectedStringNotFound, errUserNotFound, errNoNewClientFound, errUnexpectedClientCount - integration/hsic/hsic.go: Add errDatabaseEmptySchema, errDatabaseFileEmpty, errNoRegularFileInTar - integration/derp_verify_endpoint_test.go: Add errUnexpectedRecvType - cmd/mapresponses/main.go: Add errDirectoryRequired - hscontrol/auth_test.go: Add errNodeNotFoundAfterSetup, errInvalidAuthURLFormat - hscontrol/state/node_store_test.go: Add errTestUpdateNodeFailed, errTestGetNodeFailed, errTestPutNodeFailed --- cmd/mapresponses/main.go | 4 ++- hscontrol/auth_test.go | 10 +++++-- hscontrol/state/node_store_test.go | 18 ++++++++---- integration/derp_verify_endpoint_test.go | 5 +++- integration/helpers.go | 16 ++++++++--- integration/hsic/hsic.go | 10 +++++-- integration/scenario.go | 33 +++++++++++++--------- integration/tsic/tsic.go | 35 ++++++++++++++++-------- 8 files changed, 90 insertions(+), 41 deletions(-) diff --git a/cmd/mapresponses/main.go b/cmd/mapresponses/main.go index af35bc48..1951ca4b 100644 --- a/cmd/mapresponses/main.go +++ b/cmd/mapresponses/main.go @@ -12,6 +12,8 @@ import ( "github.com/juanfont/headscale/integration/integrationutil" ) +var errDirectoryRequired = errors.New("directory is required") + type MapConfig struct { Directory string `flag:"directory,Directory to read map responses from"` } @@ -41,7 +43,7 @@ func main() { // runIntegrationTest executes the integration test workflow. func runOnline(env *command.Env) error { if mapConfig.Directory == "" { - return errors.New("directory is required") + return errDirectoryRequired } resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 8a012ff6..73048d9e 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -17,6 +17,12 @@ import ( "tailscale.com/types/key" ) +// Test sentinel errors. +var ( + errNodeNotFoundAfterSetup = errors.New("node not found after setup") + errInvalidAuthURLFormat = errors.New("invalid AuthURL format") +) + // Interactive step type constants. const ( stepTypeInitialRequest = "initial_request" @@ -579,7 +585,7 @@ func TestAuthenticationFlows(t *testing.T) { }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") if !found { - return "", errors.New("node not found after setup") + return "", errNodeNotFoundAfterSetup } // Expire the node @@ -2716,7 +2722,7 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err idx := strings.LastIndex(authURL, registerPrefix) if idx == -1 { - return "", fmt.Errorf("invalid AuthURL format: %s", authURL) + return "", fmt.Errorf("%w: %s", errInvalidAuthURLFormat, authURL) } idStr := authURL[idx+len(registerPrefix):] diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index b90956aa..9740d063 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -2,6 +2,7 @@ package state import ( "context" + "errors" "fmt" "net/netip" "runtime" @@ -15,6 +16,13 @@ import ( "tailscale.com/types/key" ) +// Test sentinel errors for concurrent operations. +var ( + errTestUpdateNodeFailed = errors.New("UpdateNode failed") + errTestGetNodeFailed = errors.New("GetNode failed") + errTestPutNodeFailed = errors.New("PutNode failed") +) + func TestSnapshotFromNodes(t *testing.T) { tests := []struct { name string @@ -1001,19 +1009,19 @@ func TestNodeStoreRaceConditions(t *testing.T) { n.Hostname = "race-updated" }) if !resultNode.Valid() { - errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestUpdateNodeFailed, gid, j) } case 1: retrieved, found := store.GetNode(nodeID) if !found || !retrieved.Valid() { - errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestGetNodeFailed, gid, j) } case 2: newNode := createConcurrentTestNode(nodeID, "race-put") resultNode := store.PutNode(newNode) if !resultNode.Valid() { - errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestPutNodeFailed, gid, j) } } } @@ -1113,7 +1121,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) if !resultNode.Valid() { - putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) + putResults[idx-1] = fmt.Errorf("%w for node %d", errTestPutNodeFailed, id) } }(i, nodeID) } @@ -1140,7 +1148,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) if !ok || !resultNode.Valid() { - updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) + updateResults[idx-1] = fmt.Errorf("%w for node %d", errTestUpdateNodeFailed, id) } }(i, nodeID) } diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index d2aec30f..c92a25ee 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -1,6 +1,7 @@ package integration import ( + "errors" "fmt" "net" "strconv" @@ -19,6 +20,8 @@ import ( "tailscale.com/types/key" ) +var errUnexpectedRecvType = errors.New("client first Recv was unexpected type") + func TestDERPVerifyEndpoint(t *testing.T) { IntegrationSkip(t) @@ -113,7 +116,7 @@ func DERPVerify( if m, err := c.Recv(); err != nil { result = fmt.Errorf("client first Recv: %w", err) } else if v, ok := m.(derp.ServerInfoMessage); !ok { - result = fmt.Errorf("client first Recv was unexpected type %T", v) + result = fmt.Errorf("%w: %T", errUnexpectedRecvType, v) } if expectSuccess && result != nil { diff --git a/integration/helpers.go b/integration/helpers.go index 59b87cff..38abfdb2 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -28,6 +28,14 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for integration test helpers. +var ( + errExpectedStringNotFound = errors.New("expected string not found in output") + errUserNotFound = errors.New("user not found") + errNoNewClientFound = errors.New("no new client found") + errUnexpectedClientCount = errors.New("unexpected client count") +) + const ( // derpPingTimeout defines the timeout for individual DERP ping operations // Used in DERP connectivity tests to verify relay server communication. @@ -646,7 +654,7 @@ func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []stri } if !strings.Contains(stdout, contains) { - return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + return struct{}{}, fmt.Errorf("executing command, %w: %q not found in %q", errExpectedStringNotFound, contains, stdout) } return struct{}{}, nil @@ -811,7 +819,7 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { } } - return nil, fmt.Errorf("user %s not found", username) + return nil, fmt.Errorf("%w: %s", errUserNotFound, username) } // FindNewClient finds a client that is in the new list but not in the original list. @@ -833,7 +841,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) } } - return nil, errors.New("no new client found") + return nil, errNoNewClientFound } // AddAndLoginClient adds a new tailscale client to a user and logs it in. @@ -873,7 +881,7 @@ func (s *Scenario) AddAndLoginClient( } if len(updatedClients) != len(originalClients)+1 { - return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients)) + return struct{}{}, fmt.Errorf("%w: expected %d clients, got %d", errUnexpectedClientCount, len(originalClients)+1, len(updatedClients)) } newClient, err = FindNewClient(originalClients, updatedClients) diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index c4eaeea8..f2fe5b30 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -53,6 +53,9 @@ var ( errInvalidHeadscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_HEADSCALE_IMAGE format, expected repository:tag") errHeadscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_HEADSCALE_IMAGE must be set in CI") errInvalidPostgresImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_POSTGRES_IMAGE format, expected repository:tag") + errDatabaseEmptySchema = errors.New("database file exists but has no schema") + errDatabaseFileEmpty = errors.New("database file is empty") + errNoRegularFileInTar = errors.New("no regular file found in database tar archive") ) type fileInContainer struct { @@ -861,7 +864,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } if strings.TrimSpace(schemaCheck) == "" { - return errors.New("database file exists but has no schema (empty database)") + return errDatabaseEmptySchema } tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") @@ -914,7 +917,8 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if we actually wrote something if written == 0 { return fmt.Errorf( - "database file is empty (size: %d, header size: %d)", + "%w (size: %d, header size: %d)", + errDatabaseFileEmpty, written, header.Size, ) @@ -924,7 +928,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } } - return errors.New("no regular file found in database tar archive") + return errNoRegularFileInTar } // Execute runs a command inside the Headscale container and returns the diff --git a/integration/scenario.go b/integration/scenario.go index 0b388c0a..d1ebdd51 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -51,9 +51,16 @@ const ( var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES") var ( - errNoHeadscaleAvailable = errors.New("no headscale available") - errNoUserAvailable = errors.New("no user available") - errNoClientFound = errors.New("client not found") + errNoHeadscaleAvailable = errors.New("no headscale available") + errNoUserAvailable = errors.New("no user available") + errNoClientFound = errors.New("client not found") + errUserAlreadyInNetwork = errors.New("users can only have nodes placed in one network") + errNoNetworkNamed = errors.New("no network named") + errNoIPAMConfig = errors.New("no IPAM config found in network") + errHTTPClientNil = errors.New("http client is nil") + errLoginURLNil = errors.New("login url is nil") + errUnexpectedStatusCode = errors.New("unexpected status code") + errNetworkDoesNotExist = errors.New("network does not exist") // AllVersions represents a list of Tailscale versions the suite // uses to test compatibility with the ControlServer. @@ -203,7 +210,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { for _, user := range users { if n2, ok := userToNetwork[user]; ok { - return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) + return nil, fmt.Errorf("%w: %s into %s but already in %s", errUserAlreadyInNetwork, user, network.Network.Name, n2.Network.Name) } mak.Set(&userToNetwork, user, network) @@ -280,7 +287,7 @@ func (s *Scenario) Networks() []*dockertest.Network { func (s *Scenario) Network(name string) (*dockertest.Network, error) { net, ok := s.networks[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("%w: %s", errNoNetworkNamed, name) } return net, nil @@ -289,11 +296,11 @@ func (s *Scenario) Network(name string) (*dockertest.Network, error) { func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { net, ok := s.networks[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("%w: %s", errNoNetworkNamed, name) } if len(net.Network.IPAM.Config) == 0 { - return nil, fmt.Errorf("no IPAM config found in network: %s", name) + return nil, fmt.Errorf("%w: %s", errNoIPAMConfig, name) } pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) @@ -307,7 +314,7 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { res, ok := s.extraServices[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("%w: %s", errNoNetworkNamed, name) } return res, nil @@ -1070,11 +1077,11 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f error, ) { if hc == nil { - return "", nil, fmt.Errorf("%s http client is nil", hostname) + return "", nil, fmt.Errorf("%s %w", hostname, errHTTPClientNil) } if loginURL == nil { - return "", nil, fmt.Errorf("%s login url is nil", hostname) + return "", nil, fmt.Errorf("%s %w", hostname, errLoginURLNil) } log.Printf("%s logging in with url: %s", hostname, loginURL.String()) @@ -1121,13 +1128,13 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if followRedirects && resp.StatusCode != http.StatusOK { log.Printf("body: %s", body) - return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + return body, redirectURL, fmt.Errorf("%s %w %d", hostname, errUnexpectedStatusCode, resp.StatusCode) } if resp.StatusCode >= http.StatusBadRequest { log.Printf("body: %s", body) - return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + return body, redirectURL, fmt.Errorf("%s %w %d", hostname, errUnexpectedStatusCode, resp.StatusCode) } if hc.Jar != nil { @@ -1506,7 +1513,7 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { network, ok := s.networks[s.prefixedNetworkName(networkName)] if !ok { - return nil, fmt.Errorf("network does not exist: %s", networkName) + return nil, fmt.Errorf("%w: %s", errNetworkDoesNotExist, networkName) } webOpts := &dockertest.RunOptions{ diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index fb07896b..3136b6ae 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -59,6 +59,17 @@ var ( errTailscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE must be set in CI for HEAD version") errContainerNotInitialized = errors.New("container not initialized") errFQDNNotYetAvailable = errors.New("FQDN not yet available") + errNoNetworkSet = errors.New("no network set") + errLogoutFailed = errors.New("failed to logout") + errNoIPsReturned = errors.New("no IPs returned yet") + errNoIPv4AddressFound = errors.New("no IPv4 address found") + errBackendStateTimeout = errors.New("timeout waiting for backend state") + errPeerWaitTimeout = errors.New("timeout waiting for peers") + errPeerNotOnline = errors.New("peer is not online") + errPeerNoHostname = errors.New("peer does not have a hostname") + errPeerNoDERP = errors.New("peer does not have a DERP relay") + errFileEmpty = errors.New("file is empty") + errTailscaleVersionRequired = errors.New("tailscale version requirement not met") ) const ( @@ -338,7 +349,7 @@ func New( } if tsic.network == nil { - return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) + return nil, fmt.Errorf("%w, called from: \n%s", errNoNetworkSet, string(debug.Stack())) } tailscaleOptions := &dockertest.RunOptions{ @@ -720,7 +731,7 @@ func (t *TailscaleInContainer) Logout() error { stdout, stderr, _ = t.Execute([]string{"tailscale", "status"}) if !strings.Contains(stdout+stderr, "Logged out.") { - return fmt.Errorf("failed to logout, stdout: %s, stderr: %s", stdout, stderr) + return fmt.Errorf("%w: stdout: %s, stderr: %s", errLogoutFailed, stdout, stderr) } return t.waitForBackendState("NeedsLogin", integrationutil.PeerSyncTimeout()) @@ -832,7 +843,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { } if len(ips) == 0 { - return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname) + return nil, fmt.Errorf("%w for %s", errNoIPsReturned, t.hostname) } return ips, nil @@ -866,7 +877,7 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) { } } - return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname) + return netip.Addr{}, fmt.Errorf("%w for %s", errNoIPv4AddressFound, t.hostname) } func (t *TailscaleInContainer) MustIPv4() netip.Addr { @@ -1211,7 +1222,7 @@ func (t *TailscaleInContainer) waitForBackendState(state string, timeout time.Du for { select { case <-ctx.Done(): - return fmt.Errorf("timeout waiting for backend state %s on %s after %v", state, t.hostname, timeout) + return fmt.Errorf("%w %s on %s after %v", errBackendStateTimeout, state, t.hostname, timeout) case <-ticker.C: status, err := t.Status() if err != nil { @@ -1253,10 +1264,10 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval select { case <-ctx.Done(): if len(lastErrs) > 0 { - return fmt.Errorf("timeout waiting for %d peers on %s after %v, errors: %w", expected, t.hostname, timeout, multierr.New(lastErrs...)) + return fmt.Errorf("%w for %d peers on %s after %v, errors: %w", errPeerWaitTimeout, expected, t.hostname, timeout, multierr.New(lastErrs...)) } - return fmt.Errorf("timeout waiting for %d peers on %s after %v", expected, t.hostname, timeout) + return fmt.Errorf("%w for %d peers on %s after %v", errPeerWaitTimeout, expected, t.hostname, timeout) case <-ticker.C: status, err := t.Status() if err != nil { @@ -1284,15 +1295,15 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval peer := status.Peer[peerKey] if !peer.Online { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %w: %s", t.hostname, errPeerNotOnline, peer.HostName)) } if peer.HostName == "" { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %w: %s", t.hostname, errPeerNoHostname, peer.HostName)) } if peer.Relay == "" { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %w: %s", t.hostname, errPeerNoDERP, peer.HostName)) } } @@ -1578,7 +1589,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) { } if out.Len() == 0 { - return nil, errors.New("file is empty") + return nil, errFileEmpty } return out.Bytes(), nil @@ -1617,7 +1628,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { // This is useful for verifying that policy changes have propagated to the client. func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { - return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version) + return nil, fmt.Errorf("%w: PacketFilter() requires Tailscale 1.56+, current version: %s", errTailscaleVersionRequired, t.version) } nm, err := t.Netmap() From 1462847878e03aa1ead5121a34b31bd6aac926cc Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:11:39 +0000 Subject: [PATCH 18/30] hscontrol: extract magic numbers to named constants Define named constants for magic numbers to improve code clarity: - Batcher timeouts and intervals (initialMapSendTimeout, etc.) - Work channel multiplier - OIDC provider init timeout - CSRF token length - SQLite WAL autocheckpoint default --- hscontrol/app.go | 5 +++- hscontrol/db/sqliteconfig/config.go | 5 +++- hscontrol/mapper/batcher.go | 7 ++++-- hscontrol/mapper/batcher_lockfree.go | 34 +++++++++++++++++++++------- hscontrol/oidc.go | 3 ++- hscontrol/types/config.go | 5 ++-- 6 files changed, 44 insertions(+), 15 deletions(-) diff --git a/hscontrol/app.go b/hscontrol/app.go index 8ce1066f..f7d9ba90 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -67,6 +67,9 @@ var ( ) ) +// oidcProviderInitTimeout is the timeout for initializing the OIDC provider. +const oidcProviderInitTimeout = 30 * time.Second + var ( debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK") debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT") @@ -161,7 +164,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), oidcProviderInitTimeout) defer cancel() oidcProvider, err := NewAuthProviderOIDC( diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go index 23cb4b50..e6386718 100644 --- a/hscontrol/db/sqliteconfig/config.go +++ b/hscontrol/db/sqliteconfig/config.go @@ -22,6 +22,9 @@ var ( const ( // DefaultBusyTimeout is the default busy timeout in milliseconds. DefaultBusyTimeout = 10000 + // DefaultWALAutocheckpoint is the default WAL autocheckpoint value (number of pages). + // SQLite default is 1000 pages. + DefaultWALAutocheckpoint = 1000 ) // JournalMode represents SQLite journal_mode pragma values. @@ -310,7 +313,7 @@ func Default(path string) *Config { BusyTimeout: DefaultBusyTimeout, JournalMode: JournalModeWAL, AutoVacuum: AutoVacuumIncremental, - WALAutocheckpoint: 1000, + WALAutocheckpoint: DefaultWALAutocheckpoint, Synchronous: SynchronousNormal, ForeignKeys: true, TxLock: TxLockImmediate, diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index c1349f75..c5bbda48 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -22,6 +22,10 @@ var ( ErrNodeConnectionNil = errors.New("nodeConnection is nil") ) +// workChannelMultiplier is the multiplier for work channel capacity based on worker count. +// The size is arbitrary chosen, the sizing should be revisited. +const workChannelMultiplier = 200 + var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", @@ -49,8 +53,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB workers: workers, tick: time.NewTicker(batchTime), - // The size of this channel is arbitrary chosen, the sizing should be revisited. - workCh: make(chan work, workers*200), + workCh: make(chan work, workers*workChannelMultiplier), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 988f0b35..2580b7f4 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -25,6 +25,25 @@ var ( ErrConnectionTimeout = errors.New("connection timeout sending to channel (likely stale connection)") ) +// Batcher configuration constants. +const ( + // initialMapSendTimeout is the timeout for sending the initial map response to a new connection. + initialMapSendTimeout = 5 * time.Second + + // offlineNodeCleanupThreshold is how long a node must be offline before it's cleaned up. + offlineNodeCleanupThreshold = 15 * time.Minute + + // offlineNodeCleanupInterval is the interval between cleanup runs. + offlineNodeCleanupInterval = 5 * time.Minute + + // connectionSendTimeout is the timeout for detecting stale connections. + // Kept short to quickly detect Docker containers that are forcefully terminated. + connectionSendTimeout = 50 * time.Millisecond + + // connectionIDBytes is the number of random bytes used for connection IDs. + connectionIDBytes = 8 +) + // LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. type LockFreeBatcher struct { tick *time.Ticker @@ -94,9 +113,9 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse select { case c <- initialMap: // Success - case <-time.After(5 * time.Second): + case <-time.After(initialMapSendTimeout): log.Error().Uint64("node.id", id.Uint64()).Err(ErrInitialMapTimeout).Msg("Initial map send timeout") - log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", initialMapSendTimeout). Msg("Initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) @@ -187,7 +206,7 @@ func (b *LockFreeBatcher) doWork() { } // Create a cleanup ticker for removing truly disconnected nodes - cleanupTicker := time.NewTicker(5 * time.Minute) + cleanupTicker := time.NewTicker(offlineNodeCleanupInterval) defer cleanupTicker.Stop() for { @@ -395,14 +414,13 @@ func (b *LockFreeBatcher) processBatchedChanges() { // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. // TODO(kradalby): reevaluate if we want to keep this. func (b *LockFreeBatcher) cleanupOfflineNodes() { - cleanupThreshold := 15 * time.Minute now := time.Now() var nodesToCleanup []types.NodeID // Find nodes that have been offline for too long b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { - if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold { + if disconnectTime != nil && now.Sub(*disconnectTime) > offlineNodeCleanupThreshold { // Double-check the node doesn't have active connections if nodeConn, exists := b.nodes.Load(nodeID); exists { if !nodeConn.hasActiveConnections() { @@ -417,7 +435,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() { // Clean up the identified nodes for _, nodeID := range nodesToCleanup { log.Info().Uint64("node.id", nodeID.Uint64()). - Dur("offline_duration", cleanupThreshold). + Dur("offline_duration", offlineNodeCleanupThreshold). Msg("Cleaning up node that has been offline for too long") b.nodes.Delete(nodeID) @@ -532,7 +550,7 @@ type multiChannelNodeConn struct { // generateConnectionID generates a unique connection identifier. func generateConnectionID() string { - bytes := make([]byte, 8) + bytes := make([]byte, connectionIDBytes) _, _ = rand.Read(bytes) return hex.EncodeToString(bytes) @@ -711,7 +729,7 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { // Update last used timestamp on successful send entry.lastUsed.Store(time.Now().Unix()) return nil - case <-time.After(50 * time.Millisecond): + case <-time.After(connectionSendTimeout): // Connection is likely stale - client isn't reading from channel // This catches the case where Docker containers are killed but channels remain open return fmt.Errorf("%w: connection %s", ErrConnectionTimeout, entry.id) diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 836e8763..6ab70f78 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -28,6 +28,7 @@ const ( defaultOAuthOptionsCount = 3 registerCacheExpiration = time.Minute * 15 registerCacheCleanup = time.Minute * 20 + csrfTokenLength = 64 ) var ( @@ -614,7 +615,7 @@ func getCookieName(baseName, value string) string { } func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) { - val, err := util.GenerateRandomStringURLSafe(64) + val, err := util.GenerateRandomStringURLSafe(csrfTokenLength) if err != nil { return val, err } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index e947e104..5a32147d 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -29,7 +29,8 @@ const ( PKCEMethodPlain string = "plain" PKCEMethodS256 string = "S256" - defaultNodeStoreBatchSize = 100 + defaultNodeStoreBatchSize = 100 + defaultWALAutocheckpoint = 1000 // SQLite default ) var ( @@ -380,7 +381,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600) viper.SetDefault("database.sqlite.write_ahead_log", true) - viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default + viper.SetDefault("database.sqlite.wal_autocheckpoint", defaultWALAutocheckpoint) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) From b1463dff1e1b9a307ecceabfc423f7353e1210b9 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:14:05 +0000 Subject: [PATCH 19/30] hscontrol: extract more magic numbers to named constants - PolicyVersion for policy manager version - portRangeParts for port range parsing - invalidStringRandomLength for random string generation - hostIPRegexGroups for regex match validation - minTracerouteHeaderMatch for traceroute parsing - nodeKeyPrefixLen for node key truncation - minNameLength for username/hostname validation --- hscontrol/policy/v2/policy.go | 5 ++++- hscontrol/policy/v2/utils.go | 5 ++++- hscontrol/util/dns.go | 7 +++++-- hscontrol/util/string.go | 5 ++++- hscontrol/util/util.go | 24 ++++++++++++++++-------- 5 files changed, 33 insertions(+), 13 deletions(-) diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 8c07e6cc..bc968c3c 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -21,6 +21,9 @@ import ( "tailscale.com/util/deephash" ) +// PolicyVersion is the version number of this policy implementation. +const PolicyVersion = 2 + // ErrInvalidTagOwner is returned when a tag owner is not an Alias type. var ErrInvalidTagOwner = errors.New("tag owner is not an Alias") @@ -739,7 +742,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr } func (pm *PolicyManager) Version() int { - return 2 + return PolicyVersion } func (pm *PolicyManager) DebugString() string { diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index 3fb0d38b..68c5984b 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -9,6 +9,9 @@ import ( "tailscale.com/tailcfg" ) +// portRangeParts is the expected number of parts in a port range (start-end). +const portRangeParts = 2 + // Sentinel errors for port and destination parsing. var ( ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port") @@ -63,7 +66,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { return e == "" }) - if len(rangeParts) != 2 { + if len(rangeParts) != portRangeParts { return nil, ErrInvalidPortRange } diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index bc48f592..c816efc6 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -20,6 +20,9 @@ const ( // value related to RFC 1123 and 952. LabelHostnameLength = 63 + + // minNameLength is the minimum length for usernames and hostnames. + minNameLength = 2 ) var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") @@ -48,7 +51,7 @@ var ( // It cannot contain invalid characters. func ValidateUsername(username string) error { // Ensure the username meets the minimum length requirement - if len(username) < 2 { + if len(username) < minNameLength { return ErrUsernameTooShort } @@ -84,7 +87,7 @@ func ValidateUsername(username string) error { // This function does NOT modify the input - it only validates. // The hostname must already be lowercase and contain only valid characters. func ValidateHostname(name string) error { - if len(name) < 2 { + if len(name) < minNameLength { return fmt.Errorf("%w: %q", ErrHostnameTooShort, name) } if len(name) > LabelHostnameLength { diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index 0a37ec87..60f99420 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -9,6 +9,9 @@ import ( "tailscale.com/tailcfg" ) +// invalidStringRandomLength is the length of random bytes for invalid string generation. +const invalidStringRandomLength = 8 + // GenerateRandomBytes returns securely generated random bytes. // It will return an error if the system's secure random // number generator fails to function correctly, in which @@ -68,7 +71,7 @@ func MustGenerateRandomStringDNSSafe(size int) string { } func InvalidString() string { - hash, _ := GenerateRandomStringDNSSafe(8) + hash, _ := GenerateRandomStringDNSSafe(invalidStringRandomLength) return "invalid-" + hash } diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index b4ca0c51..77c83ece 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -24,9 +24,17 @@ var ( // Sentinel errors for traceroute parsing. var ( - ErrTracerouteEmpty = errors.New("empty traceroute output") - ErrTracerouteHeader = errors.New("parsing traceroute header") - ErrTracerouteNotReached = errors.New("traceroute did not reach target") + ErrTracerouteEmpty = errors.New("empty traceroute output") + ErrTracerouteHeader = errors.New("parsing traceroute header") + ErrTracerouteNotReached = errors.New("traceroute did not reach target") +) + +// Regex match group constants for traceroute parsing. +// The regexes capture hostname (group 1) and IP (group 2), plus the full match (group 0). +const ( + hostIPRegexGroups = 3 + nodeKeyPrefixLen = 8 + minTracerouteHeaderMatch = 2 // full match + hostname ) func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { @@ -111,7 +119,7 @@ func ParseTraceroute(output string) (Traceroute, error) { headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`) headerMatches := headerRegex.FindStringSubmatch(lines[0]) - if len(headerMatches) < 2 { + if len(headerMatches) < minTracerouteHeaderMatch { return Traceroute{}, fmt.Errorf("%w: %s", ErrTracerouteHeader, lines[0]) } @@ -210,12 +218,12 @@ func ParseTraceroute(output string) (Traceroute, error) { hopHostname = "*" // Skip any remaining asterisks remainder = strings.TrimLeft(remainder, "* ") - } else if hostMatch := hostIPRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 { + } else if hostMatch := hostIPRegex.FindStringSubmatch(remainder); len(hostMatch) >= hostIPRegexGroups { // Format: hostname (IP) hopHostname = hostMatch[1] hopIP, _ = netip.ParseAddr(hostMatch[2]) remainder = strings.TrimSpace(remainder[len(hostMatch[0]):]) - } else if hostMatch := hostIPBracketRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 { + } else if hostMatch := hostIPBracketRegex.FindStringSubmatch(remainder); len(hostMatch) >= hostIPRegexGroups { // Format: hostname [IP] (Windows) hopHostname = hostMatch[1] hopIP, _ = netip.ParseAddr(hostMatch[2]) @@ -307,8 +315,8 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri } keyPrefix := key - if len(key) > 8 { - keyPrefix = key[:8] + if len(key) > nodeKeyPrefixLen { + keyPrefix = key[:nodeKeyPrefixLen] } return "node-" + keyPrefix From 12b3da01810c820777674facdaac1eb5e6dd4eaa Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:18:07 +0000 Subject: [PATCH 20/30] all: extract remaining magic numbers to named constants Define named constants for various timeout and configuration values: - Connection validation and retry timeouts in helpers - Peer sync timeouts in integrationutil - Run ID hash length and parts in dockertestutil - Container memory limits and directory permissions - HTML parsing split count in scenario - Container restart and backoff timeouts in tsic - Stats calculation constants in cmd/hi --- cmd/hi/docker.go | 13 +++++++++--- cmd/hi/stats.go | 10 +++++++-- integration/dockertestutil/config.go | 11 +++++++--- integration/dockertestutil/network.go | 8 ++++++- integration/helpers.go | 30 ++++++++++++++++++--------- integration/hsic/hsic.go | 5 +++-- integration/integrationutil/util.go | 16 +++++++++++--- integration/scenario.go | 8 ++++--- integration/tsic/tsic.go | 19 +++++++++++------ 9 files changed, 87 insertions(+), 33 deletions(-) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 698e9d54..bcb6dbc5 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -30,6 +30,13 @@ var ( ErrMemoryLimitExceeded = errors.New("container exceeded memory limits") ) +// Docker container constants. +const ( + containerFinalStateWait = 10 * time.Second + containerStateCheckInterval = 500 * time.Millisecond + dirPermissions = 0o755 +) + // runTestContainer executes integration tests in a Docker container. func runTestContainer(ctx context.Context, config *RunConfig) error { cli, err := createDockerClient() @@ -351,8 +358,8 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC testContainers := getCurrentTestContainers(containers, testContainerID, verbose) // Wait for all test containers to reach a final state - maxWaitTime := 10 * time.Second - checkInterval := 500 * time.Millisecond + maxWaitTime := containerFinalStateWait + checkInterval := containerStateCheckInterval timeout := time.After(maxWaitTime) ticker := time.NewTicker(checkInterval) @@ -720,7 +727,7 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // extractContainerArtifacts saves logs and tar files from a container. func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { // Ensure the logs directory exists - if err := os.MkdirAll(logsDir, 0o755); err != nil { + if err := os.MkdirAll(logsDir, dirPermissions); err != nil { return fmt.Errorf("failed to create logs directory: %w", err) } diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index 00a6cc4f..173b5332 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -21,6 +21,12 @@ import ( // Sentinel errors for stats collection. var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started") +// Stats calculation constants. +const ( + bytesPerKB = 1024 + percentageMultiplier = 100.0 +) + // ContainerStats represents statistics for a single container. type ContainerStats struct { ContainerID string @@ -269,7 +275,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe } // Calculate memory usage in MB - memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024) + memoryMB := float64(stats.MemoryStats.Usage) / (bytesPerKB * bytesPerKB) // Store the sample (skip first sample since CPU calculation needs previous stats) if prevStats != nil { @@ -314,7 +320,7 @@ func calculateCPUPercent(prevStats, stats *container.StatsResponse) float64 { numCPUs = 1.0 } - return (cpuDelta / systemDelta) * numCPUs * 100.0 + return (cpuDelta / systemDelta) * numCPUs * percentageMultiplier } return 0.0 diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index 88b2712c..75bc872c 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -14,6 +14,11 @@ const ( // TimestampFormatRunID is used for generating unique run identifiers // Format: "20060102-150405" provides compact date-time for file/directory names. TimestampFormatRunID = "20060102-150405" + + // runIDHashLength is the length of the random hash in run IDs. + runIDHashLength = 6 + // runIDParts is the number of parts in a run ID (YYYYMMDD-HHMMSS-HASH). + runIDParts = 3 ) // GetIntegrationRunID returns the run ID for the current integration test session. @@ -46,7 +51,7 @@ func GenerateRunID() string { timestamp := now.Format(TimestampFormatRunID) // Add a short random hash to ensure uniqueness - randomHash := util.MustGenerateRandomStringDNSSafe(6) + randomHash := util.MustGenerateRandomStringDNSSafe(runIDHashLength) return fmt.Sprintf("%s-%s", timestamp, randomHash) } @@ -55,9 +60,9 @@ func GenerateRunID() string { // Expects format: "prefix-YYYYMMDD-HHMMSS-HASH". func ExtractRunIDFromContainerName(containerName string) string { parts := strings.Split(containerName, "-") - if len(parts) >= 3 { + if len(parts) >= runIDParts { // Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH) - return strings.Join(parts[len(parts)-3:], "-") + return strings.Join(parts[len(parts)-runIDParts:], "-") } panic("unexpected container name format: " + containerName) diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index d07841f1..95b69b88 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -13,6 +13,12 @@ import ( var ErrContainerNotFound = errors.New("container not found") +// Docker memory constants. +const ( + bytesPerKB = 1024 + containerMemoryGB = 2 +) + func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Network, error) { networks, err := pool.NetworksByName(name) if err != nil { @@ -172,6 +178,6 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) { // DockerMemoryLimit sets memory limit and disables OOM kill for containers. func DockerMemoryLimit(config *docker.HostConfig) { - config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes + config.Memory = containerMemoryGB * bytesPerKB * bytesPerKB * bytesPerKB // 2GB in bytes config.OOMKillDisable = true } diff --git a/integration/helpers.go b/integration/helpers.go index 38abfdb2..4f9c1018 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -53,6 +53,16 @@ const ( // TimestampFormatRunID is used for generating unique run identifiers // Format: "20060102-150405" provides compact date-time for file/directory names. TimestampFormatRunID = "20060102-150405" + + // Connection validation timeouts. + connectionValidationTimeout = 120 * time.Second + onlineCheckRetryInterval = 2 * time.Second + batcherValidationTimeout = 15 * time.Second + nodestoreValidationTimeout = 20 * time.Second + mapResponseTimeout = 60 * time.Second + netInfoRetryInterval = 5 * time.Second + backoffMaxElapsedTime = 10 * time.Second + backoffRetryInterval = 500 * time.Millisecond ) // NodeSystemStatus represents the status of a node across different systems. @@ -134,7 +144,7 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() - requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", connectionValidationTimeout) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") } @@ -144,7 +154,7 @@ func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNo func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() - requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", connectionValidationTimeout) } // validateReloginComplete performs comprehensive validation after client relogin. @@ -153,7 +163,7 @@ func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() - requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", connectionValidationTimeout) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin") } @@ -369,7 +379,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer } assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) - }, timeout, 2*time.Second, message) + }, timeout, onlineCheckRetryInterval, message) } // requireAllClientsOfflineStaged validates offline state with staged timeouts for different components. @@ -398,7 +408,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec } assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") - }, 15*time.Second, 1*time.Second, "batcher disconnection validation") + }, batcherValidationTimeout, 1*time.Second, "batcher disconnection validation") // Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay) t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes)) @@ -424,7 +434,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec } assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") - }, 20*time.Second, 1*time.Second, "nodestore offline validation") + }, nodestoreValidationTimeout, 1*time.Second, "nodestore offline validation") // Stage 3: Verify map response propagation (longest delay due to peer update timing) t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes)) @@ -466,7 +476,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec } assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") - }, 60*time.Second, 2*time.Second, "map response propagation validation") + }, mapResponseTimeout, onlineCheckRetryInterval, "map response propagation validation") t.Logf("All stages completed: nodes are fully offline across all systems") } @@ -528,7 +538,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) } - }, timeout, 5*time.Second, message) + }, timeout, netInfoRetryInterval, message) endTime := time.Now() duration := endTime.Sub(startTime) @@ -658,7 +668,7 @@ func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []stri } return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) assert.NoError(t, err) } @@ -890,7 +900,7 @@ func (s *Scenario) AddAndLoginClient( } return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewConstantBackOff(backoffRetryInterval)), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) if err != nil { return nil, fmt.Errorf("timeout waiting for new client: %w", err) } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index f2fe5b30..1dc5c2d3 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -46,6 +46,7 @@ const ( tlsKeyPath = "/etc/headscale/tls.key" headscaleDefaultPort = 8080 IntegrationTestDockerFileName = "Dockerfile.integration" + dirPermissions = 0o755 ) var ( @@ -720,7 +721,7 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { // extractTarToDirectory extracts a tar archive to a directory. func extractTarToDirectory(tarData []byte, targetDir string) error { - if err := os.MkdirAll(targetDir, 0o755); err != nil { + if err := os.MkdirAll(targetDir, dirPermissions); err != nil { return fmt.Errorf("failed to create directory %s: %w", targetDir, err) } @@ -784,7 +785,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { } case tar.TypeReg: // Ensure parent directories exist - if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + if err := os.MkdirAll(filepath.Dir(targetPath), dirPermissions); err != nil { return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) } diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 5604af32..3e257a8e 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -22,19 +22,29 @@ import ( "tailscale.com/tailcfg" ) +// Integration test timing constants. +const ( + // peerSyncTimeoutCI is the peer sync timeout for CI environments. + peerSyncTimeoutCI = 120 * time.Second + // peerSyncTimeoutDev is the peer sync timeout for development environments. + peerSyncTimeoutDev = 60 * time.Second + // peerSyncRetryIntervalMs is the retry interval for peer sync checks. + peerSyncRetryIntervalMs = 100 +) + // PeerSyncTimeout returns the timeout for peer synchronization based on environment: // 60s for dev, 120s for CI. func PeerSyncTimeout() time.Duration { if util.IsCI() { - return 120 * time.Second + return peerSyncTimeoutCI } - return 60 * time.Second + return peerSyncTimeoutDev } // PeerSyncRetryInterval returns the retry interval for peer synchronization checks. func PeerSyncRetryInterval() time.Duration { - return 100 * time.Millisecond + return peerSyncRetryIntervalMs * time.Millisecond } func WriteFileToContainer( diff --git a/integration/scenario.go b/integration/scenario.go index d1ebdd51..73a45edd 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -46,6 +46,8 @@ import ( const ( scenarioHashLength = 6 + // expectedHTMLSplitParts is the expected number of parts when splitting HTML for key extraction. + expectedHTMLSplitParts = 2 ) var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES") @@ -1153,17 +1155,17 @@ var errParseAuthPage = errors.New("failed to parse auth page") func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { // see api.go HTML template codeSep := strings.Split(string(body), "") - if len(codeSep) != 2 { + if len(codeSep) != expectedHTMLSplitParts { return errParseAuthPage } keySep := strings.Split(codeSep[0], "key ") - if len(keySep) != 2 { + if len(keySep) != expectedHTMLSplitParts { return errParseAuthPage } key := keySep[1] - key = strings.SplitN(key, " ", 2)[0] + key = strings.SplitN(key, " ", expectedHTMLSplitParts)[0] log.Printf("registering node %s", key) if headscale, err := s.Headscale(); err == nil { diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 3136b6ae..b943ee13 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -44,6 +44,13 @@ const ( dockerContextPath = "../." caCertRoot = "/usr/local/share/ca-certificates" dockerExecuteTimeout = 60 * time.Second + + // Container restart and backoff timeouts. + containerRestartTimeout = 30 // seconds, used by Docker API + tailscaleVersionTimeout = 5 * time.Second + containerRestartBackoff = 30 * time.Second + backoffMaxElapsedTime = 10 * time.Second + curlFailFastMaxTime = 2 * time.Second ) var ( @@ -747,7 +754,7 @@ func (t *TailscaleInContainer) Restart() error { } // Use Docker API to restart the container - err := t.pool.Client.RestartContainer(t.container.Container.ID, 30) + err := t.pool.Client.RestartContainer(t.container.Container.ID, containerRestartTimeout) if err != nil { return fmt.Errorf("failed to restart container %s: %w", t.hostname, err) } @@ -756,13 +763,13 @@ func (t *TailscaleInContainer) Restart() error { // We use exponential backoff to poll until we can successfully execute a command _, err = backoff.Retry(context.Background(), func() (struct{}, error) { // Try to execute a simple command to verify the container is responsive - _, _, err := t.Execute([]string{"tailscale", "version"}, dockertestutil.ExecuteCommandTimeout(5*time.Second)) + _, _, err := t.Execute([]string{"tailscale", "version"}, dockertestutil.ExecuteCommandTimeout(tailscaleVersionTimeout)) if err != nil { return struct{}{}, fmt.Errorf("container not ready: %w", err) } return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(30*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(containerRestartBackoff)) if err != nil { return fmt.Errorf("timeout waiting for container %s to restart and become ready: %w", t.hostname, err) } @@ -847,7 +854,7 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { } return ips, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) if err != nil { return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err) } @@ -1151,7 +1158,7 @@ func (t *TailscaleInContainer) FQDN() (string, error) { } return status.Self.DNSName, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) if err != nil { return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err) } @@ -1507,7 +1514,7 @@ func (t *TailscaleInContainer) CurlFailFast(url string) (string, error) { // Use aggressive timeouts for fast failure detection return t.Curl(url, WithCurlConnectionTimeout(1*time.Second), - WithCurlMaxTime(2*time.Second), + WithCurlMaxTime(curlFailFastMaxTime), WithCurlRetry(1)) } From 8bfd508cf0eceea6cb09b59801827007ea6dd91f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:20:53 +0000 Subject: [PATCH 21/30] all: apply testifylint fixes and correct auto-fix issues - Apply testifylint auto-fixes (assert.Positive, fmt.Sprintf in assertions) - Fix incorrect := to = conversions introduced by auto-fixer - Revert broken slices.AppendSeq FIXME placeholder --- cmd/hi/docker.go | 15 ++-- cmd/hi/main.go | 9 ++- cmd/hi/stats.go | 5 +- hscontrol/db/node.go | 16 ++-- hscontrol/policy/v2/types.go | 137 ++++++++++++++++++++++------------- integration/helpers.go | 4 +- integration/hsic/hsic.go | 6 +- integration/scenario.go | 32 ++++---- integration/tsic/tsic.go | 10 +-- 9 files changed, 143 insertions(+), 91 deletions(-) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index bcb6dbc5..3062e5ec 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -70,7 +70,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { log.Printf("Running pre-test cleanup...") } - if err := cleanupBeforeTest(ctx); err != nil && config.Verbose { + err := cleanupBeforeTest(ctx) + if err != nil && config.Verbose { log.Printf("Warning: pre-test cleanup failed: %v", err) } } @@ -123,7 +124,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection immediately - no need for complex retry logic // The new implementation monitors Docker events and will catch containers as they start - if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil { + err := statsCollector.StartCollection(ctx, runID, config.Verbose) + if err != nil { if config.Verbose { log.Printf("Warning: failed to start stats collection: %v", err) } @@ -135,7 +137,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { exitCode, err := streamAndWait(ctx, cli, resp.ID) // Ensure all containers have finished and logs are flushed before extracting artifacts - if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose { + waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose) + if waitErr != nil && config.Verbose { log.Printf("Warning: failed to wait for container finalization: %v", waitErr) } @@ -648,7 +651,8 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi for _, cont := range currentTestContainers { // Extract container logs and tar files - if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { + err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose) + if err != nil { if verbose { log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err) } @@ -727,7 +731,8 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // extractContainerArtifacts saves logs and tar files from a container. func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { // Ensure the logs directory exists - if err := os.MkdirAll(logsDir, dirPermissions); err != nil { + err := os.MkdirAll(logsDir, dirPermissions) + if err != nil { return fmt.Errorf("failed to create logs directory: %w", err) } diff --git a/cmd/hi/main.go b/cmd/hi/main.go index 0c9adc30..2bbfefe0 100644 --- a/cmd/hi/main.go +++ b/cmd/hi/main.go @@ -79,15 +79,18 @@ func main() { } func cleanAll(ctx context.Context) error { - if err := killTestContainers(ctx); err != nil { + err := killTestContainers(ctx) + if err != nil { return err } - if err := pruneDockerNetworks(ctx); err != nil { + err = pruneDockerNetworks(ctx) + if err != nil { return err } - if err := cleanOldImages(ctx); err != nil { + err = cleanOldImages(ctx) + if err != nil { return err } diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index 173b5332..47da89c4 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -23,7 +23,7 @@ var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already star // Stats calculation constants. const ( - bytesPerKB = 1024 + bytesPerKB = 1024 percentageMultiplier = 100.0 ) @@ -259,7 +259,8 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe return default: var stats container.StatsResponse - if err := decoder.Decode(&stats); err != nil { + err := decoder.Decode(&stats) + if err != nil { // EOF is expected when container stops or stream ends if err.Error() != "EOF" && verbose { log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 04f9a621..10412d36 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -36,7 +36,7 @@ var ( "node not found in registration cache", ) ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") - ErrNameNotUnique = errors.New("name is not unique") + ErrNameNotUnique = errors.New("name is not unique") ) // ListPeers returns peers of node, regardless of any Policy or if the node is expired. @@ -229,7 +229,8 @@ func SetApprovedRoutes( ) error { if len(routes) == 0 { // if no routes are provided, we remove all - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil { + err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error + if err != nil { return fmt.Errorf("removing approved routes: %w", err) } @@ -278,13 +279,15 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { func RenameNode(tx *gorm.DB, nodeID types.NodeID, newName string, ) error { - if err := util.ValidateHostname(newName); err != nil { + err := util.ValidateHostname(newName) + if err != nil { return fmt.Errorf("renaming node: %w", err) } // Check if the new name is unique var count int64 - if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { + err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error + if err != nil { return fmt.Errorf("failed to check name uniqueness: %w", err) } @@ -494,8 +497,9 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func isUniqueName(tx *gorm.DB, name string) (bool, error) { nodes := types.Nodes{} - if err := tx. - Where("given_name = ?", name).Find(&nodes).Error; err != nil { + err := tx. + Where("given_name = ?", name).Find(&nodes).Error + if err != nil { return false, err } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 4ff5dd1a..59b8ba6f 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -38,12 +38,12 @@ var ErrUndefinedTagReference = errors.New("references undefined tag") // Sentinel errors for type/alias validation. var ( - ErrUnknownAliasType = errors.New("unknown alias type") - ErrUnknownOwnerType = errors.New("unknown owner type") + ErrUnknownAliasType = errors.New("unknown alias type") + ErrUnknownOwnerType = errors.New("unknown owner type") ErrUnknownAutoApproverType = errors.New("unknown auto approver type") - ErrInvalidAlias = errors.New("invalid alias") - ErrInvalidAutoApprover = errors.New("invalid auto approver") - ErrInvalidOwner = errors.New("invalid owner") + ErrInvalidAlias = errors.New("invalid alias") + ErrInvalidAutoApprover = errors.New("invalid auto approver") + ErrInvalidOwner = errors.New("invalid owner") ) // Sentinel errors for format validation. @@ -65,16 +65,16 @@ var ( // Sentinel errors for resolution/lookup failures. var ( - ErrUserNotFound = errors.New("user not found") - ErrMultipleUsersFound = errors.New("multiple users found") - ErrHostNotResolved = errors.New("unable to resolve host") - ErrGroupNotDefined = errors.New("group not defined in policy") - ErrTagNotDefined = errors.New("tag not defined in policy") - ErrHostNotDefined = errors.New("host not defined in policy") - ErrInvalidIPAddress = errors.New("invalid IP address") - ErrNestedGroups = errors.New("nested groups not allowed") - ErrInvalidGroupMember = errors.New("invalid group member type") - ErrGroupValueNotArray = errors.New("group value must be an array") + ErrUserNotFound = errors.New("user not found") + ErrMultipleUsersFound = errors.New("multiple users found") + ErrHostNotResolved = errors.New("unable to resolve host") + ErrGroupNotDefined = errors.New("group not defined in policy") + ErrTagNotDefined = errors.New("tag not defined in policy") + ErrHostNotDefined = errors.New("host not defined in policy") + ErrInvalidIPAddress = errors.New("invalid IP address") + ErrNestedGroups = errors.New("nested groups not allowed") + ErrInvalidGroupMember = errors.New("invalid group member type") + ErrGroupValueNotArray = errors.New("group value must be an array") ErrAutoApproverNotAlias = errors.New("auto approver is not an alias") ) @@ -91,10 +91,10 @@ var ( // Sentinel errors for SSH aliases. var ( - ErrAliasNotSupportedSSHSrc = errors.New("alias type not supported for SSH source") - ErrAliasNotSupportedSSHDst = errors.New("alias type not supported for SSH destination") - ErrUnknownSSHSrcAliasType = errors.New("unknown SSH source alias type") - ErrUnknownSSHDstAliasType = errors.New("unknown SSH destination alias type") + ErrAliasNotSupportedSSHSrc = errors.New("alias type not supported for SSH source") + ErrAliasNotSupportedSSHDst = errors.New("alias type not supported for SSH destination") + ErrUnknownSSHSrcAliasType = errors.New("unknown SSH source alias type") + ErrUnknownSSHDstAliasType = errors.New("unknown SSH destination alias type") ) // Sentinel errors for policy parsing. @@ -212,7 +212,8 @@ func (p Prefix) MarshalJSON() ([]byte, error) { func (u *Username) UnmarshalJSON(b []byte) error { *u = Username(strings.Trim(string(b), `"`)) - if err := u.Validate(); err != nil { + err := u.Validate() + if err != nil { return err } @@ -306,7 +307,8 @@ func (g Group) Validate() error { func (g *Group) UnmarshalJSON(b []byte) error { *g = Group(strings.Trim(string(b), `"`)) - if err := g.Validate(); err != nil { + err := g.Validate() + if err != nil { return err } @@ -371,7 +373,8 @@ func (t Tag) Validate() error { func (t *Tag) UnmarshalJSON(b []byte) error { *t = Tag(strings.Trim(string(b), `"`)) - if err := t.Validate(); err != nil { + err := t.Validate() + if err != nil { return err } @@ -421,7 +424,8 @@ func (h Host) Validate() error { func (h *Host) UnmarshalJSON(b []byte) error { *h = Host(strings.Trim(string(b), `"`)) - if err := h.Validate(); err != nil { + err := h.Validate() + if err != nil { return err } @@ -582,7 +586,8 @@ func (ag AutoGroup) Validate() error { func (ag *AutoGroup) UnmarshalJSON(b []byte) error { *ag = AutoGroup(strings.Trim(string(b), `"`)) - if err := ag.Validate(); err != nil { + err := ag.Validate() + if err != nil { return err } @@ -669,7 +674,8 @@ type AliasWithPorts struct { func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { var v any - if err := json.Unmarshal(b, &v); err != nil { + err := json.Unmarshal(b, &v) + if err != nil { return err } @@ -1049,14 +1055,16 @@ func (g Groups) Contains(group *Group) error { func (g *Groups) UnmarshalJSON(b []byte) error { // First unmarshal as a generic map to validate group names first var rawMap map[string]any - if err := json.Unmarshal(b, &rawMap); err != nil { + err := json.Unmarshal(b, &rawMap) + if err != nil { return err } // Validate group names first before checking data types for key := range rawMap { group := Group(key) - if err := group.Validate(); err != nil { + err := group.Validate() + if err != nil { return err } } @@ -1095,7 +1103,8 @@ func (g *Groups) UnmarshalJSON(b []byte) error { for _, u := range value { username := Username(u) - if err := username.Validate(); err != nil { + err := username.Validate() + if err != nil { if isGroup(u) { return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group) } @@ -1117,7 +1126,8 @@ type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { var rawHosts map[string]string - if err := json.Unmarshal(b, &rawHosts, policyJSONOpts...); err != nil { + err := json.Unmarshal(b, &rawHosts, policyJSONOpts...) + if err != nil { return err } @@ -1125,12 +1135,14 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { for key, value := range rawHosts { host := Host(key) - if err := host.Validate(); err != nil { + err := host.Validate() + if err != nil { return err } var prefix Prefix - if err := prefix.parseString(value); err != nil { + err = prefix.parseString(value) + if err != nil { return fmt.Errorf("%w: hostname %q value %q", ErrInvalidIPAddress, key, value) } @@ -1476,7 +1488,8 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { *p = Protocol(strings.ToLower(str)) // Validate the protocol - if err := p.validate(); err != nil { + err := p.validate() + if err != nil { return err } @@ -1732,12 +1745,14 @@ func (p *Policy) validate() error { case *AutoGroup: ag := src - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSrc(ag); err != nil { + err = validateAutogroupForSrc(ag) + if err != nil { errs = append(errs, err) continue } @@ -1748,7 +1763,8 @@ func (p *Policy) validate() error { } case *Tag: tagOwner := src - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1764,12 +1780,14 @@ func (p *Policy) validate() error { case *AutoGroup: ag := dst.Alias.(*AutoGroup) - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForDst(ag); err != nil { + err = validateAutogroupForDst(ag) + if err != nil { errs = append(errs, err) continue } @@ -1780,14 +1798,16 @@ func (p *Policy) validate() error { } case *Tag: tagOwner := dst.Alias.(*Tag) - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } } // Validate protocol-port compatibility - if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil { + err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations) + if err != nil { errs = append(errs, err) } } @@ -1796,7 +1816,8 @@ func (p *Policy) validate() error { for _, user := range ssh.Users { if strings.HasPrefix(string(user), "autogroup:") { maybeAuto := AutoGroup(user) - if err := validateAutogroupForSSHUser(&maybeAuto); err != nil { + err := validateAutogroupForSSHUser(&maybeAuto) + if err != nil { errs = append(errs, err) continue } @@ -1808,23 +1829,27 @@ func (p *Policy) validate() error { case *AutoGroup: ag := src - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSSHSrc(ag); err != nil { + err = validateAutogroupForSSHSrc(ag) + if err != nil { errs = append(errs, err) continue } case *Group: g := src - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1834,18 +1859,21 @@ func (p *Policy) validate() error { switch dst := dst.(type) { case *AutoGroup: ag := dst - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSSHDst(ag); err != nil { + err = validateAutogroupForSSHDst(ag) + if err != nil { errs = append(errs, err) continue } case *Tag: tagOwner := dst - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1857,7 +1885,8 @@ func (p *Policy) validate() error { switch tagOwner := tagOwner.(type) { case *Group: g := tagOwner - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: @@ -1882,12 +1911,14 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1898,12 +1929,14 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver - if err := p.TagOwners.Contains(tagOwner); err != nil { + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } diff --git a/integration/helpers.go b/integration/helpers.go index 4f9c1018..b1701f8f 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -378,7 +378,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer stateStr = "online" } - assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) + assert.True(c, allMatch, "Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr) }, timeout, onlineCheckRetryInterval, message) } @@ -534,7 +534,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has a valid DERP server (PreferredDERP should be > 0) preferredDERP := node.Hostinfo.NetInfo.PreferredDERP - assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) + assert.Positive(c, preferredDERP, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 1dc5c2d3..3eef2d97 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -721,7 +721,8 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { // extractTarToDirectory extracts a tar archive to a directory. func extractTarToDirectory(tarData []byte, targetDir string) error { - if err := os.MkdirAll(targetDir, dirPermissions); err != nil { + err := os.MkdirAll(targetDir, dirPermissions) + if err != nil { return fmt.Errorf("failed to create directory %s: %w", targetDir, err) } @@ -780,7 +781,8 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { switch header.Typeflag { case tar.TypeDir: // Create directory - if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + err := os.MkdirAll(targetPath, os.FileMode(header.Mode)) + if err != nil { return fmt.Errorf("failed to create directory %s: %w", targetPath, err) } case tar.TypeReg: diff --git a/integration/scenario.go b/integration/scenario.go index 73a45edd..9ae0c4fc 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -53,16 +53,16 @@ const ( var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES") var ( - errNoHeadscaleAvailable = errors.New("no headscale available") - errNoUserAvailable = errors.New("no user available") - errNoClientFound = errors.New("client not found") - errUserAlreadyInNetwork = errors.New("users can only have nodes placed in one network") - errNoNetworkNamed = errors.New("no network named") - errNoIPAMConfig = errors.New("no IPAM config found in network") - errHTTPClientNil = errors.New("http client is nil") - errLoginURLNil = errors.New("login url is nil") - errUnexpectedStatusCode = errors.New("unexpected status code") - errNetworkDoesNotExist = errors.New("network does not exist") + errNoHeadscaleAvailable = errors.New("no headscale available") + errNoUserAvailable = errors.New("no user available") + errNoClientFound = errors.New("client not found") + errUserAlreadyInNetwork = errors.New("users can only have nodes placed in one network") + errNoNetworkNamed = errors.New("no network named") + errNoIPAMConfig = errors.New("no IPAM config found in network") + errHTTPClientNil = errors.New("http client is nil") + errLoginURLNil = errors.New("login url is nil") + errUnexpectedStatusCode = errors.New("unexpected status code") + errNetworkDoesNotExist = errors.New("network does not exist") // AllVersions represents a list of Tailscale versions the suite // uses to test compatibility with the ControlServer. @@ -391,13 +391,15 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { if s.mockOIDC.r != nil { s.mockOIDC.r.Close() - if err := s.mockOIDC.r.Close(); err != nil { + err := s.mockOIDC.r.Close() + if err != nil { log.Printf("failed to tear down oidc server: %s", err) } } for _, network := range s.networks { - if err := network.Close(); err != nil { + err := network.Close() + if err != nil { log.Printf("failed to tear down network: %s", err) } } @@ -775,7 +777,8 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat }) } - if err := user.syncWaitGroup.Wait(); err != nil { + err := user.syncWaitGroup.Wait() + if err != nil { allErrors = append(allErrors, err) } } @@ -938,7 +941,8 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("client %s is ready", client.Hostname()) } - if err := user.joinWaitGroup.Wait(); err != nil { + err := user.joinWaitGroup.Wait() + if err != nil { return err } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index b943ee13..b270fab3 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -46,11 +46,11 @@ const ( dockerExecuteTimeout = 60 * time.Second // Container restart and backoff timeouts. - containerRestartTimeout = 30 // seconds, used by Docker API - tailscaleVersionTimeout = 5 * time.Second - containerRestartBackoff = 30 * time.Second - backoffMaxElapsedTime = 10 * time.Second - curlFailFastMaxTime = 2 * time.Second + containerRestartTimeout = 30 // seconds, used by Docker API + tailscaleVersionTimeout = 5 * time.Second + containerRestartBackoff = 30 * time.Second + backoffMaxElapsedTime = 10 * time.Second + curlFailFastMaxTime = 2 * time.Second ) var ( From 58b532ae3c791bf13cd62a5c405a9fb4d4566bd0 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:28:35 +0000 Subject: [PATCH 22/30] all: apply lint auto-fixes and update test expectations Apply additional golangci-lint auto-fixes (wsl_v5, formatting) and update SSH policy test error message expectations to match the new sentinel error formats introduced in the err113 fixes. --- cmd/headscale/cli/policy.go | 3 ++- cmd/headscale/cli/users.go | 2 +- cmd/hi/doctor.go | 3 ++- cmd/hi/run.go | 3 ++- hscontrol/db/db.go | 30 ++++++++++++++++-------- hscontrol/db/sqliteconfig/config.go | 3 ++- hscontrol/db/text_serialiser.go | 6 ++--- hscontrol/db/users.go | 6 +++-- hscontrol/derp/server/derp_server.go | 4 +++- hscontrol/handlers.go | 3 ++- hscontrol/mapper/batcher_lockfree.go | 3 ++- hscontrol/mapper/batcher_test.go | 8 ++++--- hscontrol/noise.go | 3 ++- hscontrol/policy/policy_test.go | 4 ++-- hscontrol/policy/v2/policy_test.go | 2 +- hscontrol/policy/v2/types_test.go | 8 +++---- hscontrol/poll.go | 6 +++-- hscontrol/state/maprequest_test.go | 1 - hscontrol/state/state.go | 9 ++++--- hscontrol/types/config.go | 23 ++++++++++-------- hscontrol/types/node.go | 3 ++- hscontrol/types/users_test.go | 6 +++-- hscontrol/util/dns.go | 12 +++++----- hscontrol/util/prompt.go | 1 + hscontrol/util/prompt_test.go | 2 ++ hscontrol/util/util.go | 3 ++- integration/api_auth_test.go | 12 ++++++---- integration/auth_key_test.go | 2 +- integration/auth_web_flow_test.go | 3 +-- integration/derp_verify_endpoint_test.go | 3 ++- integration/dns_test.go | 1 + 31 files changed, 110 insertions(+), 68 deletions(-) diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index f3921a64..4cdfe126 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -35,7 +35,8 @@ func init() { checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - if err := checkPolicy.MarkFlagRequired("file"); err != nil { + err := checkPolicy.MarkFlagRequired("file") + if err != nil { log.Fatal().Err(err).Msg("") } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 084548a9..086a82b6 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -16,7 +16,7 @@ import ( // Sentinel errors for CLI commands. var ( - ErrNameOrIDRequired = errors.New("--name or --identifier flag is required") + ErrNameOrIDRequired = errors.New("--name or --identifier flag is required") ErrMultipleUsersFoundUseID = errors.New("unable to determine user, query returned multiple users, use ID") ) diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 2bfc41fd..b1cad514 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -321,7 +321,8 @@ func checkRequiredFiles() DoctorResult { for _, file := range requiredFiles { cmd := exec.CommandContext(context.Background(), "test", "-e", file) - if err := cmd.Run(); err != nil { + err := cmd.Run() + if err != nil { missingFiles = append(missingFiles, file) } } diff --git a/cmd/hi/run.go b/cmd/hi/run.go index e6c52634..881be20f 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -49,7 +49,8 @@ func runIntegrationTest(env *command.Env) error { log.Printf("Running pre-flight system checks...") } - if err := runDoctorCheck(env.Context()); err != nil { + err := runDoctorCheck(env.Context()) + if err != nil { return fmt.Errorf("pre-flight checks failed: %w", err) } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index ff9379c1..5a3364c9 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -261,7 +261,8 @@ AND auth_key_id NOT IN ( if err == nil && routesExists { log.Info().Msg("Dropping leftover routes table") - if err := tx.Exec("DROP TABLE routes").Error; err != nil { + err := tx.Exec("DROP TABLE routes").Error + if err != nil { return fmt.Errorf("dropping routes table: %w", err) } } @@ -294,7 +295,8 @@ AND auth_key_id NOT IN ( _ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error // Rename current table to _old - if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil { + err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error + if err != nil { return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err) } } @@ -368,7 +370,8 @@ AND auth_key_id NOT IN ( } for _, createSQL := range tableCreationSQL { - if err := tx.Exec(createSQL).Error; err != nil { + err := tx.Exec(createSQL).Error + if err != nil { return fmt.Errorf("creating new table: %w", err) } } @@ -397,7 +400,8 @@ AND auth_key_id NOT IN ( } for _, copySQL := range dataCopySQL { - if err := tx.Exec(copySQL).Error; err != nil { + err := tx.Exec(copySQL).Error + if err != nil { return fmt.Errorf("copying data: %w", err) } } @@ -420,14 +424,16 @@ AND auth_key_id NOT IN ( } for _, indexSQL := range indexes { - if err := tx.Exec(indexSQL).Error; err != nil { + err := tx.Exec(indexSQL).Error + if err != nil { return fmt.Errorf("creating index: %w", err) } } // Drop old tables only after everything succeeds for _, table := range tablesToRename { - if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil { + err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error + if err != nil { log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded") } } @@ -946,18 +952,21 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig if needsFKDisabled { // Disable foreign keys for this migration - if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { + err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error + if err != nil { return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err) } } else { // Ensure foreign keys are enabled for this migration - if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + err := dbConn.Exec("PRAGMA foreign_keys = ON").Error + if err != nil { return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err) } } // Run up to this specific migration (will only run the next pending migration) - if err := migrations.MigrateTo(migrationID); err != nil { + err := migrations.MigrateTo(migrationID) + if err != nil { return fmt.Errorf("running migration %s: %w", migrationID, err) } } @@ -1009,7 +1018,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig } } else { // PostgreSQL can run all migrations in one block - no foreign key issues - if err := migrations.Migrate(); err != nil { + err := migrations.Migrate() + if err != nil { return err } } diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go index e6386718..0088fe86 100644 --- a/hscontrol/db/sqliteconfig/config.go +++ b/hscontrol/db/sqliteconfig/config.go @@ -365,7 +365,8 @@ func (c *Config) Validate() error { // ToURL builds a properly encoded SQLite connection string using _pragma parameters // compatible with modernc.org/sqlite driver. func (c *Config) ToURL() (string, error) { - if err := c.Validate(); err != nil { + err := c.Validate() + if err != nil { return "", fmt.Errorf("invalid config: %w", err) } diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index b1d294ea..7a9f7010 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -12,9 +12,9 @@ import ( // Sentinel errors for text serialisation. var ( - ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value") - ErrUnsupportedType = errors.New("unsupported type") - ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported") + ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value") + ErrUnsupportedType = errors.New("unsupported type") + ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported") ) // Got from https://github.com/xdg-go/strum/blob/main/types.go diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index be073999..8af317dc 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -28,7 +28,8 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) { - if err := util.ValidateHostname(user.Name); err != nil { + err := util.ValidateHostname(user.Name) + if err != nil { return nil, err } if err := tx.Create(&user).Error; err != nil { @@ -164,7 +165,8 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { } users := []types.User{} - if err := tx.Where(user).Find(&users).Error; err != nil { + err := tx.Where(user).Find(&users).Error + if err != nil { return nil, err } diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 562061e2..6a66a7e6 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -74,6 +74,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { if err != nil { return tailcfg.DERPRegion{}, err } + var ( host string port int @@ -416,7 +417,8 @@ type DERPVerifyTransport struct { func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { buf := new(bytes.Buffer) - if err := t.handleVerifyRequest(req, buf); err != nil { + err := t.handleVerifyRequest(req, buf) + if err != nil { log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ") return nil, err diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index ef214536..2e809061 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -154,7 +154,8 @@ func (h *Headscale) KeyHandler( } writer.Header().Set("Content-Type", "application/json") - if err := json.NewEncoder(writer).Encode(resp); err != nil { + err := json.NewEncoder(writer).Encode(resp) + if err != nil { log.Error().Err(err).Msg("failed to encode key response") } diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 2580b7f4..d4b5ff6f 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -667,7 +667,8 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Str("conn.id", conn.id).Int("connection_index", i). Msg("send: attempting to send to connection") - if err := conn.send(data); err != nil { + err := conn.send(data) + if err != nil { lastErr = err failedConnections = append(failedConnections, i) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 4f950d15..7cc746a4 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -123,9 +123,9 @@ const ( // Channel configuration. NORMAL_BUFFER_SIZE = 50 - SMALL_BUFFER_SIZE = 3 - TINY_BUFFER_SIZE = 1 // For maximum contention - LARGE_BUFFER_SIZE = 200 + SMALL_BUFFER_SIZE = 3 + TINY_BUFFER_SIZE = 1 // For maximum contention + LARGE_BUFFER_SIZE = 200 ) // TestData contains all test entities created for a test scenario. @@ -1145,6 +1145,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Go(func() { runtime.Gosched() // Yield to introduce timing variability + _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }) @@ -1747,6 +1748,7 @@ func XTestBatcherScalability(t *testing.T) { for i := range testNodes { node := testNodes[i] _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true diff --git a/hscontrol/noise.go b/hscontrol/noise.go index d8e83154..67021d31 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -283,7 +283,8 @@ func (ns *noiseServer) NoiseRegistrationHandler( writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { + err := json.NewEncoder(writer).Encode(registerResponse) + if err != nil { log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") return } diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index ee4818aa..a46f30d2 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -1214,7 +1214,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, + errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`, }, { name: "invalid-check-period", @@ -1249,7 +1249,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: "autogroup \"autogroup:invalid\" is not supported", + errorMessage: `autogroup not supported for SSH: "autogroup:invalid" for SSH user`, }, { name: "autogroup-nonroot-should-use-wildcard-with-root-excluded", diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 3e3c70f5..dc5969b5 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -106,7 +106,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { require.NoError(t, err) } - require.Equal(t, len(initialNodes), len(pm.filterRulesMap)) + require.Len(t, pm.filterRulesMap, len(initialNodes)) tests := []struct { name string diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index b5e5a210..8164f576 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -3131,8 +3131,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.expected.Action, acl.Action) assert.Equal(t, tt.expected.Protocol, acl.Protocol) - assert.Equal(t, len(tt.expected.Sources), len(acl.Sources)) - assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations)) + assert.Len(t, acl.Sources, len(tt.expected.Sources)) + assert.Len(t, acl.Destinations, len(tt.expected.Destinations)) // Compare sources for i, expectedSrc := range tt.expected.Sources { @@ -3179,8 +3179,8 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { // Should be equal assert.Equal(t, original.Action, unmarshaled.Action) assert.Equal(t, original.Protocol, unmarshaled.Protocol) - assert.Equal(t, len(original.Sources), len(unmarshaled.Sources)) - assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations)) + assert.Len(t, unmarshaled.Sources, len(original.Sources)) + assert.Len(t, unmarshaled.Destinations, len(original.Destinations)) } func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index d3c9f1ef..464d252d 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -249,7 +249,8 @@ func (m *mapSession) serveLongPoll() { return } - if err := m.writeMap(update); err != nil { + err := m.writeMap(update) + if err != nil { m.errf(err, "cannot write update to client") return } @@ -258,7 +259,8 @@ func (m *mapSession) serveLongPoll() { m.resetKeepAlive() case <-m.keepAliveTicker.C: - if err := m.writeMap(&keepAlive); err != nil { + err := m.writeMap(&keepAlive) + if err != nil { m.errf(err, "cannot write keep alive") return } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index ce6804e4..8a842e49 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -133,4 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node") }) } - diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index bb929faa..c720c271 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -187,7 +187,8 @@ func NewState(cfg *types.Config) (*State, error) { func (s *State) Close() error { s.nodeStore.Stop() - if err := s.db.Close(); err != nil { + err := s.db.Close() + if err != nil { return fmt.Errorf("closing database: %w", err) } @@ -741,7 +742,8 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t // RenameNode changes the display name of a node. func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) { - if err := util.ValidateHostname(newName); err != nil { + err := util.ValidateHostname(newName) + if err != nil { return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err) } @@ -1214,7 +1216,8 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro // New node - database first to get ID, then NodeStore savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { + err := tx.Save(&nodeToRegister).Error + if err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 5a32147d..d57943f6 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -29,16 +29,16 @@ const ( PKCEMethodPlain string = "plain" PKCEMethodS256 string = "S256" - defaultNodeStoreBatchSize = 100 - defaultWALAutocheckpoint = 1000 // SQLite default + defaultNodeStoreBatchSize = 100 + defaultWALAutocheckpoint = 1000 // SQLite default ) var ( - errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") - errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") - errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") - errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") - errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") errInvalidAllocationStrategy = errors.New("invalid prefixes.allocation strategy") ) @@ -406,7 +406,8 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { + var configFileNotFoundError viper.ConfigFileNotFoundError + if errors.As(err, &configFileNotFoundError) { log.Warn().Msg("No config file found, using defaults") return nil } @@ -446,7 +447,8 @@ func validateServerConfig() error { depr.fatal("oidc.map_legacy_users") if viper.GetBool("oidc.enabled") { - if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil { + err := validatePKCEMethod(viper.GetString("oidc.pkce.method")) + if err != nil { return err } } @@ -984,7 +986,8 @@ func LoadServerConfig() (*Config, error) { // - Control plane runs on login.tailscale.com/controlplane.tailscale.com // - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net) if dnsConfig.BaseDomain != "" { - if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil { + err := isSafeServerURL(serverURL, dnsConfig.BaseDomain) + if err != nil { return nil, err } } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index ea96284c..e1fa5794 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -573,7 +573,8 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } newHostname := strings.ToLower(hostInfo.Hostname) - if err := util.ValidateHostname(newHostname); err != nil { + err := util.ValidateHostname(newHostname) + if err != nil { log.Warn(). Str("node.id", node.ID.String()). Str("current_hostname", node.Hostname). diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index acd88434..35c84a26 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -66,7 +66,8 @@ func TestUnmarshallOIDCClaims(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims - if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + err := json.Unmarshal([]byte(tt.jsonstr), &got) + if err != nil { t.Errorf("UnmarshallOIDCClaims() error = %v", err) return } @@ -482,7 +483,8 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims - if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + err := json.Unmarshal([]byte(tt.jsonstr), &got) + if err != nil { t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err) return } diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index c816efc6..8ec40790 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -31,17 +31,17 @@ var ErrInvalidHostName = errors.New("invalid hostname") // Sentinel errors for username validation. var ( - ErrUsernameTooShort = errors.New("username must be at least 2 characters long") + ErrUsernameTooShort = errors.New("username must be at least 2 characters long") ErrUsernameMustStartLetter = errors.New("username must start with a letter") - ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'") - ErrUsernameInvalidChar = errors.New("username contains invalid character") + ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'") + ErrUsernameInvalidChar = errors.New("username contains invalid character") ) // Sentinel errors for hostname validation. var ( - ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters") - ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen") - ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot") + ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters") + ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen") + ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot") ) // ValidateUsername checks if a username is valid. diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go index 410f6c2e..5f0adede 100644 --- a/hscontrol/util/prompt.go +++ b/hscontrol/util/prompt.go @@ -14,6 +14,7 @@ func YesNo(msg string) bool { fmt.Fprint(os.Stderr, msg+" [y/n] ") var resp string + _, _ = fmt.Scanln(&resp) resp = strings.ToLower(resp) diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go index c6fcb702..ac405f8c 100644 --- a/hscontrol/util/prompt_test.go +++ b/hscontrol/util/prompt_test.go @@ -106,6 +106,7 @@ func TestYesNo(t *testing.T) { // Check that the prompt was written to stderr var stderrBuf bytes.Buffer + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() @@ -149,6 +150,7 @@ func TestYesNoPromptMessage(t *testing.T) { // Check that the custom message was included in the prompt var stderrBuf bytes.Buffer + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 77c83ece..8e777349 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -323,7 +323,8 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri } lowercased := strings.ToLower(hostinfo.Hostname) - if err := ValidateHostname(lowercased); err == nil { + err := ValidateHostname(lowercased) + if err == nil { return lowercased } diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index ed4a1f4d..989cb9d4 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -79,7 +79,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_NoAuthHeader", func(t *testing.T) { // Test 1: Request without any Authorization header // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) resp, err := client.Do(req) @@ -131,7 +131,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) { // Test 2: Request with invalid Authorization header (missing "Bearer " prefix) // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "InvalidToken") @@ -165,7 +165,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Test 3: Request with Bearer prefix but invalid token // Expected: Should return 401 with ONLY "Unauthorized" text, no user data // Note: Both malformed and properly formatted invalid tokens should return 401 - req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer invalid-token-12345") @@ -198,7 +198,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_ValidAPIKey", func(t *testing.T) { // Test 4: Request with valid API key // Expected: Should return 200 with user data (this is the authorized case) - req, err := http.NewRequestWithContext(context.Background(), "GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer "+validAPIKey) @@ -294,6 +294,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { ) var responseBodySb295 strings.Builder + for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after @@ -301,6 +302,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { responseBodySb295.WriteString(line) } } + responseBody += responseBodySb295.String() // Should return 401 @@ -345,6 +347,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { ) var responseBodySb344 strings.Builder + for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after @@ -352,6 +355,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { responseBodySb344.WriteString(line) } } + responseBody += responseBodySb344.String() assert.Equal(t, "401", httpCode) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 0bced1ed..7e7747e6 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -355,7 +355,7 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) - }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname())) + }, 30*time.Second, 2*time.Second, "validating %s is logged in as user1 after auth key user switch", client.Hostname()) } } diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 256d7e4d..6dbf6dfe 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,7 +1,6 @@ package integration import ( - "fmt" "net/netip" "slices" "testing" @@ -364,7 +363,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) - }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname())) + }, 30*time.Second, 2*time.Second, "validating %s is logged in as user1 after web flow user switch", client.Hostname()) } // Test connectivity after user switch diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index c92a25ee..bd4cf6a9 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -109,7 +109,8 @@ func DERPVerify( defer c.Close() var result error - if err := c.Connect(t.Context()); err != nil { + err := c.Connect(t.Context()) + if err != nil { result = fmt.Errorf("client Connect: %w", err) } diff --git a/integration/dns_test.go b/integration/dns_test.go index 3432eb9b..08250e7b 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -199,6 +199,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { }, }) require.NoError(t, err) + command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath} _, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")}) require.NoError(t, err) From 3b59a9111293c2e5043beb692f520ea9d244ca3d Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:44:02 +0000 Subject: [PATCH 23/30] all: apply additional wsl_v5 whitespace fixes --- cmd/hi/docker.go | 3 ++- cmd/hi/doctor.go | 1 + cmd/hi/stats.go | 1 + hscontrol/db/db.go | 1 + hscontrol/db/node.go | 2 ++ hscontrol/db/users.go | 1 + hscontrol/derp/server/derp_server.go | 1 + hscontrol/handlers.go | 1 + hscontrol/policy/v2/types.go | 30 ++++++++++++++++++++++-- hscontrol/types/node.go | 1 + hscontrol/types/users_test.go | 2 ++ hscontrol/util/util.go | 1 + integration/derp_verify_endpoint_test.go | 1 + 13 files changed, 43 insertions(+), 3 deletions(-) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 3062e5ec..05aa7c9d 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -743,7 +743,8 @@ func extractContainerArtifacts(ctx context.Context, cli *client.Client, containe // Extract tar files for headscale containers only if strings.HasPrefix(containerName, "hs-") { - if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { + err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose) + if err != nil { if verbose { log.Printf("Warning: failed to extract files from %s: %v", containerName, err) } diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index b1cad514..c30a1ca9 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -321,6 +321,7 @@ func checkRequiredFiles() DoctorResult { for _, file := range requiredFiles { cmd := exec.CommandContext(context.Background(), "test", "-e", file) + err := cmd.Run() if err != nil { missingFiles = append(missingFiles, file) diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index 47da89c4..aec28c50 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -259,6 +259,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe return default: var stats container.StatsResponse + err := decoder.Decode(&stats) if err != nil { // EOF is expected when container stops or stream ends diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 5a3364c9..61f192db 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -108,6 +108,7 @@ func NewHeadscaleDatabase( if err != nil { return fmt.Errorf("automigrating types.PreAuthKey: %w", err) } + err = tx.AutoMigrate(&types.Node{}) if err != nil { return fmt.Errorf("automigrating types.Node: %w", err) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 10412d36..56408809 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -286,6 +286,7 @@ func RenameNode(tx *gorm.DB, // Check if the new name is unique var count int64 + err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error if err != nil { return fmt.Errorf("failed to check name uniqueness: %w", err) @@ -497,6 +498,7 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func isUniqueName(tx *gorm.DB, name string) (bool, error) { nodes := types.Nodes{} + err := tx. Where("given_name = ?", name).Find(&nodes).Error if err != nil { diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 8af317dc..9145ff20 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -165,6 +165,7 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { } users := []types.User{} + err := tx.Where(user).Find(&users).Error if err != nil { return nil, err diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index 6a66a7e6..b0f83fb6 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -417,6 +417,7 @@ type DERPVerifyTransport struct { func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { buf := new(bytes.Buffer) + err := t.handleVerifyRequest(req, buf) if err != nil { log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ") diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 2e809061..f1a2c88c 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -154,6 +154,7 @@ func (h *Headscale) KeyHandler( } writer.Header().Set("Content-Type", "application/json") + err := json.NewEncoder(writer).Encode(resp) if err != nil { log.Error().Err(err).Msg("failed to encode key response") diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 59b8ba6f..5443dfdb 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -212,6 +212,7 @@ func (p Prefix) MarshalJSON() ([]byte, error) { func (u *Username) UnmarshalJSON(b []byte) error { *u = Username(strings.Trim(string(b), `"`)) + err := u.Validate() if err != nil { return err @@ -307,6 +308,7 @@ func (g Group) Validate() error { func (g *Group) UnmarshalJSON(b []byte) error { *g = Group(strings.Trim(string(b), `"`)) + err := g.Validate() if err != nil { return err @@ -373,6 +375,7 @@ func (t Tag) Validate() error { func (t *Tag) UnmarshalJSON(b []byte) error { *t = Tag(strings.Trim(string(b), `"`)) + err := t.Validate() if err != nil { return err @@ -424,6 +427,7 @@ func (h Host) Validate() error { func (h *Host) UnmarshalJSON(b []byte) error { *h = Host(strings.Trim(string(b), `"`)) + err := h.Validate() if err != nil { return err @@ -586,6 +590,7 @@ func (ag AutoGroup) Validate() error { func (ag *AutoGroup) UnmarshalJSON(b []byte) error { *ag = AutoGroup(strings.Trim(string(b), `"`)) + err := ag.Validate() if err != nil { return err @@ -674,6 +679,7 @@ type AliasWithPorts struct { func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { var v any + err := json.Unmarshal(b, &v) if err != nil { return err @@ -1055,6 +1061,7 @@ func (g Groups) Contains(group *Group) error { func (g *Groups) UnmarshalJSON(b []byte) error { // First unmarshal as a generic map to validate group names first var rawMap map[string]any + err := json.Unmarshal(b, &rawMap) if err != nil { return err @@ -1063,6 +1070,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { // Validate group names first before checking data types for key := range rawMap { group := Group(key) + err := group.Validate() if err != nil { return err @@ -1103,6 +1111,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { for _, u := range value { username := Username(u) + err := username.Validate() if err != nil { if isGroup(u) { @@ -1126,6 +1135,7 @@ type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { var rawHosts map[string]string + err := json.Unmarshal(b, &rawHosts, policyJSONOpts...) if err != nil { return err @@ -1135,12 +1145,14 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { for key, value := range rawHosts { host := Host(key) + err := host.Validate() if err != nil { return err } var prefix Prefix + err = prefix.parseString(value) if err != nil { return fmt.Errorf("%w: hostname %q value %q", ErrInvalidIPAddress, key, value) @@ -1758,11 +1770,13 @@ func (p *Policy) validate() error { } case *Group: g := src - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src + err := p.TagOwners.Contains(tagOwner) if err != nil { errs = append(errs, err) @@ -1793,11 +1807,13 @@ func (p *Policy) validate() error { } case *Group: g := dst.Alias.(*Group) - if err := p.Groups.Contains(g); err != nil { + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := dst.Alias.(*Tag) + err := p.TagOwners.Contains(tagOwner) if err != nil { errs = append(errs, err) @@ -1816,6 +1832,7 @@ func (p *Policy) validate() error { for _, user := range ssh.Users { if strings.HasPrefix(string(user), "autogroup:") { maybeAuto := AutoGroup(user) + err := validateAutogroupForSSHUser(&maybeAuto) if err != nil { errs = append(errs, err) @@ -1842,12 +1859,14 @@ func (p *Policy) validate() error { } case *Group: g := src + err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src + err := p.TagOwners.Contains(tagOwner) if err != nil { errs = append(errs, err) @@ -1859,6 +1878,7 @@ func (p *Policy) validate() error { switch dst := dst.(type) { case *AutoGroup: ag := dst + err := validateAutogroupSupported(ag) if err != nil { errs = append(errs, err) @@ -1872,6 +1892,7 @@ func (p *Policy) validate() error { } case *Tag: tagOwner := dst + err := p.TagOwners.Contains(tagOwner) if err != nil { errs = append(errs, err) @@ -1885,6 +1906,7 @@ func (p *Policy) validate() error { switch tagOwner := tagOwner.(type) { case *Group: g := tagOwner + err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) @@ -1911,12 +1933,14 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver + err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver + err := p.TagOwners.Contains(tagOwner) if err != nil { errs = append(errs, err) @@ -1929,12 +1953,14 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver + err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver + err := p.TagOwners.Contains(tagOwner) if err != nil { errs = append(errs, err) diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index e1fa5794..1ebc7033 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -573,6 +573,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } newHostname := strings.ToLower(hostInfo.Hostname) + err := util.ValidateHostname(newHostname) if err != nil { log.Warn(). diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index 35c84a26..064388eb 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -66,6 +66,7 @@ func TestUnmarshallOIDCClaims(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims + err := json.Unmarshal([]byte(tt.jsonstr), &got) if err != nil { t.Errorf("UnmarshallOIDCClaims() error = %v", err) @@ -483,6 +484,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims + err := json.Unmarshal([]byte(tt.jsonstr), &got) if err != nil { t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err) diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 8e777349..c4b9dbd5 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -323,6 +323,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri } lowercased := strings.ToLower(hostinfo.Hostname) + err := ValidateHostname(lowercased) if err == nil { return lowercased diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index bd4cf6a9..c1c62f81 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -109,6 +109,7 @@ func DERPVerify( defer c.Close() var result error + err := c.Connect(t.Context()) if err != nil { result = fmt.Errorf("client Connect: %w", err) From d3bddbf1720ebaf566d75963d6a4dcde52cbc388 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:52:18 +0000 Subject: [PATCH 24/30] all: fix exptostd, ineffassign, unconvert, intrange, predeclared issues - Replace exp/maps and exp/slices with stdlib (Go 1.21+) - Fix ineffective assignments by checking errors or using blank identifier - Remove unnecessary type conversions - Use integer range syntax for loops (Go 1.22+) - Rename variables shadowing predeclared identifiers (min, max) --- cmd/hi/stats.go | 16 ++++++++-------- hscontrol/auth_test.go | 2 +- hscontrol/db/db.go | 3 +++ hscontrol/db/node_test.go | 3 +++ hscontrol/policy/v2/types_test.go | 4 ++-- hscontrol/state/state.go | 2 +- integration/auth_web_flow_test.go | 4 ++-- integration/helpers.go | 6 +++--- integration/scenario.go | 2 +- integration/tsic/tsic.go | 4 ++-- 10 files changed, 26 insertions(+), 20 deletions(-) diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index aec28c50..bd81d6da 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -409,25 +409,25 @@ func calculateStatsSummary(values []float64) StatsSummary { return StatsSummary{} } - min := values[0] - max := values[0] + minVal := values[0] + maxVal := values[0] sum := 0.0 for _, value := range values { - if value < min { - min = value + if value < minVal { + minVal = value } - if value > max { - max = value + if value > maxVal { + maxVal = value } sum += value } return StatsSummary{ - Min: min, - Max: max, + Min: minVal, + Max: maxVal, Average: sum / float64(len(values)), } } diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 73048d9e..14cfaad5 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -3113,7 +3113,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { user1Nodes := 0 user2Nodes := 0 - for i := 0; i < allNodesSlice.Len(); i++ { + for i := range allNodesSlice.Len() { n := allNodesSlice.At(i) if n.UserID().Get() == user1.ID { user1Nodes++ diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 61f192db..f9fd9a20 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -173,6 +173,9 @@ AND auth_key_id NOT IN ( routes = slices.Compact(routes) data, err := json.Marshal(routes) + if err != nil { + return fmt.Errorf("marshaling routes for node %d: %w", nodeID, err) + } err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error if err != nil { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 3696aa2e..0f16f8a4 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -754,6 +754,9 @@ func TestNodeNaming(t *testing.T) { } _, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil) + if err != nil { + return err + } _, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil) return err diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 8164f576..dc95e1f3 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1860,7 +1860,7 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-member-comprehensive", - toResolve: new(AutoGroup(AutoGroupMember)), + toResolve: new(AutoGroupMember), nodes: types.Nodes{ // Node with no tags (should be included - is a member) { @@ -1910,7 +1910,7 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-tagged", - toResolve: new(AutoGroup(AutoGroupTagged)), + toResolve: new(AutoGroupTagged), nodes: types.Nodes{ // Node with no tags (should be excluded - not tagged) { diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index c720c271..57e0ebde 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -749,7 +749,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, // Check name uniqueness against NodeStore allNodes := s.nodeStore.ListNodes() - for i := 0; i < allNodes.Len(); i++ { + for i := range allNodes.Len() { node := allNodes.At(i) if node.ID() != nodeID && node.AsStruct().GivenName == newName { return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 6dbf6dfe..8d596241 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -245,7 +245,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { allClients, err := scenario.ListTailscaleClients() requireNoErrListClients(t, err) - allIps, err := scenario.ListTailscaleClientsIPs() + _, err = scenario.ListTailscaleClientsIPs() requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() @@ -367,7 +367,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } // Test connectivity after user switch - allIps, err = scenario.ListTailscaleClientsIPs() + allIps, err := scenario.ListTailscaleClientsIPs() requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { diff --git a/integration/helpers.go b/integration/helpers.go index b1701f8f..86986bc8 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -23,8 +23,8 @@ import ( "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" + "maps" + "slices" "tailscale.com/tailcfg" ) @@ -332,7 +332,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer var failureReport strings.Builder - ids := types.NodeIDs(maps.Keys(nodeStatus)) + ids := slices.Collect(maps.Keys(nodeStatus)) slices.Sort(ids) for _, nodeID := range ids { diff --git a/integration/scenario.go b/integration/scenario.go index 9ae0c4fc..bf3f4096 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -1158,7 +1158,7 @@ var errParseAuthPage = errors.New("failed to parse auth page") func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { // see api.go HTML template - codeSep := strings.Split(string(body), "") + codeSep := strings.Split(body, "") if len(codeSep) != expectedHTMLSplitParts { return errParseAuthPage } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index b270fab3..d7ff1714 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -731,12 +731,12 @@ func (t *TailscaleInContainer) LoginWithURL( // Logout runs the logout routine on the given Tailscale instance. func (t *TailscaleInContainer) Logout() error { - stdout, stderr, err := t.Execute([]string{"tailscale", "logout"}) + _, _, err := t.Execute([]string{"tailscale", "logout"}) if err != nil { return err } - stdout, stderr, _ = t.Execute([]string{"tailscale", "status"}) + stdout, stderr, _ := t.Execute([]string{"tailscale", "status"}) if !strings.Contains(stdout+stderr, "Logged out.") { return fmt.Errorf("%w: stdout: %s, stderr: %s", errLogoutFailed, stdout, stderr) } From 00321fc282f33a53b62138fa388a2da36c457961 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 16:56:07 +0000 Subject: [PATCH 25/30] all: fix testifylint, thelper, usetesting, wsl_v5 issues - Use require instead of assert for error assertions - Add t.Helper() to test helper functions - Use t.TempDir() instead of os.MkdirTemp() - Replace useless assert.False(c, true, ...) with assert.Fail() - Add whitespace between statements per wsl_v5 rules --- hscontrol/auth_test.go | 2 +- hscontrol/db/node_test.go | 1 + hscontrol/policy/policy_autoapprove_test.go | 5 +++-- hscontrol/policy/v2/types.go | 2 ++ hscontrol/types/config_test.go | 9 +++------ hscontrol/util/util_test.go | 6 ++++++ integration/helpers.go | 2 +- 7 files changed, 17 insertions(+), 10 deletions(-) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 14cfaad5..e6c46d73 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -3363,7 +3363,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { } _, err = app.handleRegister(context.Background(), req, machineKey.Public()) - assert.Error(t, err, "expired pre-auth key should be rejected") + require.Error(t, err, "expired pre-auth key should be rejected") assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration") } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 0f16f8a4..a78eb84f 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -757,6 +757,7 @@ func TestNodeNaming(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil) return err diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 21c2a66e..68266645 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/types/key" "tailscale.com/types/views" @@ -75,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { }` pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -319,7 +320,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { if tt.name != "nil_policy_manager" { pm, err = pmf(users, nodes.ViewSlice()) - assert.NoError(t, err) + require.NoError(t, err) } else { pm = nil } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 5443dfdb..6906bd22 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1770,6 +1770,7 @@ func (p *Policy) validate() error { } case *Group: g := src + err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) @@ -1807,6 +1808,7 @@ func (p *Policy) validate() error { } case *Group: g := dst.Alias.(*Group) + err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 13a3a418..6ed2ef47 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -349,11 +349,8 @@ func TestReadConfigFromEnv(t *testing.T) { } func TestTLSConfigValidation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - t.Fatal(err) - } - // defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() + configYaml := []byte(`--- tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: "" @@ -364,7 +361,7 @@ noise: // Populate a custom config file configFilePath := filepath.Join(tmpDir, "config.yaml") - err = os.WriteFile(configFilePath, configYaml, 0o600) + err := os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) } diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 22788bff..e3a8c3c5 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1101,6 +1101,8 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { nodeKey: "nkey12345678", wantHostname: "test-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + t.Helper() + if hi == nil { t.Fatal("hostinfo should not be nil") } @@ -1147,6 +1149,8 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { nodeKey: "nkey12345678", wantHostname: "node-nkey1234", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + t.Helper() + if hi == nil { t.Fatal("hostinfo should not be nil") } @@ -1163,6 +1167,8 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { nodeKey: "", wantHostname: "unknown-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + t.Helper() + if hi == nil { t.Fatal("hostinfo should not be nil") } diff --git a/integration/helpers.go b/integration/helpers.go index 86986bc8..8487bbfa 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -455,7 +455,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec if slices.Contains(expectedNodes, nodeID) { allMapResponsesOffline = false - assert.False(c, true, "Node %d should not appear in map responses", nodeID) + assert.Fail(c, "Node should not appear in map responses", "Node %d should not appear in map responses", nodeID) } } } else { From 3770015faa54fbfb27b8beabf58f40165cb682ff Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 17:08:19 +0000 Subject: [PATCH 26/30] all: fix staticcheck and contextcheck lint issues - Add nolint:staticcheck for SA1019 deprecation warnings on types.Route (kept for GORM migrations only, intentionally deprecated) - Add nolint:staticcheck for SA4006 false positives where variables are used inside new() expressions which staticcheck doesn't recognize - Fix SA5011 potential nil pointer dereferences in util_test.go by using t.Fatal instead of t.Error for nil checks - Add nolint:contextcheck for functions where context propagation would require significant architectural changes (Docker client creation, OIDC initialization, scheduled tasks, etc.) --- cmd/hi/cleanup.go | 6 ++++++ cmd/hi/docker.go | 3 +++ cmd/hi/doctor.go | 7 +++++++ hscontrol/app.go | 2 ++ hscontrol/db/db.go | 5 +++++ hscontrol/db/node_test.go | 4 ++++ hscontrol/db/users_test.go | 1 + hscontrol/noise.go | 1 + hscontrol/oidc.go | 1 + hscontrol/policy/policy_route_approval_test.go | 1 + hscontrol/policy/v2/types_test.go | 7 +++++++ hscontrol/state/state.go | 1 + hscontrol/util/util_test.go | 4 ++-- integration/helpers.go | 1 + 14 files changed, 42 insertions(+), 2 deletions(-) diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 8dc57b4b..70480239 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -55,6 +55,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI // killTestContainers terminates and removes all test containers. func killTestContainers(ctx context.Context) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -109,6 +110,7 @@ func killTestContainers(ctx context.Context) error { // This function filters containers by the hi.run-id label to only affect containers // belonging to the specified test run, leaving other concurrent test runs untouched. func killTestContainersByRunID(ctx context.Context, runID string) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -151,6 +153,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error { // This is useful for cleaning up leftover containers from previous crashed or interrupted test runs // without interfering with currently running concurrent tests. func cleanupStaleTestContainers(ctx context.Context) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -225,6 +228,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container // pruneDockerNetworks removes unused Docker networks. func pruneDockerNetworks(ctx context.Context) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -247,6 +251,7 @@ func pruneDockerNetworks(ctx context.Context) error { // cleanOldImages removes test-related and old dangling Docker images. func cleanOldImages(ctx context.Context) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -299,6 +304,7 @@ func cleanOldImages(ctx context.Context) error { // cleanCacheVolume removes the Docker volume used for Go module cache. func cleanCacheVolume(ctx context.Context) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 05aa7c9d..c9791098 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -39,6 +39,7 @@ const ( // runTestContainer executes integration tests in a Docker container. func runTestContainer(ctx context.Context, config *RunConfig) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -110,6 +111,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if config.Stats { var err error + //nolint:contextcheck // NewStatsCollector internal functions don't accept context statsCollector, err = NewStatsCollector() if err != nil { if config.Verbose { @@ -632,6 +634,7 @@ func listControlFiles(logsDir string) { // extractArtifactsFromContainers collects container logs and files from the specific test run. func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index c30a1ca9..2fae4fbe 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -38,12 +38,15 @@ func runDoctorCheck(ctx context.Context) error { } // Check 3: Go installation + //nolint:contextcheck // These checks don't need context results = append(results, checkGoInstallation()) // Check 4: Git repository + //nolint:contextcheck // These checks don't need context results = append(results, checkGitRepository()) // Check 5: Required files + //nolint:contextcheck // These checks don't need context results = append(results, checkRequiredFiles()) // Display results @@ -86,6 +89,7 @@ func checkDockerBinary() DoctorResult { // checkDockerDaemon verifies Docker daemon is running and accessible. func checkDockerDaemon(ctx context.Context) DoctorResult { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -125,6 +129,7 @@ func checkDockerDaemon(ctx context.Context) DoctorResult { // checkDockerContext verifies Docker context configuration. func checkDockerContext(_ context.Context) DoctorResult { + //nolint:contextcheck // getCurrentDockerContext doesn't accept context contextInfo, err := getCurrentDockerContext() if err != nil { return DoctorResult{ @@ -155,6 +160,7 @@ func checkDockerContext(_ context.Context) DoctorResult { // checkDockerSocket verifies Docker socket accessibility. func checkDockerSocket(ctx context.Context) DoctorResult { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -192,6 +198,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult { // checkGolangImage verifies the golang Docker image is available locally or can be pulled. func checkGolangImage(ctx context.Context) DoctorResult { + //nolint:contextcheck // createDockerClient internal functions don't accept context cli, err := createDockerClient() if err != nil { return DoctorResult{ diff --git a/hscontrol/app.go b/hscontrol/app.go index f7d9ba90..a333c415 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -299,6 +299,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { case <-derpTickerChan: log.Info().Msg("Fetching DERPMap updates") + //nolint:contextcheck // GetDERPMap internal functions don't accept context derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { derpMap, err := derp.GetDERPMap(h.cfg.DERP) if err != nil { @@ -885,6 +886,7 @@ func (h *Headscale) Serve() error { // Close state connections info("closing state and database") + //nolint:contextcheck // Close method signature does not accept context err = h.state.Close() if err != nil { log.Error().Err(err).Msg("failed to close state") diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index f9fd9a20..02794627 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -75,6 +75,7 @@ func NewHeadscaleDatabase( ID: "202501221827", Migrate: func(tx *gorm.DB) error { // Remove any invalid routes associated with a node that does not exist. + //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error if err != nil { @@ -83,6 +84,7 @@ func NewHeadscaleDatabase( } // Remove any invalid routes without a node_id. + //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only if tx.Migrator().HasTable(&types.Route{}) { err := tx.Exec("delete from routes where node_id is null").Error if err != nil { @@ -90,6 +92,7 @@ func NewHeadscaleDatabase( } } + //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only err := tx.AutoMigrate(&types.Route{}) if err != nil { return fmt.Errorf("automigrating types.Route: %w", err) @@ -155,6 +158,7 @@ AND auth_key_id NOT IN ( nodeRoutes := map[uint64][]netip.Prefix{} + //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only var routes []types.Route err = tx.Find(&routes).Error @@ -184,6 +188,7 @@ AND auth_key_id NOT IN ( } // Drop the old table. + //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only _ = tx.Migrator().DropTable(&types.Route{}) return nil diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index a78eb84f..58d36463 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -98,6 +98,7 @@ func TestExpireNode(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -142,6 +143,7 @@ func TestSetTags(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -637,9 +639,11 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) + //nolint:staticcheck // SA4006: pakEph is used in new(pakEph.ID) below pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) require.NoError(t, err) diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index bbb8e4d4..40d301b3 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -70,6 +70,7 @@ func TestDestroyUserErrors(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 67021d31..b28fa0bb 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -244,6 +244,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } + //nolint:contextcheck // IIFE uses context from outer scope implicitly registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { var resp *tailcfg.RegisterResponse diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 6ab70f78..de02b677 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -69,6 +69,7 @@ func NewAuthProviderOIDC( ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already + //nolint:contextcheck // Initialization code - no parent context available oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) if err != nil { return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err) diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index 0e974a1a..fbe6e4bb 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -327,6 +327,7 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { } func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { + //nolint:staticcheck // SA4006: user is used in new(user.ID) and new(user) below user := types.User{ Model: gorm.Model{ID: 1}, Name: "test", diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index dc95e1f3..05f2b085 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1604,11 +1604,18 @@ func TestResolvePolicy(t *testing.T) { } // Extract users to variables so we can take their addresses + // The variables below are all used in new() calls in the test cases. + //nolint:staticcheck // SA4006: testuser is used in new(testuser) below testuser := users["testuser"] + //nolint:staticcheck // SA4006: groupuser is used in new(groupuser) below groupuser := users["groupuser"] + //nolint:staticcheck // SA4006: groupuser1 is used in new(groupuser1) below groupuser1 := users["groupuser1"] + //nolint:staticcheck // SA4006: groupuser2 is used in new(groupuser2) below groupuser2 := users["groupuser2"] + //nolint:staticcheck // SA4006: notme is used in new(notme) below notme := users["notme"] + //nolint:staticcheck // SA4006: testuser2 is used in new(testuser2) below testuser2 := users["testuser2"] tests := []struct { diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 57e0ebde..7ddcb005 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -492,6 +492,7 @@ func (s *State) Connect(id types.NodeID) []change.Change { // Disconnect marks a node as disconnected and updates its primary routes in the state. func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { + //nolint:staticcheck // SA4006: now is used in new(now) below now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index e3a8c3c5..fc55328f 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1210,7 +1210,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "test", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } if hi.Hostname != "test" { @@ -1244,7 +1244,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "123456789012345678901234567890123456789012345678901234567890123", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } if len(hi.Hostname) != 63 { diff --git a/integration/helpers.go b/integration/helpers.go index 8487bbfa..6499792d 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -765,6 +765,7 @@ func tagp(name string) policyv2.Alias { // prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. // Converts CIDR notation to policy prefix format for network range specifications. func prefixp(cidr string) policyv2.Alias { + //nolint:staticcheck // SA4006: prefix is used in new(policyv2.Prefix(prefix)) below prefix := netip.MustParsePrefix(cidr) return new(policyv2.Prefix(prefix)) } From 667efb2ab19cb7169217f3440134e929b8ed3638 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 17:11:47 +0000 Subject: [PATCH 27/30] all: fix inamedparam and recvcheck lint issues - Add parameter names to interface methods in auth.go, pm.go, types.go, and control.go as required by inamedparam linter - Add nolint:recvcheck directives to types in policy/v2/types.go that intentionally use mixed pointer/value receivers (pointer for UnmarshalJSON, value for read-only methods like String/Validate) --- hscontrol/auth.go | 4 ++-- hscontrol/policy/pm.go | 8 ++++---- hscontrol/policy/v2/types.go | 32 ++++++++++++++++++++++++++++---- integration/control.go | 8 ++++---- 4 files changed, 38 insertions(+), 14 deletions(-) diff --git a/hscontrol/auth.go b/hscontrol/auth.go index c5fa91c2..dc10d1a2 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -19,8 +19,8 @@ import ( ) type AuthProvider interface { - RegisterHandler(http.ResponseWriter, *http.Request) - AuthURL(types.RegistrationID) string + RegisterHandler(w http.ResponseWriter, r *http.Request) + AuthURL(regID types.RegistrationID) string } func (h *Headscale) handleRegister( diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index ee112609..59627dbe 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,18 +19,18 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) - SetPolicy([]byte) (bool, error) + SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SetPolicy(data []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) // NodeCanHaveTag reports whether the given node can have the given tag. - NodeCanHaveTag(types.NodeView, string) bool + NodeCanHaveTag(node types.NodeView, tag string) bool // TagExists reports whether the given tag is defined in the policy. TagExists(tag string) bool // NodeCanApproveRoute reports whether the given node can approve the given route. - NodeCanApproveRoute(types.NodeView, netip.Prefix) bool + NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool Version() int DebugString() string diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 6906bd22..3bff4561 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -186,6 +186,8 @@ func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeV } // Username is a string that represents a username, it must contain an @. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Username string func (u Username) Validate() error { @@ -296,6 +298,8 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. } // Group is a special string which is always prefixed with `group:`. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Group string func (g Group) Validate() error { @@ -363,6 +367,8 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod } // Tag is a special string which is always prefixed with `tag:`. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Tag string func (t Tag) Validate() error { @@ -415,6 +421,8 @@ func (t Tag) MarshalJSON() ([]byte, error) { } // Host is a string that represents a hostname. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Host string func (h Host) Validate() error { @@ -474,6 +482,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView return buildIPSetMultiErr(&ips, errs) } +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Prefix netip.Prefix func (p Prefix) Validate() error { @@ -562,6 +571,8 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild } // AutoGroup is a special string which is always prefixed with `autogroup:`. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type AutoGroup string const ( @@ -661,14 +672,14 @@ func (ag *AutoGroup) Is(c AutoGroup) bool { type Alias interface { Validate() error - UnmarshalJSON([]byte) error + UnmarshalJSON(data []byte) error // Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP // addresses that the Alias represents within Headscale. It is the product // of the Alias and the Policy, Users and Nodes. // This is an interface definition and the implementation is independent of // the Alias type. - Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error) + Resolve(pol *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) } type AliasWithPorts struct { @@ -793,6 +804,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { return nil } +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { @@ -883,7 +895,7 @@ func unmarshalPointer[T any]( type AutoApprover interface { CanBeAutoApprover() bool - UnmarshalJSON([]byte) error + UnmarshalJSON(data []byte) error String() string } @@ -960,7 +972,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { type Owner interface { CanBeTagOwner() bool - UnmarshalJSON([]byte) error + UnmarshalJSON(data []byte) error String() string } @@ -1038,6 +1050,8 @@ func parseOwner(s string) (Owner, error) { type Usernames []Username // Groups are a map of Group to a list of Username. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Groups map[Group]Usernames func (g Groups) Contains(group *Group) error { @@ -1131,6 +1145,8 @@ func (g *Groups) UnmarshalJSON(b []byte) error { } // Hosts are alias for IP addresses or subnets. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { @@ -1327,6 +1343,8 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } // Action represents the action to take for an ACL rule. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Action string const ( @@ -1334,6 +1352,8 @@ const ( ) // SSHAction represents the action to take for an SSH rule. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type SSHAction string const ( @@ -1390,6 +1410,8 @@ func (a SSHAction) MarshalJSON() ([]byte, error) { } // Protocol represents a network protocol with its IANA number and descriptions. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type Protocol string const ( @@ -1990,6 +2012,8 @@ type SSH struct { // SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. // It can be a list of usernames, groups, tags or autogroups. +// +//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods type SSHSrcAliases []Alias // MarshalJSON marshals the Groups to JSON. diff --git a/integration/control.go b/integration/control.go index 58a061e3..612f0ff3 100644 --- a/integration/control.go +++ b/integration/control.go @@ -15,8 +15,8 @@ import ( type ControlServer interface { Shutdown() (string, string, error) - SaveLog(string) (string, string, error) - SaveProfile(string) error + SaveLog(dir string) (string, string, error) + SaveProfile(dir string) error Execute(command []string) (string, error) WriteFile(path string, content []byte) error ConnectToNetwork(network *dockertest.Network) error @@ -35,12 +35,12 @@ type ControlServer interface { ListUsers() ([]*v1.User, error) MapUsers() (map[string]*v1.User, error) DeleteUser(userID uint64) error - ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) + ApproveRoutes(nodeID uint64, routes []netip.Prefix) (*v1.Node, error) SetNodeTags(nodeID uint64, tags []string) error GetCert() []byte GetHostname() string GetIPInNetwork(network *dockertest.Network) string - SetPolicy(*policyv2.Policy) error + SetPolicy(pol *policyv2.Policy) error GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) PrimaryRoutes() (*routes.DebugRoutes, error) DebugBatcher() (*hscontrol.DebugBatcherInfo, error) From de1f9b90d5ac8ed5db61a8350274bc328ab6e0bb Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 17:19:49 +0000 Subject: [PATCH 28/30] all: fix godoclint, gosec, testifylint, and thelper lint issues - Fix godoclint: Ensure doc comments start with symbol name - Fix gosec: Add nolint directives for false positives (G101, G110, G115, G306, G404) - Fix testifylint: Use require instead of assert for error checks - Fix thelper: Add t.Helper() to test helper functions - Auto-fix gci: Format import statements --- cmd/headscale/cli/policy.go | 1 + cmd/hi/docker.go | 2 ++ cmd/hi/stats.go | 2 +- hscontrol/db/ephemeral_garbage_collector_test.go | 8 ++++---- hscontrol/db/node.go | 2 +- hscontrol/db/preauth_keys.go | 2 +- hscontrol/derp/derp.go | 2 ++ hscontrol/dns/extrarecords.go | 2 +- hscontrol/handlers.go | 2 +- hscontrol/noise.go | 2 +- hscontrol/policy/v2/policy_test.go | 4 ++-- hscontrol/state/node_store_test.go | 15 ++++++++++----- hscontrol/tailsql.go | 2 +- hscontrol/types/preauth_key.go | 2 +- hscontrol/types/users.go | 8 ++++---- hscontrol/util/util.go | 3 +-- integration/api_auth_test.go | 2 +- integration/cli_test.go | 2 +- integration/helpers.go | 4 ++-- integration/hsic/hsic.go | 6 +++++- integration/integrationutil/util.go | 6 +++--- integration/route_test.go | 1 + integration/tsic/tsic.go | 2 +- 23 files changed, 48 insertions(+), 34 deletions(-) diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 4cdfe126..26ad8084 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -16,6 +16,7 @@ import ( ) const ( + //nolint:gosec // G101: This is a flag name, not a credential bypassFlag = "bypass-grpc-and-access-database-directly" ) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index c9791098..706e18e2 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -787,11 +787,13 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID, } // Write stdout logs + //nolint:gosec // G306: Log files are meant to be world-readable if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { return fmt.Errorf("failed to write stdout log: %w", err) } // Write stderr logs + //nolint:gosec // G306: Log files are meant to be world-readable if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { return fmt.Errorf("failed to write stderr log: %w", err) } diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index bd81d6da..dc02286b 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -18,7 +18,7 @@ import ( "github.com/docker/docker/client" ) -// Sentinel errors for stats collection. +// ErrStatsCollectionAlreadyStarted is returned when stats collection is already running. var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started") // Stats calculation constants. diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go index 2ad50885..a1581c51 100644 --- a/hscontrol/db/ephemeral_garbage_collector_test.go +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -57,7 +57,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { deletionWg.Add(numNodes) for i := 1; i <= numNodes; i++ { - gc.Schedule(types.NodeID(i), expiry) + gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // G115: Test code with controlled values } // Wait for all scheduled deletions to complete @@ -70,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // Schedule and immediately cancel to test that part of the code for i := numNodes + 1; i <= numNodes*2; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec // G115: Test code with controlled values gc.Schedule(nodeID, time.Hour) gc.Cancel(nodeID) } @@ -394,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { case <-stopScheduling: return default: - nodeID := types.NodeID(baseNodeID + j + 1) - gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test + nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // G115: Test code with controlled values + gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test atomic.AddInt64(&scheduledCount, 1) // Yield to other goroutines to introduce variability diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 56408809..3965c855 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -221,7 +221,7 @@ func SetTags( return nil } -// SetTags takes a Node struct pointer and update the forced tags. +// SetApprovedRoutes updates the approved routes for a node. func SetApprovedRoutes( tx *gorm.DB, nodeID types.NodeID, diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index c5904353..00c5985f 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -332,7 +332,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { return nil } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. +// ExpirePreAuthKey marks a PreAuthKey as expired. func ExpirePreAuthKey(tx *gorm.DB, id uint64) error { now := time.Now() return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index f3807e21..3d4f64ee 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -161,7 +161,9 @@ func derpRandom() *rand.Rand { derpRandomOnce.Do(func() { seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String()) + //nolint:gosec // G404,G115: Intentionally using math/rand for deterministic DERP server ID rnd := rand.New(rand.NewSource(0)) + //nolint:gosec // G115: Checksum is always positive and fits in int64 rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) derpRandomInst = rnd }) diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 7cd88abe..f119def9 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -16,7 +16,7 @@ import ( "tailscale.com/util/set" ) -// Sentinel errors for extra records. +// ErrPathIsDirectory is returned when a path is a directory instead of a file. var ErrPathIsDirectory = errors.New("path is a directory, only file is supported") type ExtraRecordsMan struct { diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index f1a2c88c..a904c533 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -55,7 +55,7 @@ type HTTPError struct { func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } func (e HTTPError) Unwrap() error { return e.Err } -// Error returns an HTTPError containing the given information. +// NewHTTPError returns an HTTPError containing the given information. func NewHTTPError(code int, msg string, err error) HTTPError { return HTTPError{Code: code, Msg: msg, Err: err} } diff --git a/hscontrol/noise.go b/hscontrol/noise.go index b28fa0bb..bc097519 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -31,7 +31,7 @@ const ( earlyPayloadMagic = "\xff\xff\xffTS" ) -// Sentinel errors for noise server. +// ErrUnsupportedClientVersion is returned when a client version is not supported. var ErrUnsupportedClientVersion = errors.New("unsupported client version") type noiseServer struct { diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index dc5969b5..371bba5e 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -94,7 +94,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { } for i, n := range initialNodes { - n.ID = types.NodeID(i + 1) + n.ID = types.NodeID(i + 1) //nolint:gosec // G115: Test code with small values } pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice()) @@ -187,7 +187,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { } if !found { - n.ID = types.NodeID(len(initialNodes) + i + 1) + n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec // G115: Test code with small values } } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 9740d063..2ce2aea8 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -40,6 +40,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, peersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Empty(t, snapshot.nodesByID) assert.Empty(t, snapshot.allNodes) assert.Empty(t, snapshot.peersByNode) @@ -56,6 +57,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 1) assert.Len(t, snapshot.allNodes, 1) assert.Len(t, snapshot.peersByNode, 1) @@ -79,6 +81,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 2) assert.Len(t, snapshot.allNodes, 2) assert.Len(t, snapshot.peersByNode, 2) @@ -915,6 +918,7 @@ func TestNodeStoreConcurrentPutNode(t *testing.T) { go func(nodeID int) { defer wg.Done() + //nolint:gosec // G115: Test code with controlled values node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") resultNode := store.PutNode(node) @@ -954,6 +958,7 @@ func TestNodeStoreBatchingEfficiency(t *testing.T) { go func(nodeID int) { defer wg.Done() + //nolint:gosec // G115: Test code with controlled values node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") resultNode := store.PutNode(node) @@ -1062,7 +1067,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) { const ops = 100 for i := range ops { - nodeID := types.NodeID(i + 1) + nodeID := types.NodeID(i + 1) //nolint:gosec // G115: Test code with controlled values node := createConcurrentTestNode(nodeID, "cleanup-node") resultNode := store.PutNode(node) assert.True(t, resultNode.Valid()) @@ -1106,7 +1111,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { // Launch all PutNode operations concurrently for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec // G115: Test code with controlled values wg.Add(1) @@ -1132,7 +1137,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { wg = sync.WaitGroup{} for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec // G115: Test code with controlled values wg.Add(1) @@ -1197,7 +1202,7 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() - nonExistentID := types.NodeID(999 + i) + nonExistentID := types.NodeID(999 + i) //nolint:gosec // G115: Test code with controlled values updateCallCount := 0 fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) @@ -1221,7 +1226,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) { defer store.Stop() for i := 0; b.Loop(); i++ { - nodeID := types.NodeID(i + 1) + nodeID := types.NodeID(i + 1) //nolint:gosec // G115: Benchmark code with controlled values node := createConcurrentTestNode(nodeID, "bench-node") store.PutNode(node) store.UpdateNode(nodeID, func(n *types.Node) { diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 82cf9d58..60292912 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -13,7 +13,7 @@ import ( "tailscale.com/types/logger" ) -// Sentinel errors for tailsql service. +// ErrNoCertDomains is returned when no cert domains are available for HTTPS. var ErrNoCertDomains = errors.New("no cert domains available for HTTPS") func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error { diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 3b3e59e2..18956d7a 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -114,7 +114,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { return &protoKey } -// canUsePreAuthKey checks if a pre auth key can be used. +// Validate checks if a pre auth key can be used. func (pak *PreAuthKey) Validate() error { if pak == nil { return PAKError("invalid authkey") diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index c724c909..2bf30a0c 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -19,7 +19,7 @@ import ( "tailscale.com/tailcfg" ) -// Sentinel errors for user types. +// ErrCannotParseBool is returned when a value cannot be parsed as a boolean. var ErrCannotParseBool = errors.New("could not parse value as boolean") type UserID uint64 @@ -155,7 +155,7 @@ func (u UserView) ID() uint { func (u *User) TailscaleLogin() tailcfg.Login { return tailcfg.Login{ - ID: tailcfg.LoginID(u.ID), + ID: tailcfg.LoginID(u.ID), //nolint:gosec // G115: User IDs are always positive and fit in int64 Provider: u.Provider, LoginName: u.Username(), DisplayName: u.Display(), @@ -201,8 +201,8 @@ func (u *User) Proto() *v1.User { } } -// JumpCloud returns a JSON where email_verified is returned as a -// string "true" or "false" instead of a boolean. +// FlexibleBoolean handles JSON where email_verified is returned as a +// string "true" or "false" instead of a boolean (e.g., JumpCloud). // This maps bool to a specific type with a custom unmarshaler to // ensure we can decode it from a string. // https://github.com/juanfont/headscale/issues/2293 diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index c4b9dbd5..5c9585be 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -294,11 +294,10 @@ func IsCI() bool { return false } -// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults +// EnsureHostname extracts a hostname from Hostinfo, providing sensible defaults // if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences // and ensures nodes always have a valid hostname. // The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123). -// EnsureHostname guarantees a valid hostname for node registration. // This function never fails - it always returns a valid hostname. // // Strategy: diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 989cb9d4..0fbec32f 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -404,7 +404,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { var response v1.ListUsersResponse err = protojson.Unmarshal([]byte(responseBody.String()), &response) - assert.NoError(t, err, "Response should be valid protobuf JSON") + require.NoError(t, err, "Response should be valid protobuf JSON") users := response.GetUsers() assert.Len(t, users, 2, "Should have 2 users") diff --git a/integration/cli_test.go b/integration/cli_test.go index 65d82444..1ca23f40 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1467,7 +1467,7 @@ func TestNodeRenameCommand(t *testing.T) { } nodes := make([]*v1.Node, len(regIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( diff --git a/integration/helpers.go b/integration/helpers.go index 6499792d..df89b1ea 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -6,7 +6,9 @@ import ( "errors" "fmt" "io" + "maps" "net/netip" + "slices" "strconv" "strings" "testing" @@ -23,8 +25,6 @@ import ( "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "maps" - "slices" "tailscale.com/tailcfg" ) diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 3eef2d97..d4dbb85b 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -202,7 +202,7 @@ func WithPostgres() Option { } } -// WithPolicy sets the policy mode for headscale. +// WithPolicyMode sets the policy mode for headscale. func WithPolicyMode(mode types.PolicyMode) Option { return func(hsic *HeadscaleInContainer) { hsic.policyMode = mode @@ -781,6 +781,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { switch header.Typeflag { case tar.TypeDir: // Create directory + //nolint:gosec // G115: tar.Header.Mode is int64, safe to convert to uint32 for permissions err := os.MkdirAll(targetPath, os.FileMode(header.Mode)) if err != nil { return fmt.Errorf("failed to create directory %s: %w", targetPath, err) @@ -797,6 +798,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { return fmt.Errorf("failed to create file %s: %w", targetPath, err) } + //nolint:gosec // G110: Trusted tar archive from our own container if _, err := io.Copy(outFile, tarReader); err != nil { outFile.Close() return fmt.Errorf("failed to copy file contents: %w", err) @@ -805,6 +807,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { outFile.Close() // Set file permissions + //nolint:gosec // G115: tar.Header.Mode is int64, safe to convert to uint32 for permissions if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { return fmt.Errorf("failed to set file permissions: %w", err) } @@ -903,6 +906,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { return fmt.Errorf("failed to create database file: %w", err) } + //nolint:gosec // G110: Trusted tar archive from our own container written, err := io.Copy(outFile, tarReader) outFile.Close() diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 3e257a8e..71dd8897 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -220,19 +220,19 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type for _, mr := range mrs { for _, peer := range mr.Peers { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec // G115: tailcfg.NodeID is int64, safe for test code } } for _, peer := range mr.PeersChanged { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec // G115: tailcfg.NodeID is int64, safe for test code } } for _, peer := range mr.PeersChangedPatch { if peer.Online != nil { - res[nid][types.NodeID(peer.NodeID)] = *peer.Online + res[nid][types.NodeID(peer.NodeID)] = *peer.Online //nolint:gosec // G115: tailcfg.NodeID is int64, safe for test code } } } diff --git a/integration/route_test.go b/integration/route_test.go index b6fc8d85..3d24da99 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -2887,6 +2887,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } // assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT. +//nolint:testifylint // CollectT requires assert, not require func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { assert.NotNil(c, tr) assert.True(c, tr.Success) diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index d7ff1714..9b103e53 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -639,7 +639,7 @@ func (t *TailscaleInContainer) Execute( return stdout, stderr, nil } -// Retrieve container logs. +// Logs retrieves container logs. func (t *TailscaleInContainer) Logs(stdout, stderr io.Writer) error { return dockertestutil.WriteLog( t.pool, From b36438bf90048df073e4879c81405b39c64f1824 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 21 Jan 2026 10:38:15 +0000 Subject: [PATCH 29/30] all: disable thelper linter and clean up nolint comments Disable the thelper linter which triggers on inline test closures in table-driven tests. These closures are intentionally not standalone helpers and don't benefit from t.Helper(). Also remove explanatory comments from nolint directives throughout the codebase as they add noise without providing significant value. --- .golangci.yaml | 1 + cmd/headscale/cli/policy.go | 3 +- cmd/headscale/cli/users.go | 1 + cmd/hi/cleanup.go | 13 +- cmd/hi/docker.go | 19 ++- cmd/hi/doctor.go | 14 +-- cmd/hi/run.go | 2 + hscontrol/app.go | 5 +- hscontrol/auth.go | 2 +- hscontrol/auth_test.go | 111 ++++++++++-------- hscontrol/db/api_key.go | 2 +- hscontrol/db/db.go | 14 ++- .../db/ephemeral_garbage_collector_test.go | 6 +- hscontrol/db/node.go | 3 +- hscontrol/db/node_test.go | 30 ++--- hscontrol/db/sqliteconfig/integration_test.go | 3 + hscontrol/db/text_serialiser.go | 3 +- hscontrol/db/users.go | 1 + hscontrol/db/users_test.go | 2 +- hscontrol/derp/derp.go | 4 +- hscontrol/derp/server/derp_server.go | 1 + hscontrol/dns/extrarecords.go | 1 + hscontrol/handlers.go | 2 + hscontrol/mapper/batcher.go | 4 +- hscontrol/mapper/batcher_test.go | 12 +- hscontrol/mapper/builder.go | 1 + hscontrol/mapper/builder_test.go | 2 +- hscontrol/mapper/mapper.go | 6 +- hscontrol/noise.go | 3 +- hscontrol/oidc.go | 4 +- hscontrol/policy/pm.go | 1 + .../policy/policy_route_approval_test.go | 2 +- hscontrol/policy/policy_test.go | 1 + hscontrol/policy/v2/filter.go | 5 +- hscontrol/policy/v2/policy_test.go | 21 ++-- hscontrol/policy/v2/types.go | 50 +++++--- hscontrol/policy/v2/types_test.go | 12 +- hscontrol/poll.go | 10 +- hscontrol/state/node_store.go | 4 +- hscontrol/state/node_store_test.go | 17 +-- hscontrol/state/state.go | 8 +- hscontrol/templates/design.go | 102 ++++++++-------- hscontrol/types/config.go | 1 + hscontrol/types/node.go | 8 +- hscontrol/types/node_test.go | 8 +- hscontrol/types/users.go | 4 +- hscontrol/types/version.go | 4 +- hscontrol/util/util.go | 4 +- hscontrol/util/util_test.go | 11 +- integration/acl_test.go | 3 +- integration/api_auth_test.go | 14 +-- integration/auth_key_test.go | 4 +- integration/auth_oidc_test.go | 2 +- integration/cli_test.go | 33 +++--- integration/derp_verify_endpoint_test.go | 1 + integration/dns_test.go | 1 + integration/dockertestutil/execute.go | 2 + integration/embedded_derp_test.go | 4 +- integration/helpers.go | 11 +- integration/hsic/hsic.go | 14 ++- integration/integrationutil/util.go | 6 +- integration/route_test.go | 9 +- integration/scenario.go | 5 + integration/ssh_test.go | 2 +- integration/tags_test.go | 12 +- integration/tsic/tsic.go | 2 + 66 files changed, 397 insertions(+), 276 deletions(-) diff --git a/.golangci.yaml b/.golangci.yaml index eda3bed4..a8a219d7 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -25,6 +25,7 @@ linters: - revive - tagliatelle - testpackage + - thelper - varnamelen - wrapcheck - wsl diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 26ad8084..f31d573a 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -16,7 +16,7 @@ import ( ) const ( - //nolint:gosec // G101: This is a flag name, not a credential + //nolint:gosec bypassFlag = "bypass-grpc-and-access-database-directly" ) @@ -178,6 +178,7 @@ var setPolicy = &cobra.Command{ defer cancel() defer conn.Close() + //nolint:noinlineerr if _, err := client.SetPolicy(ctx, request); err != nil { ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 086a82b6..f7db7ed4 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -100,6 +100,7 @@ var createUserCmd = &cobra.Command{ } if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" { + //nolint:noinlineerr if _, err := url.Parse(pictureURL); err != nil { ErrorOutput( err, diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 70480239..d56bb589 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -25,6 +25,7 @@ func cleanupBeforeTest(ctx context.Context) error { return fmt.Errorf("failed to clean stale test containers: %w", err) } + //nolint:noinlineerr if err := pruneDockerNetworks(ctx); err != nil { return fmt.Errorf("failed to prune networks: %w", err) } @@ -55,7 +56,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI // killTestContainers terminates and removes all test containers. func killTestContainers(ctx context.Context) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -110,7 +111,7 @@ func killTestContainers(ctx context.Context) error { // This function filters containers by the hi.run-id label to only affect containers // belonging to the specified test run, leaving other concurrent test runs untouched. func killTestContainersByRunID(ctx context.Context, runID string) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -153,7 +154,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error { // This is useful for cleaning up leftover containers from previous crashed or interrupted test runs // without interfering with currently running concurrent tests. func cleanupStaleTestContainers(ctx context.Context) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -228,7 +229,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container // pruneDockerNetworks removes unused Docker networks. func pruneDockerNetworks(ctx context.Context) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -251,7 +252,7 @@ func pruneDockerNetworks(ctx context.Context) error { // cleanOldImages removes test-related and old dangling Docker images. func cleanOldImages(ctx context.Context) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -304,7 +305,7 @@ func cleanOldImages(ctx context.Context) error { // cleanCacheVolume removes the Docker volume used for Go module cache. func cleanCacheVolume(ctx context.Context) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index 706e18e2..62b07f2f 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -38,8 +38,10 @@ const ( ) // runTestContainer executes integration tests in a Docker container. +// +//nolint:gocyclo func runTestContainer(ctx context.Context, config *RunConfig) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -62,6 +64,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } const dirPerm = 0o755 + //nolint:noinlineerr if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { return fmt.Errorf("failed to create logs directory: %w", err) } @@ -83,6 +86,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } imageName := "golang:" + config.GoVersion + //nolint:noinlineerr if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { return fmt.Errorf("failed to ensure image availability: %w", err) } @@ -96,6 +100,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { log.Printf("Created container: %s", resp.ID) } + //nolint:noinlineerr if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { return fmt.Errorf("failed to start container: %w", err) } @@ -111,7 +116,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if config.Stats { var err error - //nolint:contextcheck // NewStatsCollector internal functions don't accept context + //nolint:contextcheck statsCollector, err = NewStatsCollector() if err != nil { if config.Verbose { @@ -145,6 +150,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } // Extract artifacts from test containers before cleanup + //nolint:noinlineerr if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { log.Printf("Warning: failed to extract artifacts from containers: %v", err) } @@ -424,6 +430,7 @@ func isContainerFinalized(state *container.State) bool { func findProjectRoot(startPath string) string { current := startPath for { + //nolint:noinlineerr if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { return current } @@ -496,6 +503,7 @@ func getCurrentDockerContext() (*DockerContext, error) { } var contexts []DockerContext + //nolint:noinlineerr if err := json.Unmarshal(output, &contexts); err != nil { return nil, fmt.Errorf("failed to parse docker context: %w", err) } @@ -634,7 +642,7 @@ func listControlFiles(logsDir string) { // extractArtifactsFromContainers collects container logs and files from the specific test run. func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -740,6 +748,7 @@ func extractContainerArtifacts(ctx context.Context, cli *client.Client, containe } // Extract container logs + //nolint:noinlineerr if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { return fmt.Errorf("failed to extract logs: %w", err) } @@ -787,13 +796,13 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID, } // Write stdout logs - //nolint:gosec // G306: Log files are meant to be world-readable + //nolint:gosec,mnd,noinlineerr if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { return fmt.Errorf("failed to write stdout log: %w", err) } // Write stderr logs - //nolint:gosec // G306: Log files are meant to be world-readable + //nolint:gosec,mnd,noinlineerr if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { return fmt.Errorf("failed to write stderr log: %w", err) } diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 2fae4fbe..0c3a4764 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -38,15 +38,15 @@ func runDoctorCheck(ctx context.Context) error { } // Check 3: Go installation - //nolint:contextcheck // These checks don't need context + //nolint:contextcheck results = append(results, checkGoInstallation()) // Check 4: Git repository - //nolint:contextcheck // These checks don't need context + //nolint:contextcheck results = append(results, checkGitRepository()) // Check 5: Required files - //nolint:contextcheck // These checks don't need context + //nolint:contextcheck results = append(results, checkRequiredFiles()) // Display results @@ -89,7 +89,7 @@ func checkDockerBinary() DoctorResult { // checkDockerDaemon verifies Docker daemon is running and accessible. func checkDockerDaemon(ctx context.Context) DoctorResult { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -129,7 +129,7 @@ func checkDockerDaemon(ctx context.Context) DoctorResult { // checkDockerContext verifies Docker context configuration. func checkDockerContext(_ context.Context) DoctorResult { - //nolint:contextcheck // getCurrentDockerContext doesn't accept context + //nolint:contextcheck contextInfo, err := getCurrentDockerContext() if err != nil { return DoctorResult{ @@ -160,7 +160,7 @@ func checkDockerContext(_ context.Context) DoctorResult { // checkDockerSocket verifies Docker socket accessibility. func checkDockerSocket(ctx context.Context) DoctorResult { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -198,7 +198,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult { // checkGolangImage verifies the golang Docker image is available locally or can be pulled. func checkGolangImage(ctx context.Context) DoctorResult { - //nolint:contextcheck // createDockerClient internal functions don't accept context + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return DoctorResult{ diff --git a/cmd/hi/run.go b/cmd/hi/run.go index 881be20f..4a0506e5 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -68,8 +68,10 @@ func runIntegrationTest(env *command.Env) error { func detectGoVersion() string { goModPath := filepath.Join("..", "..", "go.mod") + //nolint:noinlineerr if _, err := os.Stat("go.mod"); err == nil { goModPath = "go.mod" + //nolint:noinlineerr } else if _, err := os.Stat("../../go.mod"); err == nil { goModPath = "../../go.mod" } diff --git a/hscontrol/app.go b/hscontrol/app.go index a333c415..cadcd227 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -299,7 +299,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { case <-derpTickerChan: log.Info().Msg("Fetching DERPMap updates") - //nolint:contextcheck // GetDERPMap internal functions don't accept context + //nolint:contextcheck derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { derpMap, err := derp.GetDERPMap(h.cfg.DERP) if err != nil { @@ -407,6 +407,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writeUnauthorized := func(statusCode int) { writer.WriteHeader(statusCode) + //nolint:noinlineerr if _, err := writer.Write([]byte("Unauthorized")); err != nil { log.Error().Err(err).Msg("writing HTTP response failed") } @@ -886,7 +887,7 @@ func (h *Headscale) Serve() error { // Close state connections info("closing state and database") - //nolint:contextcheck // Close method signature does not accept context + //nolint:contextcheck err = h.state.Close() if err != nil { log.Error().Err(err).Msg("failed to close state") diff --git a/hscontrol/auth.go b/hscontrol/auth.go index dc10d1a2..1d49f5b4 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -356,7 +356,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // If node is not valid, it means an ephemeral node was deleted during logout if !node.Valid() { h.Change(changed) - return nil, nil + return nil, nil //nolint:nilnil } // This is a bit of a back and forth, but we have a bit of a chicken and egg diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index e6c46d73..bc6c7cc2 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -38,6 +38,7 @@ type interactiveStep struct { callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked } +//nolint:gocyclo func TestAuthenticationFlows(t *testing.T) { // Shared test keys for consistent behavior across test cases machineKey1 := key.NewMachine() @@ -76,6 +77,8 @@ func TestAuthenticationFlows(t *testing.T) { { name: "preauth_key_valid_new_node", setupFunc: func(t *testing.T, app *Headscale) (string, error) { + t.Helper() + user := app.state.CreateUserForTest("preauth-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -97,7 +100,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -119,6 +122,8 @@ func TestAuthenticationFlows(t *testing.T) { { name: "preauth_key_reusable_multiple_nodes", setupFunc: func(t *testing.T, app *Headscale) (string, error) { + t.Helper() + user := app.state.CreateUserForTest("reusable-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -163,7 +168,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -232,7 +237,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, wantError: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // First node should exist, second should not @@ -266,7 +271,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -299,7 +304,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -336,7 +341,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -365,7 +370,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -433,7 +438,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { @@ -488,7 +493,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key + machineKey: machineKey2.Public, // Different machine key wantError: true, }, // TEST: Existing node cannot extend expiry without re-auth @@ -538,7 +543,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Expired node must re-authenticate @@ -601,7 +606,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), // Future expiry } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.NodeKeyExpired) @@ -655,7 +660,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), // Logout } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.NodeKeyExpired) @@ -711,7 +716,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -749,7 +754,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Invalid followup URL is rejected @@ -768,7 +773,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Non-existent registration ID is rejected @@ -787,7 +792,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -823,7 +828,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -861,7 +866,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -908,7 +913,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -942,7 +947,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1094,7 +1099,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1158,7 +1163,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuthURL: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.Contains(t, resp.AuthURL, "register/") @@ -1227,7 +1232,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1277,7 +1282,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Time{}, // Zero time } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1324,7 +1329,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1380,7 +1385,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: false, // Should not be authorized yet - needs to use new AuthURL validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Should get a new AuthURL, not an error @@ -1405,7 +1410,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Wrong followup path format is rejected @@ -1424,7 +1429,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -1455,7 +1460,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1509,7 +1514,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1548,7 +1553,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1591,7 +1596,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1635,7 +1640,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(12 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1707,7 +1712,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1781,7 +1786,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key + machineKey: machineKey1.Public, // Same machine key requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1835,7 +1840,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: false, // Should not be authorized yet - needs to use new AuthURL validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Should get a new AuthURL, not an error @@ -1845,13 +1850,13 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the response contains a valid registration URL authURL, err := url.Parse(resp.AuthURL) - assert.NoError(t, err, "AuthURL should be a valid URL") + require.NoError(t, err, "AuthURL should be a valid URL") assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/") // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") newRegID, err := types.RegistrationIDFromString(newRegIDStr) - assert.NoError(t, err, "should be able to parse new registration ID") + require.NoError(t, err, "should be able to parse new registration ID") // Verify new registration entry exists in cache _, found := app.state.GetRegistrationCacheEntry(newRegID) @@ -1905,7 +1910,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now(), // Exactly now (edge case between past and future) } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { @@ -1937,7 +1942,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2003,7 +2008,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2049,7 +2054,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2103,7 +2108,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // This test validates concurrent interactive registration attempts assert.Contains(t, resp.AuthURL, "/register/") @@ -2211,7 +2216,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2255,7 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2293,7 +2298,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Get initial AuthURL and extract registration ID authURL := resp.AuthURL @@ -2315,7 +2320,7 @@ func TestAuthenticationFlows(t *testing.T) { nil, "error-test-method", ) - assert.Error(t, err, "should fail with invalid user ID") + require.Error(t, err, "should fail with invalid user ID") // Cache entry should still exist after auth error (for retry scenarios) _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) @@ -2347,7 +2352,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Test multiple interactive registration attempts for the same node can coexist authURL1 := resp.AuthURL @@ -2407,7 +2412,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { authURL1 := resp.AuthURL regID1, err := extractRegistrationIDFromAuthURL(authURL1) @@ -2594,6 +2599,8 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { validateCompleteResponse bool }, app *Headscale, dynamicValue string, ) { + t.Helper() + // Build initial request req := tt.request(dynamicValue) machineKey := tt.machineKey() @@ -2731,7 +2738,9 @@ func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, err } // validateCompleteRegistrationResponse performs comprehensive validation of a registration response. -func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { +func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, _ tailcfg.RegisterRequest) { + t.Helper() + // Basic response validation require.NotNil(t, resp, "response should not be nil") require.True(t, resp.MachineAuthorized, "machine should be authorized") @@ -3260,7 +3269,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public()) // This is the assertion that currently FAILS in v0.27.0 - assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + require.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") if err != nil { t.Logf("Error received (this is the bug): %v", err) diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 7457670c..d179dca9 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -13,7 +13,7 @@ import ( ) const ( - apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential + apiKeyPrefix = "hskey-api-" //nolint:gosec apiKeyPrefixLength = 12 apiKeyHashLength = 64 diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 02794627..b876ee84 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -75,7 +75,7 @@ func NewHeadscaleDatabase( ID: "202501221827", Migrate: func(tx *gorm.DB) error { // Remove any invalid routes associated with a node that does not exist. - //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only + //nolint:staticcheck if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error if err != nil { @@ -84,7 +84,7 @@ func NewHeadscaleDatabase( } // Remove any invalid routes without a node_id. - //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only + //nolint:staticcheck if tx.Migrator().HasTable(&types.Route{}) { err := tx.Exec("delete from routes where node_id is null").Error if err != nil { @@ -92,7 +92,7 @@ func NewHeadscaleDatabase( } } - //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only + //nolint:staticcheck err := tx.AutoMigrate(&types.Route{}) if err != nil { return fmt.Errorf("automigrating types.Route: %w", err) @@ -158,7 +158,7 @@ AND auth_key_id NOT IN ( nodeRoutes := map[uint64][]netip.Prefix{} - //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only + //nolint:staticcheck var routes []types.Route err = tx.Find(&routes).Error @@ -188,7 +188,7 @@ AND auth_key_id NOT IN ( } // Drop the old table. - //nolint:staticcheck // SA1019: types.Route kept for GORM migrations only + //nolint:staticcheck _ = tx.Migrator().DropTable(&types.Route{}) return nil @@ -798,6 +798,7 @@ AND auth_key_id NOT IN ( }, } + //nolint:noinlineerr if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { return nil, fmt.Errorf("validating schema: %w", err) } @@ -932,6 +933,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig // Get the current foreign key status var fkOriginallyEnabled int + //nolint:noinlineerr if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { return fmt.Errorf("checking foreign key status: %w", err) } @@ -980,11 +982,13 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig } } + //nolint:noinlineerr if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { return fmt.Errorf("restoring foreign keys: %w", err) } // Run the rest of the migrations + //nolint:noinlineerr if err := migrations.Migrate(); err != nil { return err } diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go index a1581c51..8e8a1109 100644 --- a/hscontrol/db/ephemeral_garbage_collector_test.go +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -57,7 +57,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { deletionWg.Add(numNodes) for i := 1; i <= numNodes; i++ { - gc.Schedule(types.NodeID(i), expiry) //nolint:gosec // G115: Test code with controlled values + gc.Schedule(types.NodeID(i), expiry) //nolint:gosec } // Wait for all scheduled deletions to complete @@ -70,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // Schedule and immediately cancel to test that part of the code for i := numNodes + 1; i <= numNodes*2; i++ { - nodeID := types.NodeID(i) //nolint:gosec // G115: Test code with controlled values + nodeID := types.NodeID(i) //nolint:gosec gc.Schedule(nodeID, time.Hour) gc.Cancel(nodeID) } @@ -394,7 +394,7 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { case <-stopScheduling: return default: - nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec // G115: Test code with controlled values + nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test atomic.AddInt64(&scheduledCount, 1) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 3965c855..e7468207 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -253,6 +253,7 @@ func SetApprovedRoutes( return err } + //nolint:noinlineerr if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { return fmt.Errorf("updating approved routes: %w", err) } @@ -655,7 +656,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) panic("CreateNodeForTest requires a valid user") } - nodeName := "testnode" + nodeName := "testnode" //nolint:goconst if len(hostname) > 0 && hostname[0] != "" { nodeName = hostname[0] } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 58d36463..9ff96eb9 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -98,7 +98,7 @@ func TestExpireNode(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -143,7 +143,7 @@ func TestSetTags(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -468,10 +468,10 @@ func TestAutoApproveRoutes(t *testing.T) { require.NoError(t, err) users, err := adb.ListUsers() - assert.NoError(t, err) + require.NoError(t, err) nodes, err := adb.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) pm, err := pmf(users, nodes.ViewSlice()) require.NoError(t, err) @@ -600,7 +600,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) { // Use shorter expiry for faster tests for i := range want { - go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec // test code, no overflow risk + go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec } // Wait for all deletions to complete @@ -639,11 +639,11 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) - //nolint:staticcheck // SA4006: pakEph is used in new(pakEph.ID) below + //nolint:staticcheck pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) require.NoError(t, err) @@ -724,6 +724,7 @@ func TestNodeNaming(t *testing.T) { // break your network, so they should be replaced when registering // a node. // https://github.com/juanfont/headscale/issues/2343 + //nolint:gosmopolitan nodeInvalidHostname := types.Node{ MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), @@ -822,25 +823,26 @@ func TestNodeNaming(t *testing.T) { err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "test") }) - assert.ErrorContains(t, err, "name is not unique") + require.ErrorContains(t, err, "name is not unique") // Rename invalid chars + //nolint:gosmopolitan err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[2].ID, "我的电脑") }) - assert.ErrorContains(t, err, "invalid characters") + require.ErrorContains(t, err, "invalid characters") // Rename too short err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[3].ID, "a") }) - assert.ErrorContains(t, err, "at least 2 characters") + require.ErrorContains(t, err, "at least 2 characters") // Rename with emoji err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "hostname-with-💩") }) - assert.ErrorContains(t, err, "invalid characters") + require.ErrorContains(t, err, "invalid characters") // Rename with only emoji err = db.Write(func(tx *gorm.DB) error { @@ -908,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) { }, { name: "chinese_chars_with_dash_rejected", - newName: "server-北京-01", + newName: "server-北京-01", //nolint:gosmopolitan wantErr: "invalid characters", }, { name: "chinese_only_rejected", - newName: "我的电脑", + newName: "我的电脑", //nolint:gosmopolitan wantErr: "invalid characters", }, { @@ -923,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) { }, { name: "mixed_chinese_emoji_rejected", - newName: "测试💻机器", + newName: "测试💻机器", //nolint:gosmopolitan wantErr: "invalid characters", }, { diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index fa39f958..3d1d07c7 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -102,6 +102,7 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { defer db.Close() // Test connection + //nolint:noinlineerr if err := db.PingContext(context.Background()); err != nil { t.Fatalf("Failed to ping database: %v", err) } @@ -181,11 +182,13 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { ); ` + //nolint:noinlineerr if _, err := db.ExecContext(context.Background(), schema); err != nil { t.Fatalf("Failed to create schema: %v", err) } // Insert parent record + //nolint:noinlineerr if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { t.Fatalf("Failed to insert parent: %v", err) } diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 7a9f7010..8489c69c 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -67,6 +67,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect ret := f.Call(args) if !ret[0].IsNil() { + //nolint:forcetypeassert return decodingError(field.Name, ret[0].Interface().(error)) } @@ -97,7 +98,7 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec // always comparable, particularly when reflection is involved: // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { - return nil, nil + return nil, nil //nolint:nilnil } b, err := v.MarshalText() diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 9145ff20..213730cf 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -97,6 +97,7 @@ func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { return err } + //nolint:noinlineerr if err = util.ValidateHostname(newName); err != nil { return err } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 40d301b3..9d2740e5 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -70,7 +70,7 @@ func TestDestroyUserErrors(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) - //nolint:staticcheck // SA4006: pak is used in new(pak.ID) below + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 3d4f64ee..2cbc02e6 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -161,9 +161,9 @@ func derpRandom() *rand.Rand { derpRandomOnce.Do(func() { seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String()) - //nolint:gosec // G404,G115: Intentionally using math/rand for deterministic DERP server ID + //nolint:gosec rnd := rand.New(rand.NewSource(0)) - //nolint:gosec // G115: Checksum is always positive and fits in int64 + //nolint:gosec rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) derpRandomInst = rnd }) diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index b0f83fb6..56fb5de9 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -314,6 +314,7 @@ func DERPBootstrapDNSHandler( defer cancel() var resolver net.Resolver + //nolint:unqueryvet for _, region := range derpMap.Regions().All() { for _, node := range region.Nodes().All() { // we don't care if we override some nodes addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index f119def9..9aad9a7d 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -104,6 +104,7 @@ func (e *ExtraRecordsMan) Run() { // and not watch it. We will therefore attempt to re-add it with a backoff. case fsnotify.Remove, fsnotify.Rename: _, err := backoff.Retry(context.Background(), func() (struct{}, error) { + //nolint:noinlineerr if _, err := os.Stat(e.path); err != nil { return struct{}{}, err } diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index a904c533..21794f99 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -91,6 +91,7 @@ func (h *Headscale) handleVerifyRequest( } var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest + //nolint:noinlineerr if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) } @@ -183,6 +184,7 @@ func (h *Headscale) HealthHandler( res.Status = "fail" } + //nolint:noinlineerr if err := json.NewEncoder(writer).Encode(res); err != nil { log.Error().Err(err).Msg("failed to encode health response") } diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index c5bbda48..06ad7009 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -86,7 +86,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t version := nc.version() if r.IsEmpty() { - return nil, nil //nolint:nilnil // Empty response means nothing to send + return nil, nil //nolint:nilnil } if nodeID == 0 { @@ -99,7 +99,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t // Handle self-only responses if r.IsSelfOnly() && r.TargetNode != nodeID { - return nil, nil //nolint:nilnil // No response needed for other nodes when self-only + return nil, nil //nolint:nilnil } // Check if this is a self-update (the changed node is the receiving node). diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 7cc746a4..d0ebee6d 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -236,8 +236,8 @@ func setupBatcherWithTestData( } derpMap, err := derp.GetDERPMap(cfg.DERP) - assert.NoError(t, err) - assert.NotNil(t, derpMap) + require.NoError(t, err) + require.NotNil(t, derpMap) state.SetDERPMap(derpMap) @@ -1108,6 +1108,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // The test verifies that channels are closed synchronously and deterministically // even when real node updates are being processed, ensuring no race conditions // occur during channel replacement with actual workload. +// + func XTestBatcherChannelClosingRace(t *testing.T) { for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { @@ -1330,6 +1332,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // real node data. The test validates that stable clients continue to function // normally and receive proper updates despite the connection churn from other clients, // ensuring system stability under concurrent load. +// +//nolint:gocyclo func TestBatcherConcurrentClients(t *testing.T) { if testing.Short() { t.Skip("Skipping concurrent client test in short mode") @@ -1608,6 +1612,8 @@ func TestBatcherConcurrentClients(t *testing.T) { // It validates that the system remains stable with no deadlocks, panics, or // missed updates under sustained high load. The test uses real node data to // generate authentic update scenarios and tracks comprehensive statistics. +// +//nolint:gocyclo,thelper func XTestBatcherScalability(t *testing.T) { if testing.Short() { t.Skip("Skipping scalability test in short mode") @@ -1636,6 +1642,7 @@ func XTestBatcherScalability(t *testing.T) { description string } + //nolint:prealloc var testCases []testCase // Generate all combinations of the test matrix @@ -2393,6 +2400,7 @@ func TestBatcherRapidReconnection(t *testing.T) { } } +//nolint:gocyclo func TestBatcherMultiConnection(t *testing.T) { for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index cd1d9a9d..801b3e17 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -278,6 +278,7 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) // WithPeersRemoved adds removed peer IDs. func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { + //nolint:prealloc var tailscaleIDs []tailcfg.NodeID for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go index 978b2c0e..653da30b 100644 --- a/hscontrol/mapper/builder_test.go +++ b/hscontrol/mapper/builder_test.go @@ -340,7 +340,7 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) { // Build should return a multierr data, err := result.Build() assert.Nil(t, data) - assert.Error(t, err) + require.Error(t, err) // The error should contain information about multiple errors assert.Contains(t, err.Error(), "multiple errors") diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 329c9b58..abf2f062 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -192,7 +192,7 @@ func (m *mapper) policyChangeResponse( // Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved removedIDs := make([]types.NodeID, len(removedPeers)) for i, id := range removedPeers { - removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent + removedIDs[i] = types.NodeID(id) //nolint:gosec } builder.WithPeersRemoved(removedIDs...) @@ -215,7 +215,7 @@ func (m *mapper) buildFromChange( resp *change.Change, ) (*tailcfg.MapResponse, error) { if resp.IsEmpty() { - return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error + return nil, nil //nolint:nilnil } // If this is a self-update (the changed node is the receiving node), @@ -307,7 +307,7 @@ func writeDebugMapResponse( func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { if debugDumpMapResponsePath == "" { - return nil, nil + return nil, nil //nolint:nilnil } return ReadMapResponsesFromDirectory(debugDumpMapResponsePath) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index bc097519..7df6f77b 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -244,7 +244,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } - //nolint:contextcheck // IIFE uses context from outer scope implicitly + //nolint:contextcheck registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { var resp *tailcfg.RegisterResponse @@ -254,6 +254,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( } var regReq tailcfg.RegisterRequest + //nolint:noinlineerr if err := json.Unmarshal(body, ®Req); err != nil { return ®Req, regErr(err) } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index de02b677..81db5271 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -69,7 +69,7 @@ func NewAuthProviderOIDC( ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already - //nolint:contextcheck // Initialization code - no parent context available + //nolint:contextcheck oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) if err != nil { return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err) @@ -238,6 +238,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( nodeExpiry := a.determineNodeExpiry(idToken.Expiry) var claims types.OIDCClaims + //nolint:noinlineerr if err := idToken.Claims(&claims); err != nil { httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) return @@ -338,6 +339,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) + //nolint:noinlineerr if _, err := writer.Write(content.Bytes()); err != nil { util.LogErr(err, "Failed to write HTTP response") } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 59627dbe..b130bc6b 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -70,6 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ } func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { + //nolint:prealloc var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index fbe6e4bb..be4f860c 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -327,7 +327,7 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { } func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { - //nolint:staticcheck // SA4006: user is used in new(user.ID) and new(user) below + //nolint:staticcheck user := types.User{ Model: gorm.Model{ID: 1}, Name: "test", diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index a46f30d2..7752f202 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -798,6 +798,7 @@ func TestReduceNodes(t *testing.T) { func TestReduceNodesFromPolicy(t *testing.T) { n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { + //nolint:prealloc var routes []netip.Prefix for _, route := range routess { routes = append(routes, netip.MustParsePrefix(route)) diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index ced8531c..b9b7f5e7 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -119,6 +119,8 @@ func (pol *Policy) compileFilterRulesForNode( // It returns a slice of filter rules because when an ACL has both autogroup:self // and other destinations, they need to be split into separate rules with different // source filtering logic. +// +//nolint:gocyclo func (pol *Policy) compileACLWithAutogroupSelf( acl ACL, users types.Users, @@ -284,13 +286,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { } } +//nolint:gocyclo func (pol *Policy) compileSSHPolicy( users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], ) (*tailcfg.SSHPolicy, error) { if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { - return nil, nil + return nil, nil //nolint:nilnil } log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname()) diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 371bba5e..f35cff0b 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -76,6 +76,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { {Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"}, } + //nolint:goconst policy := `{ "acls": [ { @@ -94,7 +95,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { } for i, n := range initialNodes { - n.ID = types.NodeID(i + 1) //nolint:gosec // G115: Test code with small values + n.ID = types.NodeID(i + 1) //nolint:gosec } pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice()) @@ -187,7 +188,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { } if !found { - n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec // G115: Test code with small values + n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec } } @@ -753,8 +754,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) { Hostname: "admin-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -764,8 +765,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) { Hostname: "user1-server", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Tags: []string{"tag:server"}, Hostinfo: &tailcfg.Hostinfo{}, } @@ -836,8 +837,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) { Hostname: "device-a", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -847,8 +848,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) { Hostname: "device-b", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Tags: []string{"tag:web"}, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 3bff4561..48baad2d 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -187,7 +187,7 @@ func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeV // Username is a string that represents a username, it must contain an @. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Username string func (u Username) Validate() error { @@ -299,7 +299,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. // Group is a special string which is always prefixed with `group:`. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Group string func (g Group) Validate() error { @@ -368,7 +368,7 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod // Tag is a special string which is always prefixed with `tag:`. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Tag string func (t Tag) Validate() error { @@ -422,7 +422,7 @@ func (t Tag) MarshalJSON() ([]byte, error) { // Host is a string that represents a hostname. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Host string func (h Host) Validate() error { @@ -482,7 +482,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView return buildIPSetMultiErr(&ips, errs) } -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Prefix netip.Prefix func (p Prefix) Validate() error { @@ -530,6 +530,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { return err } + //nolint:noinlineerr if err := p.Validate(); err != nil { return err } @@ -572,7 +573,7 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild // AutoGroup is a special string which is always prefixed with `autogroup:`. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type AutoGroup string const ( @@ -622,6 +623,7 @@ func (ag AutoGroup) MarshalJSON() ([]byte, error) { func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder + //nolint:exhaustive switch ag { case AutoGroupInternet: return util.TheInternet(), nil @@ -724,6 +726,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { return err } + //nolint:noinlineerr if err := ve.Validate(); err != nil { return err } @@ -804,7 +807,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { return nil } -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { @@ -1051,7 +1054,7 @@ type Usernames []Username // Groups are a map of Group to a list of Username. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Groups map[Group]Usernames func (g Groups) Contains(group *Group) error { @@ -1146,7 +1149,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { // Hosts are alias for IP addresses or subnets. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { @@ -1344,7 +1347,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. // Action represents the action to take for an ACL rule. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Action string const ( @@ -1353,7 +1356,7 @@ const ( // SSHAction represents the action to take for an SSH rule. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type SSHAction string const ( @@ -1411,7 +1414,7 @@ func (a SSHAction) MarshalJSON() ([]byte, error) { // Protocol represents a network protocol with its IANA number and descriptions. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type Protocol string const ( @@ -1439,6 +1442,7 @@ func (p Protocol) String() string { // Description returns the human-readable description of the Protocol. func (p Protocol) Description() string { + //nolint:exhaustive switch p { case ProtocolICMP: return "Internet Control Message Protocol" @@ -1476,6 +1480,7 @@ func (p Protocol) Description() string { // parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. func (p Protocol) parseProtocol() []int { + //nolint:exhaustive switch p { case "": // Empty protocol applies to TCP and UDP traffic only @@ -1532,6 +1537,7 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { // validate checks if the Protocol is valid. func (p Protocol) validate() error { + //nolint:exhaustive switch p { case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP, ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE, @@ -1598,6 +1604,7 @@ type ACL struct { func (a *ACL) UnmarshalJSON(b []byte) error { // First unmarshal into a map to filter out comment fields var raw map[string]any + //nolint:noinlineerr if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil { return err } @@ -1623,6 +1630,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { var temp aclAlias // Unmarshal into the temporary struct using the v2 JSON options + //nolint:noinlineerr if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil { return err } @@ -1759,6 +1767,8 @@ func validateAutogroupForSSHUser(user *AutoGroup) error { // the unmarshaling process. // It runs through all rules and checks if there are any inconsistencies // in the policy that needs to be addressed before it can be used. +// +//nolint:gocyclo func (p *Policy) validate() error { if p == nil { panic("passed nil policy") @@ -1808,14 +1818,15 @@ func (p *Policy) validate() error { } for _, dst := range acl.Destinations { + //nolint:gocritic switch dst.Alias.(type) { case *Host: - h := dst.Alias.(*Host) + h := dst.Alias.(*Host) //nolint:forcetypeassert if !p.Hosts.exist(*h) { errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: - ag := dst.Alias.(*AutoGroup) + ag := dst.Alias.(*AutoGroup) //nolint:forcetypeassert err := validateAutogroupSupported(ag) if err != nil { @@ -1829,14 +1840,14 @@ func (p *Policy) validate() error { continue } case *Group: - g := dst.Alias.(*Group) + g := dst.Alias.(*Group) //nolint:forcetypeassert err := p.Groups.Contains(g) if err != nil { errs = append(errs, err) } case *Tag: - tagOwner := dst.Alias.(*Tag) + tagOwner := dst.Alias.(*Tag) //nolint:forcetypeassert err := p.TagOwners.Contains(tagOwner) if err != nil { @@ -2013,7 +2024,7 @@ type SSH struct { // SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. // It can be a list of usernames, groups, tags or autogroups. // -//nolint:recvcheck // Mixed receivers: pointer for UnmarshalJSON, value for read-only methods +//nolint:recvcheck type SSHSrcAliases []Alias // MarshalJSON marshals the Groups to JSON. @@ -2192,7 +2203,7 @@ func (u SSHUser) MarshalJSON() ([]byte, error) { // This is the only entrypoint of reading a policy from a file or other source. func unmarshalPolicy(b []byte) (*Policy, error) { if len(b) == 0 { - return nil, nil + return nil, nil //nolint:nilnil } var policy Policy @@ -2204,7 +2215,9 @@ func unmarshalPolicy(b []byte) (*Policy, error) { ast.Standardize() + //nolint:noinlineerr if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { + //nolint:noinlineerr if serr, ok := errors.AsType[*json.SemanticError](err); ok && errors.Is(serr.Err, json.ErrUnknownName) { ptr := serr.JSONPointer name := ptr.LastToken() @@ -2215,6 +2228,7 @@ func unmarshalPolicy(b []byte) (*Policy, error) { return nil, fmt.Errorf("parsing policy from bytes: %w", err) } + //nolint:noinlineerr if err := policy.validate(); err != nil { return nil, err } diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 05f2b085..ddc32fba 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -1605,17 +1605,17 @@ func TestResolvePolicy(t *testing.T) { // Extract users to variables so we can take their addresses // The variables below are all used in new() calls in the test cases. - //nolint:staticcheck // SA4006: testuser is used in new(testuser) below + //nolint:staticcheck testuser := users["testuser"] - //nolint:staticcheck // SA4006: groupuser is used in new(groupuser) below + //nolint:staticcheck groupuser := users["groupuser"] - //nolint:staticcheck // SA4006: groupuser1 is used in new(groupuser1) below + //nolint:staticcheck groupuser1 := users["groupuser1"] - //nolint:staticcheck // SA4006: groupuser2 is used in new(groupuser2) below + //nolint:staticcheck groupuser2 := users["groupuser2"] - //nolint:staticcheck // SA4006: notme is used in new(notme) below + //nolint:staticcheck notme := users["notme"] - //nolint:staticcheck // SA4006: testuser2 is used in new(testuser2) below + //nolint:staticcheck testuser2 := users["testuser2"] tests := []struct { diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 464d252d..9864983a 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -214,6 +214,7 @@ func (m *mapSession) serveLongPoll() { // adding this before connecting it to the state ensure that // it does not miss any updates that might be sent in the split // time between the node connecting and the batcher being ready. + //nolint:noinlineerr if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { m.errf(err, "failed to add node to batcher") log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session") @@ -288,8 +289,9 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression) } + //nolint:prealloc data := make([]byte, reservedResponseHeaderSize) - //nolint:gosec // G115: JSON response size will not exceed uint32 max + //nolint:gosec binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) data = append(data, jsonBody...) @@ -334,13 +336,13 @@ func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event { Str("node.name", m.node.Hostname) } -//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +//nolint:zerologlint func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) } -//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +//nolint:zerologlint func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) } -//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +//nolint:zerologlint func (m *mapSession) errf(err error, msg string, a ...any) { m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...) } diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 5d8d6e85..1c921d6d 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -55,8 +55,8 @@ var ( }) nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: prometheusNamespace, - Name: "nodestore_nodes_total", - Help: "Total number of nodes in the NodeStore", + Name: "nodestore_nodes", + Help: "Number of nodes in the NodeStore", }) nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ Namespace: prometheusNamespace, diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 2ce2aea8..522bb64e 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -107,6 +107,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 3) assert.Len(t, snapshot.allNodes, 3) assert.Len(t, snapshot.peersByNode, 3) @@ -136,6 +137,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, peersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 4) assert.Len(t, snapshot.allNodes, 4) assert.Len(t, snapshot.peersByNode, 4) @@ -252,6 +254,7 @@ func TestNodeStoreOperations(t *testing.T) { { name: "create empty store and add single node", setupFunc: func(t *testing.T) *NodeStore { + t.Helper() return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) }, steps: []testStep{ @@ -918,7 +921,7 @@ func TestNodeStoreConcurrentPutNode(t *testing.T) { go func(nodeID int) { defer wg.Done() - //nolint:gosec // G115: Test code with controlled values + //nolint:gosec node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") resultNode := store.PutNode(node) @@ -958,7 +961,7 @@ func TestNodeStoreBatchingEfficiency(t *testing.T) { go func(nodeID int) { defer wg.Done() - //nolint:gosec // G115: Test code with controlled values + //nolint:gosec node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") resultNode := store.PutNode(node) @@ -1067,7 +1070,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) { const ops = 100 for i := range ops { - nodeID := types.NodeID(i + 1) //nolint:gosec // G115: Test code with controlled values + nodeID := types.NodeID(i + 1) //nolint:gosec node := createConcurrentTestNode(nodeID, "cleanup-node") resultNode := store.PutNode(node) assert.True(t, resultNode.Valid()) @@ -1111,7 +1114,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { // Launch all PutNode operations concurrently for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) //nolint:gosec // G115: Test code with controlled values + nodeID := types.NodeID(i) //nolint:gosec wg.Add(1) @@ -1137,7 +1140,7 @@ func TestNodeStoreOperationTimeout(t *testing.T) { wg = sync.WaitGroup{} for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) //nolint:gosec // G115: Test code with controlled values + nodeID := types.NodeID(i) //nolint:gosec wg.Add(1) @@ -1202,7 +1205,7 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() - nonExistentID := types.NodeID(999 + i) //nolint:gosec // G115: Test code with controlled values + nonExistentID := types.NodeID(999 + i) //nolint:gosec updateCallCount := 0 fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) @@ -1226,7 +1229,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) { defer store.Stop() for i := 0; b.Loop(); i++ { - nodeID := types.NodeID(i + 1) //nolint:gosec // G115: Benchmark code with controlled values + nodeID := types.NodeID(i + 1) //nolint:gosec node := createConcurrentTestNode(nodeID, "bench-node") store.PutNode(node) store.UpdateNode(nodeID, func(n *types.Node) { diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 7ddcb005..fbb4c421 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -224,6 +224,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // propagate correctly when switching between policy types. s.nodeStore.RebuildPeerMaps() + //nolint:prealloc cs := []change.Change{change.PolicyChange()} // Always call autoApproveNodes during policy reload, regardless of whether @@ -254,6 +255,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // CreateUser creates a new user and updates the policy manager. // Returns the created user, change set, and any error. func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) { + //nolint:noinlineerr if err := s.db.DB.Save(&user).Error; err != nil { return nil, change.Change{}, fmt.Errorf("creating user: %w", err) } @@ -288,6 +290,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return nil, err } + //nolint:noinlineerr if err := updateFn(user); err != nil { return nil, err } @@ -492,7 +495,7 @@ func (s *State) Connect(id types.NodeID) []change.Change { // Disconnect marks a node as disconnected and updates its primary routes in the state. func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { - //nolint:staticcheck // SA4006: now is used in new(now) below + //nolint:staticcheck now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { @@ -817,6 +820,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha var updates []change.Change + //nolint:unqueryvet for _, node := range s.nodeStore.ListNodes().All() { if !node.Valid() { continue @@ -1697,7 +1701,7 @@ func (s *State) HandleNodeFromPreAuthKey( } } - return nil, nil + return nil, nil //nolint:nilnil }) if err != nil { return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err) diff --git a/hscontrol/templates/design.go b/hscontrol/templates/design.go index 615c0e41..2033f245 100644 --- a/hscontrol/templates/design.go +++ b/hscontrol/templates/design.go @@ -15,43 +15,43 @@ import ( // Material for MkDocs design system - exact values from official docs. const ( // Text colors - from --md-default-fg-color CSS variables. - colorTextPrimary = "#000000de" //nolint:unused // rgba(0,0,0,0.87) - Body text - colorTextSecondary = "#0000008a" //nolint:unused // rgba(0,0,0,0.54) - Headings (--md-default-fg-color--light) - colorTextTertiary = "#00000052" //nolint:unused // rgba(0,0,0,0.32) - Lighter text - colorTextLightest = "#00000012" //nolint:unused // rgba(0,0,0,0.07) - Lightest text + colorTextPrimary = "#000000de" //nolint:unused + colorTextSecondary = "#0000008a" //nolint:unused + colorTextTertiary = "#00000052" //nolint:unused + colorTextLightest = "#00000012" //nolint:unused // Code colors - from --md-code-* CSS variables. - colorCodeFg = "#36464e" //nolint:unused // Code text color (--md-code-fg-color) - colorCodeBg = "#f5f5f5" //nolint:unused // Code background (--md-code-bg-color) + colorCodeFg = "#36464e" //nolint:unused + colorCodeBg = "#f5f5f5" //nolint:unused // Border colors. - colorBorderLight = "#e5e7eb" //nolint:unused // Light borders - colorBorderMedium = "#d1d5db" //nolint:unused // Medium borders + colorBorderLight = "#e5e7eb" //nolint:unused + colorBorderMedium = "#d1d5db" //nolint:unused // Background colors. - colorBackgroundPage = "#ffffff" //nolint:unused // Page background - colorBackgroundCard = "#ffffff" //nolint:unused // Card/content background + colorBackgroundPage = "#ffffff" //nolint:unused + colorBackgroundCard = "#ffffff" //nolint:unused // Accent colors - from --md-primary/accent-fg-color. - colorPrimaryAccent = "#4051b5" //nolint:unused // Primary accent (links) - colorAccent = "#526cfe" //nolint:unused // Secondary accent + colorPrimaryAccent = "#4051b5" //nolint:unused + colorAccent = "#526cfe" //nolint:unused // Success colors. - colorSuccess = "#059669" //nolint:unused // Success states - colorSuccessLight = "#d1fae5" //nolint:unused // Success backgrounds + colorSuccess = "#059669" //nolint:unused + colorSuccessLight = "#d1fae5" //nolint:unused ) // Spacing System // Based on 4px/8px base unit for consistent rhythm. // Uses rem units for scalability with user font size preferences. const ( - spaceXS = "0.25rem" //nolint:unused // 4px - Tight spacing - spaceS = "0.5rem" //nolint:unused // 8px - Small spacing - spaceM = "1rem" //nolint:unused // 16px - Medium spacing (base) - spaceL = "1.5rem" //nolint:unused // 24px - Large spacing - spaceXL = "2rem" //nolint:unused // 32px - Extra large spacing - space2XL = "3rem" //nolint:unused // 48px - 2x extra large spacing - space3XL = "4rem" //nolint:unused // 64px - 3x extra large spacing + spaceXS = "0.25rem" //nolint:unused + spaceS = "0.5rem" //nolint:unused + spaceM = "1rem" //nolint:unused + spaceL = "1.5rem" //nolint:unused + spaceXL = "2rem" //nolint:unused + space2XL = "3rem" //nolint:unused + space3XL = "4rem" //nolint:unused ) // Typography System @@ -63,26 +63,26 @@ const ( fontFamilyCode = `"Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace` //nolint:unused // Font sizes - from .md-typeset CSS rules. - fontSizeBase = "0.8rem" //nolint:unused // 12.8px - Base text (.md-typeset) - fontSizeH1 = "2em" //nolint:unused // 2x base - Main headings - fontSizeH2 = "1.5625em" //nolint:unused // 1.5625x base - Section headings - fontSizeH3 = "1.25em" //nolint:unused // 1.25x base - Subsection headings - fontSizeSmall = "0.8em" //nolint:unused // 0.8x base - Small text - fontSizeCode = "0.85em" //nolint:unused // 0.85x base - Inline code + fontSizeBase = "0.8rem" //nolint:unused + fontSizeH1 = "2em" //nolint:unused + fontSizeH2 = "1.5625em" //nolint:unused + fontSizeH3 = "1.25em" //nolint:unused + fontSizeSmall = "0.8em" //nolint:unused + fontSizeCode = "0.85em" //nolint:unused // Line heights - from .md-typeset CSS rules. - lineHeightBase = "1.6" //nolint:unused // Body text (.md-typeset) - lineHeightH1 = "1.3" //nolint:unused // H1 headings - lineHeightH2 = "1.4" //nolint:unused // H2 headings - lineHeightH3 = "1.5" //nolint:unused // H3 headings - lineHeightCode = "1.4" //nolint:unused // Code blocks (pre) + lineHeightBase = "1.6" //nolint:unused + lineHeightH1 = "1.3" //nolint:unused + lineHeightH2 = "1.4" //nolint:unused + lineHeightH3 = "1.5" //nolint:unused + lineHeightCode = "1.4" //nolint:unused ) // Responsive Container Component // Creates a centered container with responsive padding and max-width. // Mobile-first approach: starts at 100% width with padding, constrains on larger screens. // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func responsiveContainer(children ...elem.Node) *elem.Element { return elem.Div(attrs.Props{ attrs.Style: styles.Props{ @@ -100,7 +100,7 @@ func responsiveContainer(children ...elem.Node) *elem.Element { // - title: Optional title for the card (empty string for no title) // - children: Content elements to display in the card // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func card(title string, children ...elem.Node) *elem.Element { cardContent := children if title != "" { @@ -134,7 +134,7 @@ func card(title string, children ...elem.Node) *elem.Element { // EXTRACTED FROM: .md-typeset pre CSS rules // Exact styling from Material for MkDocs documentation. // -//nolint:unused // Used across apple.go, windows.go, register_web.go templates. +//nolint:unused func codeBlock(code string) *elem.Element { return elem.Pre(attrs.Props{ attrs.Style: styles.Props{ @@ -164,7 +164,7 @@ func codeBlock(code string) *elem.Element { // Returns inline styles for the main content container that matches .md-typeset. // EXTRACTED FROM: .md-typeset CSS rule from Material for MkDocs. // -//nolint:unused // Used in general.go for mdTypesetBody. +//nolint:unused func baseTypesetStyles() styles.Props { return styles.Props{ styles.FontSize: fontSizeBase, // 0.8rem @@ -180,7 +180,7 @@ func baseTypesetStyles() styles.Props { // Returns inline styles for H1 headings that match .md-typeset h1. // EXTRACTED FROM: .md-typeset h1 CSS rule from Material for MkDocs. // -//nolint:unused // Used across templates for main headings. +//nolint:unused func h1Styles() styles.Props { return styles.Props{ styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54) @@ -198,7 +198,7 @@ func h1Styles() styles.Props { // Returns inline styles for H2 headings that match .md-typeset h2. // EXTRACTED FROM: .md-typeset h2 CSS rule from Material for MkDocs. // -//nolint:unused // Used across templates for section headings. +//nolint:unused func h2Styles() styles.Props { return styles.Props{ styles.FontSize: fontSizeH2, // 1.5625em @@ -216,7 +216,7 @@ func h2Styles() styles.Props { // Returns inline styles for H3 headings that match .md-typeset h3. // EXTRACTED FROM: .md-typeset h3 CSS rule from Material for MkDocs. // -//nolint:unused // Used across templates for subsection headings. +//nolint:unused func h3Styles() styles.Props { return styles.Props{ styles.FontSize: fontSizeH3, // 1.25em @@ -234,7 +234,7 @@ func h3Styles() styles.Props { // Returns inline styles for paragraphs that match .md-typeset p. // EXTRACTED FROM: .md-typeset p CSS rule from Material for MkDocs. // -//nolint:unused // Used for consistent paragraph spacing. +//nolint:unused func paragraphStyles() styles.Props { return styles.Props{ styles.Margin: "1em 0", @@ -250,7 +250,7 @@ func paragraphStyles() styles.Props { // Returns inline styles for ordered lists that match .md-typeset ol. // EXTRACTED FROM: .md-typeset ol CSS rule from Material for MkDocs. // -//nolint:unused // Used for numbered instruction lists. +//nolint:unused func orderedListStyles() styles.Props { return styles.Props{ styles.MarginBottom: "1em", @@ -268,7 +268,7 @@ func orderedListStyles() styles.Props { // Returns inline styles for unordered lists that match .md-typeset ul. // EXTRACTED FROM: .md-typeset ul CSS rule from Material for MkDocs. // -//nolint:unused // Used for bullet point lists. +//nolint:unused func unorderedListStyles() styles.Props { return styles.Props{ styles.MarginBottom: "1em", @@ -287,7 +287,7 @@ func unorderedListStyles() styles.Props { // EXTRACTED FROM: .md-typeset a CSS rule from Material for MkDocs. // Note: Hover states cannot be implemented with inline styles. // -//nolint:unused // Used for text links. +//nolint:unused func linkStyles() styles.Props { return styles.Props{ styles.Color: colorPrimaryAccent, // #4051b5 - var(--md-primary-fg-color) @@ -301,7 +301,7 @@ func linkStyles() styles.Props { // Returns inline styles for inline code that matches .md-typeset code. // EXTRACTED FROM: .md-typeset code CSS rule from Material for MkDocs. // -//nolint:unused // Used for inline code snippets. +//nolint:unused func inlineCodeStyles() styles.Props { return styles.Props{ styles.BackgroundColor: colorCodeBg, // #f5f5f5 @@ -317,7 +317,7 @@ func inlineCodeStyles() styles.Props { // Inline Code Component // For inline code snippets within text. // -//nolint:unused // Reserved for future inline code usage. +//nolint:unused func inlineCode(code string) *elem.Element { return elem.Code(attrs.Props{ attrs.Style: inlineCodeStyles().ToInline(), @@ -327,7 +327,7 @@ func inlineCode(code string) *elem.Element { // orDivider creates a visual "or" divider between sections. // Styled with lines on either side for better visual separation. // -//nolint:unused // Used in apple.go template. +//nolint:unused func orDivider() *elem.Element { return elem.Div(attrs.Props{ attrs.Style: styles.Props{ @@ -367,7 +367,7 @@ func orDivider() *elem.Element { // warningBox creates a warning message box with icon and content. // -//nolint:unused // Used in apple.go template. +//nolint:unused func warningBox(title, message string) *elem.Element { return elem.Div(attrs.Props{ attrs.Style: styles.Props{ @@ -404,7 +404,7 @@ func warningBox(title, message string) *elem.Element { // downloadButton creates a nice button-style link for downloads. // -//nolint:unused // Used in apple.go template. +//nolint:unused func downloadButton(href, text string) *elem.Element { return elem.A(attrs.Props{ attrs.Href: href, @@ -428,7 +428,7 @@ func downloadButton(href, text string) *elem.Element { // Creates a link with proper security attributes for external URLs. // Automatically adds rel="noreferrer noopener" and target="_blank". // -//nolint:unused // Used in apple.go, oidc_callback.go templates. +//nolint:unused func externalLink(href, text string) *elem.Element { return elem.A(attrs.Props{ attrs.Href: href, @@ -444,7 +444,7 @@ func externalLink(href, text string) *elem.Element { // Instruction Step Component // For numbered instruction lists with consistent formatting. // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func instructionStep(_ int, text string) *elem.Element { return elem.Li(attrs.Props{ attrs.Style: styles.Props{ @@ -457,7 +457,7 @@ func instructionStep(_ int, text string) *elem.Element { // Status Message Component // For displaying success/error/info messages with appropriate styling. // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func statusMessage(message string, isSuccess bool) *elem.Element { bgColor := colorSuccessLight textColor := colorSuccess diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index d57943f6..64410dd9 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -916,6 +916,7 @@ func LoadCLIConfig() (*Config, error) { // LoadServerConfig returns the full Headscale configuration to // host a Headscale server. This is called as part of `headscale serve`. func LoadServerConfig() (*Config, error) { + //nolint:noinlineerr if err := validateServerConfig(); err != nil { return nil, err } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 1ebc7033..d75da265 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -234,6 +234,7 @@ func (node *Node) RequestTags() []string { } func (node *Node) Prefixes() []netip.Prefix { + //nolint:prealloc var addrs []netip.Prefix for _, nodeAddress := range node.IPs() { ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) @@ -263,6 +264,7 @@ func (node *Node) IsExitNode() bool { } func (node *Node) IPsAsString() []string { + //nolint:prealloc var ret []string for _, ip := range node.IPs() { @@ -925,11 +927,11 @@ func (nv NodeView) TailscaleUserID() tailcfg.UserID { } if nv.IsTagged() { - //nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64 + //nolint:gosec return tailcfg.UserID(int64(TaggedDevices.ID)) } - //nolint:gosec // G115: UserID values are within int64 range + //nolint:gosec return tailcfg.UserID(int64(nv.UserID().Get())) } @@ -1054,7 +1056,7 @@ func (nv NodeView) TailNode( } tNode := tailcfg.Node{ - //nolint:gosec // G115: NodeID values are within int64 range + //nolint:gosec ID: tailcfg.NodeID(nv.ID()), StableID: nv.ID().StableID(), Name: hostname, diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 9518833f..5210e363 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -407,7 +407,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", @@ -491,7 +491,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "server-北京-01", + Hostname: "server-北京-01", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", @@ -505,7 +505,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", @@ -533,7 +533,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "测试💻机器", + Hostname: "测试💻机器", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 2bf30a0c..f1120929 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -155,7 +155,7 @@ func (u UserView) ID() uint { func (u *User) TailscaleLogin() tailcfg.Login { return tailcfg.Login{ - ID: tailcfg.LoginID(u.ID), //nolint:gosec // G115: User IDs are always positive and fit in int64 + ID: tailcfg.LoginID(u.ID), //nolint:gosec Provider: u.Provider, LoginName: u.Username(), DisplayName: u.Display(), @@ -277,8 +277,10 @@ func (c *OIDCClaims) Identifier() string { var result string // Try to parse as URL to handle URL joining correctly + //nolint:noinlineerr if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { // For URLs, use proper URL path joining + //nolint:noinlineerr if joined, err := url.JoinPath(issuer, subject); err == nil { result = joined } diff --git a/hscontrol/types/version.go b/hscontrol/types/version.go index 6676c92f..96dc58a6 100644 --- a/hscontrol/types/version.go +++ b/hscontrol/types/version.go @@ -38,9 +38,7 @@ func (v *VersionInfo) String() string { return sb.String() } -var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) { - return debug.ReadBuildInfo() -}) +var buildInfo = sync.OnceValues(debug.ReadBuildInfo) var GetVersionInfo = sync.OnceValue(func() *VersionInfo { info := &VersionInfo{ diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 5c9585be..7fa4b222 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -185,6 +185,7 @@ func ParseTraceroute(output string) (Traceroute, error) { firstSpace := strings.Index(remainder, " ") if firstSpace > 0 { firstPart := remainder[:firstSpace] + //nolint:noinlineerr if _, err := strconv.ParseFloat(strings.TrimPrefix(firstPart, "<"), 64); err == nil { latencyFirst = true } @@ -233,6 +234,7 @@ func ParseTraceroute(output string) (Traceroute, error) { parts := strings.Fields(remainder) if len(parts) > 0 { hopHostname = parts[0] + //nolint:noinlineerr if ip, err := netip.ParseAddr(parts[0]); err == nil { hopIP = ip } @@ -336,7 +338,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri // it's purely for observability and correlating log entries during the registration process. func GenerateRegistrationKey() (string, error) { const ( - registerKeyPrefix = "hskey-reg-" //nolint:gosec // This is a vanity key for logging, not a credential + registerKeyPrefix = "hskey-reg-" //nolint:gosec registerKeyLength = 64 ) diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index fc55328f..98692882 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -902,7 +902,7 @@ func TestEnsureHostname(t *testing.T) { { name: "hostname_with_unicode", hostinfo: &tailcfg.Hostinfo{ - Hostname: "node-ñoño-测试", + Hostname: "node-ñoño-测试", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -983,7 +983,7 @@ func TestEnsureHostname(t *testing.T) { { name: "chinese_chars_with_dash_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "server-北京-01", + Hostname: "server-北京-01", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -992,7 +992,7 @@ func TestEnsureHostname(t *testing.T) { { name: "chinese_only_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -1010,7 +1010,7 @@ func TestEnsureHostname(t *testing.T) { { name: "mixed_chinese_emoji_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "测试💻机器", + Hostname: "测试💻机器", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -1173,6 +1173,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { t.Fatal("hostinfo should not be nil") } + //nolint:goconst if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1283,6 +1284,8 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { for i, hostname := range testCases { t.Run(cmp.Diff("", ""), func(t *testing.T) { + t.Parallel() + hostinfo := &tailcfg.Hostinfo{Hostname: hostname} result := EnsureHostname(hostinfo, "mkey", "nkey") diff --git a/integration/acl_test.go b/integration/acl_test.go index 7a33240b..e87e0587 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -1876,7 +1876,7 @@ func TestACLAutogroupSelf(t *testing.T) { result, err := client.Curl(url) assert.Empty(t, result, "user1 should not be able to access user2's regular devices (autogroup:self isolation)") - assert.Error(t, err, "connection from user1 to user2 regular device should fail") + require.Error(t, err, "connection from user1 to user2 regular device should fail") } } @@ -1895,6 +1895,7 @@ func TestACLAutogroupSelf(t *testing.T) { } } +//nolint:gocyclo func TestACLPolicyPropagationOverTime(t *testing.T) { IntegrationSkip(t) diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 0fbec32f..f0218592 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -218,7 +218,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { var response v1.ListUsersResponse err = protojson.Unmarshal(body, &response) - assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key") + require.NoError(t, err, "Response should be valid protobuf JSON with valid API key") // Should contain our test users users := response.GetUsers() @@ -486,7 +486,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { ) // Should fail with authentication error - assert.Error(t, err, + require.Error(t, err, "gRPC connection with invalid API key should fail") // Should contain authentication error message @@ -515,7 +515,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { ) // Should succeed - assert.NoError(t, err, + require.NoError(t, err, "gRPC connection with valid API key should succeed, output: %s", output) // CLI outputs the users array directly, not wrapped in ListUsersResponse @@ -523,7 +523,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { var users []*v1.User err = json.Unmarshal([]byte(output), &users) - assert.NoError(t, err, "Response should be valid JSON array") + require.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") userNames := make([]string, len(users)) @@ -640,7 +640,7 @@ cli: ) // Should fail - assert.Error(t, err, + require.Error(t, err, "CLI with invalid API key should fail") // Should indicate authentication failure @@ -675,7 +675,7 @@ cli: ) // Should succeed - assert.NoError(t, err, + require.NoError(t, err, "CLI with valid API key should succeed") // CLI outputs the users array directly, not wrapped in ListUsersResponse @@ -683,7 +683,7 @@ cli: var users []*v1.User err = json.Unmarshal([]byte(output), &users) - assert.NoError(t, err, "Response should be valid JSON array") + require.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") userNames := make([]string, len(users)) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 7e7747e6..abf31fec 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -133,7 +133,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - //nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS + //nolint:forbidigo time.Sleep(5 * time.Minute) } @@ -453,7 +453,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - //nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS + //nolint:forbidigo time.Sleep(5 * time.Minute) } diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index bdd5bce2..076f6565 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -926,7 +926,7 @@ func TestOIDCFollowUpUrl(t *testing.T) { // wait for the registration cache to expire // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION (1m30s) - //nolint:forbidigo // Intentional delay: must wait for real-time cache expiration (HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION=1m30s) + //nolint:forbidigo time.Sleep(2 * time.Minute) var newUrl *url.URL diff --git a/integration/cli_test.go b/integration/cli_test.go index 1ca23f40..707c9992 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -203,7 +203,7 @@ func TestUserCommand(t *testing.T) { "--identifier=1", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterIDDelete []*v1.User @@ -245,7 +245,7 @@ func TestUserCommand(t *testing.T) { "--name=newname", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterNameDelete []v1.User @@ -571,8 +571,9 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { IntegrationSkip(t) + //nolint:goconst user1 := "user1" - user2 := "user2" + user2 := "user2" //nolint:goconst spec := ScenarioSpec{ NodesPerUser: 1, @@ -829,7 +830,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -907,7 +908,7 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - assert.NoError(t, err) + require.NoError(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } @@ -952,7 +953,7 @@ func TestApiKeyCommand(t *testing.T) { "--prefix", listedAPIKeys[0].GetPrefix(), }) - assert.NoError(t, err) + require.NoError(t, err) var listedAPIKeysAfterDelete []v1.ApiKey @@ -1071,7 +1072,7 @@ func TestNodeCommand(t *testing.T) { } nodes := make([]*v1.Node, len(regIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1089,7 +1090,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1156,7 +1157,7 @@ func TestNodeCommand(t *testing.T) { } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range otherUserRegIDs { _, err := headscale.Execute( @@ -1174,7 +1175,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1281,7 +1282,7 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - assert.NoError(t, err) + require.NoError(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node @@ -1348,7 +1349,7 @@ func TestNodeExpireCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1411,7 +1412,7 @@ func TestNodeExpireCommand(t *testing.T) { strconv.FormatUint(listAll[idx].GetId(), 10), }, ) - assert.NoError(t, err) + require.NoError(t, err) } var listAllAfterExpiry []v1.Node @@ -1549,7 +1550,7 @@ func TestNodeRenameCommand(t *testing.T) { fmt.Sprintf("newnode-%d", idx+1), }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, res, "Node renamed") } @@ -1590,7 +1591,7 @@ func TestNodeRenameCommand(t *testing.T) { strings.Repeat("t", 64), }, ) - assert.ErrorContains(t, err, "must not exceed 63 characters") + require.ErrorContains(t, err, "must not exceed 63 characters") var listAllAfterRenameAttempt []v1.Node @@ -1763,7 +1764,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, `invalid action "unknown-action"`) + require.ErrorContains(t, err, `invalid action "unknown-action"`) // The new policy was invalid, the old one should still be in place, which // is none. diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index c1c62f81..20ea930c 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -115,6 +115,7 @@ func DERPVerify( result = fmt.Errorf("client Connect: %w", err) } + //nolint:noinlineerr if m, err := c.Recv(); err != nil { result = fmt.Errorf("client first Recv: %w", err) } else if v, ok := m.(derp.ServerInfoMessage); !ok { diff --git a/integration/dns_test.go b/integration/dns_test.go index 08250e7b..0d3bce21 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -86,6 +86,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { const erPath = "/tmp/extra_records.json" + //nolint:prealloc extraRecords := []tailcfg.DNSRecord{ { Name: "test.myvpn.example.com", diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index 4a172471..0143ee53 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -38,6 +38,8 @@ type buffer struct { // Write appends the contents of p to the buffer, growing the buffer as needed. It returns // the number of bytes written. +// +//nolint:nonamedreturns func (b *buffer) Write(p []byte) (n int, err error) { b.mutex.Lock() defer b.mutex.Unlock() diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 89154f63..7164e113 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -99,7 +99,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) { // we *want* it to show up in stacktraces, // so marking it as a test helper would be counterproductive. // -//nolint:thelper + func derpServerScenario( t *testing.T, spec ScenarioSpec, @@ -179,7 +179,7 @@ func derpServerScenario( // Let the DERP updater run a couple of times to ensure it does not // break the DERPMap. The updater runs on a 10s interval by default. - //nolint:forbidigo // Intentional delay: must wait for DERP updater to run multiple times (interval-based) + //nolint:forbidigo time.Sleep(30 * time.Second) success = pingDerpAllHelper(t, allClients, allHostnames) diff --git a/integration/helpers.go b/integration/helpers.go index df89b1ea..7e5fcc2f 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -174,9 +174,10 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode startTime := time.Now() + //nolint:goconst stateStr := "offline" if expectedOnline { - stateStr = "online" + stateStr = "online" //nolint:goconst } t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) @@ -194,6 +195,8 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode } // requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state. +// +//nolint:gocyclo func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() @@ -548,6 +551,8 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // assertLastSeenSet validates that a node has a non-nil LastSeen timestamp. // Critical for ensuring node activity tracking is functioning properly. func assertLastSeenSet(t *testing.T, node *v1.Node) { + t.Helper() + assert.NotNil(t, node) assert.NotNil(t, node.GetLastSeen()) } @@ -566,7 +571,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { for _, client := range clients { status, err := client.Status() - assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) + assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) //nolint:testifylint assert.Equal(t, "NeedsLogin", status.BackendState, "client %s should be logged out", client.Hostname()) } @@ -765,7 +770,7 @@ func tagp(name string) policyv2.Alias { // prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. // Converts CIDR notation to policy prefix format for network range specifications. func prefixp(cidr string) policyv2.Alias { - //nolint:staticcheck // SA4006: prefix is used in new(policyv2.Prefix(prefix)) below + //nolint:staticcheck prefix := netip.MustParsePrefix(cidr) return new(policyv2.Prefix(prefix)) } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index d4dbb85b..cfe89dfc 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -781,13 +781,14 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { switch header.Typeflag { case tar.TypeDir: // Create directory - //nolint:gosec // G115: tar.Header.Mode is int64, safe to convert to uint32 for permissions + //nolint:gosec err := os.MkdirAll(targetPath, os.FileMode(header.Mode)) if err != nil { return fmt.Errorf("failed to create directory %s: %w", targetPath, err) } case tar.TypeReg: // Ensure parent directories exist + //nolint:noinlineerr if err := os.MkdirAll(filepath.Dir(targetPath), dirPermissions); err != nil { return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) } @@ -798,7 +799,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { return fmt.Errorf("failed to create file %s: %w", targetPath, err) } - //nolint:gosec // G110: Trusted tar archive from our own container + //nolint:gosec,noinlineerr if _, err := io.Copy(outFile, tarReader); err != nil { outFile.Close() return fmt.Errorf("failed to copy file contents: %w", err) @@ -807,7 +808,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { outFile.Close() // Set file permissions - //nolint:gosec // G115: tar.Header.Mode is int64, safe to convert to uint32 for permissions + //nolint:gosec,noinlineerr if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { return fmt.Errorf("failed to set file permissions: %w", err) } @@ -906,7 +907,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { return fmt.Errorf("failed to create database file: %w", err) } - //nolint:gosec // G110: Trusted tar archive from our own container + //nolint:gosec written, err := io.Copy(outFile, tarReader) outFile.Close() @@ -1593,6 +1594,7 @@ func (t *HeadscaleInContainer) GetAllMapReponses() (map[types.NodeID][]tailcfg.M } var res map[types.NodeID][]tailcfg.MapResponse + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &res); err != nil { return nil, fmt.Errorf("decoding routes response: %w", err) } @@ -1613,6 +1615,7 @@ func (t *HeadscaleInContainer) PrimaryRoutes() (*routes.DebugRoutes, error) { } var debugRoutes routes.DebugRoutes + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &debugRoutes); err != nil { return nil, fmt.Errorf("decoding routes response: %w", err) } @@ -1633,6 +1636,7 @@ func (t *HeadscaleInContainer) DebugBatcher() (*hscontrol.DebugBatcherInfo, erro } var debugInfo hscontrol.DebugBatcherInfo + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &debugInfo); err != nil { return nil, fmt.Errorf("decoding batcher debug response: %w", err) } @@ -1653,6 +1657,7 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er } var nodeStore map[types.NodeID]types.Node + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &nodeStore); err != nil { return nil, fmt.Errorf("decoding nodestore debug response: %w", err) } @@ -1673,6 +1678,7 @@ func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) { } var filterRules []tailcfg.FilterRule + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &filterRules); err != nil { return nil, fmt.Errorf("decoding filter response: %w", err) } diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 71dd8897..0563999e 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -220,19 +220,19 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type for _, mr := range mrs { for _, peer := range mr.Peers { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec // G115: tailcfg.NodeID is int64, safe for test code + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec } } for _, peer := range mr.PeersChanged { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec // G115: tailcfg.NodeID is int64, safe for test code + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec } } for _, peer := range mr.PeersChangedPatch { if peer.Online != nil { - res[nid][types.NodeID(peer.NodeID)] = *peer.Online //nolint:gosec // G115: tailcfg.NodeID is int64, safe for test code + res[nid][types.NodeID(peer.NodeID)] = *peer.Online //nolint:gosec } } } diff --git a/integration/route_test.go b/integration/route_test.go index 3d24da99..ea1cf3b8 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1610,7 +1610,7 @@ func TestSubnetRouteACL(t *testing.T) { func TestEnablingExitRoutes(t *testing.T) { IntegrationSkip(t) - user := "user2" + user := "user2" //nolint:goconst spec := ScenarioSpec{ NodesPerUser: 2, @@ -2040,6 +2040,8 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { // - Verify that peers can no longer use node // - Policy is changed back to auto approve route, check that routes already existing is approved. // - Verify that routes can now be seen by peers. +// +//nolint:gocyclo func TestAutoApproveMultiNetwork(t *testing.T) { IntegrationSkip(t) @@ -2887,7 +2889,8 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } // assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT. -//nolint:testifylint // CollectT requires assert, not require +// +//nolint:testifylint func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { assert.NotNil(c, tr) assert.True(c, tr.Success) @@ -2905,6 +2908,8 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int { } func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) { + t.Helper() + t.Logf("== Current routing map ==") slices.SortFunc(routers, SortPeerStatus) diff --git a/integration/scenario.go b/integration/scenario.go index bf3f4096..743c5830 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -323,6 +323,8 @@ func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { } func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { + t.Helper() + defer func() { _ = dockertestutil.CleanUnreferencedNetworks(s.pool) }() defer func() { _ = dockertestutil.CleanImagesInCI(s.pool) }() @@ -1172,6 +1174,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { key = strings.SplitN(key, " ", expectedHTMLSplitParts)[0] log.Printf("registering node %s", key) + //nolint:noinlineerr if headscale, err := s.Headscale(); err == nil { _, err = headscale.Execute( []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, @@ -1449,6 +1452,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") + //nolint:noinlineerr if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( headscaleBuildOptions, mockOidcOptions, @@ -1468,6 +1472,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) + //nolint:noinlineerr if err := s.pool.Retry(func() error { oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) httpClient := &http.Client{} diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 04365eae..8329155f 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -452,7 +452,7 @@ func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) { t.Helper() - assert.Error(t, err) + require.Error(t, err) if !isSSHNoAccessStdError(stderr) { t.Errorf("expected stderr output suggesting access denied, got: %s", stderr) diff --git a/integration/tags_test.go b/integration/tags_test.go index 91c771c4..9c8391c2 100644 --- a/integration/tags_test.go +++ b/integration/tags_test.go @@ -85,7 +85,7 @@ func assertNodeHasNoTagsWithCollect(c *assert.CollectT, node *v1.Node) { // This validates that tag updates have propagated to the node's own status (issue #2978). func assertNodeSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient, expectedTags []string) { status, err := client.Status() - //nolint:testifylint // must use assert with CollectT in EventuallyWithT + //nolint:testifylint assert.NoError(c, err, "failed to get client status") if status == nil || status.Self == nil { @@ -556,7 +556,7 @@ func TestTagsAuthKeyWithTagAdminOverrideReauthPreserves(t *testing.T) { "--authkey=" + authKey.GetKey(), "--force-reauth", } - //nolint:errcheck // Intentionally ignoring error - we check results below + //nolint:errcheck client.Execute(command) // Verify admin tags are preserved even after reauth - admin decisions are authoritative (server-side) @@ -2490,7 +2490,7 @@ func TestTagsAdminAPICannotRemoveAllTags(t *testing.T) { // This validates at a deeper level than status - directly from tailscale debug netmap. func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient, expectedTags []string) { nm, err := client.Netmap() - //nolint:testifylint // must use assert with CollectT in EventuallyWithT + //nolint:testifylint assert.NoError(c, err, "failed to get client netmap") if nm == nil { @@ -2501,6 +2501,7 @@ func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClie var actualTagsSlice []string if nm.SelfNode.Valid() { + //nolint:unqueryvet for _, tag := range nm.SelfNode.Tags().All() { actualTagsSlice = append(actualTagsSlice, tag) } @@ -2623,7 +2624,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { // We wait 10 seconds and check - if the client STILL shows the OLD tag, // that demonstrates the bug. If the client shows the NEW tag, the bug is fixed. t.Log("Step 2b: Waiting 10 seconds to see if client self view updates (bug: it should NOT)") - //nolint:forbidigo // intentional sleep to demonstrate bug timing - client should get update immediately, not after waiting + //nolint:forbidigo time.Sleep(10 * time.Second) // Check client status after waiting @@ -2646,6 +2647,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { var netmapTagsAfterFirstCall []string if nmErr == nil && nm != nil && nm.SelfNode.Valid() { + //nolint:unqueryvet for _, tag := range nm.SelfNode.Tags().All() { netmapTagsAfterFirstCall = append(netmapTagsAfterFirstCall, tag) } @@ -2692,7 +2694,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { // Wait and check - bug means client still shows old tag t.Log("Step 4b: Waiting 10 seconds to see if client self view updates (bug: it should NOT)") - //nolint:forbidigo // intentional sleep to demonstrate bug timing - client should get update immediately, not after waiting + //nolint:forbidigo time.Sleep(10 * time.Second) status, err = client.Status() diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 9b103e53..6c7228a0 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -1609,6 +1609,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { } store := &mem.Store{} + //nolint:noinlineerr if err = store.LoadFromJSON(state); err != nil { return nil, fmt.Errorf("failed to unmarshal state file: %w", err) } @@ -1624,6 +1625,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { } p := &ipn.Prefs{} + //nolint:noinlineerr if err = json.Unmarshal(currentProfile, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err) } From 35364bfc9a7e4ffdadb8271fc3f3d8ec527e603c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 21 Jan 2026 16:21:44 +0000 Subject: [PATCH 30/30] nix: update to Go 1.26 and fix deprecations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Update nixpkgs to master branch to get Go 1.26rc2. This is needed for the new Go 1.26 features used in the codebase. Changes: - Use buildGo126Module and go_1_26 in build dependencies - Update vendorHash for new dependencies - Fix deprecated flake outputs: overlay→overlays.default, devShell→devShells.default, defaultPackage→packages.default - Fix stdenv.isLinux→stdenv.hostPlatform.isLinux deprecation --- flake.lock | 8 ++++---- flake.nix | 19 ++++++++++--------- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/flake.lock b/flake.lock index 50a7dde2..29f6b326 100644 --- a/flake.lock +++ b/flake.lock @@ -20,16 +20,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1766840161, - "narHash": "sha256-Ss/LHpJJsng8vz1Pe33RSGIWUOcqM1fjrehjUkdrWio=", + "lastModified": 1769011238, + "narHash": "sha256-WPiOcgZv7GQ/AVd9giOrlZjzXHwBNM4yQ+JzLrgI3Xk=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "3edc4a30ed3903fdf6f90c837f961fa6b49582d1", + "rev": "a895ec2c048eba3bceab06d5dfee5026a6b1c875", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixpkgs-unstable", + "ref": "master", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index 48aa075c..8aa8d32b 100644 --- a/flake.nix +++ b/flake.nix @@ -2,7 +2,8 @@ description = "headscale - Open Source Tailscale Control server"; inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + # TODO: Move back to nixpkgs-unstable once Go 1.26 is available there + nixpkgs.url = "github:NixOS/nixpkgs/master"; flake-utils.url = "github:numtide/flake-utils"; }; @@ -23,11 +24,11 @@ default = headscale; }; - overlay = _: prev: + overlays.default = _: prev: let pkgs = nixpkgs.legacyPackages.${prev.system}; - buildGo = pkgs.buildGo125Module; - vendorHash = "sha256-escboufgbk+lEitw48eWEIltXbaCPdysb/g4YR+extg="; + buildGo = pkgs.buildGo126Module; + vendorHash = "sha256-hL9vHunaxodGt3g/CIVirXy4OjZKTI3XwbVPPRb34OY="; in { headscale = buildGo { @@ -129,10 +130,10 @@ (system: let pkgs = import nixpkgs { - overlays = [ self.overlay ]; + overlays = [ self.overlays.default ]; inherit system; }; - buildDeps = with pkgs; [ git go_1_25 gnumake ]; + buildDeps = with pkgs; [ git go_1_26 gnumake ]; devDeps = with pkgs; buildDeps ++ [ @@ -167,7 +168,7 @@ clang-tools # clang-format protobuf-language-server ] - ++ lib.optional pkgs.stdenv.isLinux [ traceroute ]; + ++ lib.optional pkgs.stdenv.hostPlatform.isLinux [ traceroute ]; # Add entry to build a docker image with headscale # caveat: only works on Linux @@ -184,7 +185,7 @@ in rec { # `nix develop` - devShell = pkgs.mkShell { + devShells.default = pkgs.mkShell { buildInputs = devDeps ++ [ @@ -219,8 +220,8 @@ packages = with pkgs; { inherit headscale; inherit headscale-docker; + default = headscale; }; - defaultPackage = pkgs.headscale; # `nix run` apps.headscale = flake-utils.lib.mkApp {