diff --git a/.golangci.yaml b/.golangci.yaml index eda3bed4..a8a219d7 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -25,6 +25,7 @@ linters: - revive - tagliatelle - testpackage + - thelper - varnamelen - wrapcheck - wsl diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 9969f7c6..d1374ec5 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -2,7 +2,6 @@ package cli import ( "encoding/json" - "errors" "fmt" "net" "net/http" @@ -19,6 +18,7 @@ const ( errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined") errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined") errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined") + errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined") refreshTTL = 60 * time.Minute ) @@ -69,10 +69,11 @@ func mockOIDC() error { userStr := os.Getenv("MOCKOIDC_USERS") if userStr == "" { - return errors.New("MOCKOIDC_USERS not defined") + return errMockOidcUsersNotDefined } var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) if err != nil { return fmt.Errorf("unmarshalling users: %w", err) @@ -133,10 +134,11 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser ErrorQueue: &mockoidc.ErrorQueue{}, } - mock.AddMiddleware(func(h http.Handler) http.Handler { + _ = mock.AddMiddleware(func(h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Info().Msgf("Request: %+v", r) h.ServeHTTP(w, r) + if r.Response != nil { log.Info().Msgf("Response: %+v", r.Response) } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index 882460dd..01f20ad0 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -72,12 +72,12 @@ func init() { nodeCmd.AddCommand(deleteNodeCmd) tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - tagCmd.MarkFlagRequired("identifier") + _ = tagCmd.MarkFlagRequired("identifier") tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") nodeCmd.AddCommand(tagCmd) approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - approveRoutesCmd.MarkFlagRequired("identifier") + _ = approveRoutesCmd.MarkFlagRequired("identifier") approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) nodeCmd.AddCommand(approveRoutesCmd) @@ -233,10 +233,7 @@ var listNodeRoutesCmd = &cobra.Command{ return } - tableData, err := nodeRoutesToPtables(nodes) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - } + tableData := nodeRoutesToPtables(nodes) err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() if err != nil { @@ -601,9 +598,7 @@ func nodesToPtables( return tableData, nil } -func nodeRoutesToPtables( - nodes []*v1.Node, -) (pterm.TableData, error) { +func nodeRoutesToPtables(nodes []*v1.Node) pterm.TableData { tableHeader := []string{ "ID", "Hostname", @@ -627,7 +622,7 @@ func nodeRoutesToPtables( ) } - return tableData, nil + return tableData } var tagCmd = &cobra.Command{ diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 2aaebcfa..f31d573a 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -16,6 +16,7 @@ import ( ) const ( + //nolint:gosec bypassFlag = "bypass-grpc-and-access-database-directly" ) @@ -29,13 +30,17 @@ func init() { if err := setPolicy.MarkFlagRequired("file"); err != nil { log.Fatal().Err(err).Msg("") } + setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") policyCmd.AddCommand(setPolicy) checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") - if err := checkPolicy.MarkFlagRequired("file"); err != nil { + + err := checkPolicy.MarkFlagRequired("file") + if err != nil { log.Fatal().Err(err).Msg("") } + policyCmd.AddCommand(checkPolicy) } @@ -173,6 +178,7 @@ var setPolicy = &cobra.Command{ defer cancel() defer conn.Close() + //nolint:noinlineerr if _, err := client.SetPolicy(ctx, request); err != nil { ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index d7cdabb6..d67c2df8 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -80,6 +80,7 @@ func initConfig() { Repository: "headscale", TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }), } + res, err := latest.Check(githubTag, versionInfo.Version) if err == nil && res.Outdated { //nolint @@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool { return true } } + return false } diff --git a/cmd/headscale/cli/serve.go b/cmd/headscale/cli/serve.go index 8f05f851..f815f9f9 100644 --- a/cmd/headscale/cli/serve.go +++ b/cmd/headscale/cli/serve.go @@ -23,8 +23,7 @@ var serveCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { app, err := newHeadscaleServerWithConfig() if err != nil { - var squibbleErr squibble.ValidationError - if errors.As(err, &squibbleErr) { + if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok { fmt.Printf("SQLite schema failed to validate:\n") fmt.Println(squibbleErr.Diff) } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 9a816c78..f7db7ed4 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -14,6 +14,12 @@ import ( "google.golang.org/grpc/status" ) +// Sentinel errors for CLI commands. +var ( + ErrNameOrIDRequired = errors.New("--name or --identifier flag is required") + ErrMultipleUsersFoundUseID = errors.New("unable to determine user, query returned multiple users, use ID") +) + func usernameAndIDFlag(cmd *cobra.Command) { cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)") cmd.Flags().StringP("name", "n", "", "Username") @@ -23,12 +29,12 @@ func usernameAndIDFlag(cmd *cobra.Command) { // If both are empty, it will exit the program with an error. func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { username, _ := cmd.Flags().GetString("name") + identifier, _ := cmd.Flags().GetInt64("identifier") if username == "" && identifier < 0 { - err := errors.New("--name or --identifier flag is required") ErrorOutput( - err, - "Cannot rename user: "+status.Convert(err).Message(), + ErrNameOrIDRequired, + "Cannot rename user: "+status.Convert(ErrNameOrIDRequired).Message(), "", ) } @@ -50,7 +56,7 @@ func init() { userCmd.AddCommand(renameUserCmd) usernameAndIDFlag(renameUserCmd) renameUserCmd.Flags().StringP("new-name", "r", "", "New username") - renameNodeCmd.MarkFlagRequired("new-name") + _ = renameUserCmd.MarkFlagRequired("new-name") } var errMissingParameter = errors.New("missing parameters") @@ -94,6 +100,7 @@ var createUserCmd = &cobra.Command{ } if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" { + //nolint:noinlineerr if _, err := url.Parse(pictureURL); err != nil { ErrorOutput( err, @@ -148,7 +155,7 @@ var destroyUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + err := ErrMultipleUsersFoundUseID ErrorOutput( err, "Error: "+status.Convert(err).Message(), @@ -276,7 +283,7 @@ var renameUserCmd = &cobra.Command{ } if len(users.GetUsers()) != 1 { - err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + err := ErrMultipleUsersFoundUseID ErrorOutput( err, "Error: "+status.Convert(err).Message(), diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 7c5b5214..d56bb589 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -10,11 +10,11 @@ import ( "time" "github.com/cenkalti/backoff/v5" + cerrdefs "github.com/containerd/errdefs" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/api/types/image" "github.com/docker/docker/client" - "github.com/docker/docker/errdefs" ) // cleanupBeforeTest performs cleanup operations before running tests. @@ -25,6 +25,7 @@ func cleanupBeforeTest(ctx context.Context) error { return fmt.Errorf("failed to clean stale test containers: %w", err) } + //nolint:noinlineerr if err := pruneDockerNetworks(ctx); err != nil { return fmt.Errorf("failed to prune networks: %w", err) } @@ -55,6 +56,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI // killTestContainers terminates and removes all test containers. func killTestContainers(ctx context.Context) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -69,8 +71,10 @@ func killTestContainers(ctx context.Context) error { } removed := 0 + for _, cont := range containers { shouldRemove := false + for _, name := range cont.Names { if strings.Contains(name, "headscale-test-suite") || strings.Contains(name, "hs-") || @@ -107,6 +111,7 @@ func killTestContainers(ctx context.Context) error { // This function filters containers by the hi.run-id label to only affect containers // belonging to the specified test run, leaving other concurrent test runs untouched. func killTestContainersByRunID(ctx context.Context, runID string) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -149,6 +154,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error { // This is useful for cleaning up leftover containers from previous crashed or interrupted test runs // without interfering with currently running concurrent tests. func cleanupStaleTestContainers(ctx context.Context) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -223,6 +229,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container // pruneDockerNetworks removes unused Docker networks. func pruneDockerNetworks(ctx context.Context) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -245,6 +252,7 @@ func pruneDockerNetworks(ctx context.Context) error { // cleanOldImages removes test-related and old dangling Docker images. func cleanOldImages(ctx context.Context) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -259,8 +267,10 @@ func cleanOldImages(ctx context.Context) error { } removed := 0 + for _, img := range images { shouldRemove := false + for _, tag := range img.RepoTags { if strings.Contains(tag, "hs-") || strings.Contains(tag, "headscale-integration") || @@ -295,6 +305,7 @@ func cleanOldImages(ctx context.Context) error { // cleanCacheVolume removes the Docker volume used for Go module cache. func cleanCacheVolume(ctx context.Context) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -302,11 +313,12 @@ func cleanCacheVolume(ctx context.Context) error { defer cli.Close() volumeName := "hs-integration-go-cache" + err = cli.VolumeRemove(ctx, volumeName, true) if err != nil { - if errdefs.IsNotFound(err) { + if cerrdefs.IsNotFound(err) { fmt.Printf("Go module cache volume not found: %s\n", volumeName) - } else if errdefs.IsConflict(err) { + } else if cerrdefs.IsConflict(err) { fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName) } else { fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err) diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index a6b94b25..62b07f2f 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -14,6 +14,7 @@ import ( "strings" "time" + cerrdefs "github.com/containerd/errdefs" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/image" "github.com/docker/docker/api/types/mount" @@ -26,10 +27,21 @@ var ( ErrTestFailed = errors.New("test failed") ErrUnexpectedContainerWait = errors.New("unexpected end of container wait") ErrNoDockerContext = errors.New("no docker context found") + ErrMemoryLimitExceeded = errors.New("container exceeded memory limits") +) + +// Docker container constants. +const ( + containerFinalStateWait = 10 * time.Second + containerStateCheckInterval = 500 * time.Millisecond + dirPermissions = 0o755 ) // runTestContainer executes integration tests in a Docker container. +// +//nolint:gocyclo func runTestContainer(ctx context.Context, config *RunConfig) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -52,6 +64,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } const dirPerm = 0o755 + //nolint:noinlineerr if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { return fmt.Errorf("failed to create logs directory: %w", err) } @@ -60,7 +73,9 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if config.Verbose { log.Printf("Running pre-test cleanup...") } - if err := cleanupBeforeTest(ctx); err != nil && config.Verbose { + + err := cleanupBeforeTest(ctx) + if err != nil && config.Verbose { log.Printf("Warning: pre-test cleanup failed: %v", err) } } @@ -71,6 +86,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { } imageName := "golang:" + config.GoVersion + //nolint:noinlineerr if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { return fmt.Errorf("failed to ensure image availability: %w", err) } @@ -84,6 +100,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { log.Printf("Created container: %s", resp.ID) } + //nolint:noinlineerr if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { return fmt.Errorf("failed to start container: %w", err) } @@ -95,13 +112,17 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection for container resource monitoring (if enabled) var statsCollector *StatsCollector + if config.Stats { var err error + + //nolint:contextcheck statsCollector, err = NewStatsCollector() if err != nil { if config.Verbose { log.Printf("Warning: failed to create stats collector: %v", err) } + statsCollector = nil } @@ -110,7 +131,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection immediately - no need for complex retry logic // The new implementation monitors Docker events and will catch containers as they start - if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil { + err := statsCollector.StartCollection(ctx, runID, config.Verbose) + if err != nil { if config.Verbose { log.Printf("Warning: failed to start stats collection: %v", err) } @@ -122,11 +144,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { exitCode, err := streamAndWait(ctx, cli, resp.ID) // Ensure all containers have finished and logs are flushed before extracting artifacts - if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose { + waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose) + if waitErr != nil && config.Verbose { log.Printf("Warning: failed to wait for container finalization: %v", waitErr) } // Extract artifacts from test containers before cleanup + //nolint:noinlineerr if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { log.Printf("Warning: failed to extract artifacts from containers: %v", err) } @@ -140,12 +164,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if len(violations) > 0 { log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") log.Printf("=================================") + for _, violation := range violations { log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) } - return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) + return fmt.Errorf("%w: %d container(s)", ErrMemoryLimitExceeded, len(violations)) } } @@ -344,9 +369,10 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC testContainers := getCurrentTestContainers(containers, testContainerID, verbose) // Wait for all test containers to reach a final state - maxWaitTime := 10 * time.Second - checkInterval := 500 * time.Millisecond + maxWaitTime := containerFinalStateWait + checkInterval := containerStateCheckInterval timeout := time.After(maxWaitTime) + ticker := time.NewTicker(checkInterval) defer ticker.Stop() @@ -356,6 +382,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction") } + return nil case <-ticker.C: allFinalized := true @@ -366,12 +393,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err) } + continue } // Check if container is in a final state if !isContainerFinalized(inspect.State) { allFinalized = false + if verbose { log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) } @@ -384,6 +413,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("All test containers finalized, ready for artifact extraction") } + return nil } } @@ -400,13 +430,16 @@ func isContainerFinalized(state *container.State) bool { func findProjectRoot(startPath string) string { current := startPath for { + //nolint:noinlineerr if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { return current } + parent := filepath.Dir(current) if parent == current { return startPath } + current = parent } } @@ -416,6 +449,7 @@ func boolToInt(b bool) int { if b { return 1 } + return 0 } @@ -435,6 +469,7 @@ func createDockerClient() (*client.Client, error) { } var clientOpts []client.Opt + clientOpts = append(clientOpts, client.WithAPIVersionNegotiation()) if contextInfo != nil { @@ -444,6 +479,7 @@ func createDockerClient() (*client.Client, error) { if runConfig.Verbose { log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) } + clientOpts = append(clientOpts, client.WithHost(host)) } } @@ -459,13 +495,15 @@ func createDockerClient() (*client.Client, error) { // getCurrentDockerContext retrieves the current Docker context information. func getCurrentDockerContext() (*DockerContext, error) { - cmd := exec.Command("docker", "context", "inspect") + cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect") + output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("failed to get docker context: %w", err) } var contexts []DockerContext + //nolint:noinlineerr if err := json.Unmarshal(output, &contexts); err != nil { return nil, fmt.Errorf("failed to parse docker context: %w", err) } @@ -486,11 +524,12 @@ func getDockerSocketPath() string { // checkImageAvailableLocally checks if the specified Docker image is available locally. func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) { - _, _, err := cli.ImageInspectWithRaw(ctx, imageName) + _, err := cli.ImageInspect(ctx, imageName) if err != nil { - if client.IsErrNotFound(err) { + if cerrdefs.IsNotFound(err) { return false, nil } + return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err) } @@ -509,6 +548,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if verbose { log.Printf("Image %s is available locally", imageName) } + return nil } @@ -533,6 +573,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if err != nil { return fmt.Errorf("failed to read pull output: %w", err) } + log.Printf("Image %s pulled successfully", imageName) } @@ -547,9 +588,11 @@ func listControlFiles(logsDir string) { return } - var logFiles []string - var dataFiles []string - var dataDirs []string + var ( + logFiles []string + dataFiles []string + dataDirs []string + ) for _, entry := range entries { name := entry.Name() @@ -578,6 +621,7 @@ func listControlFiles(logsDir string) { if len(logFiles) > 0 { log.Printf("Headscale logs:") + for _, file := range logFiles { log.Printf(" %s", file) } @@ -585,9 +629,11 @@ func listControlFiles(logsDir string) { if len(dataFiles) > 0 || len(dataDirs) > 0 { log.Printf("Headscale data:") + for _, file := range dataFiles { log.Printf(" %s", file) } + for _, dir := range dataDirs { log.Printf(" %s/", dir) } @@ -596,6 +642,7 @@ func listControlFiles(logsDir string) { // extractArtifactsFromContainers collects container logs and files from the specific test run. func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return fmt.Errorf("failed to create Docker client: %w", err) @@ -612,9 +659,11 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose) extractedCount := 0 + for _, cont := range currentTestContainers { // Extract container logs and tar files - if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { + err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose) + if err != nil { if verbose { log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err) } @@ -622,6 +671,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi if verbose { log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12]) } + extractedCount++ } } @@ -645,11 +695,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // Find the test container to get its run ID label var runID string + for _, cont := range containers { if cont.ID == testContainerID { if cont.Labels != nil { runID = cont.Labels["hi.run-id"] } + break } } @@ -690,18 +742,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // extractContainerArtifacts saves logs and tar files from a container. func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { // Ensure the logs directory exists - if err := os.MkdirAll(logsDir, 0o755); err != nil { + err := os.MkdirAll(logsDir, dirPermissions) + if err != nil { return fmt.Errorf("failed to create logs directory: %w", err) } // Extract container logs + //nolint:noinlineerr if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { return fmt.Errorf("failed to extract logs: %w", err) } // Extract tar files for headscale containers only if strings.HasPrefix(containerName, "hs-") { - if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { + err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose) + if err != nil { if verbose { log.Printf("Warning: failed to extract files from %s: %v", containerName, err) } @@ -741,11 +796,13 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID, } // Write stdout logs + //nolint:gosec,mnd,noinlineerr if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { return fmt.Errorf("failed to write stdout log: %w", err) } // Write stderr logs + //nolint:gosec,mnd,noinlineerr if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { return fmt.Errorf("failed to write stderr log: %w", err) } diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 8af6051f..0c3a4764 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -38,12 +38,15 @@ func runDoctorCheck(ctx context.Context) error { } // Check 3: Go installation + //nolint:contextcheck results = append(results, checkGoInstallation()) // Check 4: Git repository + //nolint:contextcheck results = append(results, checkGitRepository()) // Check 5: Required files + //nolint:contextcheck results = append(results, checkRequiredFiles()) // Display results @@ -86,6 +89,7 @@ func checkDockerBinary() DoctorResult { // checkDockerDaemon verifies Docker daemon is running and accessible. func checkDockerDaemon(ctx context.Context) DoctorResult { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -125,6 +129,7 @@ func checkDockerDaemon(ctx context.Context) DoctorResult { // checkDockerContext verifies Docker context configuration. func checkDockerContext(_ context.Context) DoctorResult { + //nolint:contextcheck contextInfo, err := getCurrentDockerContext() if err != nil { return DoctorResult{ @@ -155,6 +160,7 @@ func checkDockerContext(_ context.Context) DoctorResult { // checkDockerSocket verifies Docker socket accessibility. func checkDockerSocket(ctx context.Context) DoctorResult { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -192,6 +198,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult { // checkGolangImage verifies the golang Docker image is available locally or can be pulled. func checkGolangImage(ctx context.Context) DoctorResult { + //nolint:contextcheck cli, err := createDockerClient() if err != nil { return DoctorResult{ @@ -265,7 +272,8 @@ func checkGoInstallation() DoctorResult { } } - cmd := exec.Command("go", "version") + cmd := exec.CommandContext(context.Background(), "go", "version") + output, err := cmd.Output() if err != nil { return DoctorResult{ @@ -286,7 +294,8 @@ func checkGoInstallation() DoctorResult { // checkGitRepository verifies we're in a git repository. func checkGitRepository() DoctorResult { - cmd := exec.Command("git", "rev-parse", "--git-dir") + cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir") + err := cmd.Run() if err != nil { return DoctorResult{ @@ -316,9 +325,12 @@ func checkRequiredFiles() DoctorResult { } var missingFiles []string + for _, file := range requiredFiles { - cmd := exec.Command("test", "-e", file) - if err := cmd.Run(); err != nil { + cmd := exec.CommandContext(context.Background(), "test", "-e", file) + + err := cmd.Run() + if err != nil { missingFiles = append(missingFiles, file) } } @@ -350,6 +362,7 @@ func displayDoctorResults(results []DoctorResult) { for _, result := range results { var icon string + switch result.Status { case "PASS": icon = "✅" diff --git a/cmd/hi/main.go b/cmd/hi/main.go index baecc6f3..2bbfefe0 100644 --- a/cmd/hi/main.go +++ b/cmd/hi/main.go @@ -79,13 +79,18 @@ func main() { } func cleanAll(ctx context.Context) error { - if err := killTestContainers(ctx); err != nil { + err := killTestContainers(ctx) + if err != nil { return err } - if err := pruneDockerNetworks(ctx); err != nil { + + err = pruneDockerNetworks(ctx) + if err != nil { return err } - if err := cleanOldImages(ctx); err != nil { + + err = cleanOldImages(ctx) + if err != nil { return err } diff --git a/cmd/hi/run.go b/cmd/hi/run.go index 1694399d..4a0506e5 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error { if runConfig.Verbose { log.Printf("Running pre-flight system checks...") } - if err := runDoctorCheck(env.Context()); err != nil { + + err := runDoctorCheck(env.Context()) + if err != nil { return fmt.Errorf("pre-flight checks failed: %w", err) } @@ -66,8 +68,10 @@ func runIntegrationTest(env *command.Env) error { func detectGoVersion() string { goModPath := filepath.Join("..", "..", "go.mod") + //nolint:noinlineerr if _, err := os.Stat("go.mod"); err == nil { goModPath = "go.mod" + //nolint:noinlineerr } else if _, err := os.Stat("../../go.mod"); err == nil { goModPath = "../../go.mod" } @@ -94,8 +98,10 @@ func detectGoVersion() string { // splitLines splits a string into lines without using strings.Split. func splitLines(s string) []string { - var lines []string - var current string + var ( + lines []string + current string + ) for _, char := range s { if char == '\n' { diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index b68215a6..dc02286b 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -1,23 +1,32 @@ package main import ( + "cmp" "context" "encoding/json" "errors" "fmt" "log" - "sort" + "slices" "strings" "sync" "time" - "github.com/docker/docker/api/types" "github.com/docker/docker/api/types/container" "github.com/docker/docker/api/types/events" "github.com/docker/docker/api/types/filters" "github.com/docker/docker/client" ) +// ErrStatsCollectionAlreadyStarted is returned when stats collection is already running. +var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started") + +// Stats calculation constants. +const ( + bytesPerKB = 1024 + percentageMultiplier = 100.0 +) + // ContainerStats represents statistics for a single container. type ContainerStats struct { ContainerID string @@ -63,17 +72,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver defer sc.mutex.Unlock() if sc.collectionStarted { - return errors.New("stats collection already started") + return ErrStatsCollectionAlreadyStarted } sc.collectionStarted = true // Start monitoring existing containers sc.wg.Add(1) + go sc.monitorExistingContainers(ctx, runID, verbose) // Start Docker events monitoring for new containers sc.wg.Add(1) + go sc.monitorDockerEvents(ctx, runID, verbose) if verbose { @@ -87,10 +98,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver func (sc *StatsCollector) StopCollection() { // Check if already stopped without holding lock sc.mutex.RLock() + if !sc.collectionStarted { sc.mutex.RUnlock() return } + sc.mutex.RUnlock() // Signal stop to all goroutines @@ -114,6 +127,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s if verbose { log.Printf("Failed to list existing containers: %v", err) } + return } @@ -147,13 +161,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, case event := <-events: if event.Type == "container" && event.Action == "start" { // Get container details - containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) + containerInfo, err := sc.client.ContainerInspect(ctx, event.Actor.ID) if err != nil { continue } - // Convert to types.Container format for consistency - cont := types.Container{ + // Convert to container.Summary format for consistency + cont := container.Summary{ ID: containerInfo.ID, Names: []string{containerInfo.Name}, Labels: containerInfo.Config.Labels, @@ -167,13 +181,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, if verbose { log.Printf("Error in Docker events stream: %v", err) } + return } } } // shouldMonitorContainer determines if a container should be monitored. -func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { +func (sc *StatsCollector) shouldMonitorContainer(cont container.Summary, runID string) bool { // Check if it has the correct run ID label if cont.Labels == nil || cont.Labels["hi.run-id"] != runID { return false @@ -213,6 +228,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI } sc.wg.Add(1) + go sc.collectStatsForContainer(ctx, containerID, verbose) } @@ -226,12 +242,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe if verbose { log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) } + return } defer statsResponse.Body.Close() decoder := json.NewDecoder(statsResponse.Body) - var prevStats *container.Stats + + var prevStats *container.StatsResponse for { select { @@ -240,12 +258,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe case <-ctx.Done(): return default: - var stats container.Stats - if err := decoder.Decode(&stats); err != nil { + var stats container.StatsResponse + + err := decoder.Decode(&stats) + if err != nil { // EOF is expected when container stops or stream ends if err.Error() != "EOF" && verbose { log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) } + return } @@ -256,13 +277,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe } // Calculate memory usage in MB - memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024) + memoryMB := float64(stats.MemoryStats.Usage) / (bytesPerKB * bytesPerKB) // Store the sample (skip first sample since CPU calculation needs previous stats) if prevStats != nil { // Get container stats reference without holding the main mutex - var containerStats *ContainerStats - var exists bool + var ( + containerStats *ContainerStats + exists bool + ) sc.mutex.RLock() containerStats, exists = sc.containers[containerID] @@ -286,7 +309,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe } // calculateCPUPercent calculates CPU usage percentage from Docker stats. -func calculateCPUPercent(prevStats, stats *container.Stats) float64 { +func calculateCPUPercent(prevStats, stats *container.StatsResponse) float64 { // CPU calculation based on Docker's implementation cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage) systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage) @@ -299,7 +322,7 @@ func calculateCPUPercent(prevStats, stats *container.Stats) float64 { numCPUs = 1.0 } - return (cpuDelta / systemDelta) * numCPUs * 100.0 + return (cpuDelta / systemDelta) * numCPUs * percentageMultiplier } return 0.0 @@ -331,10 +354,12 @@ type StatsSummary struct { func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { // Take snapshot of container references without holding main lock long sc.mutex.RLock() + containerRefs := make([]*ContainerStats, 0, len(sc.containers)) for _, containerStats := range sc.containers { containerRefs = append(containerRefs, containerStats) } + sc.mutex.RUnlock() summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) @@ -371,8 +396,8 @@ func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { } // Sort by container name for consistent output - sort.Slice(summaries, func(i, j int) bool { - return summaries[i].ContainerName < summaries[j].ContainerName + slices.SortFunc(summaries, func(a, b ContainerStatsSummary) int { + return cmp.Compare(a.ContainerName, b.ContainerName) }) return summaries @@ -384,23 +409,25 @@ func calculateStatsSummary(values []float64) StatsSummary { return StatsSummary{} } - min := values[0] - max := values[0] + minVal := values[0] + maxVal := values[0] sum := 0.0 for _, value := range values { - if value < min { - min = value + if value < minVal { + minVal = value } - if value > max { - max = value + + if value > maxVal { + maxVal = value } + sum += value } return StatsSummary{ - Min: min, - Max: max, + Min: minVal, + Max: maxVal, Average: sum / float64(len(values)), } } @@ -434,6 +461,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo } summaries := sc.GetSummary() + var violations []MemoryViolation for _, summary := range summaries { diff --git a/cmd/mapresponses/main.go b/cmd/mapresponses/main.go index 5d7ad07d..1951ca4b 100644 --- a/cmd/mapresponses/main.go +++ b/cmd/mapresponses/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" @@ -11,6 +12,8 @@ import ( "github.com/juanfont/headscale/integration/integrationutil" ) +var errDirectoryRequired = errors.New("directory is required") + type MapConfig struct { Directory string `flag:"directory,Directory to read map responses from"` } @@ -40,7 +43,7 @@ func main() { // runIntegrationTest executes the integration test workflow. func runOnline(env *command.Env) error { if mapConfig.Directory == "" { - return fmt.Errorf("directory is required") + return errDirectoryRequired } resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) @@ -57,5 +60,6 @@ func runOnline(env *command.Env) error { os.Stderr.Write(out) os.Stderr.Write([]byte("\n")) + return nil } diff --git a/flake.lock b/flake.lock index 50a7dde2..29f6b326 100644 --- a/flake.lock +++ b/flake.lock @@ -20,16 +20,16 @@ }, "nixpkgs": { "locked": { - "lastModified": 1766840161, - "narHash": "sha256-Ss/LHpJJsng8vz1Pe33RSGIWUOcqM1fjrehjUkdrWio=", + "lastModified": 1769011238, + "narHash": "sha256-WPiOcgZv7GQ/AVd9giOrlZjzXHwBNM4yQ+JzLrgI3Xk=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "3edc4a30ed3903fdf6f90c837f961fa6b49582d1", + "rev": "a895ec2c048eba3bceab06d5dfee5026a6b1c875", "type": "github" }, "original": { "owner": "NixOS", - "ref": "nixpkgs-unstable", + "ref": "master", "repo": "nixpkgs", "type": "github" } diff --git a/flake.nix b/flake.nix index 48aa075c..8aa8d32b 100644 --- a/flake.nix +++ b/flake.nix @@ -2,7 +2,8 @@ description = "headscale - Open Source Tailscale Control server"; inputs = { - nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable"; + # TODO: Move back to nixpkgs-unstable once Go 1.26 is available there + nixpkgs.url = "github:NixOS/nixpkgs/master"; flake-utils.url = "github:numtide/flake-utils"; }; @@ -23,11 +24,11 @@ default = headscale; }; - overlay = _: prev: + overlays.default = _: prev: let pkgs = nixpkgs.legacyPackages.${prev.system}; - buildGo = pkgs.buildGo125Module; - vendorHash = "sha256-escboufgbk+lEitw48eWEIltXbaCPdysb/g4YR+extg="; + buildGo = pkgs.buildGo126Module; + vendorHash = "sha256-hL9vHunaxodGt3g/CIVirXy4OjZKTI3XwbVPPRb34OY="; in { headscale = buildGo { @@ -129,10 +130,10 @@ (system: let pkgs = import nixpkgs { - overlays = [ self.overlay ]; + overlays = [ self.overlays.default ]; inherit system; }; - buildDeps = with pkgs; [ git go_1_25 gnumake ]; + buildDeps = with pkgs; [ git go_1_26 gnumake ]; devDeps = with pkgs; buildDeps ++ [ @@ -167,7 +168,7 @@ clang-tools # clang-format protobuf-language-server ] - ++ lib.optional pkgs.stdenv.isLinux [ traceroute ]; + ++ lib.optional pkgs.stdenv.hostPlatform.isLinux [ traceroute ]; # Add entry to build a docker image with headscale # caveat: only works on Linux @@ -184,7 +185,7 @@ in rec { # `nix develop` - devShell = pkgs.mkShell { + devShells.default = pkgs.mkShell { buildInputs = devDeps ++ [ @@ -219,8 +220,8 @@ packages = with pkgs; { inherit headscale; inherit headscale-docker; + default = headscale; }; - defaultPackage = pkgs.headscale; # `nix run` apps.headscale = flake-utils.lib.mkApp { diff --git a/go.mod b/go.mod index 905a27db..5616acda 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/juanfont/headscale -go 1.25 +go 1.26rc2 require ( github.com/arl/statsviz v0.7.2 diff --git a/hscontrol/app.go b/hscontrol/app.go index aa011503..cadcd227 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -67,6 +67,9 @@ var ( ) ) +// oidcProviderInitTimeout is the timeout for initializing the OIDC provider. +const oidcProviderInitTimeout = 30 * time.Second + var ( debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK") debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT") @@ -142,6 +145,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if !ok { log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") + return } @@ -157,10 +161,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.ephemeralGC = ephemeralGC var authProvider AuthProvider + authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), oidcProviderInitTimeout) defer cancel() + oidcProvider, err := NewAuthProviderOIDC( ctx, &app, @@ -177,6 +183,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { authProvider = oidcProvider } } + app.authProvider = authProvider if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS @@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { lastExpiryCheck := time.Unix(0, 0) derpTickerChan := make(<-chan time.Time) + if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 { derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) defer derpTicker.Stop() + derpTickerChan = derpTicker.C } @@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { return case <-expireTicker.C: - var expiredNodeChanges []change.Change - var changed bool + var ( + expiredNodeChanges []change.Change + changed bool + ) lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) @@ -287,11 +298,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { case <-derpTickerChan: log.Info().Msg("Fetching DERPMap updates") + + //nolint:contextcheck derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { derpMap, err := derp.GetDERPMap(h.cfg.DERP) if err != nil { return nil, err } + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { region, _ := h.DERPServer.GenerateRegion() derpMap.Regions[region.RegionID] = ®ion @@ -303,6 +317,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { log.Error().Err(err).Msg("failed to build new DERPMap, retrying later") continue } + h.state.SetDERPMap(derpMap) h.Change(change.DERPMap()) @@ -311,6 +326,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { if !ok { continue } + h.cfg.TailcfgDNSConfig.ExtraRecords = records h.Change(change.ExtraRecords()) @@ -390,6 +406,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writeUnauthorized := func(statusCode int) { writer.WriteHeader(statusCode) + + //nolint:noinlineerr if _, err := writer.Write([]byte("Unauthorized")); err != nil { log.Error().Err(err).Msg("writing HTTP response failed") } @@ -486,6 +504,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { var err error + capver.CanOldCodeBeCleanedUp() if profilingEnabled { @@ -512,6 +531,7 @@ func (h *Headscale) Serve() error { Msg("Clients with a lower minimum version will be rejected") h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() defer h.mapBatcher.Close() @@ -545,6 +565,7 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() + ephmNodes := h.state.ListEphemeralNodes() for _, node := range ephmNodes.All() { h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) @@ -555,7 +576,9 @@ func (h *Headscale) Serve() error { if err != nil { return fmt.Errorf("setting up extrarecord manager: %w", err) } + h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() + go h.extraRecordMan.Run() defer h.extraRecordMan.Close() } @@ -564,6 +587,7 @@ func (h *Headscale) Serve() error { // records updates scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) defer scheduleCancel() + go h.scheduledTasks(scheduleCtx) if zl.GlobalLevel() == zl.TraceLevel { @@ -751,7 +775,6 @@ func (h *Headscale) Serve() error { log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)") } - var tailsqlContext context.Context if tailsqlEnabled { if h.cfg.Database.Type != types.DatabaseSqlite { @@ -863,6 +886,8 @@ func (h *Headscale) Serve() error { // Close state connections info("closing state and database") + + //nolint:contextcheck err = h.state.Close() if err != nil { log.Error().Err(err).Msg("failed to close state") diff --git a/hscontrol/auth.go b/hscontrol/auth.go index ac5968e3..1d49f5b4 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -16,12 +16,11 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) type AuthProvider interface { - RegisterHandler(http.ResponseWriter, *http.Request) - AuthURL(types.RegistrationID) string + RegisterHandler(w http.ResponseWriter, r *http.Request) + AuthURL(regID types.RegistrationID) string } func (h *Headscale) handleRegister( @@ -52,6 +51,7 @@ func (h *Headscale) handleRegister( if err != nil { return nil, fmt.Errorf("handling logout: %w", err) } + if resp != nil { return resp, nil } @@ -113,8 +113,7 @@ func (h *Headscale) handleRegister( resp, err := h.handleRegisterWithAuthKey(req, machineKey) if err != nil { // Preserve HTTPError types so they can be handled properly by the HTTP layer - var httpErr HTTPError - if errors.As(err, &httpErr) { + if httpErr, ok := errors.AsType[HTTPError](err); ok { return nil, httpErr } @@ -133,7 +132,7 @@ func (h *Headscale) handleRegister( } // handleLogout checks if the [tailcfg.RegisterRequest] is a -// logout attempt from a node. If the node is not attempting to +// logout attempt from a node. If the node is not attempting to. func (h *Headscale) handleLogout( node types.NodeView, req tailcfg.RegisterRequest, @@ -160,6 +159,7 @@ func (h *Headscale) handleLogout( Interface("reg.req", req). Bool("unexpected", true). Msg("Node key expired, forcing re-authentication") + return &tailcfg.RegisterResponse{ NodeKeyExpired: true, MachineAuthorized: false, @@ -279,6 +279,7 @@ func (h *Headscale) waitForFollowup( // registration is expired in the cache, instruct the client to try a new registration return h.reqToNewRegisterResponse(req, machineKey) } + return nodeToRegisterResponse(node.View()), nil } } @@ -316,7 +317,7 @@ func (h *Headscale) reqToNewRegisterResponse( MachineKey: machineKey, NodeKey: req.NodeKey, Hostinfo: hostinfo, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), }, ) @@ -344,8 +345,8 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } - var perr types.PAKError - if errors.As(err, &perr) { + + if perr, ok := errors.AsType[types.PAKError](err); ok { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } @@ -355,7 +356,7 @@ func (h *Headscale) handleRegisterWithAuthKey( // If node is not valid, it means an ephemeral node was deleted during logout if !node.Valid() { h.Change(changed) - return nil, nil + return nil, nil //nolint:nilnil } // This is a bit of a back and forth, but we have a bit of a chicken and egg @@ -435,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive( Str("generated.hostname", hostname). Msg("Received registration request with empty hostname, generated default") } + hostinfo.Hostname = hostname nodeToRegister := types.NewRegisterNode( @@ -443,7 +445,7 @@ func (h *Headscale) handleRegisterInteractive( MachineKey: machineKey, NodeKey: req.NodeKey, Hostinfo: hostinfo, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), }, ) diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 1677642f..bc6c7cc2 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -2,6 +2,7 @@ package hscontrol import ( "context" + "errors" "fmt" "net/url" "strings" @@ -16,14 +17,20 @@ import ( "tailscale.com/types/key" ) -// Interactive step type constants +// Test sentinel errors. +var ( + errNodeNotFoundAfterSetup = errors.New("node not found after setup") + errInvalidAuthURLFormat = errors.New("invalid AuthURL format") +) + +// Interactive step type constants. const ( stepTypeInitialRequest = "initial_request" stepTypeAuthCompletion = "auth_completion" stepTypeFollowupRequest = "followup_request" ) -// interactiveStep defines a step in the interactive authentication workflow +// interactiveStep defines a step in the interactive authentication workflow. type interactiveStep struct { stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest expectAuthURL bool @@ -31,6 +38,7 @@ type interactiveStep struct { callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked } +//nolint:gocyclo func TestAuthenticationFlows(t *testing.T) { // Shared test keys for consistent behavior across test cases machineKey1 := key.NewMachine() @@ -69,12 +77,15 @@ func TestAuthenticationFlows(t *testing.T) { { name: "preauth_key_valid_new_node", setupFunc: func(t *testing.T, app *Headscale) (string, error) { + t.Helper() + user := app.state.CreateUserForTest("preauth-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -89,7 +100,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -111,6 +122,8 @@ func TestAuthenticationFlows(t *testing.T) { { name: "preauth_key_reusable_multiple_nodes", setupFunc: func(t *testing.T, app *Headscale) (string, error) { + t.Helper() + user := app.state.CreateUserForTest("reusable-user") pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) @@ -129,6 +142,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -154,7 +168,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -163,6 +177,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both nodes exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.True(t, found2) assert.Equal(t, "reusable-node-1", node1.Hostname()) @@ -196,6 +211,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -221,12 +237,13 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, wantError: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // First node should exist, second should not _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) _, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.False(t, found2) }, @@ -254,7 +271,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -272,6 +289,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -286,7 +304,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -323,7 +341,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -352,7 +370,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -391,6 +409,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -400,8 +419,10 @@ func TestAuthenticationFlows(t *testing.T) { // Wait for node to be available in NodeStore with debug info var attemptCount int + require.EventuallyWithT(t, func(c *assert.CollectT) { attemptCount++ + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) if assert.True(c, found, "node should be available in NodeStore") { t.Logf("Node found in NodeStore after %d attempts", attemptCount) @@ -417,7 +438,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { @@ -451,6 +472,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -471,7 +493,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key + machineKey: machineKey2.Public, // Different machine key wantError: true, }, // TEST: Existing node cannot extend expiry without re-auth @@ -500,6 +522,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -520,7 +543,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Expired node must re-authenticate @@ -549,25 +572,31 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err } // Wait for node to be available in NodeStore - var node types.NodeView - var found bool + var ( + node types.NodeView + found bool + ) + require.EventuallyWithT(t, func(c *assert.CollectT) { node, found = app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(c, found, "node should be available in NodeStore") }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + if !found { - return "", fmt.Errorf("node not found after setup") + return "", errNodeNotFoundAfterSetup } // Expire the node expiredTime := time.Now().Add(-1 * time.Hour) _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + return "", err }, request: func(_ string) tailcfg.RegisterRequest { @@ -577,7 +606,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), // Future expiry } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.NodeKeyExpired) @@ -610,6 +639,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -630,7 +660,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(-1 * time.Hour), // Logout } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.NodeKeyExpired) @@ -673,6 +703,7 @@ func TestAuthenticationFlows(t *testing.T) { // and handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") + node := app.state.CreateNodeForTest(user, "followup-success-node") registered <- node }() @@ -685,7 +716,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -723,7 +754,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Invalid followup URL is rejected @@ -742,7 +773,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Non-existent registration ID is rejected @@ -761,7 +792,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -782,6 +813,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -796,7 +828,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -821,6 +853,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -833,7 +866,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -865,6 +898,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -879,7 +913,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -898,6 +932,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -912,7 +947,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -922,6 +957,7 @@ func TestAuthenticationFlows(t *testing.T) { node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found) assert.Equal(t, "tagged-pak-node", node.Hostname()) + if node.AuthKey().Valid() { assert.NotEmpty(t, node.AuthKey().Tags()) } @@ -1031,6 +1067,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1047,6 +1084,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(newAuthKey string) tailcfg.RegisterRequest { @@ -1061,7 +1099,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1099,6 +1137,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1124,7 +1163,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(48 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuthURL: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.Contains(t, resp.AuthURL, "register/") @@ -1161,6 +1200,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1177,6 +1217,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pakRotation.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1191,7 +1232,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1226,6 +1267,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1240,7 +1282,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Time{}, // Zero time } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1265,6 +1307,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1286,7 +1329,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1342,7 +1385,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: false, // Should not be authorized yet - needs to use new AuthURL validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Should get a new AuthURL, not an error @@ -1367,7 +1410,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, // TEST: Wrong followup path format is rejected @@ -1386,7 +1429,7 @@ func TestAuthenticationFlows(t *testing.T) { NodeKey: nodeKey1.Public(), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantError: true, }, @@ -1417,7 +1460,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1429,6 +1472,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify custom hostinfo was preserved through interactive workflow node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be found after interactive registration") + if found { assert.Equal(t, "custom-interactive-node", node.Hostname()) assert.Equal(t, "linux", node.Hostinfo().OS()) @@ -1455,6 +1499,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1469,7 +1514,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1508,7 +1553,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1520,6 +1565,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify registration ID was properly generated and used node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered after interactive workflow") + if found { assert.Equal(t, "registration-id-test-node", node.Hostname()) assert.Equal(t, "test-os", node.Hostinfo().OS()) @@ -1535,6 +1581,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1549,7 +1596,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1577,6 +1624,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1592,7 +1640,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(12 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1632,6 +1680,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1648,6 +1697,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(user2AuthKey string) tailcfg.RegisterRequest { @@ -1662,7 +1712,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { assert.True(t, resp.MachineAuthorized) @@ -1712,6 +1762,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -1735,7 +1786,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key + machineKey: machineKey1.Public, // Same machine key requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1789,7 +1840,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: false, // Should not be authorized yet - needs to use new AuthURL validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Should get a new AuthURL, not an error @@ -1799,13 +1850,13 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the response contains a valid registration URL authURL, err := url.Parse(resp.AuthURL) - assert.NoError(t, err, "AuthURL should be a valid URL") + require.NoError(t, err, "AuthURL should be a valid URL") assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/") // Extract and validate the new registration ID exists in cache newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") newRegID, err := types.RegistrationIDFromString(newRegIDStr) - assert.NoError(t, err, "should be able to parse new registration ID") + require.NoError(t, err, "should be able to parse new registration ID") // Verify new registration entry exists in cache _, found := app.state.GetRegistrationCacheEntry(newRegID) @@ -1838,6 +1889,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1858,7 +1910,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now(), // Exactly now (edge case between past and future) } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, wantAuth: true, wantExpired: true, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { @@ -1890,7 +1942,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey2.Public() }, + machineKey: machineKey2.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -1932,6 +1984,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -1955,7 +2008,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2001,7 +2054,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2055,7 +2108,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // This test validates concurrent interactive registration attempts assert.Contains(t, resp.AuthURL, "/register/") @@ -2097,6 +2150,7 @@ func TestAuthenticationFlows(t *testing.T) { // Collect results - at least one should succeed successCount := 0 + for range numConcurrent { select { case err := <-results: @@ -2162,7 +2216,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2206,7 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, requiresInteractiveFlow: true, interactiveSteps: []interactiveStep{ {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, @@ -2217,6 +2271,7 @@ func TestAuthenticationFlows(t *testing.T) { // Should handle nil hostinfo gracefully node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered despite nil hostinfo") + if found { // Should have some default hostname or handle nil gracefully hostname := node.Hostname() @@ -2243,7 +2298,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Get initial AuthURL and extract registration ID authURL := resp.AuthURL @@ -2265,7 +2320,7 @@ func TestAuthenticationFlows(t *testing.T) { nil, "error-test-method", ) - assert.Error(t, err, "should fail with invalid user ID") + require.Error(t, err, "should fail with invalid user ID") // Cache entry should still exist after auth error (for retry scenarios) _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) @@ -2297,7 +2352,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { // Test multiple interactive registration attempts for the same node can coexist authURL1 := resp.AuthURL @@ -2315,12 +2370,14 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL assert.Contains(t, authURL2, "/register/") // Both should have different registration IDs regID1, err1 := extractRegistrationIDFromAuthURL(authURL1) regID2, err2 := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err1) require.NoError(t, err2) assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") @@ -2328,6 +2385,7 @@ func TestAuthenticationFlows(t *testing.T) { // Both cache entries should exist simultaneously _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2354,7 +2412,7 @@ func TestAuthenticationFlows(t *testing.T) { Expiry: time.Now().Add(24 * time.Hour), } }, - machineKey: func() key.MachinePublic { return machineKey1.Public() }, + machineKey: machineKey1.Public, validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { authURL1 := resp.AuthURL regID1, err := extractRegistrationIDFromAuthURL(authURL1) @@ -2371,6 +2429,7 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL regID2, err := extractRegistrationIDFromAuthURL(authURL2) require.NoError(t, err) @@ -2378,6 +2437,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both exist _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2403,6 +2463,7 @@ func TestAuthenticationFlows(t *testing.T) { errorChan <- err return } + responseChan <- resp }() @@ -2430,6 +2491,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the node was created with the second registration's data node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered") + if found { assert.Equal(t, "pending-node-2", node.Hostname()) assert.Equal(t, "second-registration-user", node.User().Name()) @@ -2463,8 +2525,10 @@ func TestAuthenticationFlows(t *testing.T) { // Set up context with timeout for followup tests ctx := context.Background() + if req.Followup != "" { var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() } @@ -2516,7 +2580,7 @@ func TestAuthenticationFlows(t *testing.T) { } } -// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow +// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow. func runInteractiveWorkflowTest(t *testing.T, tt struct { name string setupFunc func(*testing.T, *Headscale) (string, error) @@ -2535,6 +2599,8 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { validateCompleteResponse bool }, app *Headscale, dynamicValue string, ) { + t.Helper() + // Build initial request req := tt.request(dynamicValue) machineKey := tt.machineKey() @@ -2597,6 +2663,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { errorChan <- err return } + responseChan <- resp }() @@ -2650,25 +2717,30 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if responseToValidate == nil { responseToValidate = initialResp } + tt.validate(t, responseToValidate, app) } } -// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL +// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" + idx := strings.LastIndex(authURL, registerPrefix) if idx == -1 { - return "", fmt.Errorf("invalid AuthURL format: %s", authURL) + return "", fmt.Errorf("%w: %s", errInvalidAuthURLFormat, authURL) } idStr := authURL[idx+len(registerPrefix):] + return types.RegistrationIDFromString(idStr) } -// validateCompleteRegistrationResponse performs comprehensive validation of a registration response -func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { +// validateCompleteRegistrationResponse performs comprehensive validation of a registration response. +func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, _ tailcfg.RegisterRequest) { + t.Helper() + // Basic response validation require.NotNil(t, resp, "response should not be nil") require.True(t, resp.MachineAuthorized, "machine should be authorized") @@ -2681,7 +2753,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe // Additional validation can be added here as needed } -// Simple test to validate basic node creation and lookup +// Simple test to validate basic node creation and lookup. func TestNodeStoreLookup(t *testing.T) { app := createTestApp(t) @@ -2713,8 +2785,10 @@ func TestNodeStoreLookup(t *testing.T) { // Wait for node to be available in NodeStore var node types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + node, found = app.state.GetNodeByNodeKey(nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") @@ -2783,8 +2857,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Get the node ID var registeredNode types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available") @@ -2796,6 +2872,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify initial state: user1 has 2 nodes, user2 has 2 nodes user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID)) user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID)) + require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially") require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially") @@ -2876,6 +2953,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify new nodes were created for user1 with the same machine keys t.Logf("Verifying new nodes created for user1 from user2's machine keys...") + for i := 2; i < 4; i++ { node := nodes[i] // Should be able to find a node with user1 and this machine key (the new one) @@ -2899,7 +2977,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Expected behavior: // - User1's original node should STILL EXIST (expired) // - User2 should get a NEW node created (NOT transfer) -// - Both nodes share the same machine key (same physical device) +// - Both nodes share the same machine key (same physical device). func TestWebFlowReauthDifferentUser(t *testing.T) { machineKey := key.NewMachine() nodeKey1 := key.NewNode() @@ -3043,7 +3121,8 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Count nodes per user user1Nodes := 0 user2Nodes := 0 - for i := 0; i < allNodesSlice.Len(); i++ { + + for i := range allNodesSlice.Len() { n := allNodesSlice.At(i) if n.UserID().Get() == user1.ID { user1Nodes++ @@ -3060,7 +3139,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { }) } -// Helper function to create test app +// Helper function to create test app. func createTestApp(t *testing.T) *Headscale { t.Helper() @@ -3147,6 +3226,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { } t.Log("Step 1: Initial registration with pre-auth key") + initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public()) require.NoError(t, err, "initial registration should succeed") require.NotNil(t, initialResp) @@ -3172,6 +3252,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // - System reboots // The Tailscale client persists the pre-auth key in its state and sends it on every registration t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") + restartReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ AuthKey: pakNew.Key, // Same key, now marked as Used=true @@ -3188,10 +3269,12 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public()) // This is the assertion that currently FAILS in v0.27.0 - assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + require.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + if err != nil { t.Logf("Error received (this is the bug): %v", err) t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially") + return // Stop here to show the bug clearly } @@ -3289,7 +3372,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { } _, err = app.handleRegister(context.Background(), req, machineKey.Public()) - assert.Error(t, err, "expired pre-auth key should be rejected") + require.Error(t, err, "expired pre-auth key should be rejected") assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration") } diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index 7457670c..d179dca9 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -13,7 +13,7 @@ import ( ) const ( - apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential + apiKeyPrefix = "hskey-api-" //nolint:gosec apiKeyPrefixLength = 12 apiKeyHashLength = 64 diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index a1429aa6..b876ee84 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -24,7 +24,6 @@ import ( "gorm.io/gorm" "gorm.io/gorm/logger" "gorm.io/gorm/schema" - "tailscale.com/net/tsaddr" "zgo.at/zcache/v2" ) @@ -76,6 +75,7 @@ func NewHeadscaleDatabase( ID: "202501221827", Migrate: func(tx *gorm.DB) error { // Remove any invalid routes associated with a node that does not exist. + //nolint:staticcheck if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error if err != nil { @@ -84,6 +84,7 @@ func NewHeadscaleDatabase( } // Remove any invalid routes without a node_id. + //nolint:staticcheck if tx.Migrator().HasTable(&types.Route{}) { err := tx.Exec("delete from routes where node_id is null").Error if err != nil { @@ -91,6 +92,7 @@ func NewHeadscaleDatabase( } } + //nolint:staticcheck err := tx.AutoMigrate(&types.Route{}) if err != nil { return fmt.Errorf("automigrating types.Route: %w", err) @@ -109,6 +111,7 @@ func NewHeadscaleDatabase( if err != nil { return fmt.Errorf("automigrating types.PreAuthKey: %w", err) } + err = tx.AutoMigrate(&types.Node{}) if err != nil { return fmt.Errorf("automigrating types.Node: %w", err) @@ -155,7 +158,9 @@ AND auth_key_id NOT IN ( nodeRoutes := map[uint64][]netip.Prefix{} + //nolint:staticcheck var routes []types.Route + err = tx.Find(&routes).Error if err != nil { return fmt.Errorf("fetching routes: %w", err) @@ -168,10 +173,13 @@ AND auth_key_id NOT IN ( } for nodeID, routes := range nodeRoutes { - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) routes = slices.Compact(routes) data, err := json.Marshal(routes) + if err != nil { + return fmt.Errorf("marshaling routes for node %d: %w", nodeID, err) + } err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error if err != nil { @@ -180,6 +188,7 @@ AND auth_key_id NOT IN ( } // Drop the old table. + //nolint:staticcheck _ = tx.Migrator().DropTable(&types.Route{}) return nil @@ -256,10 +265,13 @@ AND auth_key_id NOT IN ( // Check if routes table exists and drop it (should have been migrated already) var routesExists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists) if err == nil && routesExists { log.Info().Msg("Dropping leftover routes table") - if err := tx.Exec("DROP TABLE routes").Error; err != nil { + + err := tx.Exec("DROP TABLE routes").Error + if err != nil { return fmt.Errorf("dropping routes table: %w", err) } } @@ -281,6 +293,7 @@ AND auth_key_id NOT IN ( for _, table := range tablesToRename { // Check if table exists before renaming var exists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists) if err != nil { return fmt.Errorf("checking if table %s exists: %w", table, err) @@ -291,7 +304,8 @@ AND auth_key_id NOT IN ( _ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error // Rename current table to _old - if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil { + err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error + if err != nil { return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err) } } @@ -365,7 +379,8 @@ AND auth_key_id NOT IN ( } for _, createSQL := range tableCreationSQL { - if err := tx.Exec(createSQL).Error; err != nil { + err := tx.Exec(createSQL).Error + if err != nil { return fmt.Errorf("creating new table: %w", err) } } @@ -394,7 +409,8 @@ AND auth_key_id NOT IN ( } for _, copySQL := range dataCopySQL { - if err := tx.Exec(copySQL).Error; err != nil { + err := tx.Exec(copySQL).Error + if err != nil { return fmt.Errorf("copying data: %w", err) } } @@ -417,14 +433,16 @@ AND auth_key_id NOT IN ( } for _, indexSQL := range indexes { - if err := tx.Exec(indexSQL).Error; err != nil { + err := tx.Exec(indexSQL).Error + if err != nil { return fmt.Errorf("creating index: %w", err) } } // Drop old tables only after everything succeeds for _, table := range tablesToRename { - if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil { + err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error + if err != nil { log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded") } } @@ -762,6 +780,7 @@ AND auth_key_id NOT IN ( // or else it blocks... sqlConn.SetMaxIdleConns(maxIdleConns) + sqlConn.SetMaxOpenConns(maxOpenConns) defer sqlConn.SetMaxIdleConns(1) defer sqlConn.SetMaxOpenConns(1) @@ -779,6 +798,7 @@ AND auth_key_id NOT IN ( }, } + //nolint:noinlineerr if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { return nil, fmt.Errorf("validating schema: %w", err) } @@ -913,6 +933,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig // Get the current foreign key status var fkOriginallyEnabled int + //nolint:noinlineerr if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { return fmt.Errorf("checking foreign key status: %w", err) } @@ -942,27 +963,32 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig if needsFKDisabled { // Disable foreign keys for this migration - if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { + err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error + if err != nil { return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err) } } else { // Ensure foreign keys are enabled for this migration - if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + err := dbConn.Exec("PRAGMA foreign_keys = ON").Error + if err != nil { return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err) } } // Run up to this specific migration (will only run the next pending migration) - if err := migrations.MigrateTo(migrationID); err != nil { + err := migrations.MigrateTo(migrationID) + if err != nil { return fmt.Errorf("running migration %s: %w", migrationID, err) } } + //nolint:noinlineerr if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { return fmt.Errorf("restoring foreign keys: %w", err) } // Run the rest of the migrations + //nolint:noinlineerr if err := migrations.Migrate(); err != nil { return err } @@ -1005,7 +1031,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig } } else { // PostgreSQL can run all migrations in one block - no foreign key issues - if err := migrations.Migrate(); err != nil { + err := migrations.Migrate() + if err != nil { return err } } @@ -1031,7 +1058,7 @@ func (hsdb *HSDatabase) Close() error { } if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { - db.Exec("VACUUM") + _, _ = db.ExecContext(context.Background(), "VACUUM") } return db.Close() diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3cd0d14e..f93b9ef8 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -1,6 +1,7 @@ package db import ( + "context" "database/sql" "os" "os/exec" @@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { // Verify api_keys data preservation var apiKeyCount int + err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error require.NoError(t, err) assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema") @@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { return err } - _, err = db.Exec(string(schemaContent)) + _, err = db.ExecContext(context.Background(), string(schemaContent)) return err } @@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { func requireConstraintFailed(t *testing.T, err error) { t.Helper() require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) } @@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) { } // Construct the pg_restore command - cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) + cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) // Set the output streams cmd.Stdout = os.Stdout @@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { // skip already-applied migrations and only run new ones. func TestSQLiteAllTestdataMigrations(t *testing.T) { t.Parallel() + schemas, err := os.ReadDir("testdata/sqlite") require.NoError(t, err) diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go index d118b7fd..8e8a1109 100644 --- a/hscontrol/db/ephemeral_garbage_collector_test.go +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Basic deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex - var deletionWg sync.WaitGroup + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + deletionWg sync.WaitGroup + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionWg.Done() } @@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { go gc.Start() // Schedule several nodes for deletion with short expiry - const expiry = fifty - const numNodes = 100 + const ( + expiry = fifty + numNodes = 100 + ) // Set up wait group for expected deletions + deletionWg.Add(numNodes) for i := 1; i <= numNodes; i++ { - gc.Schedule(types.NodeID(i), expiry) + gc.Schedule(types.NodeID(i), expiry) //nolint:gosec } // Wait for all scheduled deletions to complete @@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // Schedule and immediately cancel to test that part of the code for i := numNodes + 1; i <= numNodes*2; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec gc.Schedule(nodeID, time.Hour) gc.Cancel(nodeID) } @@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Start GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() - const shortExpiry = fifty - const longExpiry = 1 * time.Hour + const ( + shortExpiry = fifty + longExpiry = 1 * time.Hour + ) nodeID := types.NodeID(1) @@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionNotifier <- nodeID } // Start the GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() nodeID := types.NodeID(1) + const expiry = fifty // Schedule node for deletion @@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + nodeDeleted := make(chan struct{}) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() close(nodeDeleted) // Signal that deletion happened } @@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Use a WaitGroup to ensure the GC has started var startWg sync.WaitGroup startWg.Add(1) + go func() { startWg.Done() // Signal that the goroutine has started gc.Start() }() + startWg.Wait() // Wait for the GC to start // Close GC right away @@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Check no node was deleted deleteMutex.Lock() + nodesDeleted := len(deletedIDs) + deleteMutex.Unlock() assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") @@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() } @@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { go gc.Start() // Number of concurrent scheduling goroutines - const numSchedulers = 10 - const nodesPerScheduler = 50 + const ( + numSchedulers = 10 + nodesPerScheduler = 50 + ) const closeAfterNodes = 25 // Close GC after this many nodes per scheduler @@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { case <-stopScheduling: return default: - nodeID := types.NodeID(baseNodeID + j + 1) - gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test + nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec + gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test atomic.AddInt64(&scheduledCount, 1) // Yield to other goroutines to introduce variability diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 7ba335e8..7827e002 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/net/tsaddr" - "tailscale.com/types/ptr" ) var mpp = func(pref string) *netip.Prefix { @@ -484,12 +483,13 @@ func TestBackfillIPAddresses(t *testing.T) { func TestIPAllocatorNextNoReservedIPs(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) + defer db.Close() alloc, err := NewIPAllocator( db, - ptr.To(tsaddr.CGNATRange()), - ptr.To(tsaddr.TailscaleULARange()), + new(tsaddr.CGNATRange()), + new(tsaddr.TailscaleULARange()), types.IPAllocationStrategySequential, ) if err != nil { @@ -497,17 +497,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) { } // Validate that we do not give out 100.100.100.100 - nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange())) + nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange())) require.NoError(t, err) assert.Equal(t, na("100.100.100.101"), *nextQuad100) // Validate that we do not give out fd7a:115c:a1e0::53 - nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange())) + nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange())) require.NoError(t, err) assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6) // Validate that we do not give out fd7a:115c:a1e0::53 - nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange())) + nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange())) t.Logf("chrome: %s", nextChrome.String()) require.NoError(t, err) assert.Equal(t, na("100.115.94.0"), *nextChrome) diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index bf407bb4..e7468207 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -1,13 +1,13 @@ package db import ( + "cmp" "encoding/json" "errors" "fmt" "net/netip" "regexp" "slices" - "sort" "strconv" "strings" "sync" @@ -20,7 +20,6 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) const ( @@ -37,6 +36,7 @@ var ( "node not found in registration cache", ) ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") + ErrNameNotUnique = errors.New("name is not unique") ) // ListPeers returns peers of node, regardless of any Policy or if the node is expired. @@ -60,7 +60,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types return types.Nodes{}, err } - sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) + slices.SortFunc(nodes, func(a, b *types.Node) int { return cmp.Compare(a.ID, b.ID) }) return nodes, nil } @@ -207,6 +207,7 @@ func SetTags( slices.Sort(tags) tags = slices.Compact(tags) + b, err := json.Marshal(tags) if err != nil { return err @@ -220,7 +221,7 @@ func SetTags( return nil } -// SetTags takes a Node struct pointer and update the forced tags. +// SetApprovedRoutes updates the approved routes for a node. func SetApprovedRoutes( tx *gorm.DB, nodeID types.NodeID, @@ -228,7 +229,8 @@ func SetApprovedRoutes( ) error { if len(routes) == 0 { // if no routes are provided, we remove all - if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil { + err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error + if err != nil { return fmt.Errorf("removing approved routes: %w", err) } @@ -251,6 +253,7 @@ func SetApprovedRoutes( return err } + //nolint:noinlineerr if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { return fmt.Errorf("updating approved routes: %w", err) } @@ -277,18 +280,21 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { func RenameNode(tx *gorm.DB, nodeID types.NodeID, newName string, ) error { - if err := util.ValidateHostname(newName); err != nil { + err := util.ValidateHostname(newName) + if err != nil { return fmt.Errorf("renaming node: %w", err) } // Check if the new name is unique var count int64 - if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { + + err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error + if err != nil { return fmt.Errorf("failed to check name uniqueness: %w", err) } if count > 0 { - return errors.New("name is not unique") + return ErrNameNotUnique } if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { @@ -379,6 +385,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n if ipv4 == nil { ipv4 = oldNode.IPv4 } + if ipv6 == nil { ipv6 = oldNode.IPv6 } @@ -407,6 +414,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n node.IPv6 = ipv6 var err error + node.Hostname, err = util.NormaliseHostname(node.Hostname) if err != nil { newHostname := util.InvalidString() @@ -491,8 +499,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { func isUniqueName(tx *gorm.DB, name string) (bool, error) { nodes := types.Nodes{} - if err := tx. - Where("given_name = ?", name).Find(&nodes).Error; err != nil { + + err := tx. + Where("given_name = ?", name).Find(&nodes).Error + if err != nil { return false, err } @@ -646,7 +656,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) panic("CreateNodeForTest requires a valid user") } - nodeName := "testnode" + nodeName := "testnode" //nolint:goconst if len(hostname) > 0 && hostname[0] != "" { nodeName = hostname[0] } @@ -668,7 +678,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) Hostname: nodeName, UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: &pak.ID, } err = hsdb.DB.Save(node).Error @@ -694,9 +704,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname . } var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { var err error + registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) + return err }) if err != nil { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 7e00f9ca..9ff96eb9 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -22,7 +22,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestGetNode(t *testing.T) { @@ -99,6 +98,7 @@ func TestExpireNode(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -115,7 +115,7 @@ func TestExpireNode(t *testing.T) { Hostname: "testnode", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), Expiry: &time.Time{}, } db.DB.Save(node) @@ -143,6 +143,7 @@ func TestSetTags(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -159,7 +160,7 @@ func TestSetTags(t *testing.T) { Hostname: "testnode", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } trx := db.DB.Save(node) @@ -443,7 +444,7 @@ func TestAutoApproveRoutes(t *testing.T) { Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.routes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), } err = adb.DB.Save(&node).Error @@ -460,17 +461,17 @@ func TestAutoApproveRoutes(t *testing.T) { RoutableIPs: tt.routes, }, Tags: []string{"tag:exit"}, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + IPv4: new(netip.MustParseAddr("100.64.0.2")), } err = adb.DB.Save(&nodeTagged).Error require.NoError(t, err) users, err := adb.ListUsers() - assert.NoError(t, err) + require.NoError(t, err) nodes, err := adb.ListNodes() - assert.NoError(t, err) + require.NoError(t, err) pm, err := pmf(users, nodes.ViewSlice()) require.NoError(t, err) @@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes1) == 0 { expectedRoutes1 = nil } + if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes2) == 0 { expectedRoutes2 = nil } + if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -597,7 +600,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) { // Use shorter expiry for faster tests for i := range want { - go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec // test code, no overflow risk + go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec } // Wait for all deletions to complete @@ -636,9 +639,11 @@ func TestListEphemeralNodes(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) + //nolint:staticcheck pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) require.NoError(t, err) @@ -649,7 +654,7 @@ func TestListEphemeralNodes(t *testing.T) { Hostname: "test", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } nodeEph := types.Node{ @@ -659,7 +664,7 @@ func TestListEphemeralNodes(t *testing.T) { Hostname: "ephemeral", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pakEph.ID), + AuthKeyID: new(pakEph.ID), } err = db.DB.Save(&node).Error @@ -719,6 +724,7 @@ func TestNodeNaming(t *testing.T) { // break your network, so they should be replaced when registering // a node. // https://github.com/juanfont/headscale/issues/2343 + //nolint:gosmopolitan nodeInvalidHostname := types.Node{ MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), @@ -746,12 +752,19 @@ func TestNodeNaming(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) if err != nil { return err } - _, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil) - _, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil) + + _, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil) + if err != nil { + return err + } + + _, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil) + return err }) require.NoError(t, err) @@ -810,25 +823,26 @@ func TestNodeNaming(t *testing.T) { err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "test") }) - assert.ErrorContains(t, err, "name is not unique") + require.ErrorContains(t, err, "name is not unique") // Rename invalid chars + //nolint:gosmopolitan err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[2].ID, "我的电脑") }) - assert.ErrorContains(t, err, "invalid characters") + require.ErrorContains(t, err, "invalid characters") // Rename too short err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[3].ID, "a") }) - assert.ErrorContains(t, err, "at least 2 characters") + require.ErrorContains(t, err, "at least 2 characters") // Rename with emoji err = db.Write(func(tx *gorm.DB) error { return RenameNode(tx, nodes[0].ID, "hostname-with-💩") }) - assert.ErrorContains(t, err, "invalid characters") + require.ErrorContains(t, err, "invalid characters") // Rename with only emoji err = db.Write(func(tx *gorm.DB) error { @@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) { }, { name: "chinese_chars_with_dash_rejected", - newName: "server-北京-01", + newName: "server-北京-01", //nolint:gosmopolitan wantErr: "invalid characters", }, { name: "chinese_only_rejected", - newName: "我的电脑", + newName: "我的电脑", //nolint:gosmopolitan wantErr: "invalid characters", }, { @@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) { }, { name: "mixed_chinese_emoji_rejected", - newName: "测试💻机器", + newName: "测试💻机器", //nolint:gosmopolitan wantErr: "invalid characters", }, { @@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err @@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index c5904353..00c5985f 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -332,7 +332,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { return nil } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. +// ExpirePreAuthKey marks a PreAuthKey as expired. func ExpirePreAuthKey(tx *gorm.DB, id uint64) error { now := time.Now() return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 7c5dcbd7..2f28d449 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -11,7 +11,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tailscale.com/types/ptr" ) func TestCreatePreAuthKey(t *testing.T) { @@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) { test: func(t *testing.T, db *HSDatabase) { t.Helper() - _, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil) + _, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil) assert.Error(t, err) }, }, @@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { Hostname: "testest", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(key.ID), + AuthKeyID: new(key.ID), } db.DB.Save(&node) diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go index d27977a4..0088fe86 100644 --- a/hscontrol/db/sqliteconfig/config.go +++ b/hscontrol/db/sqliteconfig/config.go @@ -22,6 +22,9 @@ var ( const ( // DefaultBusyTimeout is the default busy timeout in milliseconds. DefaultBusyTimeout = 10000 + // DefaultWALAutocheckpoint is the default WAL autocheckpoint value (number of pages). + // SQLite default is 1000 pages. + DefaultWALAutocheckpoint = 1000 ) // JournalMode represents SQLite journal_mode pragma values. @@ -310,7 +313,7 @@ func Default(path string) *Config { BusyTimeout: DefaultBusyTimeout, JournalMode: JournalModeWAL, AutoVacuum: AutoVacuumIncremental, - WALAutocheckpoint: 1000, + WALAutocheckpoint: DefaultWALAutocheckpoint, Synchronous: SynchronousNormal, ForeignKeys: true, TxLock: TxLockImmediate, @@ -362,7 +365,8 @@ func (c *Config) Validate() error { // ToURL builds a properly encoded SQLite connection string using _pragma parameters // compatible with modernc.org/sqlite driver. func (c *Config) ToURL() (string, error) { - if err := c.Validate(); err != nil { + err := c.Validate() + if err != nil { return "", fmt.Errorf("invalid config: %w", err) } @@ -372,18 +376,23 @@ func (c *Config) ToURL() (string, error) { if c.BusyTimeout > 0 { pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout)) } + if c.JournalMode != "" { pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode)) } + if c.AutoVacuum != "" { pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum)) } + if c.WALAutocheckpoint >= 0 { pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint)) } + if c.Synchronous != "" { pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous)) } + if c.ForeignKeys { pragmas = append(pragmas, "foreign_keys=ON") } diff --git a/hscontrol/db/sqliteconfig/config_test.go b/hscontrol/db/sqliteconfig/config_test.go index 66955bb9..7829d9e9 100644 --- a/hscontrol/db/sqliteconfig/config_test.go +++ b/hscontrol/db/sqliteconfig/config_test.go @@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) { t.Errorf("Config.ToURL() error = %v", err) return } + if got != tt.want { t.Errorf("Config.ToURL() = %q, want %q", got, tt.want) } @@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) { Path: "", BusyTimeout: -1, } + _, err := config.ToURL() if err == nil { t.Error("Config.ToURL() with invalid config should return error") diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index bb54ea1e..3d1d07c7 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -1,6 +1,7 @@ package sqliteconfig import ( + "context" "database/sql" "path/filepath" "strings" @@ -101,7 +102,8 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { defer db.Close() // Test connection - if err := db.Ping(); err != nil { + //nolint:noinlineerr + if err := db.PingContext(context.Background()); err != nil { t.Fatalf("Failed to ping database: %v", err) } @@ -109,8 +111,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { for pragma, expectedValue := range tt.expected { t.Run("pragma_"+pragma, func(t *testing.T) { var actualValue any + query := "PRAGMA " + pragma - err := db.QueryRow(query).Scan(&actualValue) + + err := db.QueryRowContext(context.Background(), query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) } @@ -178,23 +182,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { ); ` - if _, err := db.Exec(schema); err != nil { + //nolint:noinlineerr + if _, err := db.ExecContext(context.Background(), schema); err != nil { t.Fatalf("Failed to create schema: %v", err) } // Insert parent record - if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { + //nolint:noinlineerr + if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { t.Fatalf("Failed to insert parent: %v", err) } // Test 1: Valid foreign key should work - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") + _, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") if err != nil { t.Fatalf("Valid foreign key insert failed: %v", err) } // Test 2: Invalid foreign key should fail - _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") + _, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") if err == nil { t.Error("Expected foreign key constraint violation, but insert succeeded") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -204,7 +210,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) { } // Test 3: Deleting referenced parent should fail - _, err = db.Exec("DELETE FROM parent WHERE id = 1") + _, err = db.ExecContext(context.Background(), "DELETE FROM parent WHERE id = 1") if err == nil { t.Error("Expected foreign key constraint violation when deleting referenced parent") } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { @@ -249,7 +255,8 @@ func TestJournalModeValidation(t *testing.T) { defer db.Close() var actualMode string - err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) + + err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) } diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 6172e7e0..8489c69c 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -3,12 +3,20 @@ package db import ( "context" "encoding" + "errors" "fmt" "reflect" "gorm.io/gorm/schema" ) +// Sentinel errors for text serialisation. +var ( + ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value") + ErrUnsupportedType = errors.New("unsupported type") + ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported") +) + // Got from https://github.com/xdg-go/strum/blob/main/types.go var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() @@ -17,7 +25,7 @@ func isTextUnmarshaler(rv reflect.Value) bool { } func maybeInstantiatePtr(rv reflect.Value) { - if rv.Kind() == reflect.Ptr && rv.IsNil() { + if rv.Kind() == reflect.Pointer && rv.IsNil() { np := reflect.New(rv.Type().Elem()) rv.Set(np) } @@ -36,27 +44,30 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect // If the field is a pointer, we need to dereference it to get the actual type // so we do not end with a second pointer. - if fieldValue.Elem().Kind() == reflect.Ptr { + if fieldValue.Elem().Kind() == reflect.Pointer { fieldValue = fieldValue.Elem() } if dbValue != nil { var bytes []byte + switch v := dbValue.(type) { case []byte: bytes = v case string: bytes = []byte(v) default: - return fmt.Errorf("failed to unmarshal text value: %#v", dbValue) + return fmt.Errorf("%w: %#v", ErrTextUnmarshalFailed, dbValue) } if isTextUnmarshaler(fieldValue) { maybeInstantiatePtr(fieldValue) f := fieldValue.MethodByName("UnmarshalText") args := []reflect.Value{reflect.ValueOf(bytes)} + ret := f.Call(args) if !ret[0].IsNil() { + //nolint:forcetypeassert return decodingError(field.Name, ret[0].Interface().(error)) } @@ -65,7 +76,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect // If it is not a pointer, we need to assign the value to the // field. dstField := field.ReflectValueOf(ctx, dst) - if dstField.Kind() == reflect.Ptr { + if dstField.Kind() == reflect.Pointer { dstField.Set(fieldValue) } else { dstField.Set(fieldValue.Elem()) @@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect return nil } else { - return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) + return fmt.Errorf("%w: %T", ErrUnsupportedType, fieldValue.Interface()) } } @@ -86,9 +97,10 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec // If the value is nil, we return nil, however, go nil values are not // always comparable, particularly when reflection is involved: // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 - if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) { - return nil, nil + if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { + return nil, nil //nolint:nilnil } + b, err := v.MarshalText() if err != nil { return nil, err @@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec return string(b), nil default: - return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) + return nil, fmt.Errorf("%w, got %T", ErrTextMarshalerOnly, v) } } diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 6aff9ed1..213730cf 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -15,6 +15,8 @@ var ( ErrUserExists = errors.New("user already exists") ErrUserNotFound = errors.New("user not found") ErrUserStillHasNodes = errors.New("user not empty: node(s) found") + ErrTooManyWhereArgs = errors.New("expect 0 or 1 where User structs") + ErrMultipleUsers = errors.New("expected exactly one user") ) func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { @@ -26,7 +28,8 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { // CreateUser creates a new User. Returns error if could not be created // or another user already exists. func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) { - if err := util.ValidateHostname(user.Name); err != nil { + err := util.ValidateHostname(user.Name) + if err != nil { return nil, err } if err := tx.Create(&user).Error; err != nil { @@ -88,10 +91,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user") // not exist or if another User exists with the new name. func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } + + //nolint:noinlineerr if err = util.ValidateHostname(newName); err != nil { return err } @@ -151,7 +157,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { // ListUsers gets all the existing users. func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { if len(where) > 1 { - return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + return nil, fmt.Errorf("%w, got %d", ErrTooManyWhereArgs, len(where)) } var user *types.User @@ -160,7 +166,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { } users := []types.User{} - if err := tx.Where(user).Find(&users).Error; err != nil { + + err := tx.Where(user).Find(&users).Error + if err != nil { return nil, err } @@ -180,7 +188,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { } if len(users) != 1 { - return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + return nil, fmt.Errorf("%w, found %d", ErrMultipleUsers, len(users)) } return &users[0], nil diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index a3fd49b3..9d2740e5 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -8,7 +8,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" - "tailscale.com/types/ptr" ) func TestCreateAndDestroyUser(t *testing.T) { @@ -71,6 +70,7 @@ func TestDestroyUserErrors(t *testing.T) { user, err := db.CreateUser(types.User{Name: "test"}) require.NoError(t, err) + //nolint:staticcheck pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) require.NoError(t, err) @@ -79,7 +79,7 @@ func TestDestroyUserErrors(t *testing.T) { Hostname: "testnode", UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: ptr.To(pak.ID), + AuthKeyID: new(pak.ID), } trx := db.DB.Save(&node) require.NoError(t, trx.Error) diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 629b7be1..93200b95 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { overview := h.state.DebugOverviewJSON() + overviewJSON, err := json.MarshalIndent(overview, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(overviewJSON) + _, _ = w.Write(overviewJSON) } else { // Default to text/plain for backward compatibility overview := h.state.DebugOverview() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(overview)) + _, _ = w.Write([]byte(overview)) } })) // Configuration endpoint debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config := h.state.DebugConfig() + configJSON, err := json.MarshalIndent(config, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(configJSON) + _, _ = w.Write(configJSON) })) // Policy endpoint @@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server { } else { w.Header().Set("Content-Type", "text/plain") } + w.WriteHeader(http.StatusOK) - w.Write([]byte(policy)) + _, _ = w.Write([]byte(policy)) })) // Filter rules endpoint @@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + filterJSON, err := json.MarshalIndent(filter, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(filterJSON) + _, _ = w.Write(filterJSON) })) // SSH policies endpoint debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sshPolicies := h.state.DebugSSHPolicies() + sshJSON, err := json.MarshalIndent(sshPolicies, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(sshJSON) + _, _ = w.Write(sshJSON) })) // DERP map endpoint @@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { derpInfo := h.state.DebugDERPJSON() + derpJSON, err := json.MarshalIndent(derpInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(derpJSON) + _, _ = w.Write(derpJSON) } else { // Default to text/plain for backward compatibility derpInfo := h.state.DebugDERPMap() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(derpInfo)) + _, _ = w.Write([]byte(derpInfo)) } })) @@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { nodeStoreNodes := h.state.DebugNodeStoreJSON() + nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(nodeStoreJSON) + _, _ = w.Write(nodeStoreJSON) } else { // Default to text/plain for backward compatibility nodeStoreInfo := h.state.DebugNodeStore() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(nodeStoreInfo)) + _, _ = w.Write([]byte(nodeStoreInfo)) } })) // Registration cache endpoint debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cacheInfo := h.state.DebugRegistrationCache() + cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(cacheJSON) + _, _ = w.Write(cacheJSON) })) // Routes endpoint @@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { routes := h.state.DebugRoutes() + routesJSON, err := json.MarshalIndent(routes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(routesJSON) + _, _ = w.Write(routesJSON) } else { // Default to text/plain for backward compatibility routes := h.state.DebugRoutesString() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(routes)) + _, _ = w.Write([]byte(routes)) } })) @@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { policyManagerInfo := h.state.DebugPolicyManagerJSON() + policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(policyManagerJSON) + _, _ = w.Write(policyManagerJSON) } else { // Default to text/plain for backward compatibility policyManagerInfo := h.state.DebugPolicyManager() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(policyManagerInfo)) + _, _ = w.Write([]byte(policyManagerInfo)) } })) @@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server { if res == nil { w.WriteHeader(http.StatusOK) - w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + _, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + return } @@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(resJSON) + _, _ = w.Write(resJSON) })) // Batcher endpoint @@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server { w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) - w.Write(batcherJSON) + _, _ = w.Write(batcherJSON) } else { // Default to text/plain for backward compatibility batcherInfo := h.debugBatcher() w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) - w.Write([]byte(batcherInfo)) + _, _ = w.Write([]byte(batcherInfo)) } })) @@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string { activeConnections: info.ActiveConnections, }) totalNodes++ + if info.Connected { connectedCount++ } @@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string { activeConnections: 0, }) totalNodes++ + if connected { connectedCount++ } + return true }) } @@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { ActiveConnections: 0, } info.TotalNodes++ + return true }) } diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 42d74abe..2cbc02e6 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) { for id := range dm.Regions { ids = append(ids, id) } + slices.Sort(ids) for _, id := range ids { @@ -160,16 +161,20 @@ func derpRandom() *rand.Rand { derpRandomOnce.Do(func() { seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String()) + //nolint:gosec rnd := rand.New(rand.NewSource(0)) + //nolint:gosec rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) derpRandomInst = rnd }) + return derpRandomInst } func resetDerpRandomForTesting() { derpRandomMu.Lock() defer derpRandomMu.Unlock() + derpRandomOnce = sync.Once{} derpRandomInst = nil } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go index 91d605a6..445c1044 100644 --- a/hscontrol/derp/derp_test.go +++ b/hscontrol/derp/derp_test.go @@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { viper.Set("dns.base_domain", tt.baseDomain) + defer viper.Reset() + resetDerpRandomForTesting() testMap := tt.derpMap.View().AsStruct() diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index c736da28..56fb5de9 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -74,9 +74,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { if err != nil { return tailcfg.DERPRegion{}, err } - var host string - var port int - var portStr string + + var ( + host string + port int + portStr string + ) // Extract hostname and port from URL host, portStr, err = net.SplitHostPort(serverURL.Host) @@ -97,12 +100,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { // If debug flag is set, resolve hostname to IP address if debugUseDERPIP { - ips, err := net.LookupIP(host) + addrs, err := net.DefaultResolver.LookupIPAddr(context.Background(), host) if err != nil { log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host) - } else if len(ips) > 0 { + } else if len(addrs) > 0 { // Use the first IP address - ipStr := ips[0].String() + ipStr := addrs[0].IP.String() log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr) host = ipStr } @@ -205,6 +208,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques return } defer websocketConn.Close(websocket.StatusInternalError, "closing") + if websocketConn.Subprotocol() != "derp" { websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") @@ -309,6 +313,8 @@ func DERPBootstrapDNSHandler( resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) defer cancel() var resolver net.Resolver + + //nolint:unqueryvet for _, region := range derpMap.Regions().All() { for _, node := range region.Nodes().All() { // we don't care if we override some nodes addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) @@ -320,6 +326,7 @@ func DERPBootstrapDNSHandler( continue } + dnsEntries[node.HostName()] = addrs } } @@ -411,7 +418,9 @@ type DERPVerifyTransport struct { func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { buf := new(bytes.Buffer) - if err := t.handleVerifyRequest(req, buf); err != nil { + + err := t.handleVerifyRequest(req, buf) + if err != nil { log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ") return nil, err diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 82b3078b..9aad9a7d 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -4,6 +4,7 @@ import ( "context" "crypto/sha256" "encoding/json" + "errors" "fmt" "os" "sync" @@ -15,6 +16,9 @@ import ( "tailscale.com/util/set" ) +// ErrPathIsDirectory is returned when a path is a directory instead of a file. +var ErrPathIsDirectory = errors.New("path is a directory, only file is supported") + type ExtraRecordsMan struct { mu sync.RWMutex records set.Set[tailcfg.DNSRecord] @@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) { } if fi.IsDir() { - return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) + return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path) } records, hash, err := readExtraRecordsFromPath(path) @@ -85,18 +89,22 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher event channel closing") return } + switch event.Op { case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") + if event.Name != e.path { continue } + e.updateRecords() // If a file is removed or renamed, fsnotify will loose track of it // and not watch it. We will therefore attempt to re-add it with a backoff. case fsnotify.Remove, fsnotify.Rename: _, err := backoff.Retry(context.Background(), func() (struct{}, error) { + //nolint:noinlineerr if _, err := os.Stat(e.path); err != nil { return struct{}{}, err } @@ -123,6 +131,7 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher error channel closing") return } + log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) } } @@ -165,6 +174,7 @@ func (e *ExtraRecordsMan) updateRecords() { e.hashes[e.path] = newHash log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) + e.updateCh <- e.records.Slice() } @@ -183,6 +193,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error } var records []tailcfg.DNSRecord + err = json.Unmarshal(b, &records) if err != nil { return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a35a73af..3605be60 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -4,6 +4,7 @@ package hscontrol import ( + "cmp" "context" "errors" "fmt" @@ -11,7 +12,6 @@ import ( "net/netip" "os" "slices" - "sort" "strings" "time" @@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers( response[index] = user.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.User) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListUsersResponse{Users: response}, nil @@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.PreAuthKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil @@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes( newApproved = append(newApproved, prefix) } } - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) @@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N response[index] = resp } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.Node) int { + return cmp.Compare(a.Id, b.Id) }) return response @@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys( response[index] = key.Proto() } - sort.Slice(response, func(i, j int) bool { - return response[i].Id < response[j].Id + slices.SortFunc(response, func(a, b *v1.ApiKey) int { + return cmp.Compare(a.Id, b.Id) }) return &v1.ListApiKeysResponse{ApiKeys: response}, nil diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index dc693dae..21794f99 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -36,8 +36,7 @@ const ( // httpError logs an error and sends an HTTP error response with the given. func httpError(w http.ResponseWriter, err error) { - var herr HTTPError - if errors.As(err, &herr) { + if herr, ok := errors.AsType[HTTPError](err); ok { http.Error(w, herr.Msg, herr.Code) log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg) } else { @@ -56,7 +55,7 @@ type HTTPError struct { func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } func (e HTTPError) Unwrap() error { return e.Err } -// Error returns an HTTPError containing the given information. +// NewHTTPError returns an HTTPError containing the given information. func NewHTTPError(code int, msg string, err error) HTTPError { return HTTPError{Code: code, Msg: msg, Err: err} } @@ -92,6 +91,7 @@ func (h *Headscale) handleVerifyRequest( } var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest + //nolint:noinlineerr if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) } @@ -155,7 +155,11 @@ func (h *Headscale) KeyHandler( } writer.Header().Set("Content-Type", "application/json") - json.NewEncoder(writer).Encode(resp) + + err := json.NewEncoder(writer).Encode(resp) + if err != nil { + log.Error().Err(err).Msg("failed to encode key response") + } return } @@ -180,8 +184,12 @@ func (h *Headscale) HealthHandler( res.Status = "fail" } - json.NewEncoder(writer).Encode(res) + //nolint:noinlineerr + if err := json.NewEncoder(writer).Encode(res); err != nil { + log.Error().Err(err).Msg("failed to encode health response") + } } + err := h.state.PingDB(req.Context()) if err != nil { respond(err) @@ -218,6 +226,7 @@ func (h *Headscale) VersionHandler( writer.WriteHeader(http.StatusOK) versionInfo := types.GetVersionInfo() + err := json.NewEncoder(writer).Encode(versionInfo) if err != nil { log.Error(). @@ -267,7 +276,7 @@ func (a *AuthProviderWeb) RegisterHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) + _, _ = writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) } func FaviconHandler(writer http.ResponseWriter, req *http.Request) { diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go index 0a1e30d0..06ad7009 100644 --- a/hscontrol/mapper/batcher.go +++ b/hscontrol/mapper/batcher.go @@ -15,6 +15,17 @@ import ( "tailscale.com/tailcfg" ) +// Sentinel errors for batcher operations. +var ( + ErrInvalidNodeID = errors.New("invalid nodeID") + ErrMapperNil = errors.New("mapper is nil") + ErrNodeConnectionNil = errors.New("nodeConnection is nil") +) + +// workChannelMultiplier is the multiplier for work channel capacity based on worker count. +// The size is arbitrary chosen, the sizing should be revisited. +const workChannelMultiplier = 200 + var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: "headscale", Name: "mapresponse_generated_total", @@ -42,8 +53,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB workers: workers, tick: time.NewTicker(batchTime), - // The size of this channel is arbitrary chosen, the sizing should be revisited. - workCh: make(chan work, workers*200), + workCh: make(chan work, workers*workChannelMultiplier), nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), connected: xsync.NewMap[types.NodeID, *time.Time](), pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), @@ -76,20 +86,20 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t version := nc.version() if r.IsEmpty() { - return nil, nil //nolint:nilnil // Empty response means nothing to send + return nil, nil //nolint:nilnil } if nodeID == 0 { - return nil, fmt.Errorf("invalid nodeID: %d", nodeID) + return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID) } if mapper == nil { - return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) + return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID) } // Handle self-only responses if r.IsSelfOnly() && r.TargetNode != nodeID { - return nil, nil //nolint:nilnil // No response needed for other nodes when self-only + return nil, nil //nolint:nilnil } // Check if this is a self-update (the changed node is the receiving node). @@ -135,7 +145,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t // handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { if nc == nil { - return errors.New("nodeConnection is nil") + return ErrNodeConnectionNil } nodeID := nc.nodeID() diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index e00512b6..d4b5ff6f 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -2,6 +2,7 @@ package mapper import ( "crypto/rand" + "encoding/hex" "errors" "fmt" "sync" @@ -13,10 +14,35 @@ import ( "github.com/puzpuzpuz/xsync/v4" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) -var errConnectionClosed = errors.New("connection channel already closed") +// Sentinel errors for lock-free batcher operations. +var ( + errConnectionClosed = errors.New("connection channel already closed") + ErrInitialMapTimeout = errors.New("failed to send initial map: timeout") + ErrNodeNotFound = errors.New("node not found") + ErrBatcherShutdown = errors.New("batcher shutting down") + ErrConnectionTimeout = errors.New("connection timeout sending to channel (likely stale connection)") +) + +// Batcher configuration constants. +const ( + // initialMapSendTimeout is the timeout for sending the initial map response to a new connection. + initialMapSendTimeout = 5 * time.Second + + // offlineNodeCleanupThreshold is how long a node must be offline before it's cleaned up. + offlineNodeCleanupThreshold = 15 * time.Minute + + // offlineNodeCleanupInterval is the interval between cleanup runs. + offlineNodeCleanupInterval = 5 * time.Minute + + // connectionSendTimeout is the timeout for detecting stale connections. + // Kept short to quickly detect Docker containers that are forcefully terminated. + connectionSendTimeout = 50 * time.Millisecond + + // connectionIDBytes is the number of random bytes used for connection IDs. + connectionIDBytes = 8 +) // LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. type LockFreeBatcher struct { @@ -78,6 +104,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse if err != nil { log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) } @@ -86,12 +113,13 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse select { case c <- initialMap: // Success - case <-time.After(5 * time.Second): - log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") - log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). + case <-time.After(initialMapSendTimeout): + log.Error().Uint64("node.id", id.Uint64()).Err(ErrInitialMapTimeout).Msg("Initial map send timeout") + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", initialMapSendTimeout). Msg("Initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) - return fmt.Errorf("failed to send initial map to node %d: timeout", id) + + return fmt.Errorf("%w for node %d", ErrInitialMapTimeout, id) } // Update connection status @@ -130,13 +158,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo log.Debug().Caller().Uint64("node.id", id.Uint64()). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("Node connection removed but keeping online because other connections remain") + return true // Node still has active connections } // No active connections - keep the node entry alive for rapid reconnections // The node will get a fresh full map when it reconnects log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection") - b.connected.Store(id, ptr.To(time.Now())) + b.connected.Store(id, new(time.Now())) return false } @@ -177,7 +206,7 @@ func (b *LockFreeBatcher) doWork() { } // Create a cleanup ticker for removing truly disconnected nodes - cleanupTicker := time.NewTicker(5 * time.Minute) + cleanupTicker := time.NewTicker(offlineNodeCleanupInterval) defer cleanupTicker.Stop() for { @@ -212,10 +241,12 @@ func (b *LockFreeBatcher) worker(workerID int) { // This is used for synchronous map generation. if w.resultCh != nil { var result workResult + if nc, exists := b.nodes.Load(w.nodeID); exists { var err error result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + result.err = err if result.err != nil { b.workErrors.Add(1) @@ -229,7 +260,7 @@ func (b *LockFreeBatcher) worker(workerID int) { nc.updateSentPeers(result.mapResponse) } } else { - result.err = fmt.Errorf("node %d not found", w.nodeID) + result.err = fmt.Errorf("%w: %d", ErrNodeNotFound, w.nodeID) b.workErrors.Add(1) log.Error().Err(result.err). @@ -383,14 +414,13 @@ func (b *LockFreeBatcher) processBatchedChanges() { // cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. // TODO(kradalby): reevaluate if we want to keep this. func (b *LockFreeBatcher) cleanupOfflineNodes() { - cleanupThreshold := 15 * time.Minute now := time.Now() var nodesToCleanup []types.NodeID // Find nodes that have been offline for too long b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { - if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold { + if disconnectTime != nil && now.Sub(*disconnectTime) > offlineNodeCleanupThreshold { // Double-check the node doesn't have active connections if nodeConn, exists := b.nodes.Load(nodeID); exists { if !nodeConn.hasActiveConnections() { @@ -398,13 +428,14 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() { } } } + return true }) // Clean up the identified nodes for _, nodeID := range nodesToCleanup { log.Info().Uint64("node.id", nodeID.Uint64()). - Dur("offline_duration", cleanupThreshold). + Dur("offline_duration", offlineNodeCleanupThreshold). Msg("Cleaning up node that has been offline for too long") b.nodes.Delete(nodeID) @@ -450,6 +481,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { if nodeConn.hasActiveConnections() { ret.Store(id, true) } + return true }) @@ -465,6 +497,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret.Store(id, false) } } + return true }) @@ -484,7 +517,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang case result := <-resultCh: return result.mapResponse, result.err case <-b.done: - return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) + return nil, fmt.Errorf("%w: generating map response for node %d", ErrBatcherShutdown, id) } } @@ -517,9 +550,10 @@ type multiChannelNodeConn struct { // generateConnectionID generates a unique connection identifier. func generateConnectionID() string { - bytes := make([]byte, 8) - rand.Read(bytes) - return fmt.Sprintf("%x", bytes) + bytes := make([]byte, connectionIDBytes) + _, _ = rand.Read(bytes) + + return hex.EncodeToString(bytes) } // newMultiChannelNodeConn creates a new multi-channel node connection. @@ -546,11 +580,14 @@ func (mc *multiChannelNodeConn) close() { // addConnection adds a new connection. func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { mutexWaitStart := time.Now() + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") mc.mutex.Lock() + mutexWaitDur := time.Since(mutexWaitStart) + defer mc.mutex.Unlock() mc.connections = append(mc.connections, entry) @@ -572,9 +609,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)). Int("remaining_connections", len(mc.connections)). Msg("Successfully removed connection") + return true } } + return false } @@ -608,6 +647,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // This is not an error - the node will receive a full map when it reconnects log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + return nil // Return success instead of error } @@ -616,7 +656,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: broadcasting to all connections") var lastErr error + successCount := 0 + var failedConnections []int // Track failed connections for removal // Send to all connections @@ -625,8 +667,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Str("conn.id", conn.id).Int("connection_index", i). Msg("send: attempting to send to connection") - if err := conn.send(data); err != nil { + err := conn.send(data) + if err != nil { lastErr = err + failedConnections = append(failedConnections, i) log.Warn().Err(err). Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). @@ -634,6 +678,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: connection send failed") } else { successCount++ + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). Str("conn.id", conn.id).Int("connection_index", i). Msg("send: successfully sent to connection") @@ -685,10 +730,10 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { // Update last used timestamp on successful send entry.lastUsed.Store(time.Now().Unix()) return nil - case <-time.After(50 * time.Millisecond): + case <-time.After(connectionSendTimeout): // Connection is likely stale - client isn't reading from channel // This catches the case where Docker containers are killed but channels remain open - return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id) + return fmt.Errorf("%w: connection %s", ErrConnectionTimeout, entry.id) } } @@ -798,6 +843,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { Connected: connected, ActiveConnections: activeConnCount, } + return true }) @@ -812,6 +858,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { ActiveConnections: 0, } } + return true }) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 70d5e377..d0ebee6d 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -5,6 +5,7 @@ import ( "fmt" "net/netip" "runtime" + "slices" "strings" "sync" "sync/atomic" @@ -35,6 +36,7 @@ type batcherTestCase struct { // that would normally be sent by poll.go in production. type testBatcherWrapper struct { Batcher + state *state.State } @@ -80,12 +82,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe } // Finally remove from the real batcher - removed := t.Batcher.RemoveNode(id, c) - if !removed { - return false - } - - return true + return t.Batcher.RemoveNode(id, c) } // wrapBatcherForTest wraps a batcher with test-specific behavior. @@ -129,15 +126,13 @@ const ( SMALL_BUFFER_SIZE = 3 TINY_BUFFER_SIZE = 1 // For maximum contention LARGE_BUFFER_SIZE = 200 - - reservedResponseHeaderSize = 4 ) // TestData contains all test entities created for a test scenario. type TestData struct { Database *db.HSDatabase Users []*types.User - Nodes []node + Nodes []*node State *state.State Config *types.Config Batcher Batcher @@ -223,11 +218,11 @@ func setupBatcherWithTestData( // Create test users and nodes in the database users := database.CreateUsersForTest(userCount, "testuser") - allNodes := make([]node, 0, userCount*nodesPerUser) + allNodes := make([]*node, 0, userCount*nodesPerUser) for _, user := range users { dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node") for i := range dbNodes { - allNodes = append(allNodes, node{ + allNodes = append(allNodes, &node{ n: dbNodes[i], ch: make(chan *tailcfg.MapResponse, bufferSize), }) @@ -241,8 +236,8 @@ func setupBatcherWithTestData( } derpMap, err := derp.GetDERPMap(cfg.DERP) - assert.NoError(t, err) - assert.NotNil(t, derpMap) + require.NoError(t, err) + require.NotNil(t, derpMap) state.SetDERPMap(derpMap) @@ -318,23 +313,6 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) { stats.LastUpdate = time.Now() } -// getStats returns a copy of the statistics for a node. -func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { - ut.mu.RLock() - defer ut.mu.RUnlock() - - if stats, exists := ut.stats[nodeID]; exists { - // Return a copy to avoid race conditions - return UpdateStats{ - TotalUpdates: stats.TotalUpdates, - UpdateSizes: append([]int{}, stats.UpdateSizes...), - LastUpdate: stats.LastUpdate, - } - } - - return UpdateStats{} -} - // getAllStats returns a copy of all statistics. func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats { ut.mu.RLock() @@ -344,7 +322,7 @@ func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats { for nodeID, stats := range ut.stats { result[nodeID] = UpdateStats{ TotalUpdates: stats.TotalUpdates, - UpdateSizes: append([]int{}, stats.UpdateSizes...), + UpdateSizes: slices.Clone(stats.UpdateSizes), LastUpdate: stats.LastUpdate, } } @@ -386,16 +364,14 @@ type UpdateInfo struct { } // parseUpdateAndAnalyze parses an update and returns detailed information. -func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { - info := UpdateInfo{ +func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo { + return UpdateInfo{ PeerCount: len(resp.Peers), PatchCount: len(resp.PeersChangedPatch), IsFull: len(resp.Peers) > 0, IsPatch: len(resp.PeersChangedPatch) > 0, IsDERP: resp.DERPMap != nil, } - - return info, nil } // start begins consuming updates from the node's channel and tracking stats. @@ -417,36 +393,36 @@ func (n *node) start() { atomic.AddInt64(&n.updateCount, 1) // Parse update and track detailed stats - if info, err := parseUpdateAndAnalyze(data); err == nil { - // Track update types - if info.IsFull { - atomic.AddInt64(&n.fullCount, 1) - n.lastPeerCount.Store(int64(info.PeerCount)) - // Update max peers seen using compare-and-swap for thread safety - for { - current := n.maxPeersCount.Load() - if int64(info.PeerCount) <= current { - break - } + info := parseUpdateAndAnalyze(data) - if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) { - break - } + // Track update types + if info.IsFull { + atomic.AddInt64(&n.fullCount, 1) + n.lastPeerCount.Store(int64(info.PeerCount)) + // Update max peers seen using compare-and-swap for thread safety + for { + current := n.maxPeersCount.Load() + if int64(info.PeerCount) <= current { + break + } + + if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) { + break } } + } - if info.IsPatch { - atomic.AddInt64(&n.patchCount, 1) - // For patches, we track how many patch items using compare-and-swap - for { - current := n.maxPeersCount.Load() - if int64(info.PatchCount) <= current { - break - } + if info.IsPatch { + atomic.AddInt64(&n.patchCount, 1) + // For patches, we track how many patch items using compare-and-swap + for { + current := n.maxPeersCount.Load() + if int64(info.PatchCount) <= current { + break + } - if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) { - break - } + if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) { + break } } } @@ -540,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { defer cleanup() batcher := testData.Batcher - testNode := &testData.Nodes[0] + testNode := testData.Nodes[0] t.Logf("Testing enhanced tracking with node ID %d", testNode.n.ID) @@ -548,7 +524,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) { testNode.start() // Connect the node to the batcher - batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) // Wait for connection to be established assert.EventuallyWithT(t, func(c *assert.CollectT) { @@ -656,8 +632,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) for i := range allNodes { - node := &allNodes[i] - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + node := allNodes[i] + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) // Issue full update after each join to ensure connectivity batcher.AddWork(change.FullUpdate()) @@ -676,8 +652,9 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { connectedCount := 0 + for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] currentMaxPeers := int(node.maxPeersCount.Load()) if currentMaxPeers >= expectedPeers { @@ -693,11 +670,12 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { }, 5*time.Minute, 5*time.Second, "waiting for full connectivity") t.Logf("✅ All nodes achieved full connectivity!") + totalTime := time.Since(startTime) // Disconnect all nodes for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] batcher.RemoveNode(node.n.ID, node.ch) } @@ -718,7 +696,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { nodeDetails := make([]string, 0, min(10, len(allNodes))) for i := range allNodes { - node := &allNodes[i] + node := allNodes[i] stats := node.cleanup() totalUpdates += stats.TotalUpdates @@ -824,7 +802,7 @@ func TestBatcherBasicOperations(t *testing.T) { tn2 := testData.Nodes[1] // Test AddNode with real node ID - batcher.AddNode(tn.n.ID, tn.ch, 100) + _ = batcher.AddNode(tn.n.ID, tn.ch, 100) if !batcher.IsConnected(tn.n.ID) { t.Error("Node should be connected after AddNode") @@ -845,7 +823,7 @@ func TestBatcherBasicOperations(t *testing.T) { drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) // Add the second node and verify update message - batcher.AddNode(tn2.n.ID, tn2.ch, 100) + _ = batcher.AddNode(tn2.n.ID, tn2.ch, 100) assert.True(t, batcher.IsConnected(tn2.n.ID)) // First node should get an update that second node has connected. @@ -911,7 +889,7 @@ func TestBatcherBasicOperations(t *testing.T) { } } -func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { +func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, _ string, timeout time.Duration) { count := 0 timer := time.NewTimer(timeout) @@ -1050,7 +1028,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) { testNodes := testData.Nodes ch := make(chan *tailcfg.MapResponse, 10) - batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) // Track update content for validation var receivedUpdates []*tailcfg.MapResponse @@ -1130,6 +1108,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) { // The test verifies that channels are closed synchronously and deterministically // even when real node updates are being processed, ensuring no race conditions // occur during channel replacement with actual workload. +// + func XTestBatcherChannelClosingRace(t *testing.T) { for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { @@ -1154,7 +1134,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) { ch1 := make(chan *tailcfg.MapResponse, 1) wg.Go(func() { - batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) }) // Add real work during connection chaos @@ -1167,7 +1147,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) { wg.Go(func() { runtime.Gosched() // Yield to introduce timing variability - batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) + + _ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) }) // Remove second connection @@ -1258,7 +1239,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { ch := make(chan *tailcfg.MapResponse, 5) // Add node and immediately queue real work - batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) batcher.AddWork(change.DERPMap()) // Consumer goroutine to validate data and detect channel issues @@ -1308,6 +1289,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { for range i % 3 { runtime.Gosched() // Introduce timing variability } + batcher.RemoveNode(testNode.n.ID, ch) // Yield to allow workers to process and close channels @@ -1350,6 +1332,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { // real node data. The test validates that stable clients continue to function // normally and receive proper updates despite the connection churn from other clients, // ensuring system stability under concurrent load. +// +//nolint:gocyclo func TestBatcherConcurrentClients(t *testing.T) { if testing.Short() { t.Skip("Skipping concurrent client test in short mode") @@ -1380,7 +1364,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for _, node := range stableNodes { ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) stableChannels[node.n.ID] = ch - batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) // Monitor updates for each stable client go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { @@ -1391,6 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, reason := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1448,10 +1433,12 @@ func TestBatcherConcurrentClients(t *testing.T) { ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() - batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) // Consume updates to prevent blocking go func() { @@ -1462,6 +1449,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, _ := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1494,6 +1482,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for range i % 5 { runtime.Gosched() // Introduce timing variability } + churningChannelsMutex.Lock() ch, exists := churningChannels[nodeID] @@ -1623,6 +1612,8 @@ func TestBatcherConcurrentClients(t *testing.T) { // It validates that the system remains stable with no deadlocks, panics, or // missed updates under sustained high load. The test uses real node data to // generate authentic update scenarios and tracks comprehensive statistics. +// +//nolint:gocyclo,thelper func XTestBatcherScalability(t *testing.T) { if testing.Short() { t.Skip("Skipping scalability test in short mode") @@ -1651,6 +1642,7 @@ func XTestBatcherScalability(t *testing.T) { description string } + //nolint:prealloc var testCases []testCase // Generate all combinations of the test matrix @@ -1761,8 +1753,9 @@ func XTestBatcherScalability(t *testing.T) { var connectedNodesMutex sync.RWMutex for i := range testNodes { - node := &testNodes[i] - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + node := testNodes[i] + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() connectedNodes[node.n.ID] = true @@ -1878,6 +1871,7 @@ func XTestBatcherScalability(t *testing.T) { channel, tailcfg.CapabilityVersion(100), ) + connectedNodesMutex.Lock() connectedNodes[nodeID] = true @@ -1991,7 +1985,7 @@ func XTestBatcherScalability(t *testing.T) { // Now disconnect all nodes from batcher to stop new updates for i := range testNodes { - node := &testNodes[i] + node := testNodes[i] batcher.RemoveNode(node.n.ID, node.ch) } @@ -2010,7 +2004,7 @@ func XTestBatcherScalability(t *testing.T) { nodeStatsReport := make([]string, 0, len(testNodes)) for i := range testNodes { - node := &testNodes[i] + node := testNodes[i] stats := node.cleanup() totalUpdates += stats.TotalUpdates totalPatches += stats.PatchUpdates @@ -2139,7 +2133,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) { // Connect nodes one at a time and wait for each to be connected for i, node := range allNodes { - batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + _ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) t.Logf("Connected node %d (ID: %d)", i, node.n.ID) // Wait for node to be connected @@ -2286,6 +2280,7 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 1: Connect all nodes initially t.Logf("Phase 1: Connecting all nodes...") + for i, node := range allNodes { err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) if err != nil { @@ -2302,6 +2297,7 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) t.Logf("Phase 2: Rapid disconnect all nodes...") + for i, node := range allNodes { removed := batcher.RemoveNode(node.n.ID, node.ch) t.Logf("Node %d RemoveNode result: %t", i, removed) @@ -2309,9 +2305,11 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) t.Logf("Phase 3: Rapid reconnect with new channels...") + newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) for i, node := range allNodes { newChannels[i] = make(chan *tailcfg.MapResponse, 10) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) if err != nil { t.Errorf("Failed to reconnect node %d: %v", i, err) @@ -2342,11 +2340,13 @@ func TestBatcherRapidReconnection(t *testing.T) { if infoMap, ok := info.(map[string]any); ok { if connected, ok := infoMap["connected"].(bool); ok && !connected { disconnectedCount++ + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) } } } else { disconnectedCount++ + t.Logf("Node %d missing from debug info entirely", i) } @@ -2381,6 +2381,7 @@ func TestBatcherRapidReconnection(t *testing.T) { case update := <-newChannels[i]: if update != nil { receivedCount++ + t.Logf("Node %d received update successfully", i) } case <-timeout: @@ -2399,6 +2400,7 @@ func TestBatcherRapidReconnection(t *testing.T) { } } +//nolint:gocyclo func TestBatcherMultiConnection(t *testing.T) { for _, batcherFunc := range allBatcherFunctions { t.Run(batcherFunc.name, func(t *testing.T) { @@ -2413,6 +2415,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 1: Connect first node with initial connection t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add node1: %v", err) @@ -2432,7 +2435,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 2: Add second connection for node1 (multi-connection scenario) t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add second connection for node1: %v", err) @@ -2443,7 +2448,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 3: Add third connection for node1 t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add third connection for node1: %v", err) @@ -2454,6 +2461,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 4: Verify debug status shows correct connection count t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { Debug() map[types.NodeID]any }); ok { @@ -2461,6 +2469,7 @@ func TestBatcherMultiConnection(t *testing.T) { if info, exists := debugInfo[node1.n.ID]; exists { t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { if activeConnections, ok := infoMap["active_connections"].(int); ok { if activeConnections != 3 { @@ -2469,6 +2478,7 @@ func TestBatcherMultiConnection(t *testing.T) { t.Logf("SUCCESS: Node1 correctly shows 3 active connections") } } + if connected, ok := infoMap["connected"].(bool); ok && !connected { t.Errorf("Node1 should show as connected with 3 active connections") } @@ -2651,9 +2661,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) { batcher := testData.Batcher st := testData.State - node1 := &testData.Nodes[0] - node2 := &testData.Nodes[1] - node3 := &testData.Nodes[2] + node1 := testData.Nodes[0] + node2 := testData.Nodes[1] + node3 := testData.Nodes[2] t.Logf("Testing issue #2924: Node1=%d, Node2=%d, Node3=%d", node1.n.ID, node2.n.ID, node3.n.ID) diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index c666ff24..801b3e17 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -1,9 +1,9 @@ package mapper import ( - "errors" + "cmp" "net/netip" - "sort" + "slices" "time" "github.com/juanfont/headscale/hscontrol/policy" @@ -36,6 +36,7 @@ const ( // NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() + return &MapResponseBuilder{ resp: &tailcfg.MapResponse{ KeepAlive: false, @@ -69,7 +70,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { nv, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -123,6 +124,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, } + return b } @@ -130,7 +132,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -149,7 +151,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -162,7 +164,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -175,7 +177,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - b.addError(errors.New("node not found")) + b.addError(ErrNodeNotFound) return b } @@ -229,7 +231,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { node, ok := b.mapper.state.GetNodeByID(b.nodeID) if !ok { - return nil, errors.New("node not found") + return nil, ErrNodeNotFound } // Get unreduced matchers for peer relationship determination. @@ -261,8 +263,8 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ( } // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID + slices.SortStableFunc(tailPeers, func(a, b *tailcfg.Node) int { + return cmp.Compare(a.ID, b.ID) }) return tailPeers, nil @@ -276,20 +278,23 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) // WithPeersRemoved adds removed peer IDs. func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { + //nolint:prealloc var tailscaleIDs []tailcfg.NodeID for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } + b.resp.PeersRemoved = tailscaleIDs return b } -// Build finalizes the response and returns marshaled bytes +// Build finalizes the response and returns marshaled bytes. func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { if len(b.errs) > 0 { return nil, multierr.New(b.errs...) } + if debugDumpMapResponsePath != "" { writeDebugMapResponse(b.resp, b.debugType, b.nodeID) } diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go index 978b2c0e..653da30b 100644 --- a/hscontrol/mapper/builder_test.go +++ b/hscontrol/mapper/builder_test.go @@ -340,7 +340,7 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) { // Build should return a multierr data, err := result.Build() assert.Nil(t, data) - assert.Error(t, err) + require.Error(t, err) // The error should contain information about multiple errors assert.Contains(t, err.Error(), "multiple errors") diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 616d470f..abf2f062 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -60,7 +60,6 @@ func newMapper( state *state.State, ) *mapper { // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) - return &mapper{ state: state, cfg: cfg, @@ -80,6 +79,7 @@ func generateUserProfiles( userID := user.Model().ID userMap[userID] = &user ids = append(ids, userID) + for _, peer := range peers.All() { peerUser := peer.Owner() peerUserID := peerUser.Model().ID @@ -90,6 +90,7 @@ func generateUserProfiles( slices.Sort(ids) ids = slices.Compact(ids) var profiles []tailcfg.UserProfile + for _, id := range ids { if userMap[id] != nil { profiles = append(profiles, userMap[id].TailscaleUserProfile()) @@ -138,29 +139,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { } } -// fullMapResponse returns a MapResponse for the given node. -func (m *mapper) fullMapResponse( - nodeID types.NodeID, - capVer tailcfg.CapabilityVersion, -) (*tailcfg.MapResponse, error) { - peers := m.state.ListPeers(nodeID) - - return m.NewMapResponseBuilder(nodeID). - WithDebugType(fullResponseDebug). - WithCapabilityVersion(capVer). - WithSelfNode(). - WithDERPMap(). - WithDomain(). - WithCollectServicesDisabled(). - WithDebugConfig(). - WithSSHPolicy(). - WithDNSConfig(). - WithUserProfiles(peers). - WithPacketFilters(). - WithPeers(peers). - Build() -} - func (m *mapper) selfMapResponse( nodeID types.NodeID, capVer tailcfg.CapabilityVersion, @@ -214,7 +192,7 @@ func (m *mapper) policyChangeResponse( // Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved removedIDs := make([]types.NodeID, len(removedPeers)) for i, id := range removedPeers { - removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent + removedIDs[i] = types.NodeID(id) //nolint:gosec } builder.WithPeersRemoved(removedIDs...) @@ -237,7 +215,7 @@ func (m *mapper) buildFromChange( resp *change.Change, ) (*tailcfg.MapResponse, error) { if resp.IsEmpty() { - return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error + return nil, nil //nolint:nilnil } // If this is a self-update (the changed node is the receiving node), @@ -306,6 +284,7 @@ func writeDebugMapResponse( perms := fs.FileMode(debugMapResponsePerm) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) + err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -319,6 +298,7 @@ func writeDebugMapResponse( ) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) if err != nil { panic(err) @@ -327,7 +307,7 @@ func writeDebugMapResponse( func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { if debugDumpMapResponsePath == "" { - return nil, nil + return nil, nil //nolint:nilnil } return ReadMapResponsesFromDirectory(debugDumpMapResponsePath) @@ -375,6 +355,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe } var resp tailcfg.MapResponse + err = json.Unmarshal(body, &resp) if err != nil { log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name()) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 1bafd135..ae2900d9 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -3,18 +3,13 @@ package mapper import ( "fmt" "net/netip" - "slices" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/policy/matcher" - "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" - "tailscale.com/types/ptr" ) var iap = func(ipStr string) *netip.Addr { @@ -51,7 +46,7 @@ func TestDNSConfigMapResponse(t *testing.T) { mach := func(hostname, username string, userid uint) *types.Node { return &types.Node{ Hostname: hostname, - UserID: ptr.To(userid), + UserID: new(userid), User: &types.User{ Name: username, }, @@ -82,86 +77,6 @@ func TestDNSConfigMapResponse(t *testing.T) { } } -// mockState is a mock implementation that provides the required methods. -type mockState struct { - polMan policy.PolicyManager - derpMap *tailcfg.DERPMap - primary *routes.PrimaryRoutes - nodes types.Nodes - peers types.Nodes -} - -func (m *mockState) DERPMap() *tailcfg.DERPMap { - return m.derpMap -} - -func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { - if m.polMan == nil { - return tailcfg.FilterAllowAll, nil - } - return m.polMan.Filter() -} - -func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { - if m.polMan == nil { - return nil, nil - } - return m.polMan.SSHPolicy(node) -} - -func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { - if m.polMan == nil { - return false - } - return m.polMan.NodeCanHaveTag(node, tag) -} - -func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { - if m.primary == nil { - return nil - } - return m.primary.PrimaryRoutes(nodeID) -} - -func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { - if len(peerIDs) > 0 { - // Filter peers by the provided IDs - var filtered types.Nodes - for _, peer := range m.peers { - if slices.Contains(peerIDs, peer.ID) { - filtered = append(filtered, peer) - } - } - - return filtered, nil - } - // Return all peers except the node itself - var filtered types.Nodes - for _, peer := range m.peers { - if peer.ID != nodeID { - filtered = append(filtered, peer) - } - } - - return filtered, nil -} - -func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { - if len(nodeIDs) > 0 { - // Filter nodes by the provided IDs - var filtered types.Nodes - for _, node := range m.nodes { - if slices.Contains(nodeIDs, node.ID) { - filtered = append(filtered, node) - } - } - - return filtered, nil - } - - return m.nodes, nil -} - func Test_fullMapResponse(t *testing.T) { t.Skip("Test needs to be refactored for new state-based architecture") // TODO: Refactor this test to work with the new state-based mapper diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 5b7030de..dc1dd1c0 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -13,7 +13,6 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) { IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: ptr.To(uint(0)), + UserID: new(uint(0)), User: &types.User{ Name: "mini", }, @@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - tsaddr.AllIPv4(), - netip.MustParsePrefix("192.168.0.0/24"), - netip.MustParsePrefix("100.64.0.1/32"), - tsaddr.AllIPv6(), + tsaddr.AllIPv4(), // 0.0.0.0/0 + netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4 + netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4 + tsaddr.AllIPv6(), // ::/0 (IPv6) }, PrimaryRoutes: []netip.Prefix{ netip.MustParsePrefix("192.168.0.0/24"), diff --git a/hscontrol/noise.go b/hscontrol/noise.go index a667cd1f..7df6f77b 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -31,6 +31,9 @@ const ( earlyPayloadMagic = "\xff\xff\xffTS" ) +// ErrUnsupportedClientVersion is returned when a client version is not supported. +var ErrUnsupportedClientVersion = errors.New("unsupported client version") + type noiseServer struct { headscale *Headscale @@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler( } func unsupportedClientError(version tailcfg.CapabilityVersion) error { - return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version) + return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version) } func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { @@ -241,13 +244,17 @@ func (ns *noiseServer) NoiseRegistrationHandler( return } + //nolint:contextcheck registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { var resp *tailcfg.RegisterResponse + body, err := io.ReadAll(req.Body) if err != nil { return &tailcfg.RegisterRequest{}, regErr(err) } + var regReq tailcfg.RegisterRequest + //nolint:noinlineerr if err := json.Unmarshal(body, ®Req); err != nil { return ®Req, regErr(err) } @@ -256,11 +263,11 @@ func (ns *noiseServer) NoiseRegistrationHandler( resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer()) if err != nil { - var httpErr HTTPError - if errors.As(err, &httpErr) { + if httpErr, ok := errors.AsType[HTTPError](err); ok { resp = &tailcfg.RegisterResponse{ Error: httpErr.Msg, } + return ®Req, resp } @@ -278,7 +285,8 @@ func (ns *noiseServer) NoiseRegistrationHandler( writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) - if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { + err := json.NewEncoder(writer).Encode(registerResponse) + if err != nil { log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") return } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 7013b8ed..81db5271 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -28,6 +28,7 @@ const ( defaultOAuthOptionsCount = 3 registerCacheExpiration = time.Minute * 15 registerCacheCleanup = time.Minute * 20 + csrfTokenLength = 64 ) var ( @@ -68,6 +69,7 @@ func NewAuthProviderOIDC( ) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already + //nolint:contextcheck oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) if err != nil { return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err) @@ -163,6 +165,7 @@ func (a *AuthProviderOIDC) RegisterHandler( for k, v := range a.cfg.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info @@ -190,6 +193,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } stateCookieName := getCookieName("state", state) + cookieState, err := req.Cookie(stateCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) @@ -212,17 +216,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( httpError(writer, err) return } + if idToken.Nonce == "" { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) return } nonceCookieName := getCookieName("nonce", idToken.Nonce) + nonce, err := req.Cookie(nonceCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } + if idToken.Nonce != nonce.Value { httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return @@ -231,6 +238,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( nodeExpiry := a.determineNodeExpiry(idToken.Expiry) var claims types.OIDCClaims + //nolint:noinlineerr if err := idToken.Claims(&claims); err != nil { httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) return @@ -239,6 +247,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Fetch user information (email, groups, name, etc) from the userinfo endpoint // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo var userinfo *oidc.UserInfo + userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token)) if err != nil { util.LogErr(err, "could not get userinfo; only using claims from id token") @@ -255,6 +264,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified) claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username) claims.Name = cmp.Or(userinfo2.Name, claims.Name) + claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL) if userinfo2.Groups != nil { claims.Groups = userinfo2.Groups @@ -279,6 +289,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Msgf("could not create or update user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) + _, werr := writer.Write([]byte("Could not create or update user")) if werr != nil { log.Error(). @@ -299,6 +310,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Register the node if it does not exist. if registrationId != nil { verb := "Reauthenticated" + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { @@ -307,7 +319,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } + httpError(writer, err) + return } @@ -324,6 +338,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) + + //nolint:noinlineerr if _, err := writer.Write(content.Bytes()); err != nil { util.LogErr(err, "Failed to write HTTP response") } @@ -370,6 +386,7 @@ func (a *AuthProviderOIDC) getOauth2Token( if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } + if regInfo.Verifier != nil { exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} } @@ -516,6 +533,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( newUser bool c change.Change ) + user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) @@ -600,7 +618,7 @@ func getCookieName(baseName, value string) string { } func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) { - val, err := util.GenerateRandomStringURLSafe(64) + val, err := util.GenerateRandomStringURLSafe(csrfTokenLength) if err != nil { return val, err } diff --git a/hscontrol/platform_config.go b/hscontrol/platform_config.go index 23c4d25d..c8cc3fd4 100644 --- a/hscontrol/platform_config.go +++ b/hscontrol/platform_config.go @@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage( ) { writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) + _, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) } // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. @@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage( ) { writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) + _, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) } func (h *Headscale) ApplePlatformConfig( @@ -98,7 +98,7 @@ func (h *Headscale) ApplePlatformConfig( writer.Header(). Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") writer.WriteHeader(http.StatusOK) - writer.Write(content.Bytes()) + _, _ = writer.Write(content.Bytes()) } type AppleMobileConfig struct { diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index afc3cf68..0c84bae0 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -21,10 +21,13 @@ func (m Match) DebugString() string { sb.WriteString("Match:\n") sb.WriteString(" Sources:\n") + for _, prefix := range m.srcs.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } + sb.WriteString(" Destinations:\n") + for _, prefix := range m.dests.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index f4db88a4..b130bc6b 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -19,18 +19,18 @@ type PolicyManager interface { MatchersForNode(node types.NodeView) ([]matcher.Match, error) // BuildPeerMap constructs peer relationship maps for the given nodes BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView - SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) - SetPolicy([]byte) (bool, error) + SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) + SetPolicy(data []byte) (bool, error) SetUsers(users []types.User) (bool, error) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) // NodeCanHaveTag reports whether the given node can have the given tag. - NodeCanHaveTag(types.NodeView, string) bool + NodeCanHaveTag(node types.NodeView, tag string) bool // TagExists reports whether the given tag is defined in the policy. TagExists(tag string) bool // NodeCanApproveRoute reports whether the given node can approve the given route. - NodeCanApproveRoute(types.NodeView, netip.Prefix) bool + NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool Version() int DebugString() string @@ -38,8 +38,11 @@ type PolicyManager interface { // NewPolicyManager returns a new policy manager. func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { - var polMan PolicyManager - var err error + var ( + polMan PolicyManager + err error + ) + polMan, err = policyv2.NewPolicyManager(pol, users, nodes) if err != nil { return nil, err @@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ if err != nil { return nil, err } + polMans = append(polMans, pm) } @@ -66,6 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ } func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { + //nolint:prealloc var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 677cb854..42942f61 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -9,7 +9,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "github.com/samber/lo" - "tailscale.com/net/tsaddr" "tailscale.com/types/views" ) @@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove } // Sort and deduplicate - tsaddr.SortPrefixes(newApproved) + slices.SortFunc(newApproved, netip.Prefix.Compare) newApproved = slices.Compact(newApproved) newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { return route.IsValid() @@ -120,12 +119,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove // Sort the current approved for comparison sortedCurrent := make([]netip.Prefix, len(currentApproved)) copy(sortedCurrent, currentApproved) - tsaddr.SortPrefixes(sortedCurrent) + slices.SortFunc(sortedCurrent, netip.Prefix.Compare) // Only update if the routes actually changed if !slices.Equal(sortedCurrent, newApproved) { // Log what changed var added, kept []netip.Prefix + for _, route := range newApproved { if !slices.Contains(sortedCurrent, route) { added = append(added, route) diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index 61c69067..68266645 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -3,16 +3,16 @@ package policy import ( "fmt" "net/netip" + "slices" "testing" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" ) @@ -32,10 +32,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "test-node", - UserID: ptr.To(user1.ID), - User: ptr.To(user1), + UserID: new(user1.ID), + User: new(user1), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), Tags: []string{"tag:test"}, } @@ -44,10 +44,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "other-node", - UserID: ptr.To(user2.ID), - User: ptr.To(user2), + UserID: new(user2.ID), + User: new(user2), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + IPv4: new(netip.MustParseAddr("100.64.0.2")), } // Create a policy that auto-approves specific routes @@ -76,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { }` pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) - assert.NoError(t, err) + require.NoError(t, err) tests := []struct { name string @@ -194,7 +194,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) // Sort for comparison since ApproveRoutesWithPolicy sorts the results - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) // Verify that all previously approved routes are still present @@ -304,20 +304,23 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} // Create policy manager or use nil if specified - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + if tt.name != "nil_policy_manager" { pm, err = pmf(users, nodes.ViewSlice()) - assert.NoError(t, err) + require.NoError(t, err) } else { pm = nil } @@ -330,7 +333,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { if tt.wantApproved == nil { assert.Nil(t, gotApproved, "expected nil approved routes") } else { - tsaddr.SortPrefixes(tt.wantApproved) + slices.SortFunc(tt.wantApproved, netip.Prefix.Compare) assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") } }) diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go index 70aa6a21..be4f860c 100644 --- a/hscontrol/policy/policy_route_approval_test.go +++ b/hscontrol/policy/policy_route_approval_test.go @@ -13,7 +13,6 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { @@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, announcedRoutes: []netip.Prefix{}, // No routes announced anymore nodeUser: "test", + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), netip.MustParsePrefix("192.168.0.0/24"), }, wantChanged: false, @@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { }, nodeUser: "test", nodeTags: []string{"tag:approved"}, + // Sorted by netip.Prefix.Compare: by IP address then by prefix length wantApproved: []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved }, wantChanged: true, }, @@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: tt.nodeHostname, - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, Tags: tt.nodeTags, } @@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tt.announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: tt.currentApproved, } nodes := types.Nodes{&node} @@ -326,6 +327,7 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { } func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { + //nolint:staticcheck user := types.User{ Model: gorm.Model{ID: 1}, Name: "test", @@ -343,13 +345,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { MachineKey: key.NewMachine().Public(), NodeKey: key.NewNode().Public(), Hostname: "testnode", - UserID: ptr.To(user.ID), - User: ptr.To(user), + UserID: new(user.ID), + User: new(user), RegisterMethod: util.RegisterMethodAuthKey, Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: announcedRoutes, }, - IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + IPv4: new(netip.MustParseAddr("100.64.0.1")), ApprovedRoutes: currentApproved, } diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index 87142dd9..7752f202 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -14,7 +14,6 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) var ap = func(ipStr string) *netip.Addr { @@ -33,6 +32,7 @@ func TestReduceNodes(t *testing.T) { rules []tailcfg.FilterRule node *types.Node } + tests := []struct { name string args args @@ -783,9 +783,11 @@ func TestReduceNodes(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -796,6 +798,7 @@ func TestReduceNodes(t *testing.T) { func TestReduceNodesFromPolicy(t *testing.T) { n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { + //nolint:prealloc var routes []netip.Prefix for _, route := range routess { routes = append(routes, netip.MustParsePrefix(route)) @@ -1032,8 +1035,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(nil, tt.nodes.ViewSlice()) require.NoError(t, err) @@ -1051,9 +1057,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -1074,21 +1082,21 @@ func TestSSHPolicyRules(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: ap("100.64.0.1"), - UserID: ptr.To(uint(1)), - User: ptr.To(users[0]), + UserID: new(uint(1)), + User: new(users[0]), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: ap("100.64.0.2"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), } taggedClient := types.Node{ Hostname: "tagged-client", IPv4: ap("100.64.0.4"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), Tags: []string{"tag:client"}, } @@ -1207,7 +1215,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, + errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`, }, { name: "invalid-check-period", @@ -1242,7 +1250,7 @@ func TestSSHPolicyRules(t *testing.T) { ] }`, expectErr: true, - errorMessage: "autogroup \"autogroup:invalid\" is not supported", + errorMessage: `autogroup not supported for SSH: "autogroup:invalid" for SSH user`, }, { name: "autogroup-nonroot-should-use-wildcard-with-root-excluded", @@ -1406,13 +1414,17 @@ func TestSSHPolicyRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) if tt.expectErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errorMessage) + return } @@ -1435,6 +1447,7 @@ func TestReduceRoutes(t *testing.T) { routes []netip.Prefix rules []tailcfg.FilterRule } + tests := []struct { name string args args @@ -2056,6 +2069,7 @@ func TestReduceRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( tt.args.node.View(), tt.args.routes, diff --git a/hscontrol/policy/policyutil/reduce.go b/hscontrol/policy/policyutil/reduce.go index e4549c10..6d95a297 100644 --- a/hscontrol/policy/policyutil/reduce.go +++ b/hscontrol/policy/policyutil/reduce.go @@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf for _, rule := range rules { // record if the rule is actually relevant for the given node. var dests []tailcfg.NetPortRange + DEST_LOOP: for _, dest := range rule.DstPorts { expanded, err := util.ParseIPSet(dest.IP, nil) diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go index 35f5b472..0b674981 100644 --- a/hscontrol/policy/policyutil/reduce_test.go +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -16,7 +16,6 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" "tailscale.com/util/must" ) @@ -144,13 +143,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - User: ptr.To(users[0]), + User: new(users[0]), }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - User: ptr.To(users[0]), + User: new(users[0]), }, }, want: []tailcfg.FilterRule{}, @@ -191,7 +190,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{ netip.MustParsePrefix("10.33.0.0/16"), @@ -202,7 +201,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -283,19 +282,19 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, // "internal" exit node &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -344,7 +343,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -353,12 +352,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -453,7 +452,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: tsaddr.ExitRoutes(), }, @@ -462,12 +461,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -565,7 +564,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, }, @@ -574,12 +573,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -655,7 +654,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, }, @@ -664,12 +663,12 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[2]), + User: new(users[2]), }, &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -737,7 +736,7 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.100"), IPv6: ap("fd7a:115c:a1e0::100"), - User: ptr.To(users[3]), + User: new(users[3]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, }, @@ -747,7 +746,7 @@ func TestReduceFilterRules(t *testing.T) { &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), }, }, want: []tailcfg.FilterRule{ @@ -804,13 +803,13 @@ func TestReduceFilterRules(t *testing.T) { node: &types.Node{ IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[3]), + User: new(users[3]), }, peers: types.Nodes{ &types.Node{ IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[1]), + User: new(users[1]), Hostinfo: &tailcfg.Hostinfo{ RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, }, @@ -824,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm policy.PolicyManager - var err error + var ( + pm policy.PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) require.NoError(t, err) + got, _ := pm.Filter() t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) got = policyutil.ReduceFilterRules(tt.node.View(), got) diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 39b15cee..3d070a25 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -10,7 +10,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "gorm.io/gorm" - "tailscale.com/types/ptr" ) func TestNodeCanApproveRoute(t *testing.T) { @@ -25,24 +24,24 @@ func TestNodeCanApproveRoute(t *testing.T) { ID: 1, Hostname: "user1-device", IPv4: ap("100.64.0.1"), - UserID: ptr.To(uint(1)), - User: ptr.To(users[0]), + UserID: new(uint(1)), + User: new(users[0]), } exitNode := types.Node{ ID: 2, Hostname: "user2-device", IPv4: ap("100.64.0.2"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), } taggedNode := types.Node{ ID: 3, Hostname: "tagged-server", IPv4: ap("100.64.0.3"), - UserID: ptr.To(uint(3)), - User: ptr.To(users[2]), + UserID: new(uint(3)), + User: new(users[2]), Tags: []string{"tag:router"}, } @@ -50,8 +49,8 @@ func TestNodeCanApproveRoute(t *testing.T) { ID: 4, Hostname: "multi-tag-node", IPv4: ap("100.64.0.4"), - UserID: ptr.To(uint(2)), - User: ptr.To(users[1]), + UserID: new(uint(2)), + User: new(users[1]), Tags: []string{"tag:router", "tag:server"}, } @@ -830,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if tt.name == "empty policy" { // We expect this one to have a valid but empty policy require.NoError(t, err) + if err != nil { return } @@ -844,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if diff := cmp.Diff(tt.canApprove, result); diff != "" { t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) } + assert.Equal(t, tt.canApprove, result, "Unexpected route approval result") }) } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 78c6ebc5..b9b7f5e7 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -1,7 +1,6 @@ package v2 import ( - "errors" "fmt" "slices" "time" @@ -14,8 +13,6 @@ import ( "tailscale.com/types/views" ) -var ErrInvalidAction = errors.New("invalid action") - // compileFilterRules takes a set of nodes and an ACLPolicy and generates a // set of Tailscale compatible FilterRules used to allow traffic on clients. func (pol *Policy) compileFilterRules( @@ -42,9 +39,10 @@ func (pol *Policy) compileFilterRules( continue } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { ips, err := dest.Resolve(pol, users, nodes) if err != nil { @@ -121,14 +119,18 @@ func (pol *Policy) compileFilterRulesForNode( // It returns a slice of filter rules because when an ACL has both autogroup:self // and other destinations, they need to be split into separate rules with different // source filtering logic. +// +//nolint:gocyclo func (pol *Policy) compileACLWithAutogroupSelf( acl ACL, users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], ) ([]*tailcfg.FilterRule, error) { - var autogroupSelfDests []AliasWithPorts - var otherDests []AliasWithPorts + var ( + autogroupSelfDests []AliasWithPorts + otherDests []AliasWithPorts + ) for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -138,14 +140,15 @@ func (pol *Policy) compileACLWithAutogroupSelf( } } - protocols, _ := acl.Protocol.parseProtocol() + protocols := acl.Protocol.parseProtocol() + var rules []*tailcfg.FilterRule var resolvedSrcIPs []*netipx.IPSet for _, src := range acl.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - return nil, fmt.Errorf("autogroup:self cannot be used in sources") + return nil, ErrAutogroupSelfInSource } ips, err := src.Resolve(pol, users, nodes) @@ -167,6 +170,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(autogroupSelfDests) > 0 { // Pre-filter to same-user untagged devices once - reuse for both sources and destinations sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { sameUserNodes = append(sameUserNodes, n) @@ -176,6 +180,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(sameUserNodes) > 0 { // Filter sources to only same-user untagged devices var srcIPs netipx.IPSetBuilder + for _, ips := range resolvedSrcIPs { for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set @@ -192,6 +197,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if srcSet != nil && len(srcSet.Prefixes()) > 0 { var destPorts []tailcfg.NetPortRange + for _, dest := range autogroupSelfDests { for _, n := range sameUserNodes { for _, port := range dest.Ports { @@ -280,13 +286,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { } } +//nolint:gocyclo func (pol *Policy) compileSSHPolicy( users types.Users, node types.NodeView, nodes views.Slice[types.NodeView], ) (*tailcfg.SSHPolicy, error) { if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { - return nil, nil + return nil, nil //nolint:nilnil } log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname()) @@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy( // Separate destinations into autogroup:self and others // This is needed because autogroup:self requires filtering sources to same-user only, // while other destinations should use all resolved sources - var autogroupSelfDests []Alias - var otherDests []Alias + var ( + autogroupSelfDests []Alias + otherDests []Alias + ) for _, dst := range rule.Destinations { if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy( } var action tailcfg.SSHAction + switch rule.Action { case SSHActionAccept: action = sshAction(true, 0) @@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy( // by default, we do not allow root unless explicitly stated userMap["root"] = "" } + if rule.Users.ContainsRoot() { userMap["root"] = "root" } + for _, u := range rule.Users.NormalUsers() { userMap[u.String()] = u.String() } @@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy( if len(autogroupSelfDests) > 0 && !node.IsTagged() { // Build destination set for autogroup:self (same-user untagged devices only) var dest netipx.IPSetBuilder + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { n.AppendToIPSet(&dest) @@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy( // Filter sources to only same-user untagged devices // Pre-filter to same-user untagged devices for efficiency sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { sameUserNodes = append(sameUserNodes, n) @@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy( } var filteredSrcIPs netipx.IPSetBuilder + for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { @@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy( if len(otherDests) > 0 { // Build destination set for other destinations var dest netipx.IPSetBuilder + for _, dst := range otherDests { ips, err := dst.Resolve(pol, users, nodes) if err != nil { log.Trace().Caller().Err(err).Msgf("resolving destination ips") continue } + if ips != nil { dest.AddSet(ips) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index 46f544c9..663e3d6b 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -15,7 +15,6 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) // aliasWithPorts creates an AliasWithPorts structure from an alias and ports. @@ -410,14 +409,14 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: createAddr("100.64.0.1"), - UserID: ptr.To(users[0].ID), - User: ptr.To(users[0]), + UserID: new(users[0].ID), + User: new(users[0]), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: createAddr("100.64.0.2"), - UserID: ptr.To(users[1].ID), - User: ptr.To(users[1]), + UserID: new(users[1].ID), + User: new(users[1]), } nodes := types.Nodes{&nodeUser1, &nodeUser2} @@ -590,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { if sshPolicy == nil { return // Expected empty result } + assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match") + return } @@ -622,14 +623,14 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { nodeUser1 := types.Node{ Hostname: "user1-device", IPv4: createAddr("100.64.0.1"), - UserID: ptr.To(users[0].ID), - User: ptr.To(users[0]), + UserID: new(users[0].ID), + User: new(users[0]), } nodeUser2 := types.Node{ Hostname: "user2-device", IPv4: createAddr("100.64.0.2"), - UserID: ptr.To(users[1].ID), - User: ptr.To(users[1]), + UserID: new(users[1].ID), + User: new(users[1]), } nodes := types.Nodes{&nodeUser1, &nodeUser2} @@ -671,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } // TestSSHIntegrationReproduction reproduces the exact scenario from the integration test -// TestSSHOneUserToAll that was failing with empty sshUsers +// TestSSHOneUserToAll that was failing with empty sshUsers. func TestSSHIntegrationReproduction(t *testing.T) { // Create users matching the integration test users := types.Users{ @@ -683,15 +684,15 @@ func TestSSHIntegrationReproduction(t *testing.T) { node1 := &types.Node{ Hostname: "user1-node", IPv4: createAddr("100.64.0.1"), - UserID: ptr.To(users[0].ID), - User: ptr.To(users[0]), + UserID: new(users[0].ID), + User: new(users[0]), } node2 := &types.Node{ Hostname: "user2-node", IPv4: createAddr("100.64.0.2"), - UserID: ptr.To(users[1].ID), - User: ptr.To(users[1]), + UserID: new(users[1].ID), + User: new(users[1]), } nodes := types.Nodes{node1, node2} @@ -736,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { } // TestSSHJSONSerialization verifies that the SSH policy can be properly serialized -// to JSON and that the sshUsers field is not empty +// to JSON and that the sshUsers field is not empty. func TestSSHJSONSerialization(t *testing.T) { users := types.Users{ {Name: "user1", Model: gorm.Model{ID: 1}}, @@ -776,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) { // Parse back to verify structure var parsed tailcfg.SSHPolicy + err = json.Unmarshal(jsonData, &parsed) require.NoError(t, err) @@ -806,19 +808,19 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { nodes := types.Nodes{ { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.1"), }, { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.2"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.3"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.4"), }, // Tagged device for user1 @@ -860,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if len(rules) != 1 { t.Fatalf("expected 1 rule, got %d", len(rules)) } @@ -876,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { found := false addr := netip.MustParseAddr(expectedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -893,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"} for _, excludedIP := range excludedSourceIPs { addr := netip.MustParseAddr(excludedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -938,11 +943,11 @@ func TestTagUserMutualExclusivity(t *testing.T) { nodes := types.Nodes{ // User-owned nodes { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.1"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.2"), }, // Tagged nodes @@ -960,8 +965,8 @@ func TestTagUserMutualExclusivity(t *testing.T) { policy := &Policy{ TagOwners: TagOwners{ - Tag("tag:server"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:server"): Owners{new(Username("user1@"))}, + Tag("tag:database"): Owners{new(Username("user2@"))}, }, ACLs: []ACL{ // Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes @@ -1056,11 +1061,11 @@ func TestAutogroupTagged(t *testing.T) { nodes := types.Nodes{ // User-owned nodes (not tagged) { - User: ptr.To(users[0]), + User: new(users[0]), IPv4: ap("100.64.0.1"), }, { - User: ptr.To(users[1]), + User: new(users[1]), IPv4: ap("100.64.0.2"), }, // Tagged nodes @@ -1083,10 +1088,10 @@ func TestAutogroupTagged(t *testing.T) { policy := &Policy{ TagOwners: TagOwners{ - Tag("tag:server"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, - Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:server"): Owners{new(Username("user1@"))}, + Tag("tag:database"): Owners{new(Username("user2@"))}, + Tag("tag:web"): Owners{new(Username("user1@"))}, + Tag("tag:prod"): Owners{new(Username("user1@"))}, }, ACLs: []ACL{ // Rule: autogroup:tagged can reach user-owned nodes @@ -1206,10 +1211,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, } policy := &Policy{ @@ -1273,11 +1278,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, - {User: ptr.To(users[2]), IPv4: ap("100.64.0.5")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[2]), IPv4: ap("100.64.0.5")}, } policy := &Policy{ @@ -1326,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) { assert.Empty(t, rules3, "user3 should have no rules") } -// Helper function to create IP addresses for testing +// Helper function to create IP addresses for testing. func createAddr(ip string) *netip.Addr { addr, _ := netip.ParseAddr(ip) return &addr } // TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly -// with autogroup:self in destinations +// with autogroup:self in destinations. func TestSSHWithAutogroupSelfInDestination(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1342,13 +1347,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { nodes := types.Nodes{ // User1's nodes - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"}, + {User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"}, + {User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"}, // User2's nodes - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"}, + {User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"}, + {User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"}, // Tagged node for user1 (should be excluded) - {User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}}, + {User: new(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}}, } policy := &Policy{ @@ -1381,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // Test for user2's first node @@ -1399,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule2.Principals { principalIPs2[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2) // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy3 != nil { assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self") } @@ -1412,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user // is in the source and autogroup:self in destination, only that user's devices -// can SSH (and only if they match the target user) +// can SSH (and only if they match the target user). func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1420,10 +1428,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, } policy := &Policy{ @@ -1454,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1") } } -// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations +// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations. func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1474,11 +1484,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, - {User: ptr.To(users[2]), IPv4: ap("100.64.0.5")}, + {User: new(users[0]), IPv4: ap("100.64.0.1")}, + {User: new(users[0]), IPv4: ap("100.64.0.2")}, + {User: new(users[1]), IPv4: ap("100.64.0.3")}, + {User: new(users[1]), IPv4: ap("100.64.0.4")}, + {User: new(users[2]), IPv4: ap("100.64.0.5")}, } policy := &Policy{ @@ -1512,29 +1522,31 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)") } } // TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices -// are excluded from both sources and destinations when autogroup:self is used +// are excluded from both sources and destinations when autogroup:self is used. func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}}, + {User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"}, + {User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"}, + {User: new(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}}, + {User: new(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}}, } policy := &Policy{ @@ -1569,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs, "should only include untagged devices") @@ -1576,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self") } @@ -1591,10 +1605,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { } nodes := types.Nodes{ - {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"}, - {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"}, - {User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}}, + {User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"}, + {User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"}, + {User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"}, + {User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}}, } policy := &Policy{ @@ -1624,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Verify autogroup:self rule has filtered sources (only same-user devices) selfRule := sshPolicy1.Rules[0] require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices") + selfPrincipals := make([]string, len(selfRule.Principals)) for i, p := range selfRule.Principals { selfPrincipals[i] = p.NodeIP } + require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals, "autogroup:self rule should only include same-user untagged devices") @@ -1639,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") routerRule := sshPolicyRouter.Rules[0] + routerPrincipals := make([]string, len(routerRule.Principals)) for i, p := range routerRule.Principals { routerPrincipals[i] = p.NodeIP } + require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 54196e6b..bc968c3c 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -21,6 +21,9 @@ import ( "tailscale.com/util/deephash" ) +// PolicyVersion is the version number of this policy implementation. +const PolicyVersion = 2 + // ErrInvalidTagOwner is returned when a tag owner is not an Alias type. var ErrInvalidTagOwner = errors.New("tag owner is not an Alias") @@ -111,6 +114,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Filter: filter, Policy: pm.pol, }) + filterChanged := filterHash != pm.filterHash if filterChanged { log.Debug(). @@ -120,7 +124,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("filter.rules.new", len(filter)). Msg("Policy filter hash changed") } + pm.filter = filter + pm.filterHash = filterHash if filterChanged { pm.matchers = matcher.MatchesFromFilterRules(pm.filter) @@ -135,6 +141,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } tagOwnerMapHash := deephash.Hash(&tagMap) + tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash if tagOwnerChanged { log.Debug(). @@ -144,6 +151,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("tagOwners.new", len(tagMap)). Msg("Tag owner hash changed") } + pm.tagOwnerMap = tagMap pm.tagOwnerMapHash = tagOwnerMapHash @@ -153,6 +161,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } autoApproveMapHash := deephash.Hash(&autoMap) + autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash if autoApproveChanged { log.Debug(). @@ -162,10 +171,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("autoApprovers.new", len(autoMap)). Msg("Auto-approvers hash changed") } + pm.autoApproveMap = autoMap pm.autoApproveMapHash = autoApproveMapHash exitSetHash := deephash.Hash(&exitSet) + exitSetChanged := exitSetHash != pm.exitSetHash if exitSetChanged { log.Debug(). @@ -173,6 +184,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Str("exitSet.hash.new", exitSetHash.String()[:8]). Msg("Exit node set hash changed") } + pm.exitSet = exitSet pm.exitSetHash = exitSetHash @@ -199,6 +211,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { if !needsUpdate { log.Trace(). Msg("Policy evaluation detected no changes - all hashes match") + return false, nil } @@ -224,6 +237,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } + pm.sshPolicyMap[node.ID()] = sshPol return sshPol, nil @@ -318,6 +332,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ if err != nil || len(filter) == 0 { continue } + nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter) } @@ -398,6 +413,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil reducedFilter := policyutil.ReduceFilterRules(node, pm.filter) pm.filterRulesMap[node.ID()] = reducedFilter + return reducedFilter, nil } @@ -442,7 +458,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul // This is different from FilterForNode which returns REDUCED rules for packet filtering. // // For global policies: returns the global matchers (same for all nodes) -// For autogroup:self: returns node-specific matchers from unreduced compiled rules +// For autogroup:self: returns node-specific matchers from unreduced compiled rules. func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { if pm == nil { return nil, nil @@ -474,6 +490,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() + pm.users = users // Clear SSH policy map when users change to force SSH policy recomputation @@ -685,6 +702,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr if pm.exitSet == nil { return false } + if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) { return true } @@ -724,7 +742,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr } func (pm *PolicyManager) Version() int { - return 2 + return PolicyVersion } func (pm *PolicyManager) DebugString() string { @@ -748,8 +766,10 @@ func (pm *PolicyManager) DebugString() string { } fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap)) + for prefix, approveAddrs := range pm.autoApproveMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range approveAddrs.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } @@ -758,14 +778,17 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap)) + for prefix, tagOwners := range pm.tagOwnerMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range tagOwners.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } } sb.WriteString("\n\n") + if pm.filter != nil { filter, err := json.MarshalIndent(pm.filter, "", " ") if err == nil { @@ -778,6 +801,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Matchers:\n") sb.WriteString("an internal structure used to filter nodes and routes\n") + for _, match := range pm.matchers { sb.WriteString(match.DebugString()) sb.WriteString("\n") @@ -785,6 +809,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Nodes:\n") + for _, node := range pm.nodes.All() { sb.WriteString(node.String()) sb.WriteString("\n") @@ -841,6 +866,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // Check if IPs changed (simple check - could be more sophisticated) oldIPs := oldNode.IPs() + newIPs := newNode.IPs() if len(oldIPs) != len(newIPs) { affectedUsers[newNode.User().ID()] = struct{}{} @@ -862,6 +888,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S for nodeID := range pm.filterRulesMap { // Find the user for this cached node var nodeUserID uint + found := false // Check in new nodes first @@ -869,6 +896,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if node.ID() == nodeID { nodeUserID = node.User().ID() found = true + break } } @@ -879,6 +907,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if node.ID() == nodeID { nodeUserID = node.User().ID() found = true + break } } @@ -956,14 +985,7 @@ func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types. // It will return a Owners list where all the Tag types have been resolved to their underlying Owners. func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) { if visiting[tag] { - cycleStart := 0 - - for i, t := range chain { - if t == tag { - cycleStart = i - break - } - } + cycleStart := slices.Index(chain, tag) cycleTags := make([]string, len(chain[cycleStart:])) for i, t := range chain[cycleStart:] { diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 26b0d141..f35cff0b 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -11,18 +11,16 @@ import ( "github.com/stretchr/testify/require" "gorm.io/gorm" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) -func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { +func node(name, ipv4, ipv6 string, user types.User, _ *tailcfg.Hostinfo) *types.Node { return &types.Node{ ID: 0, Hostname: name, IPv4: ap(ipv4), IPv6: ap(ipv6), - User: ptr.To(user), - UserID: ptr.To(user.ID), - Hostinfo: hostinfo, + User: new(user), + UserID: new(user.ID), } } @@ -57,6 +55,7 @@ func TestPolicyManager(t *testing.T) { if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff( tt.wantMatchers, matchers, @@ -77,6 +76,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { {Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"}, } + //nolint:goconst policy := `{ "acls": [ { @@ -95,7 +95,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { } for i, n := range initialNodes { - n.ID = types.NodeID(i + 1) + n.ID = types.NodeID(i + 1) //nolint:gosec } pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice()) @@ -107,7 +107,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { require.NoError(t, err) } - require.Equal(t, len(initialNodes), len(pm.filterRulesMap)) + require.Len(t, pm.filterRulesMap, len(initialNodes)) tests := []struct { name string @@ -177,15 +177,18 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { t.Run(tt.name, func(t *testing.T) { for i, n := range tt.newNodes { found := false + for _, origNode := range initialNodes { if n.Hostname == origNode.Hostname { n.ID = origNode.ID found = true + break } } + if !found { - n.ID = types.NodeID(len(initialNodes) + i + 1) + n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec } } @@ -370,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) { // TestAutogroupSelfReducedVsUnreducedRules verifies that: // 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships -// 2. FilterForNode returns reduced compiled rules for packet filters +// 2. FilterForNode returns reduced compiled rules for packet filters. func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"} user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"} @@ -410,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations // For node1, destinations should only be node1's IPs node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"} + for _, rule := range filterNode1 { for _, dst := range rule.DstPorts { require.Contains(t, node1IPs, dst.IP, @@ -419,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // For node2, destinations should only be node2's IPs node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"} + for _, rule := range filterNode2 { for _, dst := range rule.DstPorts { require.Contains(t, node2IPs, dst.IP, @@ -457,8 +462,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -468,8 +473,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) { Hostname: "test-2-router", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Tags: []string{"tag:node-router"}, Hostinfo: &tailcfg.Hostinfo{}, } @@ -537,8 +542,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { Hostname: "test-1-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -547,8 +552,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { Hostname: "test-2-device", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -647,8 +652,8 @@ func TestTagPropagationToPeerMap(t *testing.T) { Hostname: "user1-node", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Tags: []string{"tag:web", "tag:internal"}, } @@ -658,8 +663,8 @@ func TestTagPropagationToPeerMap(t *testing.T) { Hostname: "user2-node", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), } initialNodes := types.Nodes{user1Node, user2Node} @@ -686,8 +691,8 @@ func TestTagPropagationToPeerMap(t *testing.T) { Hostname: "user1-node", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Tags: []string{"tag:internal"}, // tag:web removed! } @@ -749,8 +754,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) { Hostname: "admin-device", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -760,8 +765,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) { Hostname: "user1-server", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Tags: []string{"tag:server"}, Hostinfo: &tailcfg.Hostinfo{}, } @@ -832,8 +837,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) { Hostname: "device-a", IPv4: ap("100.64.0.1"), IPv6: ap("fd7a:115c:a1e0::1"), - User: ptr.To(users[0]), - UserID: ptr.To(users[0].ID), + User: new(users[0]), + UserID: new(users[0].ID), Hostinfo: &tailcfg.Hostinfo{}, } @@ -843,8 +848,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) { Hostname: "device-b", IPv4: ap("100.64.0.2"), IPv6: ap("fd7a:115c:a1e0::2"), - User: ptr.To(users[1]), - UserID: ptr.To(users[1].ID), + User: new(users[1]), + UserID: new(users[1].ID), Tags: []string{"tag:web"}, Hostinfo: &tailcfg.Hostinfo{}, } diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 75b16bc1..48baad2d 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -16,13 +16,12 @@ import ( "go4.org/netipx" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" "tailscale.com/types/views" "tailscale.com/util/multierr" "tailscale.com/util/slicesx" ) -// Global JSON options for consistent parsing across all struct unmarshaling +// Global JSON options for consistent parsing across all struct unmarshaling. var policyJSONOpts = []json.Options{ json.DefaultOptionsV2(), json.MatchCaseInsensitiveNames(true), @@ -37,6 +36,73 @@ var ErrCircularReference = errors.New("circular reference detected") var ErrUndefinedTagReference = errors.New("references undefined tag") +// Sentinel errors for type/alias validation. +var ( + ErrUnknownAliasType = errors.New("unknown alias type") + ErrUnknownOwnerType = errors.New("unknown owner type") + ErrUnknownAutoApproverType = errors.New("unknown auto approver type") + ErrInvalidAlias = errors.New("invalid alias") + ErrInvalidAutoApprover = errors.New("invalid auto approver") + ErrInvalidOwner = errors.New("invalid owner") +) + +// Sentinel errors for format validation. +var ( + ErrUsernameMissingAt = errors.New("username must contain @") + ErrGroupMissingPrefix = errors.New("group must start with 'group:'") + ErrTagMissingPrefix = errors.New("tag must start with 'tag:'") + ErrInvalidHostname = errors.New("invalid hostname") + ErrInvalidPrefix = errors.New("invalid prefix") + ErrInvalidAutoGroup = errors.New("invalid autogroup") + ErrInvalidAction = errors.New("invalid action") + ErrInvalidSSHAction = errors.New("invalid SSH action") + ErrInvalidProtocol = errors.New("invalid protocol") + ErrProtocolOutOfRange = errors.New("protocol number out of range") + ErrLeadingZeroProtocol = errors.New("leading zero not permitted in protocol number") + ErrHostportMissingColon = errors.New("hostport must contain a colon") + ErrUnsupportedType = errors.New("unsupported type") +) + +// Sentinel errors for resolution/lookup failures. +var ( + ErrUserNotFound = errors.New("user not found") + ErrMultipleUsersFound = errors.New("multiple users found") + ErrHostNotResolved = errors.New("unable to resolve host") + ErrGroupNotDefined = errors.New("group not defined in policy") + ErrTagNotDefined = errors.New("tag not defined in policy") + ErrHostNotDefined = errors.New("host not defined in policy") + ErrInvalidIPAddress = errors.New("invalid IP address") + ErrNestedGroups = errors.New("nested groups not allowed") + ErrInvalidGroupMember = errors.New("invalid group member type") + ErrGroupValueNotArray = errors.New("group value must be an array") + ErrAutoApproverNotAlias = errors.New("auto approver is not an alias") +) + +// Sentinel errors for autogroup context validation. +var ( + ErrAutogroupInternetInSource = errors.New("autogroup:internet can only be used in ACL destinations") + ErrAutogroupSelfInSource = errors.New("autogroup:self can only be used in ACL destinations") + ErrAutogroupNotSupportedSource = errors.New("autogroup not supported for source") + ErrAutogroupNotSupportedDest = errors.New("autogroup not supported for destination") + ErrAutogroupNotSupportedSSH = errors.New("autogroup not supported for SSH") + ErrAutogroupNotSupported = errors.New("autogroup not supported in headscale") + ErrAliasNotSupportedSSH = errors.New("alias type not supported for SSH") +) + +// Sentinel errors for SSH aliases. +var ( + ErrAliasNotSupportedSSHSrc = errors.New("alias type not supported for SSH source") + ErrAliasNotSupportedSSHDst = errors.New("alias type not supported for SSH destination") + ErrUnknownSSHSrcAliasType = errors.New("unknown SSH source alias type") + ErrUnknownSSHDstAliasType = errors.New("unknown SSH destination alias type") +) + +// Sentinel errors for policy parsing. +var ( + ErrUnknownField = errors.New("unknown field in policy") + ErrProtocolNoSpecificPorts = errors.New("protocol does not support specific ports") +) + type Asterix int func (a Asterix) Validate() error { @@ -59,6 +125,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { } var alias string + switch v := a.Alias.(type) { case *Username: alias = string(*v) @@ -75,7 +142,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { case Asterix: alias = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } // If no ports are specified @@ -90,6 +157,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { // Otherwise, format as "alias:ports" var ports []string + for _, port := range a.Ports { if port.First == port.Last { ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) @@ -118,13 +186,16 @@ func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeV } // Username is a string that represents a username, it must contain an @. +// +//nolint:recvcheck type Username string func (u Username) Validate() error { if isUser(string(u)) { return nil } - return fmt.Errorf("Username has to contain @, got: %q", u) + + return fmt.Errorf("%w: got %q", ErrUsernameMissingAt, u) } func (u *Username) String() string { @@ -143,7 +214,9 @@ func (p Prefix) MarshalJSON() ([]byte, error) { func (u *Username) UnmarshalJSON(b []byte) error { *u = Username(strings.Trim(string(b), `"`)) - if err := u.Validate(); err != nil { + + err := u.Validate() + if err != nil { return err } @@ -184,19 +257,21 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { } if len(potentialUsers) == 0 { - return types.User{}, fmt.Errorf("user with token %q not found", u.String()) + return types.User{}, fmt.Errorf("%w: token %q", ErrUserNotFound, u.String()) } if len(potentialUsers) > 1 { - return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String()) + return types.User{}, fmt.Errorf("%w: token %q found %s", ErrMultipleUsersFound, u.String(), potentialUsers.String()) } return potentialUsers[0], nil } func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) user, err := u.resolveUser(users) if err != nil { @@ -223,18 +298,23 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types. } // Group is a special string which is always prefixed with `group:`. +// +//nolint:recvcheck type Group string func (g Group) Validate() error { if isGroup(string(g)) { return nil } - return fmt.Errorf(`Group has to start with "group:", got: %q`, g) + + return fmt.Errorf("%w: got %q", ErrGroupMissingPrefix, g) } func (g *Group) UnmarshalJSON(b []byte) error { *g = Group(strings.Trim(string(b), `"`)) - if err := g.Validate(); err != nil { + + err := g.Validate() + if err != nil { return err } @@ -269,8 +349,10 @@ func (g Group) MarshalJSON() ([]byte, error) { } func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, user := range p.Groups[g] { uips, err := user.Resolve(nil, users, nodes) @@ -285,18 +367,23 @@ func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.Nod } // Tag is a special string which is always prefixed with `tag:`. +// +//nolint:recvcheck type Tag string func (t Tag) Validate() error { if isTag(string(t)) { return nil } - return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) + + return fmt.Errorf("%w: got %q", ErrTagMissingPrefix, t) } func (t *Tag) UnmarshalJSON(b []byte) error { *t = Tag(strings.Trim(string(b), `"`)) - if err := t.Validate(); err != nil { + + err := t.Validate() + if err != nil { return err } @@ -334,18 +421,23 @@ func (t Tag) MarshalJSON() ([]byte, error) { } // Host is a string that represents a hostname. +// +//nolint:recvcheck type Host string func (h Host) Validate() error { if isHost(string(h)) { return nil } - return fmt.Errorf("Hostname %q is invalid", h) + + return fmt.Errorf("%w: %q", ErrInvalidHostname, h) } func (h *Host) UnmarshalJSON(b []byte) error { *h = Host(strings.Trim(string(b), `"`)) - if err := h.Validate(); err != nil { + + err := h.Validate() + if err != nil { return err } @@ -353,13 +445,16 @@ func (h *Host) UnmarshalJSON(b []byte) error { } func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) pref, ok := p.Hosts[h] if !ok { - return nil, fmt.Errorf("unable to resolve host: %q", h) + return nil, fmt.Errorf("%w: %q", ErrHostNotResolved, h) } + err := pref.Validate() if err != nil { errs = append(errs, err) @@ -377,6 +472,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView if err != nil { errs = append(errs, err) } + for _, node := range nodes.All() { if node.InIPSet(ipsTemp) { node.AppendToIPSet(&ips) @@ -386,13 +482,15 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView return buildIPSetMultiErr(&ips, errs) } +//nolint:recvcheck type Prefix netip.Prefix func (p Prefix) Validate() error { if netip.Prefix(p).IsValid() { return nil } - return fmt.Errorf("Prefix %q is invalid", p) + + return fmt.Errorf("%w: %q", ErrInvalidPrefix, p) } func (p Prefix) String() string { @@ -405,6 +503,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + addrPref, err := addr.Prefix(addr.BitLen()) if err != nil { return err @@ -419,6 +518,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + *p = Prefix(pref) return nil @@ -429,6 +529,8 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { if err != nil { return err } + + //nolint:noinlineerr if err := p.Validate(); err != nil { return err } @@ -442,8 +544,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // // See [Policy], [types.Users], and [types.Nodes] for more details. func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) ips.AddPrefix(netip.Prefix(p)) // If the IP is a single host, look for a node to ensure we add all the IPs of @@ -468,6 +572,8 @@ func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuild } // AutoGroup is a special string which is always prefixed with `autogroup:`. +// +//nolint:recvcheck type AutoGroup string const ( @@ -491,12 +597,14 @@ func (ag AutoGroup) Validate() error { return nil } - return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) + return fmt.Errorf("%w: got %q, must be one of %v", ErrInvalidAutoGroup, ag, autogroups) } func (ag *AutoGroup) UnmarshalJSON(b []byte) error { *ag = AutoGroup(strings.Trim(string(b), `"`)) - if err := ag.Validate(); err != nil { + + err := ag.Validate() + if err != nil { return err } @@ -515,6 +623,7 @@ func (ag AutoGroup) MarshalJSON() ([]byte, error) { func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { var build netipx.IPSetBuilder + //nolint:exhaustive switch ag { case AutoGroupInternet: return util.TheInternet(), nil @@ -551,7 +660,7 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type return nil, ErrAutogroupSelfRequiresPerNodeResolution default: - return nil, fmt.Errorf("unknown autogroup %q", ag) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoGroup, ag) } } @@ -565,31 +674,36 @@ func (ag *AutoGroup) Is(c AutoGroup) bool { type Alias interface { Validate() error - UnmarshalJSON([]byte) error + UnmarshalJSON(data []byte) error // Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP // addresses that the Alias represents within Headscale. It is the product // of the Alias and the Policy, Users and Nodes. // This is an interface definition and the implementation is independent of // the Alias type. - Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error) + Resolve(pol *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) } type AliasWithPorts struct { Alias + Ports []tailcfg.PortRange } func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { var v any - if err := json.Unmarshal(b, &v); err != nil { + + err := json.Unmarshal(b, &v) + if err != nil { return err } switch vs := v.(type) { case string: - var portsPart string - var err error + var ( + portsPart string + err error + ) if strings.Contains(vs, ":") { vs, portsPart, err = splitDestinationAndPort(vs) @@ -601,21 +715,24 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Ports = ports } else { - return errors.New(`hostport must contain a colon (":")`) + return ErrHostportMissingColon } ve.Alias, err = parseAlias(vs) if err != nil { return err } + + //nolint:noinlineerr if err := ve.Validate(); err != nil { return err } default: - return fmt.Errorf("type %T not supported", vs) + return fmt.Errorf("%w: %T", ErrUnsupportedType, vs) } return nil @@ -647,6 +764,7 @@ func isHost(str string) bool { func parseAlias(vs string) (Alias, error) { var pref Prefix + err := pref.parseString(vs) if err == nil { return &pref, nil @@ -656,28 +774,20 @@ func parseAlias(vs string) (Alias, error) { case isWildcard(vs): return Wildcard, nil case isUser(vs): - return ptr.To(Username(vs)), nil + return new(Username(vs)), nil case isGroup(vs): - return ptr.To(Group(vs)), nil + return new(Group(vs)), nil case isTag(vs): - return ptr.To(Tag(vs)), nil + return new(Tag(vs)), nil case isAutoGroup(vs): - return ptr.To(AutoGroup(vs)), nil + return new(AutoGroup(vs)), nil } if isHost(vs) { - return ptr.To(Host(vs)), nil + return new(Host(vs)), nil } - return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: -- wildcard (*) -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") -- autogroup (starting with "autogroup:") -- host - -Please check the format and try again.`, vs) + return nil, fmt.Errorf("%w: %q", ErrInvalidAlias, vs) } // AliasEnc is used to deserialize a Alias. @@ -691,15 +801,18 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Alias = ptr return nil } +//nolint:recvcheck type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -737,7 +850,7 @@ func (a Aliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAliasType, v) } } @@ -745,8 +858,10 @@ func (a Aliases) MarshalJSON() ([]byte, error) { } func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, alias := range a { aips, err := alias.Resolve(p, users, nodes) @@ -771,6 +886,7 @@ func unmarshalPointer[T any]( parseFunc func(string) (T, error), ) (T, error) { var s string + err := json.Unmarshal(b, &s) if err != nil { var t T @@ -782,7 +898,7 @@ func unmarshalPointer[T any]( type AutoApprover interface { CanBeAutoApprover() bool - UnmarshalJSON([]byte) error + UnmarshalJSON(data []byte) error String() string } @@ -790,6 +906,7 @@ type AutoApprovers []AutoApprover func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { var autoApprovers []AutoApproverEnc + err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...) if err != nil { return err @@ -819,7 +936,7 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) { case *Group: approvers[i] = string(*v) default: - return nil, fmt.Errorf("unknown auto approver type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownAutoApproverType, v) } } @@ -829,19 +946,14 @@ func (aa AutoApprovers) MarshalJSON() ([]byte, error) { func parseAutoApprover(s string) (AutoApprover, error) { switch { case isUser(s): - return ptr.To(Username(s)), nil + return new(Username(s)), nil case isGroup(s): - return ptr.To(Group(s)), nil + return new(Group(s)), nil case isTag(s): - return ptr.To(Tag(s)), nil + return new(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidAutoApprover, s) } // AutoApproverEnc is used to deserialize a AutoApprover. @@ -855,6 +967,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.AutoApprover = ptr return nil @@ -862,7 +975,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { type Owner interface { CanBeTagOwner() bool - UnmarshalJSON([]byte) error + UnmarshalJSON(data []byte) error String() string } @@ -877,6 +990,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Owner = ptr return nil @@ -886,6 +1000,7 @@ type Owners []Owner func (o *Owners) UnmarshalJSON(b []byte) error { var owners []OwnerEnc + err := json.Unmarshal(b, &owners, policyJSONOpts...) if err != nil { return err @@ -915,7 +1030,7 @@ func (o Owners) MarshalJSON() ([]byte, error) { case *Tag: owners[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -925,24 +1040,21 @@ func (o Owners) MarshalJSON() ([]byte, error) { func parseOwner(s string) (Owner, error) { switch { case isUser(s): - return ptr.To(Username(s)), nil + return new(Username(s)), nil case isGroup(s): - return ptr.To(Group(s)), nil + return new(Group(s)), nil case isTag(s): - return ptr.To(Tag(s)), nil + return new(Tag(s)), nil } - return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: -- user (containing an "@") -- group (starting with "group:") -- tag (starting with "tag:") - -Please check the format and try again.`, s) + return nil, fmt.Errorf("%w: %q", ErrInvalidOwner, s) } type Usernames []Username // Groups are a map of Group to a list of Username. +// +//nolint:recvcheck type Groups map[Group]Usernames func (g Groups) Contains(group *Group) error { @@ -956,7 +1068,7 @@ func (g Groups) Contains(group *Group) error { } } - return fmt.Errorf(`Group %q is not defined in the Policy, please define or remove the reference to it`, group) + return fmt.Errorf("%w: %q", ErrGroupNotDefined, group) } // UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure @@ -966,41 +1078,49 @@ func (g Groups) Contains(group *Group) error { func (g *Groups) UnmarshalJSON(b []byte) error { // First unmarshal as a generic map to validate group names first var rawMap map[string]any - if err := json.Unmarshal(b, &rawMap); err != nil { + + err := json.Unmarshal(b, &rawMap) + if err != nil { return err } // Validate group names first before checking data types for key := range rawMap { group := Group(key) - if err := group.Validate(); err != nil { + + err := group.Validate() + if err != nil { return err } } // Then validate each field can be converted to []string rawGroups := make(map[string][]string) + for key, value := range rawMap { switch v := value.(type) { case []any: // Convert []interface{} to []string var stringSlice []string + for _, item := range v { if str, ok := item.(string); ok { stringSlice = append(stringSlice, str) } else { - return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) + return fmt.Errorf("%w: group %q got %T", ErrInvalidGroupMember, key, item) } } + rawGroups[key] = stringSlice case string: - return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) + return fmt.Errorf("%w: group %q got string %q", ErrGroupValueNotArray, key, v) default: - return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v) + return fmt.Errorf("%w: group %q got %T", ErrGroupValueNotArray, key, v) } } *g = make(Groups) + for key, value := range rawGroups { group := Group(key) // Group name already validated above @@ -1008,13 +1128,16 @@ func (g *Groups) UnmarshalJSON(b []byte) error { for _, u := range value { username := Username(u) - if err := username.Validate(); err != nil { + + err := username.Validate() + if err != nil { if isGroup(u) { - return fmt.Errorf("Nested groups are not allowed, found %q inside %q", u, group) + return fmt.Errorf("%w: found %q inside %q", ErrNestedGroups, u, group) } return err } + usernames = append(usernames, username) } @@ -1025,24 +1148,33 @@ func (g *Groups) UnmarshalJSON(b []byte) error { } // Hosts are alias for IP addresses or subnets. +// +//nolint:recvcheck type Hosts map[Host]Prefix func (h *Hosts) UnmarshalJSON(b []byte) error { var rawHosts map[string]string - if err := json.Unmarshal(b, &rawHosts, policyJSONOpts...); err != nil { + + err := json.Unmarshal(b, &rawHosts, policyJSONOpts...) + if err != nil { return err } *h = make(Hosts) + for key, value := range rawHosts { host := Host(key) - if err := host.Validate(); err != nil { + + err := host.Validate() + if err != nil { return err } var prefix Prefix - if err := prefix.parseString(value); err != nil { - return fmt.Errorf(`Hostname "%s" contains an invalid IP address: "%s"`, key, value) + + err = prefix.parseString(value) + if err != nil { + return fmt.Errorf("%w: hostname %q value %q", ErrInvalidIPAddress, key, value) } (*h)[host] = prefix @@ -1077,6 +1209,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { } rawTagOwners := make(map[string][]string) + for tag, owners := range to { tagStr := string(tag) ownerStrs := make([]string, len(owners)) @@ -1090,7 +1223,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { case *Tag: ownerStrs[i] = string(*v) default: - return nil, fmt.Errorf("unknown owner type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownOwnerType, v) } } @@ -1114,7 +1247,7 @@ func (to TagOwners) Contains(tagOwner *Tag) error { } } - return fmt.Errorf(`Tag %q is not defined in the Policy, please define or remove the reference to it`, tagOwner) + return fmt.Errorf("%w: %q", ErrTagNotDefined, tagOwner) } type AutoApproverPolicy struct { @@ -1153,6 +1286,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if p == nil { return nil, nil, nil } + var err error routes := make(map[netip.Prefix]*netipx.IPSetBuilder) @@ -1161,11 +1295,12 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if _, ok := routes[prefix]; !ok { routes[prefix] = new(netipx.IPSetBuilder) } + for _, autoApprover := range autoApprovers { aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1174,12 +1309,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } var exitNodeSetBuilder netipx.IPSetBuilder + if len(p.AutoApprovers.ExitNode) > 0 { for _, autoApprover := range p.AutoApprovers.ExitNode { aa, ok := autoApprover.(Alias) if !ok { // Should never happen - return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + return nil, nil, fmt.Errorf("%w: %v", ErrAutoApproverNotAlias, autoApprover) } // If it does not resolve, that means the autoApprover is not associated with any IP addresses. ips, _ := aa.Resolve(p, users, nodes) @@ -1188,11 +1324,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } ret := make(map[netip.Prefix]*netipx.IPSet) + for prefix, builder := range routes { ipSet, err := builder.IPSet() if err != nil { return nil, nil, err } + ret[prefix] = ipSet } @@ -1208,6 +1346,8 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } // Action represents the action to take for an ACL rule. +// +//nolint:recvcheck type Action string const ( @@ -1215,6 +1355,8 @@ const ( ) // SSHAction represents the action to take for an SSH rule. +// +//nolint:recvcheck type SSHAction string const ( @@ -1234,8 +1376,9 @@ func (a *Action) UnmarshalJSON(b []byte) error { case "accept": *a = ActionAccept default: - return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + return fmt.Errorf("%w: %q, must be %q", ErrInvalidAction, str, ActionAccept) } + return nil } @@ -1258,8 +1401,9 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { case "check": *a = SSHActionCheck default: - return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + return fmt.Errorf("%w: %q, must be one of: accept, check", ErrInvalidSSHAction, str) } + return nil } @@ -1269,6 +1413,8 @@ func (a SSHAction) MarshalJSON() ([]byte, error) { } // Protocol represents a network protocol with its IANA number and descriptions. +// +//nolint:recvcheck type Protocol string const ( @@ -1296,6 +1442,7 @@ func (p Protocol) String() string { // Description returns the human-readable description of the Protocol. func (p Protocol) Description() string { + //nolint:exhaustive switch p { case ProtocolICMP: return "Internet Control Message Protocol" @@ -1330,49 +1477,45 @@ func (p Protocol) Description() string { } } -// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// parseProtocol converts a Protocol to its IANA protocol numbers. // Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. -func (p Protocol) parseProtocol() ([]int, bool) { +func (p Protocol) parseProtocol() []int { + //nolint:exhaustive switch p { case "": // Empty protocol applies to TCP and UDP traffic only - return []int{protocolTCP, protocolUDP}, false + return []int{protocolTCP, protocolUDP} case ProtocolWildcard: // Wildcard protocol - defensive handling (should not reach here due to validation) - return nil, false + return nil case ProtocolIGMP: - return []int{protocolIGMP}, true + return []int{protocolIGMP} case ProtocolIPv4, ProtocolIPInIP: - return []int{protocolIPv4}, true + return []int{protocolIPv4} case ProtocolTCP: - return []int{protocolTCP}, false + return []int{protocolTCP} case ProtocolEGP: - return []int{protocolEGP}, true + return []int{protocolEGP} case ProtocolIGP: - return []int{protocolIGP}, true + return []int{protocolIGP} case ProtocolUDP: - return []int{protocolUDP}, false + return []int{protocolUDP} case ProtocolGRE: - return []int{protocolGRE}, true + return []int{protocolGRE} case ProtocolESP: - return []int{protocolESP}, true + return []int{protocolESP} case ProtocolAH: - return []int{protocolAH}, true + return []int{protocolAH} case ProtocolSCTP: - return []int{protocolSCTP}, false + return []int{protocolSCTP} case ProtocolICMP: - return []int{protocolICMP, protocolIPv6ICMP}, true + return []int{protocolICMP, protocolIPv6ICMP} default: // Try to parse as a numeric protocol number // This should not fail since validation happened during unmarshaling protocolNumber, _ := strconv.Atoi(string(p)) - // Determine if wildcard is needed based on protocol number - needsWildcard := protocolNumber != protocolTCP && - protocolNumber != protocolUDP && - protocolNumber != protocolSCTP - - return []int{protocolNumber}, needsWildcard + return []int{protocolNumber} } } @@ -1384,7 +1527,8 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { *p = Protocol(strings.ToLower(str)) // Validate the protocol - if err := p.validate(); err != nil { + err := p.validate() + if err != nil { return err } @@ -1393,6 +1537,7 @@ func (p *Protocol) UnmarshalJSON(b []byte) error { // validate checks if the Protocol is valid. func (p Protocol) validate() error { + //nolint:exhaustive switch p { case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP, ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE, @@ -1400,23 +1545,23 @@ func (p Protocol) validate() error { return nil case ProtocolWildcard: // Wildcard "*" is not allowed - Tailscale rejects it - return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + return fmt.Errorf("%w: use protocol number 0-255 or protocol name", ErrInvalidProtocol) default: // Try to parse as a numeric protocol number str := string(p) // Check for leading zeros (not allowed by Tailscale) if str == "0" || (len(str) > 1 && str[0] == '0') { - return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + return fmt.Errorf("%w: %q", ErrLeadingZeroProtocol, str) } protocolNumber, err := strconv.Atoi(str) if err != nil { - return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + return fmt.Errorf("%w: %q must be a known protocol name or valid protocol number 0-255", ErrInvalidProtocol, p) } if protocolNumber < 0 || protocolNumber > 255 { - return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + return fmt.Errorf("%w: %d", ErrProtocolOutOfRange, protocolNumber) } return nil @@ -1428,7 +1573,7 @@ func (p Protocol) MarshalJSON() ([]byte, error) { return json.Marshal(string(p)) } -// Protocol constants matching the IANA numbers +// Protocol constants matching the IANA numbers. const ( protocolICMP = 1 // Internet Control Message protocolIGMP = 2 // Internet Group Management @@ -1459,12 +1604,14 @@ type ACL struct { func (a *ACL) UnmarshalJSON(b []byte) error { // First unmarshal into a map to filter out comment fields var raw map[string]any + //nolint:noinlineerr if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil { return err } // Remove any fields that start with '#' filtered := make(map[string]any) + for key, value := range raw { if !strings.HasPrefix(key, "#") { filtered[key] = value @@ -1479,15 +1626,18 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Create a type alias to avoid infinite recursion type aclAlias ACL + var temp aclAlias // Unmarshal into the temporary struct using the v2 JSON options + //nolint:noinlineerr if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil { return err } // Copy the result back to the original struct *a = ACL(temp) + return nil } @@ -1531,7 +1681,7 @@ func validateAutogroupSupported(ag *AutoGroup) error { } if slices.Contains(autogroupNotSupported, *ag) { - return fmt.Errorf("autogroup %q is not supported in headscale", *ag) + return fmt.Errorf("%w: %q", ErrAutogroupNotSupported, *ag) } return nil @@ -1543,15 +1693,15 @@ func validateAutogroupForSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + return ErrAutogroupInternetInSource } if src.Is(AutoGroupSelf) { - return errors.New(`"autogroup:self" used in source, it can only be used in ACL destinations`) + return ErrAutogroupSelfInSource } if !slices.Contains(autogroupForSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedSource, *src, autogroupForSrc) } return nil @@ -1563,7 +1713,7 @@ func validateAutogroupForDst(dst *AutoGroup) error { } if !slices.Contains(autogroupForDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst) + return fmt.Errorf("%w: %q, can be %v", ErrAutogroupNotSupportedDest, *dst, autogroupForDst) } return nil @@ -1575,11 +1725,11 @@ func validateAutogroupForSSHSrc(src *AutoGroup) error { } if src.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + return fmt.Errorf("%w: autogroup:internet in SSH source", ErrAutogroupNotSupportedSSH) } if !slices.Contains(autogroupForSSHSrc, *src) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc) + return fmt.Errorf("%w: %q for SSH sources, can be %v", ErrAutogroupNotSupportedSSH, *src, autogroupForSSHSrc) } return nil @@ -1591,11 +1741,11 @@ func validateAutogroupForSSHDst(dst *AutoGroup) error { } if dst.Is(AutoGroupInternet) { - return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + return fmt.Errorf("%w: autogroup:internet in SSH destination", ErrAutogroupNotSupportedSSH) } if !slices.Contains(autogroupForSSHDst, *dst) { - return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst) + return fmt.Errorf("%w: %q for SSH destinations, can be %v", ErrAutogroupNotSupportedSSH, *dst, autogroupForSSHDst) } return nil @@ -1607,7 +1757,7 @@ func validateAutogroupForSSHUser(user *AutoGroup) error { } if !slices.Contains(autogroupForSSHUser, *user) { - return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser) + return fmt.Errorf("%w: %q for SSH user, can be %v", ErrAutogroupNotSupportedSSH, *user, autogroupForSSHUser) } return nil @@ -1617,6 +1767,8 @@ func validateAutogroupForSSHUser(user *AutoGroup) error { // the unmarshaling process. // It runs through all rules and checks if there are any inconsistencies // in the policy that needs to be addressed before it can be used. +// +//nolint:gocyclo func (p *Policy) validate() error { if p == nil { panic("passed nil policy") @@ -1632,67 +1784,81 @@ func (p *Policy) validate() error { case *Host: h := src if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: ag := src - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSrc(ag); err != nil { + err = validateAutogroupForSrc(ag) + if err != nil { errs = append(errs, err) continue } case *Group: g := src - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } } for _, dst := range acl.Destinations { + //nolint:gocritic switch dst.Alias.(type) { case *Host: - h := dst.Alias.(*Host) + h := dst.Alias.(*Host) //nolint:forcetypeassert if !p.Hosts.exist(*h) { - errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + errs = append(errs, fmt.Errorf("%w: %q - please define or remove the reference", ErrHostNotDefined, *h)) } case *AutoGroup: - ag := dst.Alias.(*AutoGroup) + ag := dst.Alias.(*AutoGroup) //nolint:forcetypeassert - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForDst(ag); err != nil { + err = validateAutogroupForDst(ag) + if err != nil { errs = append(errs, err) continue } case *Group: - g := dst.Alias.(*Group) - if err := p.Groups.Contains(g); err != nil { + g := dst.Alias.(*Group) //nolint:forcetypeassert + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: - tagOwner := dst.Alias.(*Tag) - if err := p.TagOwners.Contains(tagOwner); err != nil { + tagOwner := dst.Alias.(*Tag) //nolint:forcetypeassert + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } } // Validate protocol-port compatibility - if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil { + err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations) + if err != nil { errs = append(errs, err) } } @@ -1701,7 +1867,9 @@ func (p *Policy) validate() error { for _, user := range ssh.Users { if strings.HasPrefix(string(user), "autogroup:") { maybeAuto := AutoGroup(user) - if err := validateAutogroupForSSHUser(&maybeAuto); err != nil { + + err := validateAutogroupForSSHUser(&maybeAuto) + if err != nil { errs = append(errs, err) continue } @@ -1713,43 +1881,55 @@ func (p *Policy) validate() error { case *AutoGroup: ag := src - if err := validateAutogroupSupported(ag); err != nil { + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSSHSrc(ag); err != nil { + err = validateAutogroupForSSHSrc(ag) + if err != nil { errs = append(errs, err) continue } case *Group: g := src - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := src - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } } + for _, dst := range ssh.Destinations { switch dst := dst.(type) { case *AutoGroup: ag := dst - if err := validateAutogroupSupported(ag); err != nil { + + err := validateAutogroupSupported(ag) + if err != nil { errs = append(errs, err) continue } - if err := validateAutogroupForSSHDst(ag); err != nil { + err = validateAutogroupForSSHDst(ag) + if err != nil { errs = append(errs, err) continue } case *Tag: tagOwner := dst - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1761,7 +1941,9 @@ func (p *Policy) validate() error { switch tagOwner := tagOwner.(type) { case *Group: g := tagOwner - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: @@ -1786,12 +1968,16 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1802,12 +1988,16 @@ func (p *Policy) validate() error { switch approver := approver.(type) { case *Group: g := approver - if err := p.Groups.Contains(g); err != nil { + + err := p.Groups.Contains(g) + if err != nil { errs = append(errs, err) } case *Tag: tagOwner := approver - if err := p.TagOwners.Contains(tagOwner); err != nil { + + err := p.TagOwners.Contains(tagOwner) + if err != nil { errs = append(errs, err) } } @@ -1833,6 +2023,8 @@ type SSH struct { // SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. // It can be a list of usernames, groups, tags or autogroups. +// +//nolint:recvcheck type SSHSrcAliases []Alias // MarshalJSON marshals the Groups to JSON. @@ -1847,6 +2039,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { for i, username := range usernames { users[i] = string(username) } + raw[string(group)] = users } @@ -1855,6 +2048,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1866,10 +2060,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { case *Username, *Group, *Tag, *AutoGroup: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH source", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHSrc, alias.Alias) } } @@ -1878,6 +2069,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1896,10 +2088,7 @@ func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { Asterix: (*a)[i] = alias.Alias default: - return fmt.Errorf( - "alias %T is not supported for SSH destination", - alias.Alias, - ) + return fmt.Errorf("%w: %T", ErrAliasNotSupportedSSHDst, alias.Alias) } } @@ -1926,7 +2115,7 @@ func (a SSHDstAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH destination alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHDstAliasType, v) } } @@ -1953,7 +2142,7 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { case Asterix: aliases[i] = "*" default: - return nil, fmt.Errorf("unknown SSH source alias type: %T", v) + return nil, fmt.Errorf("%w: %T", ErrUnknownSSHSrcAliasType, v) } } @@ -1961,8 +2150,10 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { } func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, alias := range a { aips, err := alias.Resolve(p, users, nodes) @@ -2012,26 +2203,32 @@ func (u SSHUser) MarshalJSON() ([]byte, error) { // This is the only entrypoint of reading a policy from a file or other source. func unmarshalPolicy(b []byte) (*Policy, error) { if len(b) == 0 { - return nil, nil + return nil, nil //nolint:nilnil } var policy Policy + ast, err := hujson.Parse(b) if err != nil { return nil, fmt.Errorf("parsing HuJSON: %w", err) } ast.Standardize() + + //nolint:noinlineerr if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { - var serr *json.SemanticError - if errors.As(err, &serr) && serr.Err == json.ErrUnknownName { + //nolint:noinlineerr + if serr, ok := errors.AsType[*json.SemanticError](err); ok && errors.Is(serr.Err, json.ErrUnknownName) { ptr := serr.JSONPointer name := ptr.LastToken() - return nil, fmt.Errorf("unknown field %q", name) + + return nil, fmt.Errorf("%w: %q", ErrUnknownField, name) } + return nil, fmt.Errorf("parsing policy from bytes: %w", err) } + //nolint:noinlineerr if err := policy.validate(); err != nil { return nil, err } @@ -2053,8 +2250,8 @@ func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWi for _, dst := range destinations { for _, portRange := range dst.Ports { // Check if it's not a wildcard port (0-65535) - if !(portRange.First == 0 && portRange.Last == 65535) { - return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) + if portRange.First != 0 || portRange.Last != 65535 { + return fmt.Errorf("%w: %q only allows \"*\"", ErrProtocolNoSpecificPorts, protocol) } } } @@ -2075,6 +2272,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true @@ -2089,6 +2287,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range ssh.Destinations { if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 664f76b7..ddc32fba 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -19,7 +19,6 @@ import ( "gorm.io/gorm" "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) // TestUnmarshalPolicy tests the unmarshalling of JSON into Policy objects and the marshalling @@ -53,11 +52,11 @@ func TestMarshalJSON(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - ptr.To(Username("user@example.com")), + new(Username("user@example.com")), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Username("other@example.com")), + Alias: new(Username("other@example.com")), Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, @@ -82,6 +81,7 @@ func TestMarshalJSON(t *testing.T) { // Unmarshal back to verify round trip var roundTripped Policy + err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) @@ -253,11 +253,11 @@ func TestUnmarshalPolicy(t *testing.T) { Action: "accept", Protocol: "tcp", Sources: Aliases{ - ptr.To(Username("testuser@headscale.net")), + new(Username("testuser@headscale.net")), }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Username("otheruser@headscale.net")), + Alias: new(Username("otheruser@headscale.net")), Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, }, }, @@ -366,7 +366,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: "alias v2.Asterix is not supported for SSH source", + wantErr: "alias type not supported for SSH source: v2.Asterix", }, { name: "invalid-username", @@ -380,7 +380,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Username has to contain @, got: "invalid"`, + wantErr: `username must contain @: got "invalid"`, }, { name: "invalid-group", @@ -393,7 +393,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group has to start with "group:", got: "grou:example"`, + wantErr: `group must start with 'group:': got "grou:example"`, }, { name: "group-in-group", @@ -408,7 +408,7 @@ func TestUnmarshalPolicy(t *testing.T) { } `, // wantErr: `Username has to contain @, got: "group:inner"`, - wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`, + wantErr: `nested groups not allowed: found "group:inner" inside "group:example"`, }, { name: "invalid-addr", @@ -419,7 +419,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Hostname "derp" contains an invalid IP address: "10.0"`, + wantErr: `invalid IP address: hostname "derp" value "10.0"`, }, { name: "invalid-prefix", @@ -430,7 +430,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`, + wantErr: `invalid IP address: hostname "derp" value "10.0/42"`, }, // TODO(kradalby): Figure out why this doesn't work. // { @@ -459,7 +459,7 @@ func TestUnmarshalPolicy(t *testing.T) { ], } `, - wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, + wantErr: `invalid autogroup: got "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, }, { name: "undefined-hostname-errors-2490", @@ -478,7 +478,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `host not defined in policy: "user1" - please define or remove the reference`, }, { name: "defined-hostname-does-not-err-2490", @@ -546,7 +546,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(AutoGroup("autogroup:internet")), + Alias: new(AutoGroup("autogroup:internet")), Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, @@ -571,7 +571,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`, + wantErr: `autogroup:internet can only be used in ACL destinations`, }, { name: "autogroup:internet-in-ssh-src-not-allowed", @@ -590,7 +590,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`, + wantErr: `autogroup not supported for SSH: autogroup:internet in SSH source`, }, { name: "autogroup:internet-in-ssh-dst-not-allowed", @@ -609,7 +609,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`, + wantErr: `autogroup not supported for SSH: autogroup:internet in SSH destination`, }, { name: "ssh-basic", @@ -682,7 +682,7 @@ func TestUnmarshalPolicy(t *testing.T) { `, want: &Policy{ TagOwners: TagOwners{ - Tag("tag:web"): Owners{ptr.To(Username("admin@example.com"))}, + Tag("tag:web"): Owners{new(Username("admin@example.com"))}, }, SSHs: []SSH{ { @@ -691,7 +691,7 @@ func TestUnmarshalPolicy(t *testing.T) { tp("tag:web"), }, Destinations: SSHDstAliases{ - ptr.To(Username("admin@example.com")), + new(Username("admin@example.com")), }, Users: []SSHUser{ SSHUser("*"), @@ -733,7 +733,7 @@ func TestUnmarshalPolicy(t *testing.T) { gp("group:admins"), }, Destinations: SSHDstAliases{ - ptr.To(Username("admin@example.com")), + new(Username("admin@example.com")), }, Users: []SSHUser{ SSHUser("root"), @@ -760,7 +760,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-dst", @@ -779,7 +779,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-ssh-src", @@ -798,7 +798,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-tagOwner", @@ -809,7 +809,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-route", @@ -822,7 +822,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "group-must-be-defined-acl-autoapprover-exitnode", @@ -833,7 +833,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `group not defined in policy: "group:notdefined"`, }, { name: "tag-must-be-defined-acl-src", @@ -852,7 +852,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-dst", @@ -871,7 +871,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-src", @@ -890,7 +890,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-ssh-dst", @@ -912,7 +912,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-route", @@ -925,7 +925,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "tag-must-be-defined-acl-autoapprover-exitnode", @@ -936,7 +936,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, } `, - wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + wantErr: `tag not defined in policy: "tag:notdefined"`, }, { name: "missing-dst-port-is-err", @@ -955,7 +955,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `hostport must contain a colon (":")`, + wantErr: `hostport must contain a colon`, }, { name: "dst-port-zero-is-err", @@ -985,7 +985,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "rules"`, + wantErr: `unknown field in policy: "rules"`, }, { name: "disallow-unsupported-fields-nested", @@ -1008,7 +1008,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`, }, { name: "invalid-group-datatype", @@ -1020,7 +1020,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`, + wantErr: `group value must be an array: group "group:invalid" got string "should fail"`, }, { name: "invalid-group-name-and-datatype-fails-on-name-first", @@ -1032,7 +1032,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + wantErr: `group must start with 'group:': got "INVALID_GROUP_FIELD"`, }, { name: "disallow-unsupported-fields-hosts-level", @@ -1044,7 +1044,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`, + wantErr: `invalid IP address: hostname "INVALID_HOST_FIELD" value "should fail"`, }, { name: "disallow-unsupported-fields-tagowners-level", @@ -1056,7 +1056,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`, + wantErr: `tag must start with 'tag:': got "INVALID_TAG_FIELD"`, }, { name: "disallow-unsupported-fields-acls-level", @@ -1073,7 +1073,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_ACL_FIELD"`, + wantErr: `unknown field in policy: "INVALID_ACL_FIELD"`, }, { name: "disallow-unsupported-fields-ssh-level", @@ -1090,7 +1090,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "INVALID_SSH_FIELD"`, + wantErr: `unknown field in policy: "INVALID_SSH_FIELD"`, }, { name: "disallow-unsupported-fields-policy-level", @@ -1107,7 +1107,7 @@ func TestUnmarshalPolicy(t *testing.T) { "INVALID_POLICY_FIELD": "should fail at policy level" } `, - wantErr: `unknown field "INVALID_POLICY_FIELD"`, + wantErr: `unknown field in policy: "INVALID_POLICY_FIELD"`, }, { name: "disallow-unsupported-fields-autoapprovers-level", @@ -1122,7 +1122,7 @@ func TestUnmarshalPolicy(t *testing.T) { } } `, - wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`, + wantErr: `unknown field in policy: "INVALID_AUTO_APPROVER_FIELD"`, }, // headscale-admin uses # in some field names to add metadata, so we will ignore // those to ensure it doesnt break. @@ -1154,7 +1154,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(AutoGroup("autogroup:internet")), + Alias: new(AutoGroup("autogroup:internet")), Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, }, }, @@ -1181,7 +1181,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `unknown field "proto"`, + wantErr: `unknown field in policy: "proto"`, }, { name: "protocol-wildcard-not-allowed", @@ -1197,7 +1197,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`, + wantErr: `invalid protocol: use protocol number 0-255 or protocol name`, }, { name: "protocol-case-insensitive-uppercase", @@ -1277,7 +1277,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `leading 0 not permitted in protocol number "0"`, + wantErr: `leading zero not permitted in protocol number: "0"`, }, { name: "protocol-empty-applies-to-tcp-udp-only", @@ -1324,7 +1324,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "icmp" only allows "*"`, }, { name: "protocol-icmp-with-wildcard-port-allowed", @@ -1372,7 +1372,7 @@ func TestUnmarshalPolicy(t *testing.T) { ] } `, - wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`, + wantErr: `protocol does not support specific ports: "gre" only allows "*"`, }, { name: "protocol-tcp-with-specific-port-allowed", @@ -1491,7 +1491,7 @@ func TestUnmarshalPolicy(t *testing.T) { want: &Policy{ TagOwners: TagOwners{ Tag("tag:bigbrother"): {}, - Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:smallbrother"): {new(Tag("tag:bigbrother"))}, }, ACLs: []ACL{ { @@ -1502,7 +1502,7 @@ func TestUnmarshalPolicy(t *testing.T) { }, Destinations: []AliasWithPorts{ { - Alias: ptr.To(Tag("tag:smallbrother")), + Alias: new(Tag("tag:smallbrother")), Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}}, }, }, @@ -1583,14 +1583,14 @@ func TestUnmarshalPolicy(t *testing.T) { } } -func gp(s string) *Group { return ptr.To(Group(s)) } -func up(s string) *Username { return ptr.To(Username(s)) } -func hp(s string) *Host { return ptr.To(Host(s)) } -func tp(s string) *Tag { return ptr.To(Tag(s)) } -func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) } +func gp(s string) *Group { return new(Group(s)) } +func up(s string) *Username { return new(Username(s)) } +func hp(s string) *Host { return new(Host(s)) } +func tp(s string) *Tag { return new(Tag(s)) } +func agp(s string) *AutoGroup { return new(AutoGroup(s)) } func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } -func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } -func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) } +func ap(addr string) *netip.Addr { return new(netip.MustParseAddr(addr)) } +func pp(pref string) *Prefix { return new(Prefix(mp(pref))) } func p(pref string) Prefix { return Prefix(mp(pref)) } func TestResolvePolicy(t *testing.T) { @@ -1604,11 +1604,18 @@ func TestResolvePolicy(t *testing.T) { } // Extract users to variables so we can take their addresses + // The variables below are all used in new() calls in the test cases. + //nolint:staticcheck testuser := users["testuser"] + //nolint:staticcheck groupuser := users["groupuser"] + //nolint:staticcheck groupuser1 := users["groupuser1"] + //nolint:staticcheck groupuser2 := users["groupuser2"] + //nolint:staticcheck notme := users["notme"] + //nolint:staticcheck testuser2 := users["testuser2"] tests := []struct { @@ -1636,31 +1643,31 @@ func TestResolvePolicy(t *testing.T) { }, { name: "username", - toResolve: ptr.To(Username("testuser@")), + toResolve: new(Username("testuser@")), nodes: types.Nodes{ // Not matching other user { - User: ptr.To(notme), + User: new(notme), IPv4: ap("100.100.101.1"), }, // Not matching forced tags { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.2"), }, // not matching because it's tagged (tags copied from AuthKey) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"alsotagged"}, IPv4: ap("100.100.101.3"), }, { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.103"), }, { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.104"), }, }, @@ -1668,31 +1675,31 @@ func TestResolvePolicy(t *testing.T) { }, { name: "group", - toResolve: ptr.To(Group("group:testgroup")), + toResolve: new(Group("group:testgroup")), nodes: types.Nodes{ // Not matching other user { - User: ptr.To(notme), + User: new(notme), IPv4: ap("100.100.101.4"), }, // Not matching forced tags { - User: ptr.To(groupuser), + User: new(groupuser), Tags: []string{"tag:anything"}, IPv4: ap("100.100.101.5"), }, // not matching because it's tagged (tags copied from AuthKey) { - User: ptr.To(groupuser), + User: new(groupuser), Tags: []string{"tag:alsotagged"}, IPv4: ap("100.100.101.6"), }, { - User: ptr.To(groupuser), + User: new(groupuser), IPv4: ap("100.100.101.203"), }, { - User: ptr.To(groupuser), + User: new(groupuser), IPv4: ap("100.100.101.204"), }, }, @@ -1710,7 +1717,7 @@ func TestResolvePolicy(t *testing.T) { nodes: types.Nodes{ // Not matching other user { - User: ptr.To(notme), + User: new(notme), IPv4: ap("100.100.101.9"), }, // Not matching forced tags @@ -1746,7 +1753,7 @@ func TestResolvePolicy(t *testing.T) { pol: &Policy{ TagOwners: TagOwners{ Tag("tag:bigbrother"): {}, - Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:smallbrother"): {new(Tag("tag:bigbrother"))}, }, }, nodes: types.Nodes{ @@ -1769,7 +1776,7 @@ func TestResolvePolicy(t *testing.T) { pol: &Policy{ TagOwners: TagOwners{ Tag("tag:bigbrother"): {}, - Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:smallbrother"): {new(Tag("tag:bigbrother"))}, }, }, nodes: types.Nodes{ @@ -1804,14 +1811,14 @@ func TestResolvePolicy(t *testing.T) { }, { name: "multiple-groups", - toResolve: ptr.To(Group("group:testgroup")), + toResolve: new(Group("group:testgroup")), nodes: types.Nodes{ { - User: ptr.To(groupuser1), + User: new(groupuser1), IPv4: ap("100.100.101.203"), }, { - User: ptr.To(groupuser2), + User: new(groupuser2), IPv4: ap("100.100.101.204"), }, }, @@ -1829,14 +1836,14 @@ func TestResolvePolicy(t *testing.T) { }, { name: "invalid-username", - toResolve: ptr.To(Username("invaliduser@")), + toResolve: new(Username("invaliduser@")), nodes: types.Nodes{ { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.103"), }, }, - wantErr: `user with token "invaliduser@" not found`, + wantErr: `user not found: token "invaliduser@"`, }, { name: "invalid-tag", @@ -1860,47 +1867,47 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-member-comprehensive", - toResolve: ptr.To(AutoGroup(AutoGroupMember)), + toResolve: new(AutoGroupMember), nodes: types.Nodes{ // Node with no tags (should be included - is a member) { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.1"), }, // Node with single tag (should be excluded - tagged nodes are not members) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with multiple tags, all defined in policy (should be excluded) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.3"), }, // Node with tag not defined in policy (should be excluded - still tagged) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:undefined"}, IPv4: ap("100.100.101.4"), }, // Node with mixed tags - some defined, some not (should be excluded) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:undefined"}, IPv4: ap("100.100.101.5"), }, // Another untagged node from different user (should be included) { - User: ptr.To(testuser2), + User: new(testuser2), IPv4: ap("100.100.101.6"), }, }, pol: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, - Tag("tag:other"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:test"): Owners{new(Username("testuser@"))}, + Tag("tag:other"): Owners{new(Username("testuser@"))}, }, }, want: []netip.Prefix{ @@ -1910,54 +1917,54 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-tagged", - toResolve: ptr.To(AutoGroup(AutoGroupTagged)), + toResolve: new(AutoGroupTagged), nodes: types.Nodes{ // Node with no tags (should be excluded - not tagged) { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.1"), }, // Node with single tag defined in policy (should be included) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.2"), }, // Node with multiple tags, all defined in policy (should be included) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:other"}, IPv4: ap("100.100.101.3"), }, // Node with tag not defined in policy (should be included - still tagged) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:undefined"}, IPv4: ap("100.100.101.4"), }, // Node with mixed tags - some defined, some not (should be included) { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test", "tag:undefined"}, IPv4: ap("100.100.101.5"), }, // Another untagged node from different user (should be excluded) { - User: ptr.To(testuser2), + User: new(testuser2), IPv4: ap("100.100.101.6"), }, // Tagged node from different user (should be included) { - User: ptr.To(testuser2), + User: new(testuser2), Tags: []string{"tag:server"}, IPv4: ap("100.100.101.7"), }, }, pol: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, - Tag("tag:other"): Owners{ptr.To(Username("testuser@"))}, - Tag("tag:server"): Owners{ptr.To(Username("testuser2@"))}, + Tag("tag:test"): Owners{new(Username("testuser@"))}, + Tag("tag:other"): Owners{new(Username("testuser@"))}, + Tag("tag:server"): Owners{new(Username("testuser2@"))}, }, }, want: []netip.Prefix{ @@ -1968,38 +1975,38 @@ func TestResolvePolicy(t *testing.T) { }, { name: "autogroup-self", - toResolve: ptr.To(AutoGroupSelf), + toResolve: new(AutoGroupSelf), nodes: types.Nodes{ { - User: ptr.To(testuser), + User: new(testuser), IPv4: ap("100.100.101.1"), }, { - User: ptr.To(testuser2), + User: new(testuser2), IPv4: ap("100.100.101.2"), }, { - User: ptr.To(testuser), + User: new(testuser), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.3"), }, { - User: ptr.To(testuser2), + User: new(testuser2), Tags: []string{"tag:test"}, IPv4: ap("100.100.101.4"), }, }, pol: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:test"): Owners{new(Username("testuser@"))}, }, }, wantErr: "autogroup:self requires per-node resolution", }, { name: "autogroup-invalid", - toResolve: ptr.To(AutoGroup("autogroup:invalid")), - wantErr: "unknown autogroup", + toResolve: new(AutoGroup("autogroup:invalid")), + wantErr: "invalid autogroup", }, } @@ -2021,6 +2028,7 @@ func TestResolvePolicy(t *testing.T) { } var prefs []netip.Prefix + if ips != nil { if p := ips.Prefixes(); len(p) > 0 { prefs = p @@ -2076,7 +2084,7 @@ func TestResolveAutoApprovers(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, }, }, }, @@ -2091,8 +2099,8 @@ func TestResolveAutoApprovers(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, + mp("10.0.1.0/24"): {new(Username("user2@"))}, }, }, }, @@ -2107,7 +2115,7 @@ func TestResolveAutoApprovers(t *testing.T) { name: "exit-node", policy: &Policy{ AutoApprovers: AutoApproverPolicy{ - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, want: map[netip.Prefix]*netipx.IPSet{}, @@ -2122,7 +2130,7 @@ func TestResolveAutoApprovers(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, }, }, }, @@ -2137,20 +2145,20 @@ func TestResolveAutoApprovers(t *testing.T) { policy: &Policy{ TagOwners: TagOwners{ "tag:testtag": Owners{ - ptr.To(Username("user1@")), - ptr.To(Username("user2@")), + new(Username("user1@")), + new(Username("user2@")), }, "tag:exittest": Owners{ - ptr.To(Group("group:exitgroup")), + new(Group("group:exitgroup")), }, }, Groups: Groups{ "group:exitgroup": Usernames{"user2@"}, }, AutoApprovers: AutoApproverPolicy{ - ExitNode: AutoApprovers{ptr.To(Tag("tag:exittest"))}, + ExitNode: AutoApprovers{new(Tag("tag:exittest"))}, Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.1.0/24"): {ptr.To(Tag("tag:testtag"))}, + mp("10.0.1.0/24"): {new(Tag("tag:testtag"))}, }, }, }, @@ -2168,10 +2176,10 @@ func TestResolveAutoApprovers(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {new(Username("user3@"))}, }, - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, want: map[netip.Prefix]*netipx.IPSet{ @@ -2192,9 +2200,11 @@ func TestResolveAutoApprovers(t *testing.T) { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff) } + if tt.wantAllIPRoutes != nil { if gotAllIPRoutes == nil { t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil") @@ -2341,6 +2351,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet { for _, p := range prefixes { builder.AddPrefix(mp(p)) } + ipSet, _ := builder.IPSet() return ipSet @@ -2350,6 +2361,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool { if x == nil || y == nil { return x == y } + return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...) } @@ -2388,7 +2400,7 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, }, }, }, @@ -2401,8 +2413,8 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + mp("10.0.0.0/24"): {new(Username("user1@"))}, + mp("10.0.1.0/24"): {new(Username("user2@"))}, }, }, }, @@ -2414,7 +2426,7 @@ func TestNodeCanApproveRoute(t *testing.T) { name: "exit-node-approval", policy: &Policy{ AutoApprovers: AutoApproverPolicy{ - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2429,7 +2441,7 @@ func TestNodeCanApproveRoute(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, }, }, }, @@ -2445,10 +2457,10 @@ func TestNodeCanApproveRoute(t *testing.T) { }, AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, - mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + mp("10.0.0.0/24"): {new(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {new(Username("user3@"))}, }, - ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + ExitNode: AutoApprovers{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2460,7 +2472,7 @@ func TestNodeCanApproveRoute(t *testing.T) { policy: &Policy{ AutoApprovers: AutoApproverPolicy{ Routes: map[netip.Prefix]AutoApprovers{ - mp("10.0.0.0/24"): {ptr.To(Username("user2@"))}, + mp("10.0.0.0/24"): {new(Username("user2@"))}, }, }, }, @@ -2518,7 +2530,7 @@ func TestResolveTagOwners(t *testing.T) { name: "single-tag-owner", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2530,7 +2542,7 @@ func TestResolveTagOwners(t *testing.T) { name: "multiple-tag-owners", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user1@")), new(Username("user2@"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2545,7 +2557,7 @@ func TestResolveTagOwners(t *testing.T) { "group:testgroup": Usernames{"user1@", "user2@"}, }, TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + Tag("tag:test"): Owners{new(Group("group:testgroup"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2557,8 +2569,8 @@ func TestResolveTagOwners(t *testing.T) { name: "tag-owns-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:bigbrother"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:bigbrother"): Owners{new(Username("user1@"))}, + Tag("tag:smallbrother"): Owners{new(Tag("tag:bigbrother"))}, }, }, want: map[Tag]*netipx.IPSet{ @@ -2578,6 +2590,7 @@ func TestResolveTagOwners(t *testing.T) { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff) } @@ -2619,7 +2632,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "single-tag-owner", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2630,7 +2643,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "multiple-tag-owners", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user1@")), new(Username("user2@"))}, }, }, node: nodes[1], @@ -2644,7 +2657,7 @@ func TestNodeCanHaveTag(t *testing.T) { "group:testgroup": Usernames{"user1@", "user2@"}, }, TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + Tag("tag:test"): Owners{new(Group("group:testgroup"))}, }, }, node: nodes[1], @@ -2658,19 +2671,19 @@ func TestNodeCanHaveTag(t *testing.T) { "group:testgroup": Usernames{"invalid"}, }, TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + Tag("tag:test"): Owners{new(Group("group:testgroup"))}, }, }, node: nodes[0], tag: "tag:test", want: false, - wantErr: "Username has to contain @", + wantErr: "username must contain @", }, { name: "node-cannot-have-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user2@"))}, }, }, node: nodes[0], @@ -2681,7 +2694,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-with-unauthorized-tag-different-user", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:prod"): Owners{new(Username("user1@"))}, }, }, node: nodes[2], // user3's node @@ -2692,8 +2705,8 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-with-multiple-tags-one-unauthorized", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:web"): Owners{new(Username("user1@"))}, + Tag("tag:database"): Owners{new(Username("user2@"))}, }, }, node: nodes[0], // user1's node @@ -2713,7 +2726,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "tag-not-in-tagowners", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:prod"): Owners{new(Username("user1@"))}, }, }, node: nodes[0], @@ -2726,13 +2739,13 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-without-ip-user-owns-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[0], - UserID: ptr.To(users[0].ID), + UserID: new(users[0].ID), }, tag: "tag:test", want: true, // Should succeed via user-based fallback @@ -2741,13 +2754,13 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-without-ip-user-does-not-own-tag", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:test"): Owners{new(Username("user2@"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[0], // user1, but tag owned by user2 - UserID: ptr.To(users[0].ID), + UserID: new(users[0].ID), }, tag: "tag:test", want: false, // user1 does not own tag:test @@ -2759,13 +2772,13 @@ func TestNodeCanHaveTag(t *testing.T) { "group:admins": Usernames{"user1@", "user2@"}, }, TagOwners: TagOwners{ - Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))}, + Tag("tag:admin"): Owners{new(Group("group:admins"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[1], // user2 is in group:admins - UserID: ptr.To(users[1].ID), + UserID: new(users[1].ID), }, tag: "tag:admin", want: true, // Should succeed via group membership @@ -2777,13 +2790,13 @@ func TestNodeCanHaveTag(t *testing.T) { "group:admins": Usernames{"user1@"}, }, TagOwners: TagOwners{ - Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))}, + Tag("tag:admin"): Owners{new(Group("group:admins"))}, }, }, node: &types.Node{ // No IPv4 or IPv6 - simulates new node registration User: &users[1], // user2 is NOT in group:admins - UserID: ptr.To(users[1].ID), + UserID: new(users[1].ID), }, tag: "tag:admin", want: false, // user2 is not in group:admins @@ -2792,7 +2805,7 @@ func TestNodeCanHaveTag(t *testing.T) { name: "node-without-ip-no-user", policy: &Policy{ TagOwners: TagOwners{ - Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:test"): Owners{new(Username("user1@"))}, }, }, node: &types.Node{ @@ -2809,14 +2822,14 @@ func TestNodeCanHaveTag(t *testing.T) { }, TagOwners: TagOwners{ Tag("tag:server"): Owners{ - ptr.To(Username("user1@")), - ptr.To(Group("group:ops")), + new(Username("user1@")), + new(Group("group:ops")), }, }, }, node: &types.Node{ User: &users[0], // user1 directly owns the tag - UserID: ptr.To(users[0].ID), + UserID: new(users[0].ID), }, tag: "tag:server", want: true, @@ -2829,14 +2842,14 @@ func TestNodeCanHaveTag(t *testing.T) { }, TagOwners: TagOwners{ Tag("tag:server"): Owners{ - ptr.To(Username("user1@")), - ptr.To(Group("group:ops")), + new(Username("user1@")), + new(Group("group:ops")), }, }, }, node: &types.Node{ User: &users[2], // user3 is in group:ops - UserID: ptr.To(users[2].ID), + UserID: new(users[2].ID), }, tag: "tag:server", want: true, @@ -2853,6 +2866,7 @@ func TestNodeCanHaveTag(t *testing.T) { require.ErrorContains(t, err, tt.wantErr) return } + require.NoError(t, err) got := pm.NodeCanHaveTag(tt.node.View(), tt.tag) @@ -2881,14 +2895,14 @@ func TestUserMatchesOwner(t *testing.T) { name: "username-match", policy: &Policy{}, user: users[0], - owner: ptr.To(Username("user1@")), + owner: new(Username("user1@")), want: true, }, { name: "username-no-match", policy: &Policy{}, user: users[0], - owner: ptr.To(Username("user2@")), + owner: new(Username("user2@")), want: false, }, { @@ -2899,7 +2913,7 @@ func TestUserMatchesOwner(t *testing.T) { }, }, user: users[1], // user2 is in group:admins - owner: ptr.To(Group("group:admins")), + owner: new(Group("group:admins")), want: true, }, { @@ -2910,7 +2924,7 @@ func TestUserMatchesOwner(t *testing.T) { }, }, user: users[1], // user2 is NOT in group:admins - owner: ptr.To(Group("group:admins")), + owner: new(Group("group:admins")), want: false, }, { @@ -2919,7 +2933,7 @@ func TestUserMatchesOwner(t *testing.T) { Groups: Groups{}, }, user: users[0], - owner: ptr.To(Group("group:undefined")), + owner: new(Group("group:undefined")), want: false, }, { @@ -3113,6 +3127,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) if tt.wantErr { @@ -3123,8 +3138,8 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { require.NoError(t, err) assert.Equal(t, tt.expected.Action, acl.Action) assert.Equal(t, tt.expected.Protocol, acl.Protocol) - assert.Equal(t, len(tt.expected.Sources), len(acl.Sources)) - assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations)) + assert.Len(t, acl.Sources, len(tt.expected.Sources)) + assert.Len(t, acl.Destinations, len(tt.expected.Destinations)) // Compare sources for i, expectedSrc := range tt.expected.Sources { @@ -3164,14 +3179,15 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { // Unmarshal back var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) require.NoError(t, err) // Should be equal assert.Equal(t, original.Action, unmarshaled.Action) assert.Equal(t, original.Protocol, unmarshaled.Protocol) - assert.Equal(t, len(original.Sources), len(unmarshaled.Sources)) - assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations)) + assert.Len(t, unmarshaled.Sources, len(original.Sources)) + assert.Len(t, unmarshaled.Destinations, len(original.Destinations)) } func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { @@ -3239,15 +3255,17 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { _, err := unmarshalPolicy([]byte(policyJSON)) require.Error(t, err) - assert.Contains(t, err.Error(), `invalid action "deny"`) + assert.Contains(t, err.Error(), `invalid action`) + assert.Contains(t, err.Error(), `deny`) } -// Helper function to parse aliases for testing +// Helper function to parse aliases for testing. func mustParseAlias(s string) Alias { alias, err := parseAlias(s) if err != nil { panic(err) } + return alias } @@ -3261,20 +3279,20 @@ func TestFlattenTagOwners(t *testing.T) { { name: "tag-owns-tag", input: TagOwners{ - Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + Tag("tag:bigbrother"): Owners{new(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{new(Tag("tag:bigbrother"))}, }, want: TagOwners{ - Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:smallbrother"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:bigbrother"): Owners{new(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{new(Group("group:user1"))}, }, wantErr: "", }, { name: "circular-reference", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, - Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + Tag("tag:a"): Owners{new(Tag("tag:b"))}, + Tag("tag:b"): Owners{new(Tag("tag:a"))}, }, want: nil, wantErr: "circular reference detected: tag:a -> tag:b", @@ -3282,83 +3300,83 @@ func TestFlattenTagOwners(t *testing.T) { { name: "mixed-owners", input: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, - Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:x"): Owners{new(Username("user1@")), new(Tag("tag:y"))}, + Tag("tag:y"): Owners{new(Username("user2@"))}, }, want: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, - Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:x"): Owners{new(Username("user1@")), new(Username("user2@"))}, + Tag("tag:y"): Owners{new(Username("user2@"))}, }, wantErr: "", }, { name: "mixed-dupe-owners", input: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, - Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:x"): Owners{new(Username("user1@")), new(Tag("tag:y"))}, + Tag("tag:y"): Owners{new(Username("user1@"))}, }, want: TagOwners{ - Tag("tag:x"): Owners{ptr.To(Username("user1@"))}, - Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:x"): Owners{new(Username("user1@"))}, + Tag("tag:y"): Owners{new(Username("user1@"))}, }, wantErr: "", }, { name: "no-tag-owners", input: TagOwners{ - Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:solo"): Owners{new(Username("user1@"))}, }, want: TagOwners{ - Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:solo"): Owners{new(Username("user1@"))}, }, wantErr: "", }, { name: "tag-long-owner-chain", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, - Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, - Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, - Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, - Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, - Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + Tag("tag:a"): Owners{new(Group("group:user1"))}, + Tag("tag:b"): Owners{new(Tag("tag:a"))}, + Tag("tag:c"): Owners{new(Tag("tag:b"))}, + Tag("tag:d"): Owners{new(Tag("tag:c"))}, + Tag("tag:e"): Owners{new(Tag("tag:d"))}, + Tag("tag:f"): Owners{new(Tag("tag:e"))}, + Tag("tag:g"): Owners{new(Tag("tag:f"))}, }, want: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:b"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:c"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:d"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:e"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:f"): Owners{ptr.To(Group("group:user1"))}, - Tag("tag:g"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:a"): Owners{new(Group("group:user1"))}, + Tag("tag:b"): Owners{new(Group("group:user1"))}, + Tag("tag:c"): Owners{new(Group("group:user1"))}, + Tag("tag:d"): Owners{new(Group("group:user1"))}, + Tag("tag:e"): Owners{new(Group("group:user1"))}, + Tag("tag:f"): Owners{new(Group("group:user1"))}, + Tag("tag:g"): Owners{new(Group("group:user1"))}, }, wantErr: "", }, { name: "tag-long-circular-chain", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:g"))}, - Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, - Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, - Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, - Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, - Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, - Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + Tag("tag:a"): Owners{new(Tag("tag:g"))}, + Tag("tag:b"): Owners{new(Tag("tag:a"))}, + Tag("tag:c"): Owners{new(Tag("tag:b"))}, + Tag("tag:d"): Owners{new(Tag("tag:c"))}, + Tag("tag:e"): Owners{new(Tag("tag:d"))}, + Tag("tag:f"): Owners{new(Tag("tag:e"))}, + Tag("tag:g"): Owners{new(Tag("tag:f"))}, }, wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g", }, { name: "undefined-tag-reference", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:nonexistent"))}, + Tag("tag:a"): Owners{new(Tag("tag:nonexistent"))}, }, wantErr: `tag "tag:a" references undefined tag "tag:nonexistent"`, }, { name: "tag-with-empty-owners-is-valid", input: TagOwners{ - Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:a"): Owners{new(Tag("tag:b"))}, Tag("tag:b"): Owners{}, // empty owners but exists }, want: TagOwners{ diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index a4367775..68c5984b 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -9,6 +9,21 @@ import ( "tailscale.com/tailcfg" ) +// portRangeParts is the expected number of parts in a port range (start-end). +const portRangeParts = 2 + +// Sentinel errors for port and destination parsing. +var ( + ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port") + ErrInputStartsWithColon = errors.New("input cannot start with a colon character") + ErrInputEndsWithColon = errors.New("input cannot end with a colon character") + ErrInvalidPortRange = errors.New("invalid port range format") + ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port") + ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard") + ErrInvalidPortNumber = errors.New("invalid port number") + ErrPortOutOfRange = errors.New("port number out of range") +) + // splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid. func splitDestinationAndPort(input string) (string, string, error) { // Find the last occurrence of the colon character @@ -16,13 +31,15 @@ func splitDestinationAndPort(input string) (string, string, error) { // Check if the colon character is present and not at the beginning or end of the string if lastColonIndex == -1 { - return "", "", errors.New("input must contain a colon character separating destination and port") + return "", "", ErrInputMissingColon } + if lastColonIndex == 0 { - return "", "", errors.New("input cannot start with a colon character") + return "", "", ErrInputStartsWithColon } + if lastColonIndex == len(input)-1 { - return "", "", errors.New("input cannot end with a colon character") + return "", "", ErrInputEndsWithColon } // Split the string into destination and port based on the last colon @@ -45,11 +62,12 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { for part := range parts { if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") + rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { return e == "" }) - if len(rangeParts) != 2 { - return nil, errors.New("invalid port range format") + if len(rangeParts) != portRangeParts { + return nil, ErrInvalidPortRange } first, err := parsePort(rangeParts[0]) @@ -63,7 +81,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if first > last { - return nil, errors.New("invalid port range: first port is greater than last port") + return nil, ErrPortRangeInverted } portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last}) @@ -74,7 +92,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { } if port < 1 { - return nil, errors.New("first port must be >0, or use '*' for wildcard") + return nil, ErrPortMustBePositive } portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port}) @@ -88,11 +106,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { func parsePort(portStr string) (uint16, error) { port, err := strconv.Atoi(portStr) if err != nil { - return 0, errors.New("invalid port number") + return 0, ErrInvalidPortNumber } if port < 0 || port > 65535 { - return 0, errors.New("port number out of range") + return 0, ErrPortOutOfRange } return uint16(port), nil diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index 2084b22f..2ce95537 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -24,14 +24,14 @@ func TestParseDestinationAndPort(t *testing.T) { {"tag:api-server:443", "tag:api-server", "443", nil}, {"example-host-1:*", "example-host-1", "*", nil}, {"hostname:80-90", "hostname", "80-90", nil}, - {"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")}, - {":invalid", "", "", errors.New("input cannot start with a colon character")}, - {"invalid:", "", "", errors.New("input cannot end with a colon character")}, + {"invalidinput", "", "", ErrInputMissingColon}, + {":invalid", "", "", ErrInputStartsWithColon}, + {"invalid:", "", "", ErrInputEndsWithColon}, } for _, testCase := range testCases { dst, port, err := splitDestinationAndPort(testCase.input) - if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) { + if dst != testCase.expectedDst || port != testCase.expectedPort || !errors.Is(err, testCase.expectedErr) { t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)", testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr) } @@ -42,25 +42,23 @@ func TestParsePort(t *testing.T) { tests := []struct { input string expected uint16 - err string + err error }{ - {"80", 80, ""}, - {"0", 0, ""}, - {"65535", 65535, ""}, - {"-1", 0, "port number out of range"}, - {"65536", 0, "port number out of range"}, - {"abc", 0, "invalid port number"}, - {"", 0, "invalid port number"}, + {"80", 80, nil}, + {"0", 0, nil}, + {"65535", 65535, nil}, + {"-1", 0, ErrPortOutOfRange}, + {"65536", 0, ErrPortOutOfRange}, + {"abc", 0, ErrInvalidPortNumber}, + {"", 0, ErrInvalidPortNumber}, } for _, test := range tests { result, err := parsePort(test.input) - if err != nil && err.Error() != test.err { + if !errors.Is(err, test.err) { t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) } - if err == nil && test.err != "" { - t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) - } + if result != test.expected { t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) } @@ -71,30 +69,28 @@ func TestParsePortRange(t *testing.T) { tests := []struct { input string expected []tailcfg.PortRange - err string + err error }{ - {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""}, - {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""}, - {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""}, - {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""}, - {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, - {"80-", nil, "invalid port range format"}, - {"-90", nil, "invalid port range format"}, - {"80-90,", nil, "invalid port number"}, - {"80,90-", nil, "invalid port range format"}, - {"80-90,abc", nil, "invalid port number"}, - {"80-90,65536", nil, "port number out of range"}, - {"80-90,90-80", nil, "invalid port range: first port is greater than last port"}, + {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, nil}, + {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, nil}, + {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, nil}, + {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, nil}, + {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, nil}, + {"80-", nil, ErrInvalidPortRange}, + {"-90", nil, ErrInvalidPortRange}, + {"80-90,", nil, ErrInvalidPortNumber}, + {"80,90-", nil, ErrInvalidPortRange}, + {"80-90,abc", nil, ErrInvalidPortNumber}, + {"80-90,65536", nil, ErrPortOutOfRange}, + {"80-90,90-80", nil, ErrPortRangeInverted}, } for _, test := range tests { result, err := parsePortRange(test.input) - if err != nil && err.Error() != test.err { + if !errors.Is(err, test.err) { t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) } - if err == nil && test.err != "" { - t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) - } + if diff := cmp.Diff(result, test.expected); diff != "" { t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 02275751..9864983a 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() { // This is not my favourite solution, but it kind of works in our eventually consistent world. ticker := time.NewTicker(time.Second) defer ticker.Stop() + disconnected := true // Wait up to 10 seconds for the node to reconnect. // 10 seconds was arbitrary chosen as a reasonable time to reconnect. @@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() { disconnected = false break } + <-ticker.C } @@ -212,11 +214,14 @@ func (m *mapSession) serveLongPoll() { // adding this before connecting it to the state ensure that // it does not miss any updates that might be sent in the split // time between the node connecting and the batcher being ready. + //nolint:noinlineerr if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { m.errf(err, "failed to add node to batcher") log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session") + return } + log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher") m.h.Change(mapReqChange) @@ -245,7 +250,8 @@ func (m *mapSession) serveLongPoll() { return } - if err := m.writeMap(update); err != nil { + err := m.writeMap(update) + if err != nil { m.errf(err, "cannot write update to client") return } @@ -254,7 +260,8 @@ func (m *mapSession) serveLongPoll() { m.resetKeepAlive() case <-m.keepAliveTicker.C: - if err := m.writeMap(&keepAlive); err != nil { + err := m.writeMap(&keepAlive) + if err != nil { m.errf(err, "cannot write keep alive") return } @@ -282,8 +289,9 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression) } + //nolint:prealloc data := make([]byte, reservedResponseHeaderSize) - //nolint:gosec // G115: JSON response size will not exceed uint32 max + //nolint:gosec binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) data = append(data, jsonBody...) @@ -328,13 +336,13 @@ func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event { Str("node.name", m.node.Hostname) } -//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +//nolint:zerologlint func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) } -//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +//nolint:zerologlint func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) } -//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +//nolint:zerologlint func (m *mapSession) errf(err error, msg string, a ...any) { m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...) } diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 977dc7a9..e3708a13 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -4,7 +4,6 @@ import ( "fmt" "net/netip" "slices" - "sort" "strings" "sync" @@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { // this is important so the same node is chosen two times in a row // as the primary route. ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) // Create a map of prefixes to nodes that serve them so we // can determine the primary route for each prefix. @@ -108,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Msg("Current primary no longer available") } } + if len(nodes) >= 1 { pr.primaries[prefix] = nodes[0] changed = true + log.Debug(). Caller(). Str("prefix", prefix.String()). @@ -127,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Str("prefix", prefix.String()). Msg("Cleaning up primary route that no longer has available nodes") delete(pr.primaries, prefix) + changed = true } } @@ -162,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) // If no routes are being set, remove the node from the routes map. if len(prefixes) == 0 { wasPresent := false + if _, ok := pr.routes[node]; ok { delete(pr.routes, node) + wasPresent = true + log.Debug(). Caller(). Uint64("node.id", node.Uint64()). Msg("Removed node from primary routes (no prefixes)") } + changed := pr.updatePrimaryLocked() log.Debug(). Caller(). @@ -236,7 +242,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { } } - tsaddr.SortPrefixes(routes) + slices.SortFunc(routes, netip.Prefix.Compare) return routes } @@ -254,13 +260,15 @@ func (pr *PrimaryRoutes) stringLocked() string { fmt.Fprintln(&sb, "Available routes:") ids := types.NodeIDs(xmaps.Keys(pr.routes)) - sort.Sort(ids) + slices.Sort(ids) + for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) } fmt.Fprintln(&sb, "\n\nCurrent primary routes:") + for route, nodeID := range pr.primaries { fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID) } @@ -294,7 +302,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes { // Populate available routes for nodeID, routes := range pr.routes { prefixes := routes.Slice() - tsaddr.SortPrefixes(prefixes) + slices.SortFunc(prefixes, netip.Prefix.Compare) debug.AvailableRoutes[nodeID] = prefixes } diff --git a/hscontrol/routes/primary_test.go b/hscontrol/routes/primary_test.go index 7a9767b2..b03c8f81 100644 --- a/hscontrol/routes/primary_test.go +++ b/hscontrol/routes/primary_test.go @@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) pr.SetRoutes(2, mp("192.168.2.0/24")) pr.SetRoutes(1) // Deregister by setting no routes + return pr.SetRoutes(1, mp("192.168.3.0/24")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) { { name: "multiple-nodes-register-same-route", operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - pr.SetRoutes(2, mp("192.168.1.0/24")) // true + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true + return pr.SetRoutes(3, mp("192.168.1.0/24")) // false }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) // false pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary - return pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(1) // true, 2 primary }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ 2: { @@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) + return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { var wg sync.WaitGroup wg.Add(2) + var change1, change2 bool + go func() { defer wg.Done() + change1 = pr.SetRoutes(1, mp("192.168.1.0/24")) }() go func() { defer wg.Done() + change2 = pr.SetRoutes(2, mp("192.168.2.0/24")) }() + wg.Wait() return change1 || change2 @@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pr := New() + change := tt.operations(pr) if change != tt.expectedChange { t.Errorf("change = %v, want %v", change, tt.expectedChange) } + comps := append(util.Comparers, cmpopts.EquateEmpty()) if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { t.Errorf("routes mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { t.Errorf("primaries mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 3ed1d79f..abb34eb0 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -77,6 +77,7 @@ func (s *State) DebugOverview() string { ephemeralCount := 0 now := time.Now() + for _, node := range allNodes.All() { if node.Valid() { userName := node.Owner().Name() @@ -103,17 +104,21 @@ func (s *State) DebugOverview() string { // User statistics sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users))) + for userName, nodeCount := range userNodeCounts { sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount)) } + sb.WriteString("\n") // Policy information sb.WriteString("Policy:\n") sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode)) + if s.cfg.Policy.Mode == types.PolicyModeFile { sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path)) } + sb.WriteString("\n") // DERP information @@ -123,6 +128,7 @@ func (s *State) DebugOverview() string { } else { sb.WriteString("DERP: not configured\n") } + sb.WriteString("\n") // Route information @@ -130,6 +136,7 @@ func (s *State) DebugOverview() string { if s.primaryRoutes.String() == "" { routeCount = 0 } + sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount)) sb.WriteString("\n") @@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string { for _, node := range region.Nodes { sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n", node.Name, node.HostName, node.DERPPort)) + if node.STUNPort != 0 { sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort)) } } + sb.WriteString("\n") } @@ -236,7 +245,7 @@ func (s *State) DebugPolicy() (string, error) { return string(pol), nil default: - return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode) + return "", fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, s.cfg.Policy.Mode) } } @@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { if s.primaryRoutes.String() == "" { routeCount = 0 } + info.PrimaryRoutes = routeCount return info diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go index 632af13c..5c755687 100644 --- a/hscontrol/state/ephemeral_test.go +++ b/hscontrol/state/ephemeral_test.go @@ -8,7 +8,6 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "tailscale.com/types/ptr" ) // TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode @@ -21,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // Create NodeStore store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -44,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // 6. If DELETE came after UPDATE, the returned node should be invalid done := make(chan bool, 2) - var updatedNode types.NodeView - var updateOk bool + + var ( + updatedNode types.NodeView + updateOk bool + ) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(node.ID) + done <- true }() @@ -91,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -106,7 +113,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { // Start UpdateNode in goroutine - it will queue and wait for batch go func() { node, ok := store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) resultChan <- struct { node types.NodeView @@ -148,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { node := createTestNode(3, 1, "test-user", "test-node-3") store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -156,7 +164,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { // Simulate UpdateNode being called updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) require.True(t, ok, "UpdateNode should succeed") require.True(t, updatedNode.Valid(), "UpdateNode should return valid node") @@ -204,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -214,21 +223,26 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 1. UpdateNode (from UpdateNodeFromMapRequest during polling) // 2. DeleteNode (from handleLogout when client sends logout request) - var updatedNode types.NodeView - var updateOk bool + var ( + updatedNode types.NodeView + updateOk bool + ) + done := make(chan bool, 2) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) go func() { updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(ephemeralNode.ID) + done <- true }() @@ -267,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 5. UpdateNode and DeleteNode batch together // 6. UpdateNode returns a valid node (from before delete in batch) // 7. persistNodeToDB is called with the stale valid node -// 8. Node gets re-inserted into database instead of staying deleted +// 8. Node gets re-inserted into database instead of staying deleted. func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -279,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -294,7 +309,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { go func() { node, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) endpoint := netip.MustParseAddrPort("10.0.0.1:41641") n.Endpoints = []netip.AddrPort{endpoint} }) @@ -349,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -363,7 +379,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { go func() { updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) updateDone <- struct { node types.NodeView @@ -399,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // 3. UpdateNode and DeleteNode batch together // 4. UpdateNode returns a valid node (from before delete in batch) // 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node -// 6. persistNodeToDB must detect the node is deleted and refuse to persist +// 6. persistNodeToDB must detect the node is deleted and refuse to persist. func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -409,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { } store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -417,7 +434,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { // UpdateNode returns a node updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { - n.LastSeen = ptr.To(time.Now()) + n.LastSeen = new(time.Now()) }) require.True(t, ok, "UpdateNode should succeed") require.True(t, updatedNode.Valid(), "updated node should be valid") diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go index e7dfc11c..d8cddaa1 100644 --- a/hscontrol/state/maprequest.go +++ b/hscontrol/state/maprequest.go @@ -29,6 +29,7 @@ func netInfoFromMapRequest( Uint64("node.id", nodeID.Uint64()). Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP). Msg("using NetInfo from previous Hostinfo in MapRequest") + return currentHostinfo.NetInfo } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 99f781d4..8a842e49 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -1,15 +1,12 @@ package state import ( - "net/netip" "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/key" - "tailscale.com/types/ptr" ) func TestNetInfoFromMapRequest(t *testing.T) { @@ -136,26 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node") }) } - -// Simple helper function for tests -func createTestNodeSimple(id types.NodeID) *types.Node { - user := types.User{ - Name: "test-user", - } - - machineKey := key.NewMachine() - nodeKey := key.NewNode() - - node := &types.Node{ - ID: id, - Hostname: "test-node", - UserID: ptr.To(uint(id)), - User: &user, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPv4: &netip.Addr{}, - IPv6: &netip.Addr{}, - } - - return node -} diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 6327b46b..1c921d6d 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -55,8 +55,8 @@ var ( }) nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ Namespace: prometheusNamespace, - Name: "nodestore_nodes_total", - Help: "Total number of nodes in the NodeStore", + Name: "nodestore_nodes", + Help: "Number of nodes in the NodeStore", }) nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ Namespace: prometheusNamespace, @@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc for _, n := range allNodes { nodes[n.ID] = *n } + snap := snapshotFromNodes(nodes, peersFunc) store := &NodeStore{ @@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("put").Inc() return resultNode @@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("update").Inc() // Return the node and whether it exists (is valid) @@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() @@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() { if len(batch) != 0 { s.applyBatch(batch) } + return } + batch = append(batch, w) if len(batch) >= s.batchSize { s.applyBatch(batch) @@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) { w.updateFn(&n) nodes[w.nodeID] = n } + if w.nodeResult != nil { nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) } @@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) { nodeView := node.View() for _, w := range workItems { w.nodeResult <- nodeView + close(w.nodeResult) } } else { // Node was deleted or doesn't exist for _, w := range workItems { w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) } } @@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S peersByNode: func() map[types.NodeID][]types.NodeView { peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) defer peersTimer.ObserveDuration() + return peersFunc(allNodes) }(), nodesByUser: make(map[types.UserID][]types.NodeView), @@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S if newSnap.nodesByMachineKey[n.MachineKey] == nil { newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView) } + newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView } @@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string { // User distribution (shows internal UserID tracking, not display owner) sb.WriteString("Nodes by Internal User ID:\n") + for userID, nodes := range snapshot.nodesByUser { if len(nodes) > 0 { userName := "unknown" taggedCount := 0 + if len(nodes) > 0 && nodes[0].Valid() { userName = nodes[0].User().Name() // Count tagged nodes (which have UserID set but are owned by "tagged-devices") @@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string { } } } + sb.WriteString("\n") // Peer relationships summary sb.WriteString("Peer Relationships:\n") + totalPeers := 0 + for nodeID, peers := range snapshot.peersByNode { peerCount := len(peers) + totalPeers += peerCount if node, exists := snapshot.nodesByID[nodeID]; exists { sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n", nodeID, node.Hostname, peerCount)) } } + if len(snapshot.peersByNode) > 0 { avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode)) sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers)) } + sb.WriteString("\n") // Node key index @@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() { } s.writeQueue <- w + <-result } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 3d6184ba..522bb64e 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -2,6 +2,7 @@ package state import ( "context" + "errors" "fmt" "net/netip" "runtime" @@ -13,7 +14,13 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/types/key" - "tailscale.com/types/ptr" +) + +// Test sentinel errors for concurrent operations. +var ( + errTestUpdateNodeFailed = errors.New("UpdateNode failed") + errTestGetNodeFailed = errors.New("GetNode failed") + errTestPutNodeFailed = errors.New("PutNode failed") ) func TestSnapshotFromNodes(t *testing.T) { @@ -33,6 +40,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, peersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Empty(t, snapshot.nodesByID) assert.Empty(t, snapshot.allNodes) assert.Empty(t, snapshot.peersByNode) @@ -45,9 +53,11 @@ func TestSnapshotFromNodes(t *testing.T) { nodes := map[types.NodeID]types.Node{ 1: createTestNode(1, 1, "user1", "node1"), } + return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 1) assert.Len(t, snapshot.allNodes, 1) assert.Len(t, snapshot.peersByNode, 1) @@ -71,6 +81,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 2) assert.Len(t, snapshot.allNodes, 2) assert.Len(t, snapshot.peersByNode, 2) @@ -96,6 +107,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 3) assert.Len(t, snapshot.allNodes, 3) assert.Len(t, snapshot.peersByNode, 3) @@ -125,6 +137,7 @@ func TestSnapshotFromNodes(t *testing.T) { return nodes, peersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + t.Helper() assert.Len(t, snapshot.nodesByID, 4) assert.Len(t, snapshot.allNodes, 4) assert.Len(t, snapshot.peersByNode, 4) @@ -174,7 +187,7 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) DiscoKey: discoKey.Public(), Hostname: hostname, GivenName: hostname, - UserID: ptr.To(userID), + UserID: new(userID), User: &types.User{ Name: username, DisplayName: username, @@ -193,11 +206,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + for _, n := range nodes { if n.ID() != node.ID() { peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -208,6 +223,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 for _, n := range nodes { @@ -222,6 +238,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -237,6 +254,7 @@ func TestNodeStoreOperations(t *testing.T) { { name: "create empty store and add single node", setupFunc: func(t *testing.T) *NodeStore { + t.Helper() return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) }, steps: []testStep{ @@ -455,10 +473,13 @@ func TestNodeStoreOperations(t *testing.T) { // Add nodes in sequence n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) assert.True(t, n4.Valid()) @@ -526,16 +547,20 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2 types.NodeView - var newNode3 types.NodeView - var ok1, ok2 bool + var ( + resultNode1, resultNode2 types.NodeView + newNode3 types.NodeView + ok1, ok2 bool + ) // These should all be processed in the same batch + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "batch-updated-node1" n.GivenName = "batch-given-1" }) + close(done1) }() @@ -544,12 +569,14 @@ func TestNodeStoreOperations(t *testing.T) { n.Hostname = "batch-updated-node2" n.GivenName = "batch-given-2" }) + close(done2) }() go func() { node3 := createTestNode(3, 1, "user1", "node3") newNode3 = store.PutNode(node3) + close(done3) }() @@ -602,20 +629,23 @@ func TestNodeStoreOperations(t *testing.T) { // This test verifies that when multiple updates to the same node // are batched together, each returned node reflects ALL changes // in the batch, not just the individual update's changes. - done1 := make(chan struct{}) done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2, resultNode3 types.NodeView - var ok1, ok2, ok3 bool + var ( + resultNode1, resultNode2, resultNode3 types.NodeView + ok1, ok2, ok3 bool + ) // These updates all modify node 1 and should be batched together // The final state should have all three modifications applied + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "multi-update-hostname" }) + close(done1) }() @@ -623,6 +653,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "multi-update-givenname" }) + close(done2) }() @@ -630,6 +661,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"tag1", "tag2"} }) + close(done3) }() @@ -723,14 +755,18 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var result1, result2, result3 types.NodeView - var ok1, ok2, ok3 bool + var ( + result1, result2, result3 types.NodeView + ok1, ok2, ok3 bool + ) // Start concurrent updates + go func() { result1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "concurrent-db-hostname" }) + close(done1) }() @@ -738,6 +774,7 @@ func TestNodeStoreOperations(t *testing.T) { result2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "concurrent-db-given" }) + close(done2) }() @@ -745,6 +782,7 @@ func TestNodeStoreOperations(t *testing.T) { result3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"concurrent-tag"} }) + close(done3) }() @@ -828,6 +866,7 @@ func TestNodeStoreOperations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := tt.setupFunc(t) + store.Start() defer store.Stop() @@ -847,88 +886,107 @@ type testStep struct { // --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- -// Helper for concurrent test nodes +// Helper for concurrent test nodes. func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { machineKey := key.NewMachine() nodeKey := key.NewNode() + return types.Node{ ID: id, Hostname: hostname, MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), - UserID: ptr.To(uint(1)), + UserID: new(uint(1)), User: &types.User{ Name: "concurrent-test-user", }, } } -// --- Concurrency: concurrent PutNode operations --- +// --- Concurrency: concurrent PutNode operations ---. func TestNodeStoreConcurrentPutNode(t *testing.T) { const concurrentOps = 20 store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, concurrentOps) for i := range concurrentOps { wg.Add(1) + go func(nodeID int) { defer wg.Done() + + //nolint:gosec node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") } -// --- Batching: concurrent ops fit in one batch --- +// --- Batching: concurrent ops fit in one batch ---. func TestNodeStoreBatchingEfficiency(t *testing.T) { - const batchSize = 10 - const ops = 15 // more than batchSize + const ops = 15 store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, ops) for i := range ops { wg.Add(1) + go func(nodeID int) { defer wg.Done() + + //nolint:gosec node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") } -// --- Race conditions: many goroutines on same node --- +// --- Race conditions: many goroutines on same node ---. func TestNodeStoreRaceConditions(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -937,13 +995,18 @@ func TestNodeStoreRaceConditions(t *testing.T) { resultNode := store.PutNode(node) require.True(t, resultNode.Valid()) - const numGoroutines = 30 - const opsPerGoroutine = 10 + const ( + numGoroutines = 30 + opsPerGoroutine = 10 + ) + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) for i := range numGoroutines { wg.Add(1) + go func(gid int) { defer wg.Done() @@ -954,40 +1017,46 @@ func TestNodeStoreRaceConditions(t *testing.T) { n.Hostname = "race-updated" }) if !resultNode.Valid() { - errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestUpdateNodeFailed, gid, j) } case 1: retrieved, found := store.GetNode(nodeID) if !found || !retrieved.Valid() { - errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestGetNodeFailed, gid, j) } case 2: newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) if !resultNode.Valid() { - errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) + errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestPutNodeFailed, gid, j) } } } }(i) } + wg.Wait() close(errors) errorCount := 0 + for err := range errors { t.Error(err) + errorCount++ } + if errorCount > 0 { t.Fatalf("Race condition test failed with %d errors", errorCount) } } -// --- Resource cleanup: goroutine leak detection --- +// --- Resource cleanup: goroutine leak detection ---. func TestNodeStoreResourceCleanup(t *testing.T) { // initialGoroutines := runtime.NumGoroutine() store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1001,7 +1070,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) { const ops = 100 for i := range ops { - nodeID := types.NodeID(i + 1) + nodeID := types.NodeID(i + 1) //nolint:gosec node := createConcurrentTestNode(nodeID, "cleanup-node") resultNode := store.PutNode(node) assert.True(t, resultNode.Valid()) @@ -1010,10 +1079,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }) retrieved, found := store.GetNode(nodeID) assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { store.DeleteNode(nodeID) } } + runtime.GC() // Wait for goroutines to settle and check for leaks @@ -1024,9 +1095,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }, time.Second, 10*time.Millisecond, "goroutines should not leak") } -// --- Timeout/deadlock: operations complete within reasonable time --- +// --- Timeout/deadlock: operations complete within reasonable time ---. func TestNodeStoreOperationTimeout(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1034,36 +1106,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) { defer cancel() const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) updateResults := make([]error, ops) // Launch all PutNode operations concurrently for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) node := createConcurrentTestNode(id, "timeout-node") resultNode := store.PutNode(node) endPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { - putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) + putResults[idx-1] = fmt.Errorf("%w for node %d", errTestPutNodeFailed, id) } }(i, nodeID) } + wg.Wait() // Launch all UpdateNode operations concurrently wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { - nodeID := types.NodeID(i) + nodeID := types.NodeID(i) //nolint:gosec + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) resultNode, ok := store.UpdateNode(id, func(n *types.Node) { @@ -1071,31 +1154,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) { }) endUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { - updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) + updateResults[idx-1] = fmt.Errorf("%w for node %d", errTestUpdateNodeFailed, id) } }(i, nodeID) } + done := make(chan struct{}) + go func() { wg.Wait() close(done) }() + select { case <-done: errorCount := 0 + for _, err := range putResults { if err != nil { t.Error(err) + errorCount++ } } + for _, err := range updateResults { if err != nil { t.Error(err) + errorCount++ } } + if errorCount == 0 { t.Log("All concurrent operations completed successfully within timeout") } else { @@ -1107,13 +1199,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) { } } -// --- Edge case: update non-existent node --- +// --- Edge case: update non-existent node ---. func TestNodeStoreUpdateNonExistentNode(t *testing.T) { for i := range 10 { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() - nonExistentID := types.NodeID(999 + i) + + nonExistentID := types.NodeID(999 + i) //nolint:gosec updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { updateCallCount++ @@ -1127,20 +1221,22 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) { } } -// --- Allocation benchmark --- +// --- Allocation benchmark ---. func BenchmarkNodeStoreAllocations(b *testing.B) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() for i := 0; b.Loop(); i++ { - nodeID := types.NodeID(i + 1) + nodeID := types.NodeID(i + 1) //nolint:gosec node := createConcurrentTestNode(nodeID, "bench-node") store.PutNode(node) store.UpdateNode(nodeID, func(n *types.Node) { n.Hostname = "bench-updated" }) store.GetNode(nodeID) + if i%10 == 9 { store.DeleteNode(nodeID) } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index d1401ef0..fbb4c421 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -25,10 +25,8 @@ import ( "github.com/rs/zerolog/log" "golang.org/x/sync/errgroup" "gorm.io/gorm" - "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/ptr" "tailscale.com/types/views" zcache "zgo.at/zcache/v2" ) @@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) { // On startup, all nodes should be marked as offline until they reconnect // This ensures we don't have stale online status from previous runs for _, node := range nodes { - node.IsOnline = ptr.To(false) + node.IsOnline = new(false) } users, err := db.ListUsers() if err != nil { @@ -189,7 +187,8 @@ func NewState(cfg *types.Config) (*State, error) { func (s *State) Close() error { s.nodeStore.Stop() - if err := s.db.Close(); err != nil { + err := s.db.Close() + if err != nil { return fmt.Errorf("closing database: %w", err) } @@ -225,6 +224,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // propagate correctly when switching between policy types. s.nodeStore.RebuildPeerMaps() + //nolint:prealloc cs := []change.Change{change.PolicyChange()} // Always call autoApproveNodes during policy reload, regardless of whether @@ -255,6 +255,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) { // CreateUser creates a new user and updates the policy manager. // Returns the created user, change set, and any error. func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) { + //nolint:noinlineerr if err := s.db.DB.Save(&user).Error; err != nil { return nil, change.Change{}, fmt.Errorf("creating user: %w", err) } @@ -289,6 +290,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error return nil, err } + //nolint:noinlineerr if err := updateFn(user); err != nil { return nil, err } @@ -468,10 +470,8 @@ func (s *State) Connect(id types.NodeID) []change.Change { // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, // the NodeStore already reflects the correct online status for full map generation. - // now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.IsOnline = ptr.To(true) - // n.LastSeen = ptr.To(now) + n.IsOnline = new(true) }) if !ok { return nil @@ -495,16 +495,17 @@ func (s *State) Connect(id types.NodeID) []change.Change { // Disconnect marks a node as disconnected and updates its primary routes in the state. func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { + //nolint:staticcheck now := time.Now() node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { - n.LastSeen = ptr.To(now) + n.LastSeen = new(now) // NodeStore is the source of truth for all node state including online status. - n.IsOnline = ptr.To(false) + n.IsOnline = new(false) }) if !ok { - return nil, fmt.Errorf("node not found: %d", id) + return nil, fmt.Errorf("%w: %d", ErrNodeNotFound, id) } log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected") @@ -745,13 +746,14 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t // RenameNode changes the display name of a node. func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) { - if err := util.ValidateHostname(newName); err != nil { + err := util.ValidateHostname(newName) + if err != nil { return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err) } // Check name uniqueness against NodeStore allNodes := s.nodeStore.ListNodes() - for i := 0; i < allNodes.Len(); i++ { + for i := range allNodes.Len() { node := allNodes.At(i) if node.ID() != nodeID && node.AsStruct().GivenName == newName { return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName) @@ -790,7 +792,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) { // Preserve online status and NetInfo when refreshing from database existingNode, exists := s.nodeStore.GetNode(node.ID) if exists && existingNode.Valid() { - node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + node.IsOnline = new(existingNode.IsOnline().Get()) // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). @@ -818,6 +820,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha var updates []change.Change + //nolint:unqueryvet for _, node := range s.nodeStore.ListNodes().All() { if !node.Valid() { continue @@ -1117,7 +1120,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro DiscoKey: params.DiscoKey, Hostinfo: params.Hostinfo, Endpoints: params.Endpoints, - LastSeen: ptr.To(time.Now()), + LastSeen: new(time.Now()), RegisterMethod: params.RegisterMethod, Expiry: params.Expiry, } @@ -1218,7 +1221,8 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro // New node - database first to get ID, then NodeStore savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { - if err := tx.Save(&nodeToRegister).Error; err != nil { + err := tx.Save(&nodeToRegister).Error + if err != nil { return nil, fmt.Errorf("failed to save node: %w", err) } @@ -1407,8 +1411,8 @@ func (s *State) HandleNodeFromAuthPath( node.Endpoints = regEntry.Node.Endpoints node.RegisterMethod = regEntry.Node.RegisterMethod - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the provided value or registration entry. @@ -1669,8 +1673,8 @@ func (s *State) HandleNodeFromPreAuthKey( // Only update AuthKey reference node.AuthKey = pak node.AuthKeyID = &pak.ID - node.IsOnline = ptr.To(false) - node.LastSeen = ptr.To(time.Now()) + node.IsOnline = new(false) + node.LastSeen = new(time.Now()) // Tagged nodes keep their existing expiry (disabled). // User-owned nodes update expiry from the client request. @@ -1697,7 +1701,7 @@ func (s *State) HandleNodeFromPreAuthKey( } } - return nil, nil + return nil, nil //nolint:nilnil }) if err != nil { return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err) @@ -2122,8 +2126,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool { newRoutes = []netip.Prefix{} } - tsaddr.SortPrefixes(oldRoutes) - tsaddr.SortPrefixes(newRoutes) + slices.SortFunc(oldRoutes, netip.Prefix.Compare) + slices.SortFunc(newRoutes, netip.Prefix.Compare) return !slices.Equal(oldRoutes, newRoutes) } diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 1a949173..60292912 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -13,6 +13,9 @@ import ( "tailscale.com/types/logger" ) +// ErrNoCertDomains is returned when no cert domains are available for HTTPS. +var ErrNoCertDomains = errors.New("no cert domains available for HTTPS") + func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error { opts := tailsql.Options{ Hostname: "tailsql-headscale", @@ -71,7 +74,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. certDomains := tsNode.CertDomains() if len(certDomains) == 0 { - return errors.New("no cert domains available for HTTPS") + return ErrNoCertDomains } base := "https://" + certDomains[0] go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -93,6 +96,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s mux := tsql.NewMux() tsweb.Debugger(mux) go http.Serve(lst, mux) + logf("TailSQL started") <-ctx.Done() logf("TailSQL shutting down...") diff --git a/hscontrol/templates/design.go b/hscontrol/templates/design.go index 615c0e41..2033f245 100644 --- a/hscontrol/templates/design.go +++ b/hscontrol/templates/design.go @@ -15,43 +15,43 @@ import ( // Material for MkDocs design system - exact values from official docs. const ( // Text colors - from --md-default-fg-color CSS variables. - colorTextPrimary = "#000000de" //nolint:unused // rgba(0,0,0,0.87) - Body text - colorTextSecondary = "#0000008a" //nolint:unused // rgba(0,0,0,0.54) - Headings (--md-default-fg-color--light) - colorTextTertiary = "#00000052" //nolint:unused // rgba(0,0,0,0.32) - Lighter text - colorTextLightest = "#00000012" //nolint:unused // rgba(0,0,0,0.07) - Lightest text + colorTextPrimary = "#000000de" //nolint:unused + colorTextSecondary = "#0000008a" //nolint:unused + colorTextTertiary = "#00000052" //nolint:unused + colorTextLightest = "#00000012" //nolint:unused // Code colors - from --md-code-* CSS variables. - colorCodeFg = "#36464e" //nolint:unused // Code text color (--md-code-fg-color) - colorCodeBg = "#f5f5f5" //nolint:unused // Code background (--md-code-bg-color) + colorCodeFg = "#36464e" //nolint:unused + colorCodeBg = "#f5f5f5" //nolint:unused // Border colors. - colorBorderLight = "#e5e7eb" //nolint:unused // Light borders - colorBorderMedium = "#d1d5db" //nolint:unused // Medium borders + colorBorderLight = "#e5e7eb" //nolint:unused + colorBorderMedium = "#d1d5db" //nolint:unused // Background colors. - colorBackgroundPage = "#ffffff" //nolint:unused // Page background - colorBackgroundCard = "#ffffff" //nolint:unused // Card/content background + colorBackgroundPage = "#ffffff" //nolint:unused + colorBackgroundCard = "#ffffff" //nolint:unused // Accent colors - from --md-primary/accent-fg-color. - colorPrimaryAccent = "#4051b5" //nolint:unused // Primary accent (links) - colorAccent = "#526cfe" //nolint:unused // Secondary accent + colorPrimaryAccent = "#4051b5" //nolint:unused + colorAccent = "#526cfe" //nolint:unused // Success colors. - colorSuccess = "#059669" //nolint:unused // Success states - colorSuccessLight = "#d1fae5" //nolint:unused // Success backgrounds + colorSuccess = "#059669" //nolint:unused + colorSuccessLight = "#d1fae5" //nolint:unused ) // Spacing System // Based on 4px/8px base unit for consistent rhythm. // Uses rem units for scalability with user font size preferences. const ( - spaceXS = "0.25rem" //nolint:unused // 4px - Tight spacing - spaceS = "0.5rem" //nolint:unused // 8px - Small spacing - spaceM = "1rem" //nolint:unused // 16px - Medium spacing (base) - spaceL = "1.5rem" //nolint:unused // 24px - Large spacing - spaceXL = "2rem" //nolint:unused // 32px - Extra large spacing - space2XL = "3rem" //nolint:unused // 48px - 2x extra large spacing - space3XL = "4rem" //nolint:unused // 64px - 3x extra large spacing + spaceXS = "0.25rem" //nolint:unused + spaceS = "0.5rem" //nolint:unused + spaceM = "1rem" //nolint:unused + spaceL = "1.5rem" //nolint:unused + spaceXL = "2rem" //nolint:unused + space2XL = "3rem" //nolint:unused + space3XL = "4rem" //nolint:unused ) // Typography System @@ -63,26 +63,26 @@ const ( fontFamilyCode = `"Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace` //nolint:unused // Font sizes - from .md-typeset CSS rules. - fontSizeBase = "0.8rem" //nolint:unused // 12.8px - Base text (.md-typeset) - fontSizeH1 = "2em" //nolint:unused // 2x base - Main headings - fontSizeH2 = "1.5625em" //nolint:unused // 1.5625x base - Section headings - fontSizeH3 = "1.25em" //nolint:unused // 1.25x base - Subsection headings - fontSizeSmall = "0.8em" //nolint:unused // 0.8x base - Small text - fontSizeCode = "0.85em" //nolint:unused // 0.85x base - Inline code + fontSizeBase = "0.8rem" //nolint:unused + fontSizeH1 = "2em" //nolint:unused + fontSizeH2 = "1.5625em" //nolint:unused + fontSizeH3 = "1.25em" //nolint:unused + fontSizeSmall = "0.8em" //nolint:unused + fontSizeCode = "0.85em" //nolint:unused // Line heights - from .md-typeset CSS rules. - lineHeightBase = "1.6" //nolint:unused // Body text (.md-typeset) - lineHeightH1 = "1.3" //nolint:unused // H1 headings - lineHeightH2 = "1.4" //nolint:unused // H2 headings - lineHeightH3 = "1.5" //nolint:unused // H3 headings - lineHeightCode = "1.4" //nolint:unused // Code blocks (pre) + lineHeightBase = "1.6" //nolint:unused + lineHeightH1 = "1.3" //nolint:unused + lineHeightH2 = "1.4" //nolint:unused + lineHeightH3 = "1.5" //nolint:unused + lineHeightCode = "1.4" //nolint:unused ) // Responsive Container Component // Creates a centered container with responsive padding and max-width. // Mobile-first approach: starts at 100% width with padding, constrains on larger screens. // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func responsiveContainer(children ...elem.Node) *elem.Element { return elem.Div(attrs.Props{ attrs.Style: styles.Props{ @@ -100,7 +100,7 @@ func responsiveContainer(children ...elem.Node) *elem.Element { // - title: Optional title for the card (empty string for no title) // - children: Content elements to display in the card // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func card(title string, children ...elem.Node) *elem.Element { cardContent := children if title != "" { @@ -134,7 +134,7 @@ func card(title string, children ...elem.Node) *elem.Element { // EXTRACTED FROM: .md-typeset pre CSS rules // Exact styling from Material for MkDocs documentation. // -//nolint:unused // Used across apple.go, windows.go, register_web.go templates. +//nolint:unused func codeBlock(code string) *elem.Element { return elem.Pre(attrs.Props{ attrs.Style: styles.Props{ @@ -164,7 +164,7 @@ func codeBlock(code string) *elem.Element { // Returns inline styles for the main content container that matches .md-typeset. // EXTRACTED FROM: .md-typeset CSS rule from Material for MkDocs. // -//nolint:unused // Used in general.go for mdTypesetBody. +//nolint:unused func baseTypesetStyles() styles.Props { return styles.Props{ styles.FontSize: fontSizeBase, // 0.8rem @@ -180,7 +180,7 @@ func baseTypesetStyles() styles.Props { // Returns inline styles for H1 headings that match .md-typeset h1. // EXTRACTED FROM: .md-typeset h1 CSS rule from Material for MkDocs. // -//nolint:unused // Used across templates for main headings. +//nolint:unused func h1Styles() styles.Props { return styles.Props{ styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54) @@ -198,7 +198,7 @@ func h1Styles() styles.Props { // Returns inline styles for H2 headings that match .md-typeset h2. // EXTRACTED FROM: .md-typeset h2 CSS rule from Material for MkDocs. // -//nolint:unused // Used across templates for section headings. +//nolint:unused func h2Styles() styles.Props { return styles.Props{ styles.FontSize: fontSizeH2, // 1.5625em @@ -216,7 +216,7 @@ func h2Styles() styles.Props { // Returns inline styles for H3 headings that match .md-typeset h3. // EXTRACTED FROM: .md-typeset h3 CSS rule from Material for MkDocs. // -//nolint:unused // Used across templates for subsection headings. +//nolint:unused func h3Styles() styles.Props { return styles.Props{ styles.FontSize: fontSizeH3, // 1.25em @@ -234,7 +234,7 @@ func h3Styles() styles.Props { // Returns inline styles for paragraphs that match .md-typeset p. // EXTRACTED FROM: .md-typeset p CSS rule from Material for MkDocs. // -//nolint:unused // Used for consistent paragraph spacing. +//nolint:unused func paragraphStyles() styles.Props { return styles.Props{ styles.Margin: "1em 0", @@ -250,7 +250,7 @@ func paragraphStyles() styles.Props { // Returns inline styles for ordered lists that match .md-typeset ol. // EXTRACTED FROM: .md-typeset ol CSS rule from Material for MkDocs. // -//nolint:unused // Used for numbered instruction lists. +//nolint:unused func orderedListStyles() styles.Props { return styles.Props{ styles.MarginBottom: "1em", @@ -268,7 +268,7 @@ func orderedListStyles() styles.Props { // Returns inline styles for unordered lists that match .md-typeset ul. // EXTRACTED FROM: .md-typeset ul CSS rule from Material for MkDocs. // -//nolint:unused // Used for bullet point lists. +//nolint:unused func unorderedListStyles() styles.Props { return styles.Props{ styles.MarginBottom: "1em", @@ -287,7 +287,7 @@ func unorderedListStyles() styles.Props { // EXTRACTED FROM: .md-typeset a CSS rule from Material for MkDocs. // Note: Hover states cannot be implemented with inline styles. // -//nolint:unused // Used for text links. +//nolint:unused func linkStyles() styles.Props { return styles.Props{ styles.Color: colorPrimaryAccent, // #4051b5 - var(--md-primary-fg-color) @@ -301,7 +301,7 @@ func linkStyles() styles.Props { // Returns inline styles for inline code that matches .md-typeset code. // EXTRACTED FROM: .md-typeset code CSS rule from Material for MkDocs. // -//nolint:unused // Used for inline code snippets. +//nolint:unused func inlineCodeStyles() styles.Props { return styles.Props{ styles.BackgroundColor: colorCodeBg, // #f5f5f5 @@ -317,7 +317,7 @@ func inlineCodeStyles() styles.Props { // Inline Code Component // For inline code snippets within text. // -//nolint:unused // Reserved for future inline code usage. +//nolint:unused func inlineCode(code string) *elem.Element { return elem.Code(attrs.Props{ attrs.Style: inlineCodeStyles().ToInline(), @@ -327,7 +327,7 @@ func inlineCode(code string) *elem.Element { // orDivider creates a visual "or" divider between sections. // Styled with lines on either side for better visual separation. // -//nolint:unused // Used in apple.go template. +//nolint:unused func orDivider() *elem.Element { return elem.Div(attrs.Props{ attrs.Style: styles.Props{ @@ -367,7 +367,7 @@ func orDivider() *elem.Element { // warningBox creates a warning message box with icon and content. // -//nolint:unused // Used in apple.go template. +//nolint:unused func warningBox(title, message string) *elem.Element { return elem.Div(attrs.Props{ attrs.Style: styles.Props{ @@ -404,7 +404,7 @@ func warningBox(title, message string) *elem.Element { // downloadButton creates a nice button-style link for downloads. // -//nolint:unused // Used in apple.go template. +//nolint:unused func downloadButton(href, text string) *elem.Element { return elem.A(attrs.Props{ attrs.Href: href, @@ -428,7 +428,7 @@ func downloadButton(href, text string) *elem.Element { // Creates a link with proper security attributes for external URLs. // Automatically adds rel="noreferrer noopener" and target="_blank". // -//nolint:unused // Used in apple.go, oidc_callback.go templates. +//nolint:unused func externalLink(href, text string) *elem.Element { return elem.A(attrs.Props{ attrs.Href: href, @@ -444,7 +444,7 @@ func externalLink(href, text string) *elem.Element { // Instruction Step Component // For numbered instruction lists with consistent formatting. // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func instructionStep(_ int, text string) *elem.Element { return elem.Li(attrs.Props{ attrs.Style: styles.Props{ @@ -457,7 +457,7 @@ func instructionStep(_ int, text string) *elem.Element { // Status Message Component // For displaying success/error/info messages with appropriate styling. // -//nolint:unused // Reserved for future use in Phase 4. +//nolint:unused func statusMessage(message string, isSuccess bool) *elem.Element { bgColor := colorSuccessLight textColor := colorSuccess diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go index a76fb7c4..37a63e80 100644 --- a/hscontrol/types/change/change.go +++ b/hscontrol/types/change/change.go @@ -333,7 +333,7 @@ func NodeOnline(nodeID types.NodeID) Change { PeerPatches: []*tailcfg.PeerChange{ { NodeID: nodeID.NodeID(), - Online: ptrTo(true), + Online: new(true), }, }, } @@ -346,7 +346,7 @@ func NodeOffline(nodeID types.NodeID) Change { PeerPatches: []*tailcfg.PeerChange{ { NodeID: nodeID.NodeID(), - Online: ptrTo(false), + Online: new(false), }, }, } @@ -365,11 +365,6 @@ func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change { } } -// ptrTo returns a pointer to the given value. -func ptrTo[T any](v T) *T { - return &v -} - // High-level change constructors // NodeAdded returns a Change for when a node is added or updated. diff --git a/hscontrol/types/change/change_test.go b/hscontrol/types/change/change_test.go index 9f181dd6..dc2dd0af 100644 --- a/hscontrol/types/change/change_test.go +++ b/hscontrol/types/change/change_test.go @@ -16,8 +16,8 @@ func TestChange_FieldSync(t *testing.T) { typ := reflect.TypeFor[Change]() boolCount := 0 - for i := range typ.NumField() { - if typ.Field(i).Type.Kind() == reflect.Bool { + for field := range typ.Fields() { + if field.Type.Kind() == reflect.Bool { boolCount++ } } diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index f4814519..e0f4fcdd 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -20,7 +20,10 @@ const ( DatabaseSqlite = "sqlite3" ) -var ErrCannotParsePrefix = errors.New("cannot parse prefix") +var ( + ErrCannotParsePrefix = errors.New("cannot parse prefix") + ErrInvalidRegIDLength = errors.New("registration ID has invalid length") +) type StateUpdateType int @@ -175,8 +178,9 @@ func MustRegistrationID() RegistrationID { func RegistrationIDFromString(str string) (RegistrationID, error) { if len(str) != RegistrationIDLength { - return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) + return "", fmt.Errorf("%w: expected %d characters", ErrInvalidRegIDLength, RegistrationIDLength) } + return RegistrationID(str), nil } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4068d72e..64410dd9 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -30,13 +30,16 @@ const ( PKCEMethodS256 string = "S256" defaultNodeStoreBatchSize = 100 + defaultWALAutocheckpoint = 1000 // SQLite default ) var ( - errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") - errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") - errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") - errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") + errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + errInvalidAllocationStrategy = errors.New("invalid prefixes.allocation strategy") ) type IPAllocationStrategy string @@ -301,6 +304,7 @@ func validatePKCEMethod(method string) error { if method != PKCEMethodPlain && method != PKCEMethodS256 { return errInvalidPKCEMethod } + return nil } @@ -377,7 +381,7 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600) viper.SetDefault("database.sqlite.write_ahead_log", true) - viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default + viper.SetDefault("database.sqlite.wal_autocheckpoint", defaultWALAutocheckpoint) viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) @@ -402,7 +406,8 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) if err := viper.ReadInConfig(); err != nil { - if _, ok := err.(viper.ConfigFileNotFoundError); ok { + var configFileNotFoundError viper.ConfigFileNotFoundError + if errors.As(err, &configFileNotFoundError) { log.Warn().Msg("No config file found, using defaults") return nil } @@ -442,7 +447,8 @@ func validateServerConfig() error { depr.fatal("oidc.map_legacy_users") if viper.GetBool("oidc.enabled") { - if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil { + err := validatePKCEMethod(viper.GetString("oidc.pkce.method")) + if err != nil { return err } } @@ -910,6 +916,7 @@ func LoadCLIConfig() (*Config, error) { // LoadServerConfig returns the full Headscale configuration to // host a Headscale server. This is called as part of `headscale serve`. func LoadServerConfig() (*Config, error) { + //nolint:noinlineerr if err := validateServerConfig(); err != nil { return nil, err } @@ -928,7 +935,7 @@ func LoadServerConfig() (*Config, error) { } if prefix4 == nil && prefix6 == nil { - return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + return nil, errNoPrefixConfigured } allocStr := viper.GetString("prefixes.allocation") @@ -940,7 +947,8 @@ func LoadServerConfig() (*Config, error) { alloc = IPAllocationStrategyRandom default: return nil, fmt.Errorf( - "config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", + "%w: %q, allowed options: %s, %s", + errInvalidAllocationStrategy, allocStr, IPAllocationStrategySequential, IPAllocationStrategyRandom, @@ -979,7 +987,8 @@ func LoadServerConfig() (*Config, error) { // - Control plane runs on login.tailscale.com/controlplane.tailscale.com // - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net) if dnsConfig.BaseDomain != "" { - if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil { + err := isSafeServerURL(serverURL, dnsConfig.BaseDomain) + if err != nil { return nil, err } } @@ -1082,6 +1091,7 @@ func LoadServerConfig() (*Config, error) { if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 { return workers } + return DefaultBatcherWorkers() }(), RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), @@ -1117,6 +1127,7 @@ func isSafeServerURL(serverURL, baseDomain string) error { } s := len(serverDomainParts) + b := len(baseDomainParts) for i := range baseDomainParts { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 6b9fc2ef..6ed2ef47 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -349,11 +349,8 @@ func TestReadConfigFromEnv(t *testing.T) { } func TestTLSConfigValidation(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - t.Fatal(err) - } - // defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() + configYaml := []byte(`--- tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: "" @@ -363,7 +360,8 @@ noise: // Populate a custom config file configFilePath := filepath.Join(tmpDir, "config.yaml") - err = os.WriteFile(configFilePath, configYaml, 0o600) + + err := os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) } @@ -398,10 +396,12 @@ server_url: http://127.0.0.1:8080 tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: TLS-ALPN-01 `) + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) } + err = LoadConfig(tmpDir, false) require.NoError(t, err) } @@ -463,6 +463,7 @@ func TestSafeServerURL(t *testing.T) { return } + assert.NoError(t, err) }) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 41cd9759..d75da265 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -42,10 +42,6 @@ type ( NodeIDs []NodeID ) -func (n NodeIDs) Len() int { return len(n) } -func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } -func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] } - func (id NodeID) StableID() tailcfg.StableNodeID { return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) } @@ -160,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool { // Strip invalid DNS characters for givenName comparison normalised := strings.ToLower(node.Hostname) normalised = invalidDNSRegex.ReplaceAllString(normalised, "") + return node.GivenName == normalised } @@ -197,13 +194,7 @@ func (node *Node) IPs() []netip.Addr { // HasIP reports if a node has a given IP address. func (node *Node) HasIP(i netip.Addr) bool { - for _, ip := range node.IPs() { - if ip.Compare(i) == 0 { - return true - } - } - - return false + return slices.Contains(node.IPs(), i) } // IsTagged reports if a device is tagged and therefore should not be treated @@ -243,6 +234,7 @@ func (node *Node) RequestTags() []string { } func (node *Node) Prefixes() []netip.Prefix { + //nolint:prealloc var addrs []netip.Prefix for _, nodeAddress := range node.IPs() { ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) @@ -272,6 +264,7 @@ func (node *Node) IsExitNode() bool { } func (node *Node) IPsAsString() []string { + //nolint:prealloc var ret []string for _, ip := range node.IPs() { @@ -355,13 +348,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { } func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool { - for _, node := range nodes { - if node.NodeKey == nodeKey { - return true - } - } - - return false + return slices.ContainsFunc(nodes, func(node *Node) bool { + return node.NodeKey == nodeKey + }) } func (node *Node) Proto() *v1.Node { @@ -478,7 +467,7 @@ func (node *Node) IsSubnetRouter() bool { return len(node.SubnetRoutes()) > 0 } -// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes +// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes. func (node *Node) AllApprovedRoutes() []netip.Prefix { return append(node.SubnetRoutes(), node.ExitRoutes()...) } @@ -586,13 +575,16 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { } newHostname := strings.ToLower(hostInfo.Hostname) - if err := util.ValidateHostname(newHostname); err != nil { + + err := util.ValidateHostname(newHostname) + if err != nil { log.Warn(). Str("node.id", node.ID.String()). Str("current_hostname", node.Hostname). Str("rejected_hostname", hostInfo.Hostname). Err(err). Msg("Rejecting invalid hostname update from hostinfo") + return } @@ -684,6 +676,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node { func (nodes Nodes) DebugString() string { var sb strings.Builder sb.WriteString("Nodes:\n") + for _, node := range nodes { sb.WriteString(node.DebugString()) sb.WriteString("\n") @@ -855,7 +848,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer // GetFQDN returns the fully qualified domain name for the node. func (nv NodeView) GetFQDN(baseDomain string) (string, error) { if !nv.Valid() { - return "", errors.New("failed to create valid FQDN: node view is invalid") + return "", fmt.Errorf("failed to create valid FQDN: %w", ErrInvalidNodeView) } return nv.ж.GetFQDN(baseDomain) @@ -934,11 +927,11 @@ func (nv NodeView) TailscaleUserID() tailcfg.UserID { } if nv.IsTagged() { - //nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64 + //nolint:gosec return tailcfg.UserID(int64(TaggedDevices.ID)) } - //nolint:gosec // G115: UserID values are within int64 range + //nolint:gosec return tailcfg.UserID(int64(nv.UserID().Get())) } @@ -1048,7 +1041,7 @@ func (nv NodeView) TailNode( primaryRoutes := primaryRouteFunc(nv.ID()) allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes()) - tsaddr.SortPrefixes(allowedIPs) + slices.SortFunc(allowedIPs, netip.Prefix.Compare) capMap := tailcfg.NodeCapMap{ tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, @@ -1063,7 +1056,7 @@ func (nv NodeView) TailNode( } tNode := tailcfg.Node{ - //nolint:gosec // G115: NodeID values are within int64 range + //nolint:gosec ID: tailcfg.NodeID(nv.ID()), StableID: nv.ID().StableID(), Name: hostname, diff --git a/hscontrol/types/node_tags_test.go b/hscontrol/types/node_tags_test.go index 72598b3c..97e01b2a 100644 --- a/hscontrol/types/node_tags_test.go +++ b/hscontrol/types/node_tags_test.go @@ -6,7 +6,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/stretchr/testify/assert" "gorm.io/gorm" - "tailscale.com/types/ptr" ) // TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged. @@ -69,7 +68,7 @@ func TestNodeIsTagged(t *testing.T) { { name: "node with user and no tags - not tagged", node: Node{ - UserID: ptr.To(uint(42)), + UserID: new(uint(42)), Tags: []string{}, }, want: false, @@ -112,7 +111,7 @@ func TestNodeViewIsTagged(t *testing.T) { { name: "user-owned node", node: Node{ - UserID: ptr.To(uint(1)), + UserID: new(uint(1)), }, want: false, }, @@ -223,7 +222,7 @@ func TestNodeTagsImmutableAfterRegistration(t *testing.T) { // Test that a user-owned node is not tagged userNode := Node{ ID: 2, - UserID: ptr.To(uint(42)), + UserID: new(uint(42)), Tags: []string{}, RegisterMethod: util.RegisterMethodOIDC, } @@ -243,7 +242,7 @@ func TestNodeOwnershipModel(t *testing.T) { name: "tagged node has tags, UserID is informational", node: Node{ ID: 1, - UserID: ptr.To(uint(5)), // "created by" user 5 + UserID: new(uint(5)), // "created by" user 5 Tags: []string{"tag:server"}, }, wantIsTagged: true, @@ -253,7 +252,7 @@ func TestNodeOwnershipModel(t *testing.T) { name: "user-owned node has no tags", node: Node{ ID: 2, - UserID: ptr.To(uint(5)), + UserID: new(uint(5)), Tags: []string{}, }, wantIsTagged: false, @@ -265,7 +264,7 @@ func TestNodeOwnershipModel(t *testing.T) { name: "node with only authkey tags - not tagged (tags should be copied)", node: Node{ ID: 3, - UserID: ptr.To(uint(5)), // "created by" user 5 + UserID: new(uint(5)), // "created by" user 5 AuthKey: &PreAuthKey{ Tags: []string{"tag:database"}, }, diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 9518833f..5210e363 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -407,7 +407,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", @@ -491,7 +491,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "server-北京-01", + Hostname: "server-北京-01", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", @@ -505,7 +505,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", @@ -533,7 +533,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) { Hostname: "valid-hostname", }, change: &tailcfg.Hostinfo{ - Hostname: "测试💻机器", + Hostname: "测试💻机器", //nolint:gosmopolitan }, want: Node{ GivenName: "valid-hostname", diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 2ce02f02..18956d7a 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -114,7 +114,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { return &protoKey } -// canUsePreAuthKey checks if a pre auth key can be used. +// Validate checks if a pre auth key can be used. func (pak *PreAuthKey) Validate() error { if pak == nil { return PAKError("invalid authkey") @@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error { if pak.Expiration != nil { return *pak.Expiration } + return time.Time{} }()). Time("now", time.Now()). diff --git a/hscontrol/types/preauth_key_test.go b/hscontrol/types/preauth_key_test.go index 4ab1c717..1b280149 100644 --- a/hscontrol/types/preauth_key_test.go +++ b/hscontrol/types/preauth_key_test.go @@ -110,8 +110,7 @@ func TestCanUsePreAuthKey(t *testing.T) { if err == nil { t.Errorf("expected error but got none") } else { - var httpErr PAKError - ok := errors.As(err, &httpErr) + httpErr, ok := errors.AsType[PAKError](err) if !ok { t.Errorf("expected HTTPError but got %T", err) } else { diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index ec40492b..f1120929 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -4,6 +4,7 @@ import ( "cmp" "database/sql" "encoding/json" + "errors" "fmt" "net/mail" "net/url" @@ -18,6 +19,9 @@ import ( "tailscale.com/tailcfg" ) +// ErrCannotParseBool is returned when a value cannot be parsed as a boolean. +var ErrCannotParseBool = errors.New("could not parse value as boolean") + type UserID uint64 type Users []User @@ -40,9 +44,11 @@ var TaggedDevices = User{ func (u Users) String() string { var sb strings.Builder sb.WriteString("[ ") + for _, user := range u { fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) } + sb.WriteString(" ]") return sb.String() @@ -89,11 +95,12 @@ func (u *User) StringID() string { if u == nil { return "" } + return strconv.FormatUint(uint64(u.ID), 10) } // TypedID returns a pointer to the user's ID as a UserID type. -// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)). +// This is a convenience method to avoid ugly casting like new(types.UserID(user.ID)). func (u *User) TypedID() *UserID { uid := UserID(u.ID) return &uid @@ -148,7 +155,7 @@ func (u UserView) ID() uint { func (u *User) TailscaleLogin() tailcfg.Login { return tailcfg.Login{ - ID: tailcfg.LoginID(u.ID), + ID: tailcfg.LoginID(u.ID), //nolint:gosec Provider: u.Provider, LoginName: u.Username(), DisplayName: u.Display(), @@ -194,8 +201,8 @@ func (u *User) Proto() *v1.User { } } -// JumpCloud returns a JSON where email_verified is returned as a -// string "true" or "false" instead of a boolean. +// FlexibleBoolean handles JSON where email_verified is returned as a +// string "true" or "false" instead of a boolean (e.g., JumpCloud). // This maps bool to a specific type with a custom unmarshaler to // ensure we can decode it from a string. // https://github.com/juanfont/headscale/issues/2293 @@ -203,6 +210,7 @@ type FlexibleBoolean bool func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { var val any + err := json.Unmarshal(data, &val) if err != nil { return fmt.Errorf("could not unmarshal data: %w", err) @@ -216,10 +224,11 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { if err != nil { return fmt.Errorf("could not parse %s as boolean: %w", v, err) } + *bit = FlexibleBoolean(pv) default: - return fmt.Errorf("could not parse %v as boolean", v) + return fmt.Errorf("%w: %v", ErrCannotParseBool, v) } return nil @@ -253,9 +262,11 @@ func (c *OIDCClaims) Identifier() string { if c.Iss == "" && c.Sub == "" { return "" } + if c.Iss == "" { return CleanIdentifier(c.Sub) } + if c.Sub == "" { return CleanIdentifier(c.Iss) } @@ -266,8 +277,10 @@ func (c *OIDCClaims) Identifier() string { var result string // Try to parse as URL to handle URL joining correctly + //nolint:noinlineerr if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { // For URLs, use proper URL path joining + //nolint:noinlineerr if joined, err := url.JoinPath(issuer, subject); err == nil { result = joined } @@ -340,6 +353,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, trimmed) } } + if len(cleanParts) == 0 { return "" } @@ -382,6 +396,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) { if claims.Iss == "" && !strings.HasPrefix(identifier, "/") { identifier = "/" + identifier } + u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true} u.DisplayName = claims.Name u.ProfilePicURL = claims.ProfilePictureURL diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index 15386553..064388eb 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -66,10 +66,13 @@ func TestUnmarshallOIDCClaims(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims - if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + + err := json.Unmarshal([]byte(tt.jsonstr), &got) + if err != nil { t.Errorf("UnmarshallOIDCClaims() error = %v", err) return } + if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff) } @@ -190,6 +193,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) { } result := claims.Identifier() assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("Identifier() mismatch (-want +got):\n%s", diff) } @@ -282,6 +286,7 @@ func TestCleanIdentifier(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := CleanIdentifier(tt.identifier) assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff) } @@ -479,7 +484,9 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var got OIDCClaims - if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + + err := json.Unmarshal([]byte(tt.jsonstr), &got) + if err != nil { t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err) return } @@ -487,6 +494,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { var user User user.FromClaim(&got, tt.emailVerifiedRequired) + if diff := cmp.Diff(user, tt.want); diff != "" { t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/types/version.go b/hscontrol/types/version.go index 6676c92f..96dc58a6 100644 --- a/hscontrol/types/version.go +++ b/hscontrol/types/version.go @@ -38,9 +38,7 @@ func (v *VersionInfo) String() string { return sb.String() } -var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) { - return debug.ReadBuildInfo() -}) +var buildInfo = sync.OnceValues(debug.ReadBuildInfo) var GetVersionInfo = sync.OnceValue(func() *VersionInfo { info := &VersionInfo{ diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index dcd58528..8ec40790 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -20,12 +20,30 @@ const ( // value related to RFC 1123 and 952. LabelHostnameLength = 63 + + // minNameLength is the minimum length for usernames and hostnames. + minNameLength = 2 ) var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") var ErrInvalidHostName = errors.New("invalid hostname") +// Sentinel errors for username validation. +var ( + ErrUsernameTooShort = errors.New("username must be at least 2 characters long") + ErrUsernameMustStartLetter = errors.New("username must start with a letter") + ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'") + ErrUsernameInvalidChar = errors.New("username contains invalid character") +) + +// Sentinel errors for hostname validation. +var ( + ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters") + ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen") + ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot") +) + // ValidateUsername checks if a username is valid. // It must be at least 2 characters long, start with a letter, and contain // only letters, numbers, hyphens, dots, and underscores. @@ -33,13 +51,13 @@ var ErrInvalidHostName = errors.New("invalid hostname") // It cannot contain invalid characters. func ValidateUsername(username string) error { // Ensure the username meets the minimum length requirement - if len(username) < 2 { - return errors.New("username must be at least 2 characters long") + if len(username) < minNameLength { + return ErrUsernameTooShort } // Ensure the username starts with a letter if !unicode.IsLetter(rune(username[0])) { - return errors.New("username must start with a letter") + return ErrUsernameMustStartLetter } atCount := 0 @@ -55,10 +73,10 @@ func ValidateUsername(username string) error { case char == '@': atCount++ if atCount > 1 { - return errors.New("username cannot contain more than one '@'") + return ErrUsernameTooManyAt } default: - return fmt.Errorf("username contains invalid character: '%c'", char) + return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char) } } @@ -69,11 +87,8 @@ func ValidateUsername(username string) error { // This function does NOT modify the input - it only validates. // The hostname must already be lowercase and contain only valid characters. func ValidateHostname(name string) error { - if len(name) < 2 { - return fmt.Errorf( - "hostname %q is too short, must be at least 2 characters", - name, - ) + if len(name) < minNameLength { + return fmt.Errorf("%w: %q", ErrHostnameTooShort, name) } if len(name) > LabelHostnameLength { return fmt.Errorf( @@ -90,17 +105,11 @@ func ValidateHostname(name string) error { } if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") { - return fmt.Errorf( - "hostname %q cannot start or end with a hyphen", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameHyphenEnds, name) } if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { - return fmt.Errorf( - "hostname %q cannot start or end with a dot", - name, - ) + return fmt.Errorf("%w: %q", ErrHostnameDotEnds, name) } if invalidDNSRegex.MatchString(name) { diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index b492e4d6..4f9a338f 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) { t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if !tt.wantErr && got != tt.want { t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want) } @@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) { t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr && tt.errorContains != "" { if err == nil || !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains) diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go index 098f1979..5f0adede 100644 --- a/hscontrol/util/prompt.go +++ b/hscontrol/util/prompt.go @@ -14,11 +14,14 @@ func YesNo(msg string) bool { fmt.Fprint(os.Stderr, msg+" [y/n] ") var resp string - fmt.Scanln(&resp) + + _, _ = fmt.Scanln(&resp) + resp = strings.ToLower(resp) switch resp { case "y", "yes", "sure": return true } + return false } diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go index d726ec60..ac405f8c 100644 --- a/hscontrol/util/prompt_test.go +++ b/hscontrol/util/prompt_test.go @@ -86,7 +86,8 @@ func TestYesNo(t *testing.T) { // Write test input go func() { defer w.Close() - w.WriteString(tt.input) + + _, _ = w.WriteString(tt.input) }() // Call the function @@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check the result @@ -104,10 +106,12 @@ func TestYesNo(t *testing.T) { // Check that the prompt was written to stderr var stderrBuf bytes.Buffer - io.Copy(&stderrBuf, stderrR) + + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() expectedPrompt := "Test question [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -130,7 +134,8 @@ func TestYesNoPromptMessage(t *testing.T) { // Write test input go func() { defer w.Close() - w.WriteString("n\n") + + _, _ = w.WriteString("n\n") }() // Call the function with a custom message @@ -140,14 +145,17 @@ func TestYesNoPromptMessage(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check that the custom message was included in the prompt var stderrBuf bytes.Buffer - io.Copy(&stderrBuf, stderrR) + + _, _ = io.Copy(&stderrBuf, stderrR) stderrR.Close() expectedPrompt := customMessage + " [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -186,7 +194,8 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Write test input go func() { defer w.Close() - w.WriteString(tc.input) + + _, _ = w.WriteString(tc.input) }() // Call the function @@ -195,10 +204,11 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Drain stderr - io.Copy(io.Discard, stderrR) + _, _ = io.Copy(io.Discard, stderrR) stderrR.Close() if result != tc.expected { diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index d1d7ece7..60f99420 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -9,6 +9,9 @@ import ( "tailscale.com/tailcfg" ) +// invalidStringRandomLength is the length of random bytes for invalid string generation. +const invalidStringRandomLength = 8 + // GenerateRandomBytes returns securely generated random bytes. // It will return an error if the system's secure random // number generator fails to function correctly, in which @@ -33,6 +36,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) { b, err := GenerateRandomBytes(n) uenc := base64.RawURLEncoding.EncodeToString(b) + return uenc[:n], err } @@ -67,7 +71,7 @@ func MustGenerateRandomStringDNSSafe(size int) string { } func InvalidString() string { - hash, _ := GenerateRandomStringDNSSafe(8) + hash, _ := GenerateRandomStringDNSSafe(invalidStringRandomLength) return "invalid-" + hash } @@ -99,6 +103,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string { DstIPs: %v } `, rule.SrcIPs, rule.DstPorts)) + if index < len(rules)-1 { sb.WriteString(", ") } diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 4d828d02..7fa4b222 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -16,6 +16,27 @@ import ( "tailscale.com/util/cmpver" ) +// Sentinel errors for URL parsing. +var ( + ErrMultipleURLsFound = errors.New("multiple URLs found") + ErrNoURLFound = errors.New("no URL found") +) + +// Sentinel errors for traceroute parsing. +var ( + ErrTracerouteEmpty = errors.New("empty traceroute output") + ErrTracerouteHeader = errors.New("parsing traceroute header") + ErrTracerouteNotReached = errors.New("traceroute did not reach target") +) + +// Regex match group constants for traceroute parsing. +// The regexes capture hostname (group 1) and IP (group 2), plus the full match (group 0). +const ( + hostIPRegexGroups = 3 + nodeKeyPrefixLen = 8 + minTracerouteHeaderMatch = 2 // full match + hostname +) + func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { if cmpver.Compare(minimum, toCheck) <= 0 || toCheck == "unstable" || @@ -30,20 +51,22 @@ func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { // It returns an error if not exactly one URL is found. func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { lines := strings.Split(output, "\n") + var urlStr string for _, line := range lines { line = strings.TrimSpace(line) if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") { if urlStr != "" { - return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) + return nil, fmt.Errorf("%w: %s and %s", ErrMultipleURLsFound, urlStr, line) } + urlStr = line } } if urlStr == "" { - return nil, errors.New("no URL found") + return nil, ErrNoURLFound } loginURL, err := url.Parse(urlStr) @@ -89,14 +112,15 @@ type Traceroute struct { func ParseTraceroute(output string) (Traceroute, error) { lines := strings.Split(strings.TrimSpace(output), "\n") if len(lines) < 1 { - return Traceroute{}, errors.New("empty traceroute output") + return Traceroute{}, ErrTracerouteEmpty } // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`) + headerMatches := headerRegex.FindStringSubmatch(lines[0]) - if len(headerMatches) < 2 { - return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) + if len(headerMatches) < minTracerouteHeaderMatch { + return Traceroute{}, fmt.Errorf("%w: %s", ErrTracerouteHeader, lines[0]) } hostname := headerMatches[1] @@ -105,6 +129,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if ipStr == "" { ipStr = headerMatches[3] } + ip, err := netip.ParseAddr(ipStr) if err != nil { return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) @@ -144,18 +169,23 @@ func ParseTraceroute(output string) (Traceroute, error) { } remainder := strings.TrimSpace(matches[2]) - var hopHostname string - var hopIP netip.Addr - var latencies []time.Duration + + var ( + hopHostname string + hopIP netip.Addr + latencies []time.Duration + ) // Check for Windows tracert format which has latencies before hostname // Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]" latencyFirst := false + if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") { // Check if latencies appear before any hostname/IP firstSpace := strings.Index(remainder, " ") if firstSpace > 0 { firstPart := remainder[:firstSpace] + //nolint:noinlineerr if _, err := strconv.ParseFloat(strings.TrimPrefix(firstPart, "<"), 64); err == nil { latencyFirst = true } @@ -171,12 +201,14 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Extract and remove the latency from the beginning latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues duration := time.Duration(ms * float64(time.Millisecond)) latencies = append(latencies, duration.Round(time.Microsecond)) } + remainder = strings.TrimSpace(remainder[latMatch[1]:]) } } @@ -187,12 +219,12 @@ func ParseTraceroute(output string) (Traceroute, error) { hopHostname = "*" // Skip any remaining asterisks remainder = strings.TrimLeft(remainder, "* ") - } else if hostMatch := hostIPRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 { + } else if hostMatch := hostIPRegex.FindStringSubmatch(remainder); len(hostMatch) >= hostIPRegexGroups { // Format: hostname (IP) hopHostname = hostMatch[1] hopIP, _ = netip.ParseAddr(hostMatch[2]) remainder = strings.TrimSpace(remainder[len(hostMatch[0]):]) - } else if hostMatch := hostIPBracketRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 { + } else if hostMatch := hostIPBracketRegex.FindStringSubmatch(remainder); len(hostMatch) >= hostIPRegexGroups { // Format: hostname [IP] (Windows) hopHostname = hostMatch[1] hopIP, _ = netip.ParseAddr(hostMatch[2]) @@ -202,9 +234,11 @@ func ParseTraceroute(output string) (Traceroute, error) { parts := strings.Fields(remainder) if len(parts) > 0 { hopHostname = parts[0] + //nolint:noinlineerr if ip, err := netip.ParseAddr(parts[0]); err == nil { hopIP = ip } + remainder = strings.TrimSpace(strings.Join(parts[1:], " ")) } } @@ -216,6 +250,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if len(match) > 1 { // Remove '<' prefix if present (e.g., "<1 ms") latStr := strings.TrimPrefix(match[1], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues @@ -243,7 +278,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // If we didn't reach the target, it's unsuccessful if !result.Success { - result.Err = errors.New("traceroute did not reach target") + result.Err = ErrTracerouteNotReached } return result, nil @@ -261,11 +296,10 @@ func IsCI() bool { return false } -// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults +// EnsureHostname extracts a hostname from Hostinfo, providing sensible defaults // if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences // and ensures nodes always have a valid hostname. // The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123). -// EnsureHostname guarantees a valid hostname for node registration. // This function never fails - it always returns a valid hostname. // // Strategy: @@ -280,15 +314,19 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri if key == "" { return "unknown-node" } + keyPrefix := key - if len(key) > 8 { - keyPrefix = key[:8] + if len(key) > nodeKeyPrefixLen { + keyPrefix = key[:nodeKeyPrefixLen] } - return fmt.Sprintf("node-%s", keyPrefix) + + return "node-" + keyPrefix } lowercased := strings.ToLower(hostinfo.Hostname) - if err := ValidateHostname(lowercased); err == nil { + + err := ValidateHostname(lowercased) + if err == nil { return lowercased } @@ -300,7 +338,7 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri // it's purely for observability and correlating log entries during the registration process. func GenerateRegistrationKey() (string, error) { const ( - registerKeyPrefix = "hskey-reg-" //nolint:gosec // This is a vanity key for logging, not a credential + registerKeyPrefix = "hskey-reg-" //nolint:gosec registerKeyLength = 64 ) diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 33f27b7a..98692882 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -1,7 +1,6 @@ package util import ( - "errors" "net/netip" "strings" "testing" @@ -180,6 +179,7 @@ Success.`, if err != nil { t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr) } + if gotURL.String() != tt.wantURL { t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL) } @@ -321,7 +321,7 @@ func TestParseTraceroute(t *testing.T) { }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteNotReached, }, wantErr: false, }, @@ -489,7 +489,7 @@ over a maximum of 30 hops: }, }, Success: false, - Err: errors.New("traceroute did not reach target"), + Err: ErrTracerouteNotReached, }, wantErr: false, }, @@ -902,7 +902,7 @@ func TestEnsureHostname(t *testing.T) { { name: "hostname_with_unicode", hostinfo: &tailcfg.Hostinfo{ - Hostname: "node-ñoño-测试", + Hostname: "node-ñoño-测试", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -983,7 +983,7 @@ func TestEnsureHostname(t *testing.T) { { name: "chinese_chars_with_dash_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "server-北京-01", + Hostname: "server-北京-01", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -992,7 +992,7 @@ func TestEnsureHostname(t *testing.T) { { name: "chinese_only_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "我的电脑", + Hostname: "我的电脑", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -1010,7 +1010,7 @@ func TestEnsureHostname(t *testing.T) { { name: "mixed_chinese_emoji_invalid", hostinfo: &tailcfg.Hostinfo{ - Hostname: "测试💻机器", + Hostname: "测试💻机器", //nolint:gosmopolitan }, machineKey: "mkey12345678", nodeKey: "nkey12345678", @@ -1066,6 +1066,7 @@ func TestEnsureHostname(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { @@ -1100,12 +1101,16 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { nodeKey: "nkey12345678", wantHostname: "test-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + t.Helper() + if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if hi.Hostname != "test-node" { t.Errorf("hostname = %v, want test-node", hi.Hostname) } + if hi.OS != "linux" { t.Errorf("OS = %v, want linux", hi.OS) } @@ -1144,9 +1149,12 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { nodeKey: "nkey12345678", wantHostname: "node-nkey1234", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + t.Helper() + if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if hi.Hostname != "node-nkey1234" { t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname) } @@ -1159,9 +1167,13 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { nodeKey: "", wantHostname: "unknown-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + t.Helper() + if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + + //nolint:goconst if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1177,8 +1189,9 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "unknown-node", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1198,20 +1211,25 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "test", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if hi.Hostname != "test" { t.Errorf("hostname = %v, want test", hi.Hostname) } + if hi.OS != "windows" { t.Errorf("OS = %v, want windows", hi.OS) } + if hi.OSVersion != "10.0.19044" { t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion) } + if hi.DeviceModel != "test-device" { t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel) } + if hi.BackendLogID != "log123" { t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID) } @@ -1227,8 +1245,9 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { wantHostname: "123456789012345678901234567890123456789012345678901234567890123", checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { if hi == nil { - t.Error("hostinfo should not be nil") + t.Fatal("hostinfo should not be nil") } + if len(hi.Hostname) != 63 { t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) } @@ -1239,6 +1258,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { @@ -1264,7 +1284,10 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { for i, hostname := range testCases { t.Run(cmp.Diff("", ""), func(t *testing.T) { + t.Parallel() + hostinfo := &tailcfg.Hostinfo{Hostname: hostname} + result := EnsureHostname(hostinfo, "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) diff --git a/integration/acl_test.go b/integration/acl_test.go index c746f900..e87e0587 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -20,7 +20,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) var veryLargeDestination = []policyv2.AliasWithPorts{ @@ -1284,9 +1283,9 @@ func TestACLAutogroupMember(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupMember)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupMember), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupMember), tailcfg.PortRangeAny), }, }, }, @@ -1372,9 +1371,9 @@ func TestACLAutogroupTagged(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupTagged)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupTagged)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), }, }, }, @@ -1657,9 +1656,9 @@ func TestACLAutogroupSelf(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupMember)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), }, }, { @@ -1877,7 +1876,7 @@ func TestACLAutogroupSelf(t *testing.T) { result, err := client.Curl(url) assert.Empty(t, result, "user1 should not be able to access user2's regular devices (autogroup:self isolation)") - assert.Error(t, err, "connection from user1 to user2 regular device should fail") + require.Error(t, err, "connection from user1 to user2 regular device should fail") } } @@ -1896,6 +1895,7 @@ func TestACLAutogroupSelf(t *testing.T) { } } +//nolint:gocyclo func TestACLPolicyPropagationOverTime(t *testing.T) { IntegrationSkip(t) @@ -1956,9 +1956,9 @@ func TestACLPolicyPropagationOverTime(t *testing.T) { ACLs: []policyv2.ACL{ { Action: "accept", - Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Sources: []policyv2.Alias{new(policyv2.AutoGroupMember)}, Destinations: []policyv2.AliasWithPorts{ - aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + aliasWithPorts(new(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), }, }, }, diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index df5f2455..f0218592 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -1,6 +1,7 @@ package integration import ( + "context" "crypto/tls" "encoding/json" "fmt" @@ -35,6 +36,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -46,6 +48,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Create an API key using the CLI var validAPIKey string + assert.EventuallyWithT(t, func(ct *assert.CollectT) { apiKeyOutput, err := headscale.Execute( []string{ @@ -63,7 +66,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Get the API endpoint endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" // Create HTTP client client := &http.Client{ @@ -76,11 +79,12 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_NoAuthHeader", func(t *testing.T) { // Test 1: Request without any Authorization header // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -99,6 +103,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should NOT contain user data after "Unauthorized" // This is the security bypass - if users array is present, auth was bypassed var jsonCheck map[string]any + jsonErr := json.Unmarshal(body, &jsonCheck) // If we can unmarshal JSON and it contains "users", that's the bypass @@ -126,12 +131,13 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) { // Test 2: Request with invalid Authorization header (missing "Bearer " prefix) // Expected: Should return 401 with ONLY "Unauthorized" text, no user data - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "InvalidToken") resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -159,12 +165,13 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Test 3: Request with Bearer prefix but invalid token // Expected: Should return 401 with ONLY "Unauthorized" text, no user data // Note: Both malformed and properly formatted invalid tokens should return 401 - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) req.Header.Set("Authorization", "Bearer invalid-token-12345") resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -191,12 +198,13 @@ func TestAPIAuthenticationBypass(t *testing.T) { t.Run("HTTP_ValidAPIKey", func(t *testing.T) { // Test 4: Request with valid API key // Expected: Should return 200 with user data (this is the authorized case) - req, err := http.NewRequest("GET", apiURL, nil) + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, apiURL, nil) require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey)) + req.Header.Set("Authorization", "Bearer "+validAPIKey) resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -208,16 +216,19 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should be able to parse as protobuf JSON var response v1.ListUsersResponse + err = protojson.Unmarshal(body, &response) - assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key") + require.NoError(t, err, "Response should be valid protobuf JSON with valid API key") // Should contain our test users users := response.GetUsers() assert.Len(t, users, 3, "Should have 3 users") + userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "user1") assert.Contains(t, userNames, "user2") assert.Contains(t, userNames, "user3") @@ -234,6 +245,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -254,10 +266,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" t.Run("Curl_NoAuth", func(t *testing.T) { // Execute curl from inside the headscale container without auth @@ -274,17 +287,24 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Parse the output lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + + var ( + httpCode string + responseBody string + ) + + var responseBodySb295 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb295.WriteString(line) } } + responseBody += responseBodySb295.String() + // Should return 401 assert.Equal(t, "401", httpCode, "Curl without auth should return 401") @@ -320,17 +340,24 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + + var ( + httpCode string + responseBody string + ) + + var responseBodySb344 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb344.WriteString(line) } } + responseBody += responseBodySb344.String() + assert.Equal(t, "401", httpCode) assert.Contains(t, responseBody, "Unauthorized") assert.NotContains(t, responseBody, "testuser1", @@ -346,7 +373,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { "curl", "-s", "-H", - fmt.Sprintf("Authorization: Bearer %s", validAPIKey), + "Authorization: Bearer " + validAPIKey, "-w", "\nHTTP_CODE:%{http_code}", apiURL, @@ -355,14 +382,17 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + + var ( + httpCode string + responseBody strings.Builder + ) for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBody.WriteString(line) } } @@ -372,8 +402,10 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Should contain user data var response v1.ListUsersResponse - err = protojson.Unmarshal([]byte(responseBody), &response) - assert.NoError(t, err, "Response should be valid protobuf JSON") + + err = protojson.Unmarshal([]byte(responseBody.String()), &response) + require.NoError(t, err, "Response should be valid protobuf JSON") + users := response.GetUsers() assert.Len(t, users, 2, "Should have 2 users") }) @@ -391,6 +423,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -420,11 +453,12 @@ func TestGRPCAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) // Get the gRPC endpoint // For gRPC, we need to use the hostname and port 50443 - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" t.Run("gRPC_NoAPIKey", func(t *testing.T) { // Test 1: Try to use CLI without API key (should fail) @@ -452,7 +486,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { ) // Should fail with authentication error - assert.Error(t, err, + require.Error(t, err, "gRPC connection with invalid API key should fail") // Should contain authentication error message @@ -481,20 +515,22 @@ func TestGRPCAuthenticationBypass(t *testing.T) { ) // Should succeed - assert.NoError(t, err, + require.NoError(t, err, "gRPC connection with valid API key should succeed, output: %s", output) // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) - assert.NoError(t, err, "Response should be valid JSON array") + require.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "grpcuser1") assert.Contains(t, userNames, "grpcuser2") }) @@ -513,6 +549,7 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -540,9 +577,10 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" // Create a config file for testing configWithoutKey := fmt.Sprintf(` @@ -602,7 +640,7 @@ cli: ) // Should fail - assert.Error(t, err, + require.Error(t, err, "CLI with invalid API key should fail") // Should indicate authentication failure @@ -637,20 +675,22 @@ cli: ) // Should succeed - assert.NoError(t, err, + require.NoError(t, err, "CLI with valid API key should succeed") // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) - assert.NoError(t, err, "Response should be valid JSON array") + require.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "cliuser1") assert.Contains(t, userNames, "cliuser2") }) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index ba6a195b..abf31fec 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -17,7 +17,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { @@ -31,6 +30,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -64,23 +64,29 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second) // Validate that all nodes have NetInfo and DERP servers before logout - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout") // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -111,6 +117,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -126,7 +133,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - //nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS + //nolint:forbidigo time.Sleep(5 * time.Minute) } @@ -148,6 +155,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after relogin") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -164,7 +172,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { requireNoErrSync(t, err) // Validate that all nodes have NetInfo and DERP servers after reconnection - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection") err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) @@ -201,6 +209,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, nodeCountBeforeLogout) @@ -253,12 +262,16 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // Validate initial connection state requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") + + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) - var listNodes []*v1.Node - var nodeCountBeforeLogout int assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -301,9 +314,11 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -317,27 +332,30 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // Validate connection state after relogin as user1 requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin") // Validate that user2 still has their original nodes after user1's re-authentication // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created // for the new user. The original nodes remain with the original user. var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) - }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname())) + }, 30*time.Second, 2*time.Second, "validating %s is logged in as user1 after auth key user switch", client.Hostname()) } } @@ -352,6 +370,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -377,11 +396,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } @@ -393,12 +414,16 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // Validate initial connection state requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") + + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) - var listNodes []*v1.Node - var nodeCountBeforeLogout int assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -428,7 +453,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 // https://github.com/juanfont/headscale/issues/2164 if !https { - //nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS + //nolint:forbidigo time.Sleep(5 * time.Minute) } @@ -608,7 +633,7 @@ func TestAuthKeyLogoutAndReloginRoutesPreserved(t *testing.T) { }, AutoApprovers: policyv2.AutoApproverPolicy{ Routes: map[netip.Prefix]policyv2.AutoApprovers{ - netip.MustParsePrefix(advertiseRoute): {ptr.To(policyv2.Username(user + "@test.no"))}, + netip.MustParsePrefix(advertiseRoute): {new(policyv2.Username(user + "@test.no"))}, }, }, }, diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 359dd456..076f6565 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1,15 +1,16 @@ package integration import ( + "cmp" "maps" "net/netip" "net/url" - "sort" + "slices" "strconv" "testing" "time" - "github.com/google/go-cmp/cmp" + gocmp "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" @@ -111,11 +112,11 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { t.Fatalf("unexpected users: %s", diff) } } @@ -148,6 +149,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -175,6 +177,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { syncCompleteTime := time.Now() err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) + loginDuration := time.Since(syncCompleteTime) t.Logf("Login and sync completed in %v", loginDuration) @@ -206,6 +209,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check each client's status individually to provide better diagnostics expiredCount := 0 + for _, client := range allClients { status, err := client.Status() if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) { @@ -355,6 +359,7 @@ func TestOIDC024UserCreation(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -388,11 +393,11 @@ func TestOIDC024UserCreation(t *testing.T) { listUsers, err := headscale.ListUsers() require.NoError(t, err) - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { t.Errorf("unexpected users: %s", diff) } }) @@ -412,6 +417,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -469,6 +475,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { oidcMockUser("user1", true), }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -507,6 +514,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -517,19 +525,22 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("User validation failed after first login - unexpected users: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) @@ -537,14 +548,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -582,6 +598,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users after user2 login") assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -599,11 +616,11 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff) } }, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login") @@ -637,10 +654,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user2's node should be active after user switch var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { if node.GetUser().GetId() == 2 { // user2 activeUser2NodeID = types.NodeID(node.GetId()) t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -654,6 +673,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user2 node is online if node, exists := nodeStore[activeUser2NodeID]; exists { assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User2 node should be online after login") } @@ -746,6 +766,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -763,11 +784,11 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)") @@ -815,10 +836,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user1's node should be active after relogin var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { if node.GetUser().GetId() == 1 { // user1 activeUser1NodeID = types.NodeID(node.GetId()) t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -832,6 +855,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after relogin") } @@ -902,10 +926,11 @@ func TestOIDCFollowUpUrl(t *testing.T) { // wait for the registration cache to expire // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION (1m30s) - //nolint:forbidigo // Intentional delay: must wait for real-time cache expiration (HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION=1m30s) + //nolint:forbidigo time.Sleep(2 * time.Minute) var newUrl *url.URL + assert.EventuallyWithT(t, func(c *assert.CollectT) { st, err := ts.Status() assert.NoError(c, err) @@ -935,13 +960,11 @@ func TestOIDCFollowUpUrl(t *testing.T) { }, } - sort.Slice( - listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }, - ) + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) + }) - if diff := cmp.Diff( + if diff := gocmp.Diff( wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), @@ -1029,7 +1052,7 @@ func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { require.NotEqual(t, redirect1.String(), redirect2.String()) // complete auth with the first opened "browser tab" - _, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) + _, _, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) require.NoError(t, err) listUsers, err = headscale.ListUsers() @@ -1046,13 +1069,11 @@ func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { }, } - sort.Slice( - listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() - }, - ) + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) + }) - if diff := cmp.Diff( + if diff := gocmp.Diff( wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), @@ -1106,6 +1127,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { oidcMockUser("user1", true), // Relogin with same user }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -1145,6 +1167,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1155,19 +1178,22 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("User validation failed after first login - unexpected users: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var initialNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + initialNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes)) @@ -1175,14 +1201,19 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -1239,6 +1270,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1249,16 +1281,17 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, } - sort.Slice(listUsers, func(i, j int) bool { - return listUsers[i].GetId() < listUsers[j].GetId() + slices.SortFunc(listUsers, func(a, b *v1.User) int { + return cmp.Compare(a.GetId(), b.GetId()) }) - if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + if diff := gocmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { ct.Errorf("Final user validation failed - user1 should persist after same-user relogin: %s", diff) } }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") var finalNodes []*v1.Node + t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { finalNodes, err = headscale.ListNodes() @@ -1282,6 +1315,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Security validation: user1's node should be active after relogin activeUser1NodeID := types.NodeID(finalNodes[0].GetId()) + t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat)) require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() @@ -1290,6 +1324,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin") } @@ -1359,6 +1394,7 @@ func TestOIDCExpiryAfterRestart(t *testing.T) { // Verify initial expiry is set var initialExpiry time.Time + assert.EventuallyWithT(t, func(ct *assert.CollectT) { nodes, err := headscale.ListNodes() assert.NoError(ct, err) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 5dd546f3..8d596241 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,7 +1,6 @@ package integration import ( - "fmt" "net/netip" "slices" "testing" @@ -67,6 +66,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -106,13 +106,16 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -152,6 +155,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web flow logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -226,6 +230,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -240,7 +245,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { allClients, err := scenario.ListTailscaleClients() requireNoErrListClients(t, err) - allIps, err := scenario.ListTailscaleClientsIPs() + _, err = scenario.ListTailscaleClientsIPs() requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() @@ -256,13 +261,16 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after initial web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -299,7 +307,10 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Register all clients as user1 (this is where cross-user registration happens) // This simulates: headscale nodes register --user user1 --key - scenario.runHeadscaleRegister("user1", body) + err = scenario.runHeadscaleRegister("user1", body) + if err != nil { + t.Fatalf("failed to register client %s: %s", client.Hostname(), err) + } } // Wait for all clients to reach running state @@ -313,9 +324,11 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { t.Logf("all clients logged back in as user1") var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -333,25 +346,28 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Validate that user2's old nodes still exist in database (but are expired/offline) // When CLI registration creates new nodes for user1, user2's old nodes remain var user2Nodes []*v1.Node + t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1") t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) - }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname())) + }, 30*time.Second, 2*time.Second, "validating %s is logged in as user1 after web flow user switch", client.Hostname()) } // Test connectivity after user switch - allIps, err = scenario.ListTailscaleClientsIPs() + allIps, err := scenario.ListTailscaleClientsIPs() requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { diff --git a/integration/cli_test.go b/integration/cli_test.go index 65d82444..707c9992 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -203,7 +203,7 @@ func TestUserCommand(t *testing.T) { "--identifier=1", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterIDDelete []*v1.User @@ -245,7 +245,7 @@ func TestUserCommand(t *testing.T) { "--name=newname", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, deleteResult, "User destroyed") var listAfterNameDelete []v1.User @@ -571,8 +571,9 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { IntegrationSkip(t) + //nolint:goconst user1 := "user1" - user2 := "user2" + user2 := "user2" //nolint:goconst spec := ScenarioSpec{ NodesPerUser: 1, @@ -829,7 +830,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -907,7 +908,7 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - assert.NoError(t, err) + require.NoError(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } @@ -952,7 +953,7 @@ func TestApiKeyCommand(t *testing.T) { "--prefix", listedAPIKeys[0].GetPrefix(), }) - assert.NoError(t, err) + require.NoError(t, err) var listedAPIKeysAfterDelete []v1.ApiKey @@ -1071,7 +1072,7 @@ func TestNodeCommand(t *testing.T) { } nodes := make([]*v1.Node, len(regIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1089,7 +1090,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1156,7 +1157,7 @@ func TestNodeCommand(t *testing.T) { } otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range otherUserRegIDs { _, err := headscale.Execute( @@ -1174,7 +1175,7 @@ func TestNodeCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1281,7 +1282,7 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - assert.NoError(t, err) + require.NoError(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node @@ -1348,7 +1349,7 @@ func TestNodeExpireCommand(t *testing.T) { "json", }, ) - assert.NoError(t, err) + require.NoError(t, err) var node v1.Node @@ -1411,7 +1412,7 @@ func TestNodeExpireCommand(t *testing.T) { strconv.FormatUint(listAll[idx].GetId(), 10), }, ) - assert.NoError(t, err) + require.NoError(t, err) } var listAllAfterExpiry []v1.Node @@ -1467,7 +1468,7 @@ func TestNodeRenameCommand(t *testing.T) { } nodes := make([]*v1.Node, len(regIDs)) - assert.NoError(t, err) + require.NoError(t, err) for index, regID := range regIDs { _, err := headscale.Execute( @@ -1549,7 +1550,7 @@ func TestNodeRenameCommand(t *testing.T) { fmt.Sprintf("newnode-%d", idx+1), }, ) - assert.NoError(t, err) + require.NoError(t, err) assert.Contains(t, res, "Node renamed") } @@ -1590,7 +1591,7 @@ func TestNodeRenameCommand(t *testing.T) { strings.Repeat("t", 64), }, ) - assert.ErrorContains(t, err, "must not exceed 63 characters") + require.ErrorContains(t, err, "must not exceed 63 characters") var listAllAfterRenameAttempt []v1.Node @@ -1763,7 +1764,7 @@ func TestPolicyBrokenConfigCommand(t *testing.T) { policyFilePath, }, ) - assert.ErrorContains(t, err, `invalid action "unknown-action"`) + require.ErrorContains(t, err, `invalid action "unknown-action"`) // The new policy was invalid, the old one should still be in place, which // is none. diff --git a/integration/control.go b/integration/control.go index 58a061e3..612f0ff3 100644 --- a/integration/control.go +++ b/integration/control.go @@ -15,8 +15,8 @@ import ( type ControlServer interface { Shutdown() (string, string, error) - SaveLog(string) (string, string, error) - SaveProfile(string) error + SaveLog(dir string) (string, string, error) + SaveProfile(dir string) error Execute(command []string) (string, error) WriteFile(path string, content []byte) error ConnectToNetwork(network *dockertest.Network) error @@ -35,12 +35,12 @@ type ControlServer interface { ListUsers() ([]*v1.User, error) MapUsers() (map[string]*v1.User, error) DeleteUser(userID uint64) error - ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) + ApproveRoutes(nodeID uint64, routes []netip.Prefix) (*v1.Node, error) SetNodeTags(nodeID uint64, tags []string) error GetCert() []byte GetHostname() string GetIPInNetwork(network *dockertest.Network) string - SetPolicy(*policyv2.Policy) error + SetPolicy(pol *policyv2.Policy) error GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) PrimaryRoutes() (*routes.DebugRoutes, error) DebugBatcher() (*hscontrol.DebugBatcherInfo, error) diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 60260bb1..20ea930c 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -1,6 +1,7 @@ package integration import ( + "errors" "fmt" "net" "strconv" @@ -19,12 +20,15 @@ import ( "tailscale.com/types/key" ) +var errUnexpectedRecvType = errors.New("client first Recv was unexpected type") + func TestDERPVerifyEndpoint(t *testing.T) { IntegrationSkip(t) // Generate random hostname for the headscale instance hash, err := util.GenerateRandomStringDNSSafe(6) require.NoError(t, err) + testName := "derpverify" hostname := fmt.Sprintf("hs-%s-%s", testName, hash) @@ -40,6 +44,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -104,13 +109,17 @@ func DERPVerify( defer c.Close() var result error - if err := c.Connect(t.Context()); err != nil { + + err := c.Connect(t.Context()) + if err != nil { result = fmt.Errorf("client Connect: %w", err) } + + //nolint:noinlineerr if m, err := c.Recv(); err != nil { result = fmt.Errorf("client first Recv: %w", err) } else if v, ok := m.(derp.ServerInfoMessage); !ok { - result = fmt.Errorf("client first Recv was unexpected type %T", v) + result = fmt.Errorf("%w: %T", errUnexpectedRecvType, v) } if expectSuccess && result != nil { diff --git a/integration/dns_test.go b/integration/dns_test.go index e937a421..0d3bce21 100644 --- a/integration/dns_test.go +++ b/integration/dns_test.go @@ -86,6 +86,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { const erPath = "/tmp/extra_records.json" + //nolint:prealloc extraRecords := []tailcfg.DNSRecord{ { Name: "test.myvpn.example.com", @@ -93,7 +94,8 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { Value: "6.6.6.6", }, } - b, _ := json.Marshal(extraRecords) + b, err := json.Marshal(extraRecords) + require.NoError(t, err) err = scenario.CreateHeadscaleEnv([]tsic.Option{ tsic.WithPackages("python3", "curl", "bind-tools"), @@ -133,13 +135,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { require.NoError(t, err) // Write the file directly into place from the docker API. - b0, _ := json.Marshal([]tailcfg.DNSRecord{ + b0, err := json.Marshal([]tailcfg.DNSRecord{ { Name: "docker.myvpn.example.com", Type: "A", Value: "2.2.2.2", }, }) + require.NoError(t, err) err = hs.WriteFile(erPath, b0) require.NoError(t, err) @@ -155,7 +158,8 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { Type: "A", Value: "7.7.7.7", }) - b2, _ := json.Marshal(extraRecords) + b2, err := json.Marshal(extraRecords) + require.NoError(t, err) err = hs.WriteFile(erPath+"2", b2) require.NoError(t, err) @@ -169,13 +173,14 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { // Write a new file and copy it to the path to ensure the reload // works when a file is copied into place. - b3, _ := json.Marshal([]tailcfg.DNSRecord{ + b3, err := json.Marshal([]tailcfg.DNSRecord{ { Name: "copy.myvpn.example.com", Type: "A", Value: "8.8.8.8", }, }) + require.NoError(t, err) err = hs.WriteFile(erPath+"3", b3) require.NoError(t, err) @@ -187,13 +192,15 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { } // Write in place to ensure pipe like behaviour works - b4, _ := json.Marshal([]tailcfg.DNSRecord{ + b4, err := json.Marshal([]tailcfg.DNSRecord{ { Name: "docker.myvpn.example.com", Type: "A", Value: "9.9.9.9", }, }) + require.NoError(t, err) + command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath} _, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")}) require.NoError(t, err) diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index c0c57a3e..75bc872c 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -14,6 +14,11 @@ const ( // TimestampFormatRunID is used for generating unique run identifiers // Format: "20060102-150405" provides compact date-time for file/directory names. TimestampFormatRunID = "20060102-150405" + + // runIDHashLength is the length of the random hash in run IDs. + runIDHashLength = 6 + // runIDParts is the number of parts in a run ID (YYYYMMDD-HHMMSS-HASH). + runIDParts = 3 ) // GetIntegrationRunID returns the run ID for the current integration test session. @@ -34,6 +39,7 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { if opts.Labels == nil { opts.Labels = make(map[string]string) } + opts.Labels["hi.run-id"] = runID opts.Labels["hi.test-type"] = testType } @@ -45,7 +51,7 @@ func GenerateRunID() string { timestamp := now.Format(TimestampFormatRunID) // Add a short random hash to ensure uniqueness - randomHash := util.MustGenerateRandomStringDNSSafe(6) + randomHash := util.MustGenerateRandomStringDNSSafe(runIDHashLength) return fmt.Sprintf("%s-%s", timestamp, randomHash) } @@ -54,9 +60,9 @@ func GenerateRunID() string { // Expects format: "prefix-YYYYMMDD-HHMMSS-HASH". func ExtractRunIDFromContainerName(containerName string) string { parts := strings.Split(containerName, "-") - if len(parts) >= 3 { + if len(parts) >= runIDParts { // Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH) - return strings.Join(parts[len(parts)-3:], "-") + return strings.Join(parts[len(parts)-runIDParts:], "-") } panic("unexpected container name format: " + containerName) diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index b09e0d40..0143ee53 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -38,9 +38,12 @@ type buffer struct { // Write appends the contents of p to the buffer, growing the buffer as needed. It returns // the number of bytes written. +// +//nolint:nonamedreturns func (b *buffer) Write(p []byte) (n int, err error) { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.Write(p) } @@ -49,6 +52,7 @@ func (b *buffer) Write(p []byte) (n int, err error) { func (b *buffer) String() string { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.String() } diff --git a/integration/dockertestutil/logs.go b/integration/dockertestutil/logs.go index 7d104e43..d5911ca7 100644 --- a/integration/dockertestutil/logs.go +++ b/integration/dockertestutil/logs.go @@ -47,6 +47,7 @@ func SaveLog( } var stdout, stderr bytes.Buffer + err = WriteLog(pool, resource, &stdout, &stderr) if err != nil { return "", "", err diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 42483247..95b69b88 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -13,11 +13,18 @@ import ( var ErrContainerNotFound = errors.New("container not found") +// Docker memory constants. +const ( + bytesPerKB = 1024 + containerMemoryGB = 2 +) + func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Network, error) { networks, err := pool.NetworksByName(name) if err != nil { return nil, fmt.Errorf("looking up network names: %w", err) } + if len(networks) == 0 { if _, err := pool.CreateNetwork(name); err == nil { // Create does not give us an updated version of the resource, so we need to @@ -90,6 +97,7 @@ func RandomFreeHostPort() (int, error) { // CleanUnreferencedNetworks removes networks that are not referenced by any containers. func CleanUnreferencedNetworks(pool *dockertest.Pool) error { filter := "name=hs-" + networks, err := pool.NetworksByName(filter) if err != nil { return fmt.Errorf("getting networks by filter %q: %w", filter, err) @@ -122,6 +130,7 @@ func CleanImagesInCI(pool *dockertest.Pool) error { } removedCount := 0 + for _, image := range images { // Only remove dangling (untagged) images to avoid forcing rebuilds // Dangling images have no RepoTags or only have ":" @@ -169,6 +178,6 @@ func DockerAllowNetworkAdministration(config *docker.HostConfig) { // DockerMemoryLimit sets memory limit and disables OOM kill for containers. func DockerMemoryLimit(config *docker.HostConfig) { - config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes + config.Memory = containerMemoryGB * bytesPerKB * bytesPerKB * bytesPerKB // 2GB in bytes config.OOMKillDisable = true } diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index d8a77575..344d93f7 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -159,10 +159,12 @@ func New( } else { hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) } + tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname) if err != nil { return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err) } + dsic := &DERPServerInContainer{ version: version, hostname: hostname, @@ -185,6 +187,7 @@ func New( fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort) fmt.Fprintf(&cmdArgs, " --stun=true") fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort) + if dsic.withVerifyClientURL != "" { fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL) } @@ -214,11 +217,13 @@ func New( } var container *dockertest.Resource + buildOptions := &dockertest.BuildOptions{ Dockerfile: "Dockerfile.derper", ContextDir: dockerContextPath, BuildArgs: []docker.BuildArg{}, } + switch version { case "head": buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ @@ -249,6 +254,7 @@ func New( err, ) } + log.Printf("Created %s container\n", hostname) dsic.container = container @@ -259,12 +265,14 @@ func New( return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } } + if len(dsic.tlsCert) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert) if err != nil { return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } } + if len(dsic.tlsKey) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey) if err != nil { diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 89154f63..7164e113 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -99,7 +99,7 @@ func TestDERPServerWebsocketScenario(t *testing.T) { // we *want* it to show up in stacktraces, // so marking it as a test helper would be counterproductive. // -//nolint:thelper + func derpServerScenario( t *testing.T, spec ScenarioSpec, @@ -179,7 +179,7 @@ func derpServerScenario( // Let the DERP updater run a couple of times to ensure it does not // break the DERPMap. The updater runs on a 10s interval by default. - //nolint:forbidigo // Intentional delay: must wait for DERP updater to run multiple times (interval-based) + //nolint:forbidigo time.Sleep(30 * time.Second) success = pingDerpAllHelper(t, allClients, allHostnames) diff --git a/integration/helpers.go b/integration/helpers.go index 7d40c8e6..7e5fcc2f 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -3,12 +3,14 @@ package integration import ( "bufio" "bytes" + "errors" "fmt" "io" + "maps" "net/netip" + "slices" "strconv" "strings" - "sync" "testing" "time" @@ -23,10 +25,15 @@ import ( "github.com/oauth2-proxy/mockoidc" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "golang.org/x/exp/maps" - "golang.org/x/exp/slices" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" +) + +// Sentinel errors for integration test helpers. +var ( + errExpectedStringNotFound = errors.New("expected string not found in output") + errUserNotFound = errors.New("user not found") + errNoNewClientFound = errors.New("no new client found") + errUnexpectedClientCount = errors.New("unexpected client count") ) const ( @@ -46,9 +53,19 @@ const ( // TimestampFormatRunID is used for generating unique run identifiers // Format: "20060102-150405" provides compact date-time for file/directory names. TimestampFormatRunID = "20060102-150405" + + // Connection validation timeouts. + connectionValidationTimeout = 120 * time.Second + onlineCheckRetryInterval = 2 * time.Second + batcherValidationTimeout = 15 * time.Second + nodestoreValidationTimeout = 20 * time.Second + mapResponseTimeout = 60 * time.Second + netInfoRetryInterval = 5 * time.Second + backoffMaxElapsedTime = 10 * time.Second + backoffRetryInterval = 500 * time.Millisecond ) -// NodeSystemStatus represents the status of a node across different systems +// NodeSystemStatus represents the status of a node across different systems. type NodeSystemStatus struct { Batcher bool BatcherConnCount int @@ -105,7 +122,7 @@ func requireNoErrLogout(t *testing.T, err error) { require.NoError(t, err, "failed to log out tailscale nodes") } -// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes. func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { t.Helper() @@ -114,8 +131,10 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod status := client.MustStatus() nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) } + return expectedNodes } @@ -125,8 +144,8 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() - requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", connectionValidationTimeout) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login") } // validateLogoutComplete performs comprehensive validation after client logout. @@ -135,7 +154,7 @@ func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNo func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() - requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", connectionValidationTimeout) } // validateReloginComplete performs comprehensive validation after client relogin. @@ -144,20 +163,23 @@ func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { t.Helper() - requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second) - requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", connectionValidationTimeout) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin") } // requireAllClientsOnline validates that all nodes are online/offline across all headscale systems -// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems. func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() startTime := time.Now() + + //nolint:goconst stateStr := "offline" if expectedOnline { - stateStr = "online" + stateStr = "online" //nolint:goconst } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) if expectedOnline { @@ -172,15 +194,19 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) } -// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state. +// +//nolint:gocyclo func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { // Get batcher state debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } @@ -188,6 +214,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get map responses mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -195,6 +222,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -265,6 +293,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if id == nodeID { continue // Skip self-references } + expectedPeerMaps++ if online, exists := peerMap[nodeID]; exists && online { @@ -279,6 +308,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer } } } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") // Update status with map response data @@ -302,10 +332,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Verify all systems show nodes in expected state and report failures allMatch := true + var failureReport strings.Builder - ids := types.NodeIDs(maps.Keys(nodeStatus)) + ids := slices.Collect(maps.Keys(nodeStatus)) slices.Sort(ids) + for _, nodeID := range ids { status := nodeStatus[nodeID] systemsMatch := (status.Batcher == expectedOnline) && @@ -314,10 +346,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if !systemsMatch { allMatch = false + stateStr := "offline" if expectedOnline { stateStr = "online" } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) @@ -332,6 +366,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer t.Logf("Previous report:\n%s", prevReport) t.Logf("Current report:\n%s", failureReport.String()) t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() } @@ -345,12 +380,13 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if expectedOnline { stateStr = "online" } - assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) - }, timeout, 2*time.Second, message) + + assert.True(c, allMatch, "Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr) + }, timeout, onlineCheckRetryInterval, message) } -// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components -func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components. +func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, _ string, _ time.Duration) { t.Helper() // Stage 1: Verify batcher disconnection (should be immediate) @@ -358,48 +394,57 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } allBatcherOffline := true + for _, nodeID := range expectedNodes { nodeIDStr := fmt.Sprintf("%d", nodeID) if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) } } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") - }, 15*time.Second, 1*time.Second, "batcher disconnection validation") + }, batcherValidationTimeout, 1*time.Second, "batcher disconnection validation") // Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay) t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes)) require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } allNodeStoreOffline := true + for _, nodeID := range expectedNodes { if node, exists := nodeStore[nodeID]; exists { isOnline := node.IsOnline != nil && *node.IsOnline if isOnline { allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) } } } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") - }, 20*time.Second, 1*time.Second, "nodestore offline validation") + }, nodestoreValidationTimeout, 1*time.Second, "nodestore offline validation") // Stage 3: Verify map response propagation (longest delay due to peer update timing) t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes)) require.EventuallyWithT(t, func(c *assert.CollectT) { mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -412,7 +457,8 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec for nodeID := range onlineMap { if slices.Contains(expectedNodes, nodeID) { allMapResponsesOffline = false - assert.False(c, true, "Node %d should not appear in map responses", nodeID) + + assert.Fail(c, "Node should not appear in map responses", "Node %d should not appear in map responses", nodeID) } } } else { @@ -422,15 +468,18 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec if id == nodeID { continue // Skip self-references } + if online, exists := peerMap[nodeID]; exists && online { allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) } } } } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") - }, 60*time.Second, 2*time.Second, "map response propagation validation") + }, mapResponseTimeout, onlineCheckRetryInterval, "map response propagation validation") t.Logf("All stages completed: nodes are fully offline across all systems") } @@ -438,9 +487,11 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec // requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database // and a valid DERP server based on the NetInfo. This function follows the pattern of // requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. -func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { +func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string) { t.Helper() + const timeout = 3 * time.Minute + startTime := time.Now() t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message) @@ -448,6 +499,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -462,12 +514,14 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe for _, nodeID := range expectedNodes { node, exists := nodeStore[nodeID] assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { continue } // Validate that the node has Hostinfo assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -475,6 +529,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has NetInfo assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -482,11 +537,11 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has a valid DERP server (PreferredDERP should be > 0) preferredDERP := node.Hostinfo.NetInfo.PreferredDERP - assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) + assert.Positive(c, preferredDERP, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) } - }, timeout, 5*time.Second, message) + }, timeout, netInfoRetryInterval, message) endTime := time.Now() duration := endTime.Sub(startTime) @@ -496,6 +551,8 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // assertLastSeenSet validates that a node has a non-nil LastSeen timestamp. // Critical for ensuring node activity tracking is functioning properly. func assertLastSeenSet(t *testing.T, node *v1.Node) { + t.Helper() + assert.NotNil(t, node) assert.NotNil(t, node.GetLastSeen()) } @@ -514,7 +571,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { for _, client := range clients { status, err := client.Status() - assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) + assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) //nolint:testifylint assert.Equal(t, "NeedsLogin", status.BackendState, "client %s should be logged out", client.Hostname()) } @@ -523,13 +580,14 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { // pingAllHelper performs ping tests between all clients and addresses, returning success count. // This is used to validate network connectivity in integration tests. // Returns the total number of successful ping operations. -func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { +func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { t.Helper() + success := 0 for _, client := range clients { for _, addr := range addrs { - err := client.Ping(addr, opts...) + err := client.Ping(addr) if err != nil { t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) } else { @@ -546,6 +604,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts // for validating NAT traversal and relay functionality. Returns success count. func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { t.Helper() + success := 0 for _, client := range clients { @@ -593,160 +652,6 @@ func isSelfClient(client TailscaleClient, addr string) bool { return false } -// assertClientsState validates the status and netmap of a list of clients for general connectivity. -// Runs parallel validation of status, netcheck, and netmap for all clients to ensure -// they have proper network configuration for all-to-all connectivity tests. -func assertClientsState(t *testing.T, clients []TailscaleClient) { - t.Helper() - - var wg sync.WaitGroup - - for _, client := range clients { - wg.Add(1) - c := client // Avoid loop pointer - go func() { - defer wg.Done() - assertValidStatus(t, c) - assertValidNetcheck(t, c) - assertValidNetmap(t, c) - }() - } - - t.Logf("waiting for client state checks to finish") - wg.Wait() -} - -// assertValidNetmap validates that a client's netmap has all required fields for proper operation. -// Checks self node and all peers for essential networking data including hostinfo, addresses, -// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56. -// This test is not suitable for ACL/partial connection tests. -func assertValidNetmap(t *testing.T, client TailscaleClient) { - t.Helper() - - if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { - t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) - - return - } - - t.Logf("Checking netmap of %q", client.Hostname()) - - assert.EventuallyWithT(t, func(c *assert.CollectT) { - netmap, err := client.Netmap() - assert.NoError(c, err, "getting netmap for %q", client.Hostname()) - - assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) - if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { - assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) - } - - assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) - assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) - - assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) - - assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) - assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) - assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) - - for _, peer := range netmap.Peers { - assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) - assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) - - assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) - if hi := peer.Hostinfo(); hi.Valid() { - assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) - - // Netinfo is not always set - // assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) - if ni := hi.NetInfo(); ni.Valid() { - assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) - } - } - - assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) - assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) - - assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) - - assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) - assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) - } - }, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname()) -} - -// assertValidStatus validates that a client's status has all required fields for proper operation. -// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints, -// and network map presence. This test is not suitable for ACL/partial connection tests. -func assertValidStatus(t *testing.T, client TailscaleClient) { - t.Helper() - status, err := client.Status(true) - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) - assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) - - assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) - - // This seem to not appear until version 1.56 - if status.Self.AllowedIPs != nil { - assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) - } - - assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) - - assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) - - assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) - - // This isn't really relevant for Self as it won't be in its own socket/wireguard. - // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) - // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) - - for _, peer := range status.Peer { - assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) - assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) - - assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) - - // This seem to not appear until version 1.56 - if peer.AllowedIPs != nil { - assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) - } - - // Addrs does not seem to appear in the status from peers. - // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) - - assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) - assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) - - // TODO(kradalby): InEngine is only true when a proper tunnel is set up, - // there might be some interesting stuff to test here in the future. - // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) - } -} - -// assertValidNetcheck validates that a client has a proper DERP relay configured. -// Ensures the client has discovered and selected a DERP server for relay functionality, -// which is essential for NAT traversal and connectivity in restricted networks. -func assertValidNetcheck(t *testing.T, client TailscaleClient) { - t.Helper() - report, err := client.Netcheck() - if err != nil { - t.Fatalf("getting status for %q: %s", client.Hostname(), err) - } - - assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) -} - // assertCommandOutputContains executes a command with exponential backoff retry until the output // contains the expected string or timeout is reached (10 seconds). // This implements eventual consistency patterns and should be used instead of time.Sleep @@ -764,11 +669,11 @@ func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []stri } if !strings.Contains(stdout, contains) { - return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + return struct{}{}, fmt.Errorf("executing command, %w: %q not found in %q", errExpectedStringNotFound, contains, stdout) } return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) assert.NoError(t, err) } @@ -793,6 +698,7 @@ func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { t.Helper() buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) if err != nil { t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) @@ -816,6 +722,7 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) scanner := bufio.NewScanner(in) { const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) scanner.Buffer(buff, len(buff)) scanner.Split(bufio.ScanLines) @@ -839,32 +746,33 @@ func wildcard() policyv2.Alias { // usernamep returns a pointer to a Username as an Alias for policy v2 configurations. // Used in ACL rules to reference specific users in network access policies. func usernamep(name string) policyv2.Alias { - return ptr.To(policyv2.Username(name)) + return new(policyv2.Username(name)) } // hostp returns a pointer to a Host as an Alias for policy v2 configurations. // Used in ACL rules to reference specific hosts in network access policies. func hostp(name string) policyv2.Alias { - return ptr.To(policyv2.Host(name)) + return new(policyv2.Host(name)) } // groupp returns a pointer to a Group as an Alias for policy v2 configurations. // Used in ACL rules to reference user groups in network access policies. func groupp(name string) policyv2.Alias { - return ptr.To(policyv2.Group(name)) + return new(policyv2.Group(name)) } // tagp returns a pointer to a Tag as an Alias for policy v2 configurations. // Used in ACL rules to reference node tags in network access policies. func tagp(name string) policyv2.Alias { - return ptr.To(policyv2.Tag(name)) + return new(policyv2.Tag(name)) } // prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. // Converts CIDR notation to policy prefix format for network range specifications. func prefixp(cidr string) policyv2.Alias { + //nolint:staticcheck prefix := netip.MustParsePrefix(cidr) - return ptr.To(policyv2.Prefix(prefix)) + return new(policyv2.Prefix(prefix)) } // aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges. @@ -880,31 +788,25 @@ func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.A // usernameOwner returns a Username as an Owner for use in TagOwners policies. // Specifies which users can assign and manage specific tags in ACL configurations. func usernameOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Username(name)) -} - -// groupOwner returns a Group as an Owner for use in TagOwners policies. -// Specifies which groups can assign and manage specific tags in ACL configurations. -func groupOwner(name string) policyv2.Owner { - return ptr.To(policyv2.Group(name)) + return new(policyv2.Username(name)) } // usernameApprover returns a Username as an AutoApprover for subnet route policies. // Specifies which users can automatically approve subnet route advertisements. func usernameApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Username(name)) + return new(policyv2.Username(name)) } // groupApprover returns a Group as an AutoApprover for subnet route policies. // Specifies which groups can automatically approve subnet route advertisements. func groupApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Group(name)) + return new(policyv2.Group(name)) } // tagApprover returns a Tag as an AutoApprover for subnet route policies. // Specifies which tagged nodes can automatically approve subnet route advertisements. func tagApprover(name string) policyv2.AutoApprover { - return ptr.To(policyv2.Tag(name)) + return new(policyv2.Tag(name)) } // oidcMockUser creates a MockUser for OIDC authentication testing. @@ -933,7 +835,7 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { } } - return nil, fmt.Errorf("user %s not found", username) + return nil, fmt.Errorf("%w: %s", errUserNotFound, username) } // FindNewClient finds a client that is in the new list but not in the original list. @@ -942,17 +844,20 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) { for _, client := range updated { isOriginal := false + for _, origClient := range original { if client.Hostname() == origClient.Hostname() { isOriginal = true break } } + if !isOriginal { return client, nil } } - return nil, fmt.Errorf("no new client found") + + return nil, errNoNewClientFound } // AddAndLoginClient adds a new tailscale client to a user and logs it in. @@ -960,7 +865,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) // 1. Creating a new node // 2. Finding the new node in the client list // 3. Getting the user to create a preauth key -// 4. Logging in the new node +// 4. Logging in the new node. func (s *Scenario) AddAndLoginClient( t *testing.T, username string, @@ -992,7 +897,7 @@ func (s *Scenario) AddAndLoginClient( } if len(updatedClients) != len(originalClients)+1 { - return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients)) + return struct{}{}, fmt.Errorf("%w: expected %d clients, got %d", errUnexpectedClientCount, len(originalClients)+1, len(updatedClients)) } newClient, err = FindNewClient(originalClients, updatedClients) @@ -1001,7 +906,7 @@ func (s *Scenario) AddAndLoginClient( } return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewConstantBackOff(backoffRetryInterval)), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) if err != nil { return nil, fmt.Errorf("timeout waiting for new client: %w", err) } @@ -1038,5 +943,6 @@ func (s *Scenario) MustAddAndLoginClient( client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...) require.NoError(t, err) + return client } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 42bb8e93..cfe89dfc 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -16,7 +16,7 @@ import ( "os" "path" "path/filepath" - "sort" + "slices" "strconv" "strings" "time" @@ -46,6 +46,7 @@ const ( tlsKeyPath = "/etc/headscale/tls.key" headscaleDefaultPort = 8080 IntegrationTestDockerFileName = "Dockerfile.integration" + dirPermissions = 0o755 ) var ( @@ -53,6 +54,9 @@ var ( errInvalidHeadscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_HEADSCALE_IMAGE format, expected repository:tag") errHeadscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_HEADSCALE_IMAGE must be set in CI") errInvalidPostgresImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_POSTGRES_IMAGE format, expected repository:tag") + errDatabaseEmptySchema = errors.New("database file exists but has no schema") + errDatabaseFileEmpty = errors.New("database file is empty") + errNoRegularFileInTar = errors.New("no regular file found in database tar archive") ) type fileInContainer struct { @@ -198,7 +202,7 @@ func WithPostgres() Option { } } -// WithPolicy sets the policy mode for headscale. +// WithPolicyMode sets the policy mode for headscale. func WithPolicyMode(mode types.PolicyMode) Option { return func(hsic *HeadscaleInContainer) { hsic.policyMode = mode @@ -717,20 +721,21 @@ func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { // extractTarToDirectory extracts a tar archive to a directory. func extractTarToDirectory(tarData []byte, targetDir string) error { - if err := os.MkdirAll(targetDir, 0o755); err != nil { + err := os.MkdirAll(targetDir, dirPermissions) + if err != nil { return fmt.Errorf("failed to create directory %s: %w", targetDir, err) } - tarReader := tar.NewReader(bytes.NewReader(tarData)) - // Find the top-level directory to strip var topLevelDir string + firstPass := tar.NewReader(bytes.NewReader(tarData)) for { header, err := firstPass.Next() if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -741,12 +746,13 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { } } - tarReader = tar.NewReader(bytes.NewReader(tarData)) + tarReader := tar.NewReader(bytes.NewReader(tarData)) for { header, err := tarReader.Next() if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -775,12 +781,15 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { switch header.Typeflag { case tar.TypeDir: // Create directory - if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + //nolint:gosec + err := os.MkdirAll(targetPath, os.FileMode(header.Mode)) + if err != nil { return fmt.Errorf("failed to create directory %s: %w", targetPath, err) } case tar.TypeReg: // Ensure parent directories exist - if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + //nolint:noinlineerr + if err := os.MkdirAll(filepath.Dir(targetPath), dirPermissions); err != nil { return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) } @@ -790,13 +799,16 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { return fmt.Errorf("failed to create file %s: %w", targetPath, err) } + //nolint:gosec,noinlineerr if _, err := io.Copy(outFile, tarReader); err != nil { outFile.Close() return fmt.Errorf("failed to copy file contents: %w", err) } + outFile.Close() // Set file permissions + //nolint:gosec,noinlineerr if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { return fmt.Errorf("failed to set file permissions: %w", err) } @@ -844,10 +856,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if the database file exists and has a schema dbPath := "/tmp/integration_test_db.sqlite3" + fileInfo, err := t.Execute([]string{"ls", "-la", dbPath}) if err != nil { return fmt.Errorf("database file does not exist at %s: %w", dbPath, err) } + log.Printf("Database file info: %s", fileInfo) // Check if the database has any tables (schema) @@ -857,7 +871,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } if strings.TrimSpace(schemaCheck) == "" { - return errors.New("database file exists but has no schema (empty database)") + return errDatabaseEmptySchema } tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") @@ -872,6 +886,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -886,13 +901,16 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Extract the first regular file we find if header.Typeflag == tar.TypeReg { dbPath := path.Join(savePath, t.hostname+".db") + outFile, err := os.Create(dbPath) if err != nil { return fmt.Errorf("failed to create database file: %w", err) } + //nolint:gosec written, err := io.Copy(outFile, tarReader) outFile.Close() + if err != nil { return fmt.Errorf("failed to copy database file: %w", err) } @@ -907,7 +925,8 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if we actually wrote something if written == 0 { return fmt.Errorf( - "database file is empty (size: %d, header size: %d)", + "%w (size: %d, header size: %d)", + errDatabaseFileEmpty, written, header.Size, ) @@ -917,7 +936,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { } } - return errors.New("no regular file found in database tar archive") + return errNoRegularFileInTar } // Execute runs a command inside the Headscale container and returns the @@ -1059,6 +1078,7 @@ func (t *HeadscaleInContainer) CreateUser( } var u v1.User + err = json.Unmarshal([]byte(result), &u) if err != nil { return nil, fmt.Errorf("failed to unmarshal user: %w", err) @@ -1195,6 +1215,7 @@ func (t *HeadscaleInContainer) ListNodes( users ...string, ) ([]*v1.Node, error) { var ret []*v1.Node + execUnmarshal := func(command []string) error { result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -1206,6 +1227,7 @@ func (t *HeadscaleInContainer) ListNodes( } var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) if err != nil { return fmt.Errorf("failed to unmarshal nodes: %w", err) @@ -1232,8 +1254,8 @@ func (t *HeadscaleInContainer) ListNodes( } } - sort.Slice(ret, func(i, j int) bool { - return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 + slices.SortFunc(ret, func(a, b *v1.Node) int { + return cmp.Compare(a.GetId(), b.GetId()) }) return ret, nil @@ -1245,7 +1267,7 @@ func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error { "nodes", "delete", "--identifier", - fmt.Sprintf("%d", nodeID), + strconv.FormatUint(nodeID, 10), "--output", "json", "--force", @@ -1309,6 +1331,7 @@ func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { } var users []*v1.User + err = json.Unmarshal([]byte(result), &users) if err != nil { return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) @@ -1439,6 +1462,7 @@ func (h *HeadscaleInContainer) PID() (int, error) { if pidInt == 1 { continue } + pids = append(pids, pidInt) } @@ -1494,6 +1518,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( } var node *v1.Node + err = json.Unmarshal([]byte(result), &node) if err != nil { return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err) @@ -1569,6 +1594,7 @@ func (t *HeadscaleInContainer) GetAllMapReponses() (map[types.NodeID][]tailcfg.M } var res map[types.NodeID][]tailcfg.MapResponse + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &res); err != nil { return nil, fmt.Errorf("decoding routes response: %w", err) } @@ -1589,6 +1615,7 @@ func (t *HeadscaleInContainer) PrimaryRoutes() (*routes.DebugRoutes, error) { } var debugRoutes routes.DebugRoutes + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &debugRoutes); err != nil { return nil, fmt.Errorf("decoding routes response: %w", err) } @@ -1609,6 +1636,7 @@ func (t *HeadscaleInContainer) DebugBatcher() (*hscontrol.DebugBatcherInfo, erro } var debugInfo hscontrol.DebugBatcherInfo + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &debugInfo); err != nil { return nil, fmt.Errorf("decoding batcher debug response: %w", err) } @@ -1629,6 +1657,7 @@ func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, er } var nodeStore map[types.NodeID]types.Node + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &nodeStore); err != nil { return nil, fmt.Errorf("decoding nodestore debug response: %w", err) } @@ -1649,6 +1678,7 @@ func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) { } var filterRules []tailcfg.FilterRule + //nolint:noinlineerr if err := json.Unmarshal([]byte(result), &filterRules); err != nil { return nil, fmt.Errorf("decoding filter response: %w", err) } diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 4ddc7ae9..0563999e 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -22,18 +22,29 @@ import ( "tailscale.com/tailcfg" ) +// Integration test timing constants. +const ( + // peerSyncTimeoutCI is the peer sync timeout for CI environments. + peerSyncTimeoutCI = 120 * time.Second + // peerSyncTimeoutDev is the peer sync timeout for development environments. + peerSyncTimeoutDev = 60 * time.Second + // peerSyncRetryIntervalMs is the retry interval for peer sync checks. + peerSyncRetryIntervalMs = 100 +) + // PeerSyncTimeout returns the timeout for peer synchronization based on environment: // 60s for dev, 120s for CI. func PeerSyncTimeout() time.Duration { if util.IsCI() { - return 120 * time.Second + return peerSyncTimeoutCI } - return 60 * time.Second + + return peerSyncTimeoutDev } // PeerSyncRetryInterval returns the retry interval for peer synchronization checks. func PeerSyncRetryInterval() time.Duration { - return 100 * time.Millisecond + return peerSyncRetryIntervalMs * time.Millisecond } func WriteFileToContainer( @@ -205,25 +216,27 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type res := make(map[types.NodeID]map[types.NodeID]bool) for nid, mrs := range all { res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { for _, peer := range mr.Peers { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec } } for _, peer := range mr.PeersChanged { if peer.Online != nil { - res[nid][types.NodeID(peer.ID)] = *peer.Online + res[nid][types.NodeID(peer.ID)] = *peer.Online //nolint:gosec } } for _, peer := range mr.PeersChangedPatch { if peer.Online != nil { - res[nid][types.NodeID(peer.NodeID)] = *peer.Online + res[nid][types.NodeID(peer.NodeID)] = *peer.Online //nolint:gosec } } } } + return res } diff --git a/integration/route_test.go b/integration/route_test.go index 0460b5ef..ea1cf3b8 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -7,7 +7,6 @@ import ( "maps" "net/netip" "slices" - "sort" "strconv" "strings" "testing" @@ -49,6 +48,7 @@ func TestEnablingRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -91,6 +91,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -127,6 +128,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route approvals to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -149,9 +151,11 @@ func TestEnablingRoutes(t *testing.T) { assert.NotNil(c, peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) } } @@ -172,6 +176,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -271,6 +276,7 @@ func TestHASubnetRouterFailover(t *testing.T) { prefp, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) + pref := *prefp t.Logf("usernet1 prefix: %s", pref.String()) @@ -287,11 +293,11 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf("webservice: %s, %s", webip.String(), weburl) // Sort nodes by ID - sort.SliceStable(allClients, func(i, j int) bool { - statusI := allClients[i].MustStatus() - statusJ := allClients[j].MustStatus() + slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { + statusA := a.MustStatus() + statusB := b.MustStatus() - return statusI.Self.ID < statusJ.Self.ID + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) // This is ok because the scenario makes users in order, so the three first @@ -310,6 +316,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String()) t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String()) t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time") + for _, client := range allClients[:3] { command := []string{ "tailscale", @@ -325,6 +332,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Wait for route configuration changes after advertising routes var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -364,10 +372,12 @@ func TestHASubnetRouterFailover(t *testing.T) { checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) { if t.Failed() { t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat)) + status, err := client.Status() if err == nil { printCurrentRouteMap(t, xmaps.Values(status.Peer)...) } + t.FailNow() } } @@ -386,6 +396,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String()) t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes") t.Logf(" Expected: Client can access webservice through router 1 only") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -456,10 +467,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1") @@ -483,6 +496,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)") t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter2.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -494,6 +508,7 @@ func TestHASubnetRouterFailover(t *testing.T) { nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 6) + if len(nodes) >= 3 { requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) @@ -569,10 +584,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode") @@ -598,6 +615,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY") t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -672,12 +690,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") var expectedIP netip.Addr + for _, ip := range ips { if ip.Is4() { expectedIP = ip break } } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") assertTracerouteViaIPWithCollect(c, tr, expectedIP) @@ -754,10 +774,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover") @@ -825,10 +847,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover") @@ -853,6 +877,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)") t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)") t.Logf(" Expected: HA is restored with 2 routers available") + err = subRouter1.Up() require.NoError(t, err) @@ -902,10 +927,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery") @@ -932,6 +959,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname()) t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname()) t.Logf(" Expected: Full HA restored with all 3 routers online") + err = subRouter2.Up() require.NoError(t, err) @@ -982,10 +1010,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery") @@ -1067,10 +1097,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable") @@ -1153,10 +1185,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable") @@ -1182,6 +1216,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname()) t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname()) t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY") + r1Node := MustFindNode(subRouter1.Hostname(), nodes) _, err = headscale.ApproveRoutes( r1Node.GetId(), @@ -1237,10 +1272,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable") @@ -1266,6 +1303,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname()) t.Logf(" Expected: Routers 1 & 3 are both STANDBY") t.Logf(" Expected: Full HA restored with all 3 routers available") + r3Node := MustFindNode(subRouter3.Hostname(), nodes) _, err = headscale.ApproveRoutes( r3Node.GetId(), @@ -1315,6 +1353,7 @@ func TestSubnetRouteACL(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1359,10 +1398,11 @@ func TestSubnetRouteACL(t *testing.T) { } // Sort nodes by ID - sort.SliceStable(allClients, func(i, j int) bool { - statusI := allClients[i].MustStatus() - statusJ := allClients[j].MustStatus() - return statusI.Self.ID < statusJ.Self.ID + slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { + statusA := a.MustStatus() + statusB := b.MustStatus() + + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) subRouter1 := allClients[0] @@ -1391,15 +1431,20 @@ func TestSubnetRouteACL(t *testing.T) { // Wait for route advertisements to propagate to the server var nodes []*v1.Node + require.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) // Find the node that should have the route by checking node IDs - var routeNode *v1.Node - var otherNode *v1.Node + var ( + routeNode *v1.Node + otherNode *v1.Node + ) + for _, node := range nodes { nodeIDStr := strconv.FormatUint(node.GetId(), 10) if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { @@ -1462,6 +1507,7 @@ func TestSubnetRouteACL(t *testing.T) { srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + if srs1PeerStatus == nil { return } @@ -1564,7 +1610,7 @@ func TestSubnetRouteACL(t *testing.T) { func TestEnablingExitRoutes(t *testing.T) { IntegrationSkip(t) - user := "user2" + user := "user2" //nolint:goconst spec := ScenarioSpec{ NodesPerUser: 2, @@ -1572,6 +1618,7 @@ func TestEnablingExitRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) @@ -1593,8 +1640,10 @@ func TestEnablingExitRoutes(t *testing.T) { requireNoErrSync(t, err) var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1652,6 +1701,7 @@ func TestEnablingExitRoutes(t *testing.T) { peerStatus := status.Peer[peerKey] assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) @@ -1682,6 +1732,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1712,10 +1763,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1732,6 +1785,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1762,6 +1816,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1779,6 +1834,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) } }, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") @@ -1805,10 +1861,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := user2c.Traceroute(webip) assert.NoError(c, err) + ip, err := user1c.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for user1c") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router") } @@ -1829,6 +1887,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1856,10 +1915,12 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1876,6 +1937,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1958,6 +2020,7 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { return node } } + panic("node not found") } @@ -1977,6 +2040,8 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { // - Verify that peers can no longer use node // - Policy is changed back to auto approve route, check that routes already existing is approved. // - Verify that routes can now be seen by peers. +// +//nolint:gocyclo func TestAutoApproveMultiNetwork(t *testing.T) { IntegrationSkip(t) @@ -2241,10 +2306,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } scenario, err := NewScenario(tt.spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) var nodes []*v1.Node + opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -2300,6 +2367,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the Docker network route to the auto-approvers // Keep existing auto-approvers (like bigRoute) in place var approvers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): approvers = append(approvers, tagApprover(tt.approver)) @@ -2368,6 +2436,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } else { pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) } + require.NoError(t, err) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) @@ -2403,11 +2472,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { t.Logf("webservice: %s, %s", webip.String(), weburl) // Sort nodes by ID - sort.SliceStable(allClients, func(i, j int) bool { - statusI := allClients[i].MustStatus() - statusJ := allClients[j].MustStatus() + slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { + statusA := a.MustStatus() + statusB := b.MustStatus() - return statusI.Self.ID < statusJ.Self.ID + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) // This is ok because the scenario makes users in order, so the three first @@ -2459,11 +2528,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) routerPeerFound := false + for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] if peerStatus.ID == routerUsernet1ID.StableID() { routerPeerFound = true + t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", peerStatus.HostName, peerStatus.ID, @@ -2471,9 +2542,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2510,10 +2583,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router") @@ -2550,9 +2625,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2572,10 +2649,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change") @@ -2609,6 +2688,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the route back to the auto approver in the policy, the route should // now become available again. var newApprovers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): newApprovers = append(newApprovers, tagApprover(tt.approver)) @@ -2642,9 +2722,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2664,10 +2746,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval") @@ -2703,11 +2787,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "2" { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2745,9 +2831,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2785,6 +2873,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "3" { requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) @@ -2800,6 +2889,8 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } // assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT. +// +//nolint:testifylint func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { assert.NotNil(c, tr) assert.True(c, tr.Success) @@ -2817,12 +2908,16 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int { } func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) { + t.Helper() + t.Logf("== Current routing map ==") slices.SortFunc(routers, SortPeerStatus) + for _, router := range routers { got := filterNonRoutes(router) t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID) t.Logf(" AllowedIPs: %v", got) + if router.PrimaryRoutes != nil { t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice()) } @@ -2835,6 +2930,7 @@ func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix { if tsaddr.IsExitRoute(p) { return true } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) }) } @@ -2886,6 +2982,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -3026,6 +3123,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { // List nodes and verify the router has 3 available routes var err error + nodes, err := headscale.NodesByUser() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -3061,10 +3159,12 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := nodeClient.Traceroute(webip) assert.NoError(c, err) + ip, err := routerClient.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerClient") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router") } diff --git a/integration/scenario.go b/integration/scenario.go index 35fee73e..743c5830 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -46,6 +46,8 @@ import ( const ( scenarioHashLength = 6 + // expectedHTMLSplitParts is the expected number of parts when splitting HTML for key extraction. + expectedHTMLSplitParts = 2 ) var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES") @@ -54,6 +56,13 @@ var ( errNoHeadscaleAvailable = errors.New("no headscale available") errNoUserAvailable = errors.New("no user available") errNoClientFound = errors.New("client not found") + errUserAlreadyInNetwork = errors.New("users can only have nodes placed in one network") + errNoNetworkNamed = errors.New("no network named") + errNoIPAMConfig = errors.New("no IPAM config found in network") + errHTTPClientNil = errors.New("http client is nil") + errLoginURLNil = errors.New("login url is nil") + errUnexpectedStatusCode = errors.New("unexpected status code") + errNetworkDoesNotExist = errors.New("network does not exist") // AllVersions represents a list of Tailscale versions the suite // uses to test compatibility with the ControlServer. @@ -96,7 +105,7 @@ type User struct { type Scenario struct { // TODO(kradalby): support multiple headcales for later, currently only // use one. - controlServers *xsync.MapOf[string, ControlServer] + controlServers *xsync.Map[string, ControlServer] derpServers []*dsic.DERPServerInContainer users map[string]*User @@ -169,8 +178,8 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { // Opportunity to clean up unreferenced networks. // This might be a no op, but it is worth a try as we sometime // dont clean up nicely after ourselves. - dockertestutil.CleanUnreferencedNetworks(pool) - dockertestutil.CleanImagesInCI(pool) + _ = dockertestutil.CleanUnreferencedNetworks(pool) + _ = dockertestutil.CleanImagesInCI(pool) if spec.MaxWait == 0 { pool.MaxWait = dockertestMaxWait() @@ -180,7 +189,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { testHashPrefix := "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) s := &Scenario{ - controlServers: xsync.NewMapOf[string, ControlServer](), + controlServers: xsync.NewMap[string, ControlServer](), users: make(map[string]*User), pool: pool, @@ -191,9 +200,11 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { } var userToNetwork map[string]*dockertest.Network + if spec.Networks != nil || len(spec.Networks) != 0 { for name, users := range s.spec.Networks { networkName := testHashPrefix + "-" + name + network, err := s.AddNetwork(networkName) if err != nil { return nil, err @@ -201,8 +212,9 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { for _, user := range users { if n2, ok := userToNetwork[user]; ok { - return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) + return nil, fmt.Errorf("%w: %s into %s but already in %s", errUserAlreadyInNetwork, user, network.Network.Name, n2.Network.Name) } + mak.Set(&userToNetwork, user, network) } } @@ -219,6 +231,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if err != nil { return nil, err } + mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc)) } } @@ -230,6 +243,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if spec.OIDCAccessTTL != 0 { ttl = spec.OIDCAccessTTL } + err = s.runMockOIDC(ttl, spec.OIDCUsers) if err != nil { return nil, err @@ -268,13 +282,14 @@ func (s *Scenario) Networks() []*dockertest.Network { if len(s.networks) == 0 { panic("Scenario.Networks called with empty network list") } + return xmaps.Values(s.networks) } func (s *Scenario) Network(name string) (*dockertest.Network, error) { net, ok := s.networks[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("%w: %s", errNoNetworkNamed, name) } return net, nil @@ -283,11 +298,11 @@ func (s *Scenario) Network(name string) (*dockertest.Network, error) { func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { net, ok := s.networks[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("%w: %s", errNoNetworkNamed, name) } if len(net.Network.IPAM.Config) == 0 { - return nil, fmt.Errorf("no IPAM config found in network: %s", name) + return nil, fmt.Errorf("%w: %s", errNoIPAMConfig, name) } pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) @@ -301,15 +316,17 @@ func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { res, ok := s.extraServices[s.prefixedNetworkName(name)] if !ok { - return nil, fmt.Errorf("no network named: %s", name) + return nil, fmt.Errorf("%w: %s", errNoNetworkNamed, name) } return res, nil } func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { - defer dockertestutil.CleanUnreferencedNetworks(s.pool) - defer dockertestutil.CleanImagesInCI(s.pool) + t.Helper() + + defer func() { _ = dockertestutil.CleanUnreferencedNetworks(s.pool) }() + defer func() { _ = dockertestutil.CleanImagesInCI(s.pool) }() s.controlServers.Range(func(_ string, control ControlServer) bool { stdoutPath, stderrPath, err := control.Shutdown() @@ -337,6 +354,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { for userName, user := range s.users { for _, client := range user.Clients { log.Printf("removing client %s in user %s", client.Hostname(), userName) + stdoutPath, stderrPath, err := client.Shutdown() if err != nil { log.Printf("failed to tear down client: %s", err) @@ -353,6 +371,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { } } } + s.mu.Unlock() for _, derp := range s.derpServers { @@ -373,13 +392,16 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { if s.mockOIDC.r != nil { s.mockOIDC.r.Close() - if err := s.mockOIDC.r.Close(); err != nil { + + err := s.mockOIDC.r.Close() + if err != nil { log.Printf("failed to tear down oidc server: %s", err) } } for _, network := range s.networks { - if err := network.Close(); err != nil { + err := network.Close() + if err != nil { log.Printf("failed to tear down network: %s", err) } } @@ -552,6 +574,7 @@ func (s *Scenario) CreateTailscaleNode( s.mu.Lock() defer s.mu.Unlock() + opts = append(opts, tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), @@ -591,6 +614,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) error { if user, ok := s.users[userStr]; ok { var versions []string + for i := range count { version := requestedVersion if requestedVersion == "all" { @@ -749,11 +773,14 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat for _, client := range user.Clients { c := client expectedCount := expectedPeers + user.syncWaitGroup.Go(func() error { return c.WaitForPeers(expectedCount, timeout, retryInterval) }) } - if err := user.syncWaitGroup.Wait(); err != nil { + + err := user.syncWaitGroup.Wait() + if err != nil { allErrors = append(allErrors, err) } } @@ -871,6 +898,7 @@ func (s *Scenario) createHeadscaleEnvWithTags( } else { key, err = s.CreatePreAuthKey(u.GetId(), true, false) } + if err != nil { return err } @@ -887,9 +915,11 @@ func (s *Scenario) createHeadscaleEnvWithTags( func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { tsc := client + user.joinWaitGroup.Go(func() error { loginURL, err := tsc.LoginWithURL(loginServer) if err != nil { @@ -904,7 +934,7 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { // If the URL is not a OIDC URL, then we need to // run the register command to fully log in the client. if !strings.Contains(loginURL.String(), "/oidc/") { - s.runHeadscaleRegister(userStr, body) + _ = s.runHeadscaleRegister(userStr, body) } return nil @@ -913,7 +943,8 @@ func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("client %s is ready", client.Hostname()) } - if err := user.joinWaitGroup.Wait(); err != nil { + err := user.joinWaitGroup.Wait() + if err != nil { return err } @@ -945,6 +976,7 @@ func newDebugJar() (*debugJar, error) { if err != nil { return nil, err } + return &debugJar{ inner: jar, store: make(map[string]map[string]map[string]*http.Cookie), @@ -961,20 +993,25 @@ func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) { if c == nil || c.Name == "" { continue } + domain := c.Domain if domain == "" { domain = u.Hostname() } + path := c.Path if path == "" { path = "/" } + if _, ok := j.store[domain]; !ok { j.store[domain] = make(map[string]map[string]*http.Cookie) } + if _, ok := j.store[domain][path]; !ok { j.store[domain][path] = make(map[string]*http.Cookie) } + j.store[domain][path][c.Name] = copyCookie(c) } } @@ -989,8 +1026,10 @@ func (j *debugJar) Dump(w io.Writer) { for domain, paths := range j.store { fmt.Fprintf(w, "Domain: %s\n", domain) + for path, byName := range paths { fmt.Fprintf(w, " Path: %s\n", path) + for _, c := range byName { fmt.Fprintf( w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n", @@ -1046,15 +1085,17 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f error, ) { if hc == nil { - return "", nil, fmt.Errorf("%s http client is nil", hostname) + return "", nil, fmt.Errorf("%s %w", hostname, errHTTPClientNil) } if loginURL == nil { - return "", nil, fmt.Errorf("%s login url is nil", hostname) + return "", nil, fmt.Errorf("%s %w", hostname, errLoginURLNil) } log.Printf("%s logging in with url: %s", hostname, loginURL.String()) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) if err != nil { return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err) @@ -1066,6 +1107,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f return http.ErrUseLastResponse } } + defer func() { hc.CheckRedirect = originalRedirect }() @@ -1080,6 +1122,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if err != nil { return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err) } + body := string(bodyBytes) var redirectURL *url.URL @@ -1093,13 +1136,13 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if followRedirects && resp.StatusCode != http.StatusOK { log.Printf("body: %s", body) - return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + return body, redirectURL, fmt.Errorf("%s %w %d", hostname, errUnexpectedStatusCode, resp.StatusCode) } if resp.StatusCode >= http.StatusBadRequest { log.Printf("body: %s", body) - return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + return body, redirectURL, fmt.Errorf("%s %w %d", hostname, errUnexpectedStatusCode, resp.StatusCode) } if hc.Jar != nil { @@ -1117,19 +1160,21 @@ var errParseAuthPage = errors.New("failed to parse auth page") func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { // see api.go HTML template - codeSep := strings.Split(string(body), "") - if len(codeSep) != 2 { + codeSep := strings.Split(body, "") + if len(codeSep) != expectedHTMLSplitParts { return errParseAuthPage } keySep := strings.Split(codeSep[0], "key ") - if len(keySep) != 2 { + if len(keySep) != expectedHTMLSplitParts { return errParseAuthPage } + key := keySep[1] - key = strings.SplitN(key, " ", 2)[0] + key = strings.SplitN(key, " ", expectedHTMLSplitParts)[0] log.Printf("registering node %s", key) + //nolint:noinlineerr if headscale, err := s.Headscale(); err == nil { _, err = headscale.Execute( []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, @@ -1154,6 +1199,7 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error noTls := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint } + resp, err := noTls.RoundTrip(req) if err != nil { return nil, err @@ -1361,6 +1407,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse if err != nil { log.Fatalf("could not find an open port: %s", err) } + portNotation := fmt.Sprintf("%d/tcp", port) hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) @@ -1405,6 +1452,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse // Add integration test labels if running under hi tool dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") + //nolint:noinlineerr if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( headscaleBuildOptions, mockOidcOptions, @@ -1421,8 +1469,10 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse ipAddr := s.mockOIDC.r.GetIPInNetwork(network) log.Println("Waiting for headscale mock oidc to be ready for tests") + hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) + //nolint:noinlineerr if err := s.pool.Retry(func() error { oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) httpClient := &http.Client{} @@ -1468,14 +1518,13 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { // log.Fatalf("could not find an open port: %s", err) // } // portNotation := fmt.Sprintf("%d/tcp", port) - hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := "hs-webservice-" + hash network, ok := s.networks[s.prefixedNetworkName(networkName)] if !ok { - return nil, fmt.Errorf("network does not exist: %s", networkName) + return nil, fmt.Errorf("%w: %s", errNetworkDoesNotExist, networkName) } webOpts := &dockertest.RunOptions{ diff --git a/integration/scenario_test.go b/integration/scenario_test.go index 1e2a151a..71998fca 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -35,6 +35,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -83,6 +84,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 1ca291c0..8329155f 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) func isSSHNoAccessStdError(stderr string) bool { @@ -453,7 +452,7 @@ func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) { t.Helper() - assert.Error(t, err) + require.Error(t, err) if !isSSHNoAccessStdError(stderr) { t.Errorf("expected stderr output suggesting access denied, got: %s", stderr) @@ -482,10 +481,10 @@ func TestSSHAutogroupSelf(t *testing.T) { { Action: "accept", Sources: policyv2.SSHSrcAliases{ - ptr.To(policyv2.AutoGroupMember), + new(policyv2.AutoGroupMember), }, Destinations: policyv2.SSHDstAliases{ - ptr.To(policyv2.AutoGroupSelf), + new(policyv2.AutoGroupSelf), }, Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, diff --git a/integration/tags_test.go b/integration/tags_test.go index 5dad36e5..9c8391c2 100644 --- a/integration/tags_test.go +++ b/integration/tags_test.go @@ -1,7 +1,7 @@ package integration import ( - "sort" + "slices" "testing" "time" @@ -13,7 +13,6 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "tailscale.com/tailcfg" - "tailscale.com/types/ptr" ) const tagTestUser = "taguser" @@ -30,9 +29,9 @@ const tagTestUser = "taguser" func tagsTestPolicy() *policyv2.Policy { return &policyv2.Policy{ TagOwners: policyv2.TagOwners{ - "tag:valid-owned": policyv2.Owners{ptr.To(policyv2.Username(tagTestUser + "@"))}, - "tag:second": policyv2.Owners{ptr.To(policyv2.Username(tagTestUser + "@"))}, - "tag:valid-unowned": policyv2.Owners{ptr.To(policyv2.Username("other-user@"))}, + "tag:valid-owned": policyv2.Owners{new(policyv2.Username(tagTestUser + "@"))}, + "tag:second": policyv2.Owners{new(policyv2.Username(tagTestUser + "@"))}, + "tag:valid-unowned": policyv2.Owners{new(policyv2.Username("other-user@"))}, // Note: tag:nonexistent deliberately NOT defined }, ACLs: []policyv2.ACL{ @@ -51,11 +50,11 @@ func tagsEqual(actual, expected []string) bool { return false } - sortedActual := append([]string{}, actual...) - sortedExpected := append([]string{}, expected...) + sortedActual := slices.Clone(actual) + sortedExpected := slices.Clone(expected) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) for i := range sortedActual { if sortedActual[i] != sortedExpected[i] { @@ -69,11 +68,11 @@ func tagsEqual(actual, expected []string) bool { // assertNodeHasTagsWithCollect asserts that a node has exactly the expected tags (order-independent). func assertNodeHasTagsWithCollect(c *assert.CollectT, node *v1.Node, expectedTags []string) { actualTags := node.GetTags() - sortedActual := append([]string{}, actualTags...) - sortedExpected := append([]string{}, expectedTags...) + sortedActual := slices.Clone(actualTags) + sortedExpected := slices.Clone(expectedTags) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) assert.Equal(c, sortedExpected, sortedActual, "Node %s tags mismatch", node.GetName()) } @@ -86,7 +85,7 @@ func assertNodeHasNoTagsWithCollect(c *assert.CollectT, node *v1.Node) { // This validates that tag updates have propagated to the node's own status (issue #2978). func assertNodeSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient, expectedTags []string) { status, err := client.Status() - //nolint:testifylint // must use assert with CollectT in EventuallyWithT + //nolint:testifylint assert.NoError(c, err, "failed to get client status") if status == nil || status.Self == nil { @@ -102,11 +101,11 @@ func assertNodeSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient } } - sortedActual := append([]string{}, actualTagsSlice...) - sortedExpected := append([]string{}, expectedTags...) + sortedActual := slices.Clone(actualTagsSlice) + sortedExpected := slices.Clone(expectedTags) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) assert.Equal(c, sortedExpected, sortedActual, "Client %s self tags mismatch", client.Hostname()) } @@ -557,7 +556,7 @@ func TestTagsAuthKeyWithTagAdminOverrideReauthPreserves(t *testing.T) { "--authkey=" + authKey.GetKey(), "--force-reauth", } - //nolint:errcheck // Intentionally ignoring error - we check results below + //nolint:errcheck client.Execute(command) // Verify admin tags are preserved even after reauth - admin decisions are authoritative (server-side) @@ -2491,7 +2490,7 @@ func TestTagsAdminAPICannotRemoveAllTags(t *testing.T) { // This validates at a deeper level than status - directly from tailscale debug netmap. func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient, expectedTags []string) { nm, err := client.Netmap() - //nolint:testifylint // must use assert with CollectT in EventuallyWithT + //nolint:testifylint assert.NoError(c, err, "failed to get client netmap") if nm == nil { @@ -2502,16 +2501,17 @@ func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClie var actualTagsSlice []string if nm.SelfNode.Valid() { + //nolint:unqueryvet for _, tag := range nm.SelfNode.Tags().All() { actualTagsSlice = append(actualTagsSlice, tag) } } - sortedActual := append([]string{}, actualTagsSlice...) - sortedExpected := append([]string{}, expectedTags...) + sortedActual := slices.Clone(actualTagsSlice) + sortedExpected := slices.Clone(expectedTags) - sort.Strings(sortedActual) - sort.Strings(sortedExpected) + slices.Sort(sortedActual) + slices.Sort(sortedExpected) assert.Equal(c, sortedExpected, sortedActual, "Client %s netmap self tags mismatch", client.Hostname()) } @@ -2624,7 +2624,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { // We wait 10 seconds and check - if the client STILL shows the OLD tag, // that demonstrates the bug. If the client shows the NEW tag, the bug is fixed. t.Log("Step 2b: Waiting 10 seconds to see if client self view updates (bug: it should NOT)") - //nolint:forbidigo // intentional sleep to demonstrate bug timing - client should get update immediately, not after waiting + //nolint:forbidigo time.Sleep(10 * time.Second) // Check client status after waiting @@ -2647,6 +2647,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { var netmapTagsAfterFirstCall []string if nmErr == nil && nm != nil && nm.SelfNode.Valid() { + //nolint:unqueryvet for _, tag := range nm.SelfNode.Tags().All() { netmapTagsAfterFirstCall = append(netmapTagsAfterFirstCall, tag) } @@ -2693,7 +2694,7 @@ func TestTagsIssue2978ReproTagReplacement(t *testing.T) { // Wait and check - bug means client still shows old tag t.Log("Step 4b: Waiting 10 seconds to see if client self view updates (bug: it should NOT)") - //nolint:forbidigo // intentional sleep to demonstrate bug timing - client should get update immediately, not after waiting + //nolint:forbidigo time.Sleep(10 * time.Second) status, err = client.Status() diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index fb07896b..6c7228a0 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -44,6 +44,13 @@ const ( dockerContextPath = "../." caCertRoot = "/usr/local/share/ca-certificates" dockerExecuteTimeout = 60 * time.Second + + // Container restart and backoff timeouts. + containerRestartTimeout = 30 // seconds, used by Docker API + tailscaleVersionTimeout = 5 * time.Second + containerRestartBackoff = 30 * time.Second + backoffMaxElapsedTime = 10 * time.Second + curlFailFastMaxTime = 2 * time.Second ) var ( @@ -59,6 +66,17 @@ var ( errTailscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE must be set in CI for HEAD version") errContainerNotInitialized = errors.New("container not initialized") errFQDNNotYetAvailable = errors.New("FQDN not yet available") + errNoNetworkSet = errors.New("no network set") + errLogoutFailed = errors.New("failed to logout") + errNoIPsReturned = errors.New("no IPs returned yet") + errNoIPv4AddressFound = errors.New("no IPv4 address found") + errBackendStateTimeout = errors.New("timeout waiting for backend state") + errPeerWaitTimeout = errors.New("timeout waiting for peers") + errPeerNotOnline = errors.New("peer is not online") + errPeerNoHostname = errors.New("peer does not have a hostname") + errPeerNoDERP = errors.New("peer does not have a DERP relay") + errFileEmpty = errors.New("file is empty") + errTailscaleVersionRequired = errors.New("tailscale version requirement not met") ) const ( @@ -338,7 +356,7 @@ func New( } if tsic.network == nil { - return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) + return nil, fmt.Errorf("%w, called from: \n%s", errNoNetworkSet, string(debug.Stack())) } tailscaleOptions := &dockertest.RunOptions{ @@ -621,7 +639,7 @@ func (t *TailscaleInContainer) Execute( return stdout, stderr, nil } -// Retrieve container logs. +// Logs retrieves container logs. func (t *TailscaleInContainer) Logs(stdout, stderr io.Writer) error { return dockertestutil.WriteLog( t.pool, @@ -713,14 +731,14 @@ func (t *TailscaleInContainer) LoginWithURL( // Logout runs the logout routine on the given Tailscale instance. func (t *TailscaleInContainer) Logout() error { - stdout, stderr, err := t.Execute([]string{"tailscale", "logout"}) + _, _, err := t.Execute([]string{"tailscale", "logout"}) if err != nil { return err } - stdout, stderr, _ = t.Execute([]string{"tailscale", "status"}) + stdout, stderr, _ := t.Execute([]string{"tailscale", "status"}) if !strings.Contains(stdout+stderr, "Logged out.") { - return fmt.Errorf("failed to logout, stdout: %s, stderr: %s", stdout, stderr) + return fmt.Errorf("%w: stdout: %s, stderr: %s", errLogoutFailed, stdout, stderr) } return t.waitForBackendState("NeedsLogin", integrationutil.PeerSyncTimeout()) @@ -736,7 +754,7 @@ func (t *TailscaleInContainer) Restart() error { } // Use Docker API to restart the container - err := t.pool.Client.RestartContainer(t.container.Container.ID, 30) + err := t.pool.Client.RestartContainer(t.container.Container.ID, containerRestartTimeout) if err != nil { return fmt.Errorf("failed to restart container %s: %w", t.hostname, err) } @@ -745,13 +763,13 @@ func (t *TailscaleInContainer) Restart() error { // We use exponential backoff to poll until we can successfully execute a command _, err = backoff.Retry(context.Background(), func() (struct{}, error) { // Try to execute a simple command to verify the container is responsive - _, _, err := t.Execute([]string{"tailscale", "version"}, dockertestutil.ExecuteCommandTimeout(5*time.Second)) + _, _, err := t.Execute([]string{"tailscale", "version"}, dockertestutil.ExecuteCommandTimeout(tailscaleVersionTimeout)) if err != nil { return struct{}{}, fmt.Errorf("container not ready: %w", err) } return struct{}{}, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(30*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(containerRestartBackoff)) if err != nil { return fmt.Errorf("timeout waiting for container %s to restart and become ready: %w", t.hostname, err) } @@ -832,11 +850,11 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { } if len(ips) == 0 { - return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname) + return nil, fmt.Errorf("%w for %s", errNoIPsReturned, t.hostname) } return ips, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) if err != nil { return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err) } @@ -866,7 +884,7 @@ func (t *TailscaleInContainer) IPv4() (netip.Addr, error) { } } - return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname) + return netip.Addr{}, fmt.Errorf("%w for %s", errNoIPv4AddressFound, t.hostname) } func (t *TailscaleInContainer) MustIPv4() netip.Addr { @@ -1140,7 +1158,7 @@ func (t *TailscaleInContainer) FQDN() (string, error) { } return status.Self.DNSName, nil - }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(backoffMaxElapsedTime)) if err != nil { return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err) } @@ -1211,7 +1229,7 @@ func (t *TailscaleInContainer) waitForBackendState(state string, timeout time.Du for { select { case <-ctx.Done(): - return fmt.Errorf("timeout waiting for backend state %s on %s after %v", state, t.hostname, timeout) + return fmt.Errorf("%w %s on %s after %v", errBackendStateTimeout, state, t.hostname, timeout) case <-ticker.C: status, err := t.Status() if err != nil { @@ -1253,10 +1271,10 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval select { case <-ctx.Done(): if len(lastErrs) > 0 { - return fmt.Errorf("timeout waiting for %d peers on %s after %v, errors: %w", expected, t.hostname, timeout, multierr.New(lastErrs...)) + return fmt.Errorf("%w for %d peers on %s after %v, errors: %w", errPeerWaitTimeout, expected, t.hostname, timeout, multierr.New(lastErrs...)) } - return fmt.Errorf("timeout waiting for %d peers on %s after %v", expected, t.hostname, timeout) + return fmt.Errorf("%w for %d peers on %s after %v", errPeerWaitTimeout, expected, t.hostname, timeout) case <-ticker.C: status, err := t.Status() if err != nil { @@ -1284,15 +1302,15 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval peer := status.Peer[peerKey] if !peer.Online { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %w: %s", t.hostname, errPeerNotOnline, peer.HostName)) } if peer.HostName == "" { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %w: %s", t.hostname, errPeerNoHostname, peer.HostName)) } if peer.Relay == "" { - peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %w: %s", t.hostname, errPeerNoDERP, peer.HostName)) } } @@ -1496,7 +1514,7 @@ func (t *TailscaleInContainer) CurlFailFast(url string) (string, error) { // Use aggressive timeouts for fast failure detection return t.Curl(url, WithCurlConnectionTimeout(1*time.Second), - WithCurlMaxTime(2*time.Second), + WithCurlMaxTime(curlFailFastMaxTime), WithCurlRetry(1)) } @@ -1578,7 +1596,7 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) { } if out.Len() == 0 { - return nil, errors.New("file is empty") + return nil, errFileEmpty } return out.Bytes(), nil @@ -1591,6 +1609,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { } store := &mem.Store{} + //nolint:noinlineerr if err = store.LoadFromJSON(state); err != nil { return nil, fmt.Errorf("failed to unmarshal state file: %w", err) } @@ -1606,6 +1625,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { } p := &ipn.Prefs{} + //nolint:noinlineerr if err = json.Unmarshal(currentProfile, &p); err != nil { return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err) } @@ -1617,7 +1637,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { // This is useful for verifying that policy changes have propagated to the client. func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) { if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { - return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version) + return nil, fmt.Errorf("%w: PacketFilter() requires Tailscale 1.56+, current version: %s", errTailscaleVersionRequired, t.version) } nm, err := t.Netmap()