From ad7669a2d40d6631311e6498f014546dd78d4d6f Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Tue, 20 Jan 2026 14:37:24 +0000 Subject: [PATCH] all: apply golangci-lint auto-fixes Apply auto-fixes from golangci-lint for the following linters: - wsl_v5: whitespace formatting and blank line adjustments - godot: add periods to comment sentences - nlreturn: add newlines before return statements - perfsprint: optimize fmt.Sprintf to more efficient alternatives Also add missing imports (errors, encoding/hex) where auto-fix added new code patterns that require them. --- cmd/headscale/cli/mockoidc.go | 2 + cmd/headscale/cli/policy.go | 3 + cmd/headscale/cli/root.go | 2 + cmd/headscale/cli/users.go | 1 + cmd/hi/cleanup.go | 5 + cmd/hi/docker.go | 34 ++++- cmd/hi/doctor.go | 4 + cmd/hi/main.go | 2 + cmd/hi/run.go | 7 +- cmd/hi/stats.go | 21 +++- cmd/mapresponses/main.go | 4 +- hscontrol/app.go | 25 +++- hscontrol/auth.go | 7 +- hscontrol/auth_test.go | 90 +++++++++++-- hscontrol/db/db.go | 5 + hscontrol/db/db_test.go | 3 + .../db/ephemeral_garbage_collector_test.go | 79 +++++++++--- hscontrol/db/ip_test.go | 1 + hscontrol/db/node.go | 6 + hscontrol/db/node_test.go | 7 ++ hscontrol/db/sqliteconfig/config.go | 5 + hscontrol/db/sqliteconfig/config_test.go | 2 + hscontrol/db/sqliteconfig/integration_test.go | 3 + hscontrol/db/text_serialiser.go | 3 + hscontrol/db/users.go | 2 + hscontrol/debug.go | 30 +++++ hscontrol/derp/derp.go | 3 + hscontrol/derp/derp_test.go | 2 + hscontrol/derp/server/derp_server.go | 11 +- hscontrol/dns/extrarecords.go | 6 + hscontrol/handlers.go | 2 + hscontrol/mapper/batcher_lockfree.go | 26 +++- hscontrol/mapper/batcher_test.go | 24 ++++ hscontrol/mapper/builder.go | 6 +- hscontrol/mapper/mapper.go | 6 +- hscontrol/mapper/mapper_test.go | 7 ++ hscontrol/noise.go | 3 + hscontrol/oidc.go | 14 +++ hscontrol/policy/matcher/matcher.go | 3 + hscontrol/policy/pm.go | 8 +- hscontrol/policy/policy.go | 1 + hscontrol/policy/policy_autoapprove_test.go | 7 +- hscontrol/policy/policy_test.go | 22 +++- hscontrol/policy/policyutil/reduce.go | 1 + hscontrol/policy/policyutil/reduce_test.go | 8 +- hscontrol/policy/route_approval_test.go | 2 + hscontrol/policy/v2/filter.go | 27 +++- hscontrol/policy/v2/filter_test.go | 33 +++-- hscontrol/policy/v2/policy.go | 28 ++++- hscontrol/policy/v2/policy_test.go | 8 +- hscontrol/policy/v2/types.go | 97 +++++++++++--- hscontrol/policy/v2/types_test.go | 13 +- hscontrol/policy/v2/utils.go | 3 + hscontrol/policy/v2/utils_test.go | 4 + hscontrol/poll.go | 4 + hscontrol/routes/primary.go | 9 ++ hscontrol/routes/primary_test.go | 19 ++- hscontrol/state/debug.go | 10 ++ hscontrol/state/ephemeral_test.go | 30 ++++- hscontrol/state/maprequest.go | 1 + hscontrol/state/maprequest_test.go | 2 +- hscontrol/state/node_store.go | 25 ++++ hscontrol/state/node_store_test.go | 119 +++++++++++++++--- hscontrol/tailsql.go | 1 + hscontrol/types/common.go | 1 + hscontrol/types/config.go | 3 + hscontrol/types/config_test.go | 4 + hscontrol/types/node.go | 5 +- hscontrol/types/preauth_key.go | 1 + hscontrol/types/users.go | 9 ++ hscontrol/types/users_test.go | 4 + hscontrol/util/dns_test.go | 2 + hscontrol/util/prompt.go | 2 + hscontrol/util/prompt_test.go | 8 ++ hscontrol/util/string.go | 2 + hscontrol/util/util.go | 22 +++- hscontrol/util/util_test.go | 15 +++ integration/api_auth_test.go | 63 +++++++--- integration/auth_key_test.go | 38 +++++- integration/auth_oidc_test.go | 39 ++++++ integration/auth_web_flow_test.go | 14 +++ integration/derp_verify_endpoint_test.go | 3 + integration/dockertestutil/config.go | 1 + integration/dockertestutil/execute.go | 2 + integration/dockertestutil/logs.go | 1 + integration/dockertestutil/network.go | 3 + integration/dsic/dsic.go | 8 ++ integration/helpers.go | 63 ++++++++-- integration/hsic/hsic.go | 17 ++- integration/integrationutil/util.go | 3 + integration/route_test.go | 101 ++++++++++++++- integration/scenario.go | 33 ++++- integration/scenario_test.go | 2 + 93 files changed, 1262 insertions(+), 155 deletions(-) diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 9969f7c6..af28ce9f 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -73,6 +73,7 @@ func mockOIDC() error { } var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) if err != nil { return fmt.Errorf("unmarshalling users: %w", err) @@ -137,6 +138,7 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { log.Info().Msgf("Request: %+v", r) h.ServeHTTP(w, r) + if r.Response != nil { log.Info().Msgf("Response: %+v", r.Response) } diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index 2aaebcfa..f3921a64 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -29,13 +29,16 @@ func init() { if err := setPolicy.MarkFlagRequired("file"); err != nil { log.Fatal().Err(err).Msg("") } + setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") policyCmd.AddCommand(setPolicy) checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") + if err := checkPolicy.MarkFlagRequired("file"); err != nil { log.Fatal().Err(err).Msg("") } + policyCmd.AddCommand(checkPolicy) } diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index d7cdabb6..d67c2df8 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -80,6 +80,7 @@ func initConfig() { Repository: "headscale", TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }), } + res, err := latest.Check(githubTag, versionInfo.Version) if err == nil && res.Outdated { //nolint @@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool { return true } } + return false } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index 9a816c78..6e4bdd02 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -23,6 +23,7 @@ func usernameAndIDFlag(cmd *cobra.Command) { // If both are empty, it will exit the program with an error. func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { username, _ := cmd.Flags().GetString("name") + identifier, _ := cmd.Flags().GetInt64("identifier") if username == "" && identifier < 0 { err := errors.New("--name or --identifier flag is required") diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go index 7c5b5214..e0268fd8 100644 --- a/cmd/hi/cleanup.go +++ b/cmd/hi/cleanup.go @@ -69,8 +69,10 @@ func killTestContainers(ctx context.Context) error { } removed := 0 + for _, cont := range containers { shouldRemove := false + for _, name := range cont.Names { if strings.Contains(name, "headscale-test-suite") || strings.Contains(name, "hs-") || @@ -259,8 +261,10 @@ func cleanOldImages(ctx context.Context) error { } removed := 0 + for _, img := range images { shouldRemove := false + for _, tag := range img.RepoTags { if strings.Contains(tag, "hs-") || strings.Contains(tag, "headscale-integration") || @@ -302,6 +306,7 @@ func cleanCacheVolume(ctx context.Context) error { defer cli.Close() volumeName := "hs-integration-go-cache" + err = cli.VolumeRemove(ctx, volumeName, true) if err != nil { if errdefs.IsNotFound(err) { diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go index a6b94b25..3ad70173 100644 --- a/cmd/hi/docker.go +++ b/cmd/hi/docker.go @@ -60,6 +60,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if config.Verbose { log.Printf("Running pre-test cleanup...") } + if err := cleanupBeforeTest(ctx); err != nil && config.Verbose { log.Printf("Warning: pre-test cleanup failed: %v", err) } @@ -95,13 +96,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { // Start stats collection for container resource monitoring (if enabled) var statsCollector *StatsCollector + if config.Stats { var err error + statsCollector, err = NewStatsCollector() if err != nil { if config.Verbose { log.Printf("Warning: failed to create stats collector: %v", err) } + statsCollector = nil } @@ -140,6 +144,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error { if len(violations) > 0 { log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") log.Printf("=================================") + for _, violation := range violations { log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) @@ -347,6 +352,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC maxWaitTime := 10 * time.Second checkInterval := 500 * time.Millisecond timeout := time.After(maxWaitTime) + ticker := time.NewTicker(checkInterval) defer ticker.Stop() @@ -356,6 +362,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction") } + return nil case <-ticker.C: allFinalized := true @@ -366,12 +373,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err) } + continue } // Check if container is in a final state if !isContainerFinalized(inspect.State) { allFinalized = false + if verbose { log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) } @@ -384,6 +393,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC if verbose { log.Printf("All test containers finalized, ready for artifact extraction") } + return nil } } @@ -403,10 +413,12 @@ func findProjectRoot(startPath string) string { if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { return current } + parent := filepath.Dir(current) if parent == current { return startPath } + current = parent } } @@ -416,6 +428,7 @@ func boolToInt(b bool) int { if b { return 1 } + return 0 } @@ -435,6 +448,7 @@ func createDockerClient() (*client.Client, error) { } var clientOpts []client.Opt + clientOpts = append(clientOpts, client.WithAPIVersionNegotiation()) if contextInfo != nil { @@ -444,6 +458,7 @@ func createDockerClient() (*client.Client, error) { if runConfig.Verbose { log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) } + clientOpts = append(clientOpts, client.WithHost(host)) } } @@ -460,6 +475,7 @@ func createDockerClient() (*client.Client, error) { // getCurrentDockerContext retrieves the current Docker context information. func getCurrentDockerContext() (*DockerContext, error) { cmd := exec.Command("docker", "context", "inspect") + output, err := cmd.Output() if err != nil { return nil, fmt.Errorf("failed to get docker context: %w", err) @@ -491,6 +507,7 @@ func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageNa if client.IsErrNotFound(err) { return false, nil } + return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err) } @@ -509,6 +526,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if verbose { log.Printf("Image %s is available locally", imageName) } + return nil } @@ -533,6 +551,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str if err != nil { return fmt.Errorf("failed to read pull output: %w", err) } + log.Printf("Image %s pulled successfully", imageName) } @@ -547,9 +566,11 @@ func listControlFiles(logsDir string) { return } - var logFiles []string - var dataFiles []string - var dataDirs []string + var ( + logFiles []string + dataFiles []string + dataDirs []string + ) for _, entry := range entries { name := entry.Name() @@ -578,6 +599,7 @@ func listControlFiles(logsDir string) { if len(logFiles) > 0 { log.Printf("Headscale logs:") + for _, file := range logFiles { log.Printf(" %s", file) } @@ -585,9 +607,11 @@ func listControlFiles(logsDir string) { if len(dataFiles) > 0 || len(dataDirs) > 0 { log.Printf("Headscale data:") + for _, file := range dataFiles { log.Printf(" %s", file) } + for _, dir := range dataDirs { log.Printf(" %s/", dir) } @@ -612,6 +636,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose) extractedCount := 0 + for _, cont := range currentTestContainers { // Extract container logs and tar files if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { @@ -622,6 +647,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi if verbose { log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12]) } + extractedCount++ } } @@ -645,11 +671,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st // Find the test container to get its run ID label var runID string + for _, cont := range containers { if cont.ID == testContainerID { if cont.Labels != nil { runID = cont.Labels["hi.run-id"] } + break } } diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go index 8af6051f..8ebda159 100644 --- a/cmd/hi/doctor.go +++ b/cmd/hi/doctor.go @@ -266,6 +266,7 @@ func checkGoInstallation() DoctorResult { } cmd := exec.Command("go", "version") + output, err := cmd.Output() if err != nil { return DoctorResult{ @@ -287,6 +288,7 @@ func checkGoInstallation() DoctorResult { // checkGitRepository verifies we're in a git repository. func checkGitRepository() DoctorResult { cmd := exec.Command("git", "rev-parse", "--git-dir") + err := cmd.Run() if err != nil { return DoctorResult{ @@ -316,6 +318,7 @@ func checkRequiredFiles() DoctorResult { } var missingFiles []string + for _, file := range requiredFiles { cmd := exec.Command("test", "-e", file) if err := cmd.Run(); err != nil { @@ -350,6 +353,7 @@ func displayDoctorResults(results []DoctorResult) { for _, result := range results { var icon string + switch result.Status { case "PASS": icon = "✅" diff --git a/cmd/hi/main.go b/cmd/hi/main.go index baecc6f3..0c9adc30 100644 --- a/cmd/hi/main.go +++ b/cmd/hi/main.go @@ -82,9 +82,11 @@ func cleanAll(ctx context.Context) error { if err := killTestContainers(ctx); err != nil { return err } + if err := pruneDockerNetworks(ctx); err != nil { return err } + if err := cleanOldImages(ctx); err != nil { return err } diff --git a/cmd/hi/run.go b/cmd/hi/run.go index 1694399d..e6c52634 100644 --- a/cmd/hi/run.go +++ b/cmd/hi/run.go @@ -48,6 +48,7 @@ func runIntegrationTest(env *command.Env) error { if runConfig.Verbose { log.Printf("Running pre-flight system checks...") } + if err := runDoctorCheck(env.Context()); err != nil { return fmt.Errorf("pre-flight checks failed: %w", err) } @@ -94,8 +95,10 @@ func detectGoVersion() string { // splitLines splits a string into lines without using strings.Split. func splitLines(s string) []string { - var lines []string - var current string + var ( + lines []string + current string + ) for _, char := range s { if char == '\n' { diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go index c1bb9cfe..1c17df84 100644 --- a/cmd/hi/stats.go +++ b/cmd/hi/stats.go @@ -71,10 +71,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver // Start monitoring existing containers sc.wg.Add(1) + go sc.monitorExistingContainers(ctx, runID, verbose) // Start Docker events monitoring for new containers sc.wg.Add(1) + go sc.monitorDockerEvents(ctx, runID, verbose) if verbose { @@ -88,10 +90,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver func (sc *StatsCollector) StopCollection() { // Check if already stopped without holding lock sc.mutex.RLock() + if !sc.collectionStarted { sc.mutex.RUnlock() return } + sc.mutex.RUnlock() // Signal stop to all goroutines @@ -115,6 +119,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s if verbose { log.Printf("Failed to list existing containers: %v", err) } + return } @@ -168,6 +173,7 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, if verbose { log.Printf("Error in Docker events stream: %v", err) } + return } } @@ -214,6 +220,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI } sc.wg.Add(1) + go sc.collectStatsForContainer(ctx, containerID, verbose) } @@ -227,11 +234,13 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe if verbose { log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) } + return } defer statsResponse.Body.Close() decoder := json.NewDecoder(statsResponse.Body) + var prevStats *container.Stats for { @@ -247,6 +256,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe if err.Error() != "EOF" && verbose { log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) } + return } @@ -262,8 +272,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe // Store the sample (skip first sample since CPU calculation needs previous stats) if prevStats != nil { // Get container stats reference without holding the main mutex - var containerStats *ContainerStats - var exists bool + var ( + containerStats *ContainerStats + exists bool + ) sc.mutex.RLock() containerStats, exists = sc.containers[containerID] @@ -332,10 +344,12 @@ type StatsSummary struct { func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { // Take snapshot of container references without holding main lock long sc.mutex.RLock() + containerRefs := make([]*ContainerStats, 0, len(sc.containers)) for _, containerStats := range sc.containers { containerRefs = append(containerRefs, containerStats) } + sc.mutex.RUnlock() summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) @@ -393,9 +407,11 @@ func calculateStatsSummary(values []float64) StatsSummary { if value < min { min = value } + if value > max { max = value } + sum += value } @@ -435,6 +451,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo } summaries := sc.GetSummary() + var violations []MemoryViolation for _, summary := range summaries { diff --git a/cmd/mapresponses/main.go b/cmd/mapresponses/main.go index 5d7ad07d..af35bc48 100644 --- a/cmd/mapresponses/main.go +++ b/cmd/mapresponses/main.go @@ -2,6 +2,7 @@ package main import ( "encoding/json" + "errors" "fmt" "os" @@ -40,7 +41,7 @@ func main() { // runIntegrationTest executes the integration test workflow. func runOnline(env *command.Env) error { if mapConfig.Directory == "" { - return fmt.Errorf("directory is required") + return errors.New("directory is required") } resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) @@ -57,5 +58,6 @@ func runOnline(env *command.Env) error { os.Stderr.Write(out) os.Stderr.Write([]byte("\n")) + return nil } diff --git a/hscontrol/app.go b/hscontrol/app.go index aa011503..8ce1066f 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -142,6 +142,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { if !ok { log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") + return } @@ -157,10 +158,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { app.ephemeralGC = ephemeralGC var authProvider AuthProvider + authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) defer cancel() + oidcProvider, err := NewAuthProviderOIDC( ctx, &app, @@ -177,6 +180,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { authProvider = oidcProvider } } + app.authProvider = authProvider if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS @@ -251,9 +255,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { lastExpiryCheck := time.Unix(0, 0) derpTickerChan := make(<-chan time.Time) + if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 { derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) defer derpTicker.Stop() + derpTickerChan = derpTicker.C } @@ -271,8 +277,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { return case <-expireTicker.C: - var expiredNodeChanges []change.Change - var changed bool + var ( + expiredNodeChanges []change.Change + changed bool + ) lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) @@ -287,11 +295,13 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { case <-derpTickerChan: log.Info().Msg("Fetching DERPMap updates") + derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { derpMap, err := derp.GetDERPMap(h.cfg.DERP) if err != nil { return nil, err } + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { region, _ := h.DERPServer.GenerateRegion() derpMap.Regions[region.RegionID] = ®ion @@ -303,6 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { log.Error().Err(err).Msg("failed to build new DERPMap, retrying later") continue } + h.state.SetDERPMap(derpMap) h.Change(change.DERPMap()) @@ -311,6 +322,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) { if !ok { continue } + h.cfg.TailcfgDNSConfig.ExtraRecords = records h.Change(change.ExtraRecords()) @@ -390,6 +402,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler writeUnauthorized := func(statusCode int) { writer.WriteHeader(statusCode) + if _, err := writer.Write([]byte("Unauthorized")); err != nil { log.Error().Err(err).Msg("writing HTTP response failed") } @@ -486,6 +499,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { // Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { var err error + capver.CanOldCodeBeCleanedUp() if profilingEnabled { @@ -512,6 +526,7 @@ func (h *Headscale) Serve() error { Msg("Clients with a lower minimum version will be rejected") h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() defer h.mapBatcher.Close() @@ -545,6 +560,7 @@ func (h *Headscale) Serve() error { // around between restarts, they will reconnect and the GC will // be cancelled. go h.ephemeralGC.Start() + ephmNodes := h.state.ListEphemeralNodes() for _, node := range ephmNodes.All() { h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) @@ -555,7 +571,9 @@ func (h *Headscale) Serve() error { if err != nil { return fmt.Errorf("setting up extrarecord manager: %w", err) } + h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() + go h.extraRecordMan.Run() defer h.extraRecordMan.Close() } @@ -564,6 +582,7 @@ func (h *Headscale) Serve() error { // records updates scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) defer scheduleCancel() + go h.scheduledTasks(scheduleCtx) if zl.GlobalLevel() == zl.TraceLevel { @@ -751,7 +770,6 @@ func (h *Headscale) Serve() error { log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)") } - var tailsqlContext context.Context if tailsqlEnabled { if h.cfg.Database.Type != types.DatabaseSqlite { @@ -863,6 +881,7 @@ func (h *Headscale) Serve() error { // Close state connections info("closing state and database") + err = h.state.Close() if err != nil { log.Error().Err(err).Msg("failed to close state") diff --git a/hscontrol/auth.go b/hscontrol/auth.go index aa7088d7..c5fa91c2 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -51,6 +51,7 @@ func (h *Headscale) handleRegister( if err != nil { return nil, fmt.Errorf("handling logout: %w", err) } + if resp != nil { return resp, nil } @@ -131,7 +132,7 @@ func (h *Headscale) handleRegister( } // handleLogout checks if the [tailcfg.RegisterRequest] is a -// logout attempt from a node. If the node is not attempting to +// logout attempt from a node. If the node is not attempting to. func (h *Headscale) handleLogout( node types.NodeView, req tailcfg.RegisterRequest, @@ -158,6 +159,7 @@ func (h *Headscale) handleLogout( Interface("reg.req", req). Bool("unexpected", true). Msg("Node key expired, forcing re-authentication") + return &tailcfg.RegisterResponse{ NodeKeyExpired: true, MachineAuthorized: false, @@ -277,6 +279,7 @@ func (h *Headscale) waitForFollowup( // registration is expired in the cache, instruct the client to try a new registration return h.reqToNewRegisterResponse(req, machineKey) } + return nodeToRegisterResponse(node.View()), nil } } @@ -342,6 +345,7 @@ func (h *Headscale) handleRegisterWithAuthKey( if errors.Is(err, gorm.ErrRecordNotFound) { return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) } + if perr, ok := errors.AsType[types.PAKError](err); ok { return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } @@ -432,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive( Str("generated.hostname", hostname). Msg("Received registration request with empty hostname, generated default") } + hostinfo.Hostname = hostname nodeToRegister := types.NewRegisterNode( diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go index 1677642f..8a012ff6 100644 --- a/hscontrol/auth_test.go +++ b/hscontrol/auth_test.go @@ -2,6 +2,7 @@ package hscontrol import ( "context" + "errors" "fmt" "net/url" "strings" @@ -16,14 +17,14 @@ import ( "tailscale.com/types/key" ) -// Interactive step type constants +// Interactive step type constants. const ( stepTypeInitialRequest = "initial_request" stepTypeAuthCompletion = "auth_completion" stepTypeFollowupRequest = "followup_request" ) -// interactiveStep defines a step in the interactive authentication workflow +// interactiveStep defines a step in the interactive authentication workflow. type interactiveStep struct { stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest expectAuthURL bool @@ -75,6 +76,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -129,6 +131,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -163,6 +166,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both nodes exist node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.True(t, found2) assert.Equal(t, "reusable-node-1", node1.Hostname()) @@ -196,6 +200,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) if err != nil { return "", err @@ -227,6 +232,7 @@ func TestAuthenticationFlows(t *testing.T) { // First node should exist, second should not _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) _, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) assert.False(t, found2) }, @@ -272,6 +278,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -391,6 +398,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -400,8 +408,10 @@ func TestAuthenticationFlows(t *testing.T) { // Wait for node to be available in NodeStore with debug info var attemptCount int + require.EventuallyWithT(t, func(c *assert.CollectT) { attemptCount++ + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) if assert.True(c, found, "node should be available in NodeStore") { t.Logf("Node found in NodeStore after %d attempts", attemptCount) @@ -451,6 +461,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -500,6 +511,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -549,25 +561,31 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err } // Wait for node to be available in NodeStore - var node types.NodeView - var found bool + var ( + node types.NodeView + found bool + ) + require.EventuallyWithT(t, func(c *assert.CollectT) { node, found = app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(c, found, "node should be available in NodeStore") }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + if !found { - return "", fmt.Errorf("node not found after setup") + return "", errors.New("node not found after setup") } // Expire the node expiredTime := time.Now().Add(-1 * time.Hour) _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + return "", err }, request: func(_ string) tailcfg.RegisterRequest { @@ -610,6 +628,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -673,6 +692,7 @@ func TestAuthenticationFlows(t *testing.T) { // and handleRegister will receive the value when it starts waiting go func() { user := app.state.CreateUserForTest("followup-user") + node := app.state.CreateNodeForTest(user, "followup-success-node") registered <- node }() @@ -782,6 +802,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -821,6 +842,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -865,6 +887,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -898,6 +921,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -922,6 +946,7 @@ func TestAuthenticationFlows(t *testing.T) { node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found) assert.Equal(t, "tagged-pak-node", node.Hostname()) + if node.AuthKey().Valid() { assert.NotEmpty(t, node.AuthKey().Tags()) } @@ -1031,6 +1056,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1047,6 +1073,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(newAuthKey string) tailcfg.RegisterRequest { @@ -1099,6 +1126,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1161,6 +1189,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1177,6 +1206,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pakRotation.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1226,6 +1256,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1265,6 +1296,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1429,6 +1461,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify custom hostinfo was preserved through interactive workflow node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be found after interactive registration") + if found { assert.Equal(t, "custom-interactive-node", node.Hostname()) assert.Equal(t, "linux", node.Hostinfo().OS()) @@ -1455,6 +1488,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1520,6 +1554,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify registration ID was properly generated and used node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered after interactive workflow") + if found { assert.Equal(t, "registration-id-test-node", node.Hostname()) assert.Equal(t, "test-os", node.Hostinfo().OS()) @@ -1535,6 +1570,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1577,6 +1613,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak.Key, nil }, request: func(authKey string) tailcfg.RegisterRequest { @@ -1632,6 +1669,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1648,6 +1686,7 @@ func TestAuthenticationFlows(t *testing.T) { if err != nil { return "", err } + return pak2.Key, nil }, request: func(user2AuthKey string) tailcfg.RegisterRequest { @@ -1712,6 +1751,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -1838,6 +1878,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) if err != nil { return "", err @@ -1932,6 +1973,7 @@ func TestAuthenticationFlows(t *testing.T) { }, Expiry: time.Now().Add(24 * time.Hour), } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) if err != nil { return "", err @@ -2097,6 +2139,7 @@ func TestAuthenticationFlows(t *testing.T) { // Collect results - at least one should succeed successCount := 0 + for range numConcurrent { select { case err := <-results: @@ -2217,6 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) { // Should handle nil hostinfo gracefully node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered despite nil hostinfo") + if found { // Should have some default hostname or handle nil gracefully hostname := node.Hostname() @@ -2315,12 +2359,14 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL assert.Contains(t, authURL2, "/register/") // Both should have different registration IDs regID1, err1 := extractRegistrationIDFromAuthURL(authURL1) regID2, err2 := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err1) require.NoError(t, err2) assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") @@ -2328,6 +2374,7 @@ func TestAuthenticationFlows(t *testing.T) { // Both cache entries should exist simultaneously _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first registration cache entry should exist") assert.True(t, found2, "second registration cache entry should exist") @@ -2371,6 +2418,7 @@ func TestAuthenticationFlows(t *testing.T) { resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) require.NoError(t, err) + authURL2 := resp2.AuthURL regID2, err := extractRegistrationIDFromAuthURL(authURL2) require.NoError(t, err) @@ -2378,6 +2426,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify both exist _, found1 := app.state.GetRegistrationCacheEntry(regID1) _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first cache entry should exist") assert.True(t, found2, "second cache entry should exist") @@ -2403,6 +2452,7 @@ func TestAuthenticationFlows(t *testing.T) { errorChan <- err return } + responseChan <- resp }() @@ -2430,6 +2480,7 @@ func TestAuthenticationFlows(t *testing.T) { // Verify the node was created with the second registration's data node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) assert.True(t, found, "node should be registered") + if found { assert.Equal(t, "pending-node-2", node.Hostname()) assert.Equal(t, "second-registration-user", node.User().Name()) @@ -2463,8 +2514,10 @@ func TestAuthenticationFlows(t *testing.T) { // Set up context with timeout for followup tests ctx := context.Background() + if req.Followup != "" { var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) defer cancel() } @@ -2516,7 +2569,7 @@ func TestAuthenticationFlows(t *testing.T) { } } -// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow +// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow. func runInteractiveWorkflowTest(t *testing.T, tt struct { name string setupFunc func(*testing.T, *Headscale) (string, error) @@ -2597,6 +2650,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { errorChan <- err return } + responseChan <- resp }() @@ -2650,24 +2704,27 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct { if responseToValidate == nil { responseToValidate = initialResp } + tt.validate(t, responseToValidate, app) } } -// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL +// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL. func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { // AuthURL format: "http://localhost/register/abc123" const registerPrefix = "/register/" + idx := strings.LastIndex(authURL, registerPrefix) if idx == -1 { return "", fmt.Errorf("invalid AuthURL format: %s", authURL) } idStr := authURL[idx+len(registerPrefix):] + return types.RegistrationIDFromString(idStr) } -// validateCompleteRegistrationResponse performs comprehensive validation of a registration response +// validateCompleteRegistrationResponse performs comprehensive validation of a registration response. func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { // Basic response validation require.NotNil(t, resp, "response should not be nil") @@ -2681,7 +2738,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe // Additional validation can be added here as needed } -// Simple test to validate basic node creation and lookup +// Simple test to validate basic node creation and lookup. func TestNodeStoreLookup(t *testing.T) { app := createTestApp(t) @@ -2713,8 +2770,10 @@ func TestNodeStoreLookup(t *testing.T) { // Wait for node to be available in NodeStore var node types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + node, found = app.state.GetNodeByNodeKey(nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") @@ -2783,8 +2842,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Get the node ID var registeredNode types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { var found bool + registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public()) assert.True(c, found, "Node should be found in NodeStore") }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available") @@ -2796,6 +2857,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify initial state: user1 has 2 nodes, user2 has 2 nodes user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID)) user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID)) + require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially") require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially") @@ -2876,6 +2938,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Verify new nodes were created for user1 with the same machine keys t.Logf("Verifying new nodes created for user1 from user2's machine keys...") + for i := 2; i < 4; i++ { node := nodes[i] // Should be able to find a node with user1 and this machine key (the new one) @@ -2899,7 +2962,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { // Expected behavior: // - User1's original node should STILL EXIST (expired) // - User2 should get a NEW node created (NOT transfer) -// - Both nodes share the same machine key (same physical device) +// - Both nodes share the same machine key (same physical device). func TestWebFlowReauthDifferentUser(t *testing.T) { machineKey := key.NewMachine() nodeKey1 := key.NewNode() @@ -3043,6 +3106,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { // Count nodes per user user1Nodes := 0 user2Nodes := 0 + for i := 0; i < allNodesSlice.Len(); i++ { n := allNodesSlice.At(i) if n.UserID().Get() == user1.ID { @@ -3060,7 +3124,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) { }) } -// Helper function to create test app +// Helper function to create test app. func createTestApp(t *testing.T) *Headscale { t.Helper() @@ -3147,6 +3211,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { } t.Log("Step 1: Initial registration with pre-auth key") + initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public()) require.NoError(t, err, "initial registration should succeed") require.NotNil(t, initialResp) @@ -3172,6 +3237,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // - System reboots // The Tailscale client persists the pre-auth key in its state and sends it on every registration t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") + restartReq := tailcfg.RegisterRequest{ Auth: &tailcfg.RegisterResponseAuth{ AuthKey: pakNew.Key, // Same key, now marked as Used=true @@ -3189,9 +3255,11 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { // This is the assertion that currently FAILS in v0.27.0 assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + if err != nil { t.Logf("Error received (this is the bug): %v", err) t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially") + return // Stop here to show the bug clearly } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 05a4c7c8..988675b9 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -155,6 +155,7 @@ AND auth_key_id NOT IN ( nodeRoutes := map[uint64][]netip.Prefix{} var routes []types.Route + err = tx.Find(&routes).Error if err != nil { return fmt.Errorf("fetching routes: %w", err) @@ -255,9 +256,11 @@ AND auth_key_id NOT IN ( // Check if routes table exists and drop it (should have been migrated already) var routesExists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists) if err == nil && routesExists { log.Info().Msg("Dropping leftover routes table") + if err := tx.Exec("DROP TABLE routes").Error; err != nil { return fmt.Errorf("dropping routes table: %w", err) } @@ -280,6 +283,7 @@ AND auth_key_id NOT IN ( for _, table := range tablesToRename { // Check if table exists before renaming var exists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists) if err != nil { return fmt.Errorf("checking if table %s exists: %w", table, err) @@ -761,6 +765,7 @@ AND auth_key_id NOT IN ( // or else it blocks... sqlConn.SetMaxIdleConns(maxIdleConns) + sqlConn.SetMaxOpenConns(maxOpenConns) defer sqlConn.SetMaxIdleConns(1) defer sqlConn.SetMaxOpenConns(1) diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 3cd0d14e..47a527b9 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -44,6 +44,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) { // Verify api_keys data preservation var apiKeyCount int + err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error require.NoError(t, err) assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema") @@ -186,6 +187,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { func requireConstraintFailed(t *testing.T, err error) { t.Helper() require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) } @@ -401,6 +403,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { // skip already-applied migrations and only run new ones. func TestSQLiteAllTestdataMigrations(t *testing.T) { t.Parallel() + schemas, err := os.ReadDir("testdata/sqlite") require.NoError(t, err) diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go index d118b7fd..2ad50885 100644 --- a/hscontrol/db/ephemeral_garbage_collector_test.go +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Basic deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex - var deletionWg sync.WaitGroup + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + deletionWg sync.WaitGroup + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionWg.Done() } @@ -43,10 +47,13 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { go gc.Start() // Schedule several nodes for deletion with short expiry - const expiry = fifty - const numNodes = 100 + const ( + expiry = fifty + numNodes = 100 + ) // Set up wait group for expected deletions + deletionWg.Add(numNodes) for i := 1; i <= numNodes; i++ { @@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { // and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // Start GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() - const shortExpiry = fifty - const longExpiry = 1 * time.Hour + const ( + shortExpiry = fifty + longExpiry = 1 * time.Hour + ) nodeID := types.NodeID(1) @@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) { // and verifies that the node is deleted only once. func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionNotifier <- nodeID } // Start the GC gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() defer gc.Close() nodeID := types.NodeID(1) + const expiry = fifty // Schedule node for deletion @@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { // It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deletionNotifier := make(chan types.NodeID, 1) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() deletionNotifier <- nodeID @@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) + nodeDeleted := make(chan struct{}) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() close(nodeDeleted) // Signal that deletion happened } @@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Use a WaitGroup to ensure the GC has started var startWg sync.WaitGroup startWg.Add(1) + go func() { startWg.Done() // Signal that the goroutine has started gc.Start() }() + startWg.Wait() // Wait for the GC to start // Close GC right away @@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { // Check no node was deleted deleteMutex.Lock() + nodesDeleted := len(deletedIDs) + deleteMutex.Unlock() assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") @@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { t.Logf("Initial number of goroutines: %d", initialGoroutines) // Deletion tracking mechanism - var deletedIDs []types.NodeID - var deleteMutex sync.Mutex + var ( + deletedIDs []types.NodeID + deleteMutex sync.Mutex + ) deleteFunc := func(nodeID types.NodeID) { deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() } @@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { go gc.Start() // Number of concurrent scheduling goroutines - const numSchedulers = 10 - const nodesPerScheduler = 50 + const ( + numSchedulers = 10 + nodesPerScheduler = 50 + ) const closeAfterNodes = 25 // Close GC after this many nodes per scheduler diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go index 73895876..7827e002 100644 --- a/hscontrol/db/ip_test.go +++ b/hscontrol/db/ip_test.go @@ -483,6 +483,7 @@ func TestBackfillIPAddresses(t *testing.T) { func TestIPAllocatorNextNoReservedIPs(t *testing.T) { db, err := newSQLiteTestDB() require.NoError(t, err) + defer db.Close() alloc, err := NewIPAllocator( diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 3887350b..7c818a75 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -206,6 +206,7 @@ func SetTags( slices.Sort(tags) tags = slices.Compact(tags) + b, err := json.Marshal(tags) if err != nil { return err @@ -378,6 +379,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n if ipv4 == nil { ipv4 = oldNode.IPv4 } + if ipv6 == nil { ipv6 = oldNode.IPv6 } @@ -406,6 +408,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n node.IPv6 = ipv6 var err error + node.Hostname, err = util.NormaliseHostname(node.Hostname) if err != nil { newHostname := util.InvalidString() @@ -693,9 +696,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname . } var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { var err error + registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) + return err }) if err != nil { diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index e82cdb62..3696aa2e 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -497,6 +497,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes1) == 0 { expectedRoutes1 = nil } + if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -508,6 +509,7 @@ func TestAutoApproveRoutes(t *testing.T) { if len(expectedRoutes2) == 0 { expectedRoutes2 = nil } + if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) } @@ -745,12 +747,15 @@ func TestNodeNaming(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) if err != nil { return err } + _, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil) _, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil) + return err }) require.NoError(t, err) @@ -999,6 +1004,7 @@ func TestListPeers(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err @@ -1084,6 +1090,7 @@ func TestListNodes(t *testing.T) { if err != nil { return err } + _, err = RegisterNodeForTest(tx, node2, nil, nil) return err diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go index d27977a4..23cb4b50 100644 --- a/hscontrol/db/sqliteconfig/config.go +++ b/hscontrol/db/sqliteconfig/config.go @@ -372,18 +372,23 @@ func (c *Config) ToURL() (string, error) { if c.BusyTimeout > 0 { pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout)) } + if c.JournalMode != "" { pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode)) } + if c.AutoVacuum != "" { pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum)) } + if c.WALAutocheckpoint >= 0 { pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint)) } + if c.Synchronous != "" { pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous)) } + if c.ForeignKeys { pragmas = append(pragmas, "foreign_keys=ON") } diff --git a/hscontrol/db/sqliteconfig/config_test.go b/hscontrol/db/sqliteconfig/config_test.go index 66955bb9..7829d9e9 100644 --- a/hscontrol/db/sqliteconfig/config_test.go +++ b/hscontrol/db/sqliteconfig/config_test.go @@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) { t.Errorf("Config.ToURL() error = %v", err) return } + if got != tt.want { t.Errorf("Config.ToURL() = %q, want %q", got, tt.want) } @@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) { Path: "", BusyTimeout: -1, } + _, err := config.ToURL() if err == nil { t.Error("Config.ToURL() with invalid config should return error") diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go index bb54ea1e..b411daeb 100644 --- a/hscontrol/db/sqliteconfig/integration_test.go +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -109,7 +109,9 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) { for pragma, expectedValue := range tt.expected { t.Run("pragma_"+pragma, func(t *testing.T) { var actualValue any + query := "PRAGMA " + pragma + err := db.QueryRow(query).Scan(&actualValue) if err != nil { t.Fatalf("Failed to query %s: %v", query, err) @@ -249,6 +251,7 @@ func TestJournalModeValidation(t *testing.T) { defer db.Close() var actualMode string + err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) if err != nil { t.Fatalf("Failed to query journal_mode: %v", err) diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go index 46bd154f..102c0e9c 100644 --- a/hscontrol/db/text_serialiser.go +++ b/hscontrol/db/text_serialiser.go @@ -42,6 +42,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect if dbValue != nil { var bytes []byte + switch v := dbValue.(type) { case []byte: bytes = v @@ -55,6 +56,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect maybeInstantiatePtr(fieldValue) f := fieldValue.MethodByName("UnmarshalText") args := []reflect.Value{reflect.ValueOf(bytes)} + ret := f.Call(args) if !ret[0].IsNil() { return decodingError(field.Name, ret[0].Interface().(error)) @@ -89,6 +91,7 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) { return nil, nil } + b, err := v.MarshalText() if err != nil { return nil, err diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 6aff9ed1..650dbd49 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -88,10 +88,12 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user") // not exist or if another User exists with the new name. func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } + if err = util.ValidateHostname(newName); err != nil { return err } diff --git a/hscontrol/debug.go b/hscontrol/debug.go index 629b7be1..4fdcac11 100644 --- a/hscontrol/debug.go +++ b/hscontrol/debug.go @@ -25,17 +25,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { overview := h.state.DebugOverviewJSON() + overviewJSON, err := json.MarshalIndent(overview, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(overviewJSON) } else { // Default to text/plain for backward compatibility overview := h.state.DebugOverview() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(overview)) @@ -45,11 +48,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Configuration endpoint debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { config := h.state.DebugConfig() + configJSON, err := json.MarshalIndent(config, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(configJSON) @@ -70,6 +75,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { } else { w.Header().Set("Content-Type", "text/plain") } + w.WriteHeader(http.StatusOK) w.Write([]byte(policy)) })) @@ -81,11 +87,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + filterJSON, err := json.MarshalIndent(filter, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(filterJSON) @@ -94,11 +102,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { // SSH policies endpoint debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { sshPolicies := h.state.DebugSSHPolicies() + sshJSON, err := json.MarshalIndent(sshPolicies, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(sshJSON) @@ -112,17 +122,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { derpInfo := h.state.DebugDERPJSON() + derpJSON, err := json.MarshalIndent(derpInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(derpJSON) } else { // Default to text/plain for backward compatibility derpInfo := h.state.DebugDERPMap() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(derpInfo)) @@ -137,17 +150,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { nodeStoreNodes := h.state.DebugNodeStoreJSON() + nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(nodeStoreJSON) } else { // Default to text/plain for backward compatibility nodeStoreInfo := h.state.DebugNodeStore() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(nodeStoreInfo)) @@ -157,11 +173,13 @@ func (h *Headscale) debugHTTPServer() *http.Server { // Registration cache endpoint debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cacheInfo := h.state.DebugRegistrationCache() + cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(cacheJSON) @@ -175,17 +193,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { routes := h.state.DebugRoutes() + routesJSON, err := json.MarshalIndent(routes, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(routesJSON) } else { // Default to text/plain for backward compatibility routes := h.state.DebugRoutesString() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(routes)) @@ -200,17 +221,20 @@ func (h *Headscale) debugHTTPServer() *http.Server { if wantsJSON { policyManagerInfo := h.state.DebugPolicyManagerJSON() + policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") if err != nil { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(policyManagerJSON) } else { // Default to text/plain for backward compatibility policyManagerInfo := h.state.DebugPolicyManager() + w.Header().Set("Content-Type", "text/plain") w.WriteHeader(http.StatusOK) w.Write([]byte(policyManagerInfo)) @@ -227,6 +251,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { if res == nil { w.WriteHeader(http.StatusOK) w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + return } @@ -235,6 +260,7 @@ func (h *Headscale) debugHTTPServer() *http.Server { httpError(w, err) return } + w.Header().Set("Content-Type", "application/json") w.WriteHeader(http.StatusOK) w.Write(resJSON) @@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string { activeConnections: info.ActiveConnections, }) totalNodes++ + if info.Connected { connectedCount++ } @@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string { activeConnections: 0, }) totalNodes++ + if connected { connectedCount++ } + return true }) } @@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { ActiveConnections: 0, } info.TotalNodes++ + return true }) } diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 42d74abe..f3807e21 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) { for id := range dm.Regions { ids = append(ids, id) } + slices.Sort(ids) for _, id := range ids { @@ -164,12 +165,14 @@ func derpRandom() *rand.Rand { rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) derpRandomInst = rnd }) + return derpRandomInst } func resetDerpRandomForTesting() { derpRandomMu.Lock() defer derpRandomMu.Unlock() + derpRandomOnce = sync.Once{} derpRandomInst = nil } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go index 91d605a6..445c1044 100644 --- a/hscontrol/derp/derp_test.go +++ b/hscontrol/derp/derp_test.go @@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { viper.Set("dns.base_domain", tt.baseDomain) + defer viper.Reset() + resetDerpRandomForTesting() testMap := tt.derpMap.View().AsStruct() diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index c736da28..bf292d03 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -74,9 +74,11 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { if err != nil { return tailcfg.DERPRegion{}, err } - var host string - var port int - var portStr string + var ( + host string + port int + portStr string + ) // Extract hostname and port from URL host, portStr, err = net.SplitHostPort(serverURL.Host) @@ -205,6 +207,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques return } defer websocketConn.Close(websocket.StatusInternalError, "closing") + if websocketConn.Subprotocol() != "derp" { websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") @@ -309,6 +312,7 @@ func DERPBootstrapDNSHandler( resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) defer cancel() var resolver net.Resolver + for _, region := range derpMap.Regions().All() { for _, node := range region.Nodes().All() { // we don't care if we override some nodes addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) @@ -320,6 +324,7 @@ func DERPBootstrapDNSHandler( continue } + dnsEntries[node.HostName()] = addrs } } diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go index 82b3078b..5d16c675 100644 --- a/hscontrol/dns/extrarecords.go +++ b/hscontrol/dns/extrarecords.go @@ -85,12 +85,15 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher event channel closing") return } + switch event.Op { case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") + if event.Name != e.path { continue } + e.updateRecords() // If a file is removed or renamed, fsnotify will loose track of it @@ -123,6 +126,7 @@ func (e *ExtraRecordsMan) Run() { log.Error().Caller().Msgf("file watcher error channel closing") return } + log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) } } @@ -165,6 +169,7 @@ func (e *ExtraRecordsMan) updateRecords() { e.hashes[e.path] = newHash log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) + e.updateCh <- e.records.Slice() } @@ -183,6 +188,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error } var records []tailcfg.DNSRecord + err = json.Unmarshal(b, &records) if err != nil { return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index 2aee3cb2..7ec26994 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -181,6 +181,7 @@ func (h *Headscale) HealthHandler( json.NewEncoder(writer).Encode(res) } + err := h.state.PingDB(req.Context()) if err != nil { respond(err) @@ -217,6 +218,7 @@ func (h *Headscale) VersionHandler( writer.WriteHeader(http.StatusOK) versionInfo := types.GetVersionInfo() + err := json.NewEncoder(writer).Encode(versionInfo) if err != nil { log.Error(). diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go index 1d9c2c32..918b7049 100644 --- a/hscontrol/mapper/batcher_lockfree.go +++ b/hscontrol/mapper/batcher_lockfree.go @@ -2,6 +2,7 @@ package mapper import ( "crypto/rand" + "encoding/hex" "errors" "fmt" "sync" @@ -77,6 +78,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse if err != nil { log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) } @@ -86,10 +88,11 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse case c <- initialMap: // Success case <-time.After(5 * time.Second): - log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") + log.Error().Uint64("node.id", id.Uint64()).Err(errors.New("timeout")).Msg("Initial map send timeout") log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). Msg("Initial map send timed out because channel was blocked or receiver not ready") nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to send initial map to node %d: timeout", id) } @@ -129,6 +132,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo log.Debug().Caller().Uint64("node.id", id.Uint64()). Int("active.connections", nodeConn.getActiveConnectionCount()). Msg("Node connection removed but keeping online because other connections remain") + return true // Node still has active connections } @@ -211,10 +215,12 @@ func (b *LockFreeBatcher) worker(workerID int) { // This is used for synchronous map generation. if w.resultCh != nil { var result workResult + if nc, exists := b.nodes.Load(w.nodeID); exists { var err error result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + result.err = err if result.err != nil { b.workErrors.Add(1) @@ -397,6 +403,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() { } } } + return true }) @@ -449,6 +456,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { if nodeConn.hasActiveConnections() { ret.Store(id, true) } + return true }) @@ -464,6 +472,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { ret.Store(id, false) } } + return true }) @@ -518,7 +527,8 @@ type multiChannelNodeConn struct { func generateConnectionID() string { bytes := make([]byte, 8) rand.Read(bytes) - return fmt.Sprintf("%x", bytes) + + return hex.EncodeToString(bytes) } // newMultiChannelNodeConn creates a new multi-channel node connection. @@ -545,11 +555,14 @@ func (mc *multiChannelNodeConn) close() { // addConnection adds a new connection. func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { mutexWaitStart := time.Now() + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") mc.mutex.Lock() + mutexWaitDur := time.Since(mutexWaitStart) + defer mc.mutex.Unlock() mc.connections = append(mc.connections, entry) @@ -571,9 +584,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)). Int("remaining_connections", len(mc.connections)). Msg("Successfully removed connection") + return true } } + return false } @@ -607,6 +622,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { // This is not an error - the node will receive a full map when it reconnects log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + return nil // Return success instead of error } @@ -615,7 +631,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: broadcasting to all connections") var lastErr error + successCount := 0 + var failedConnections []int // Track failed connections for removal // Send to all connections @@ -626,6 +644,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { if err := conn.send(data); err != nil { lastErr = err + failedConnections = append(failedConnections, i) log.Warn().Err(err). Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). @@ -633,6 +652,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { Msg("send: connection send failed") } else { successCount++ + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). Str("conn.id", conn.id).Int("connection_index", i). Msg("send: successfully sent to connection") @@ -797,6 +817,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { Connected: connected, ActiveConnections: activeConnCount, } + return true }) @@ -811,6 +832,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { ActiveConnections: 0, } } + return true }) diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 00053892..595fb252 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -677,6 +677,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { connectedCount := 0 + for i := range allNodes { node := &allNodes[i] @@ -694,6 +695,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) { }, 5*time.Minute, 5*time.Second, "waiting for full connectivity") t.Logf("✅ All nodes achieved full connectivity!") + totalTime := time.Since(startTime) // Disconnect all nodes @@ -1309,6 +1311,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) { for range i % 3 { runtime.Gosched() // Introduce timing variability } + batcher.RemoveNode(testNode.n.ID, ch) // Yield to allow workers to process and close channels @@ -1392,6 +1395,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, reason := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1449,7 +1453,9 @@ func TestBatcherConcurrentClients(t *testing.T) { ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) @@ -1463,6 +1469,7 @@ func TestBatcherConcurrentClients(t *testing.T) { // Channel was closed, exit gracefully return } + if valid, _ := validateUpdateContent(data); valid { tracker.recordUpdate( nodeID, @@ -1495,6 +1502,7 @@ func TestBatcherConcurrentClients(t *testing.T) { for range i % 5 { runtime.Gosched() // Introduce timing variability } + churningChannelsMutex.Lock() ch, exists := churningChannels[nodeID] @@ -1879,6 +1887,7 @@ func XTestBatcherScalability(t *testing.T) { channel, tailcfg.CapabilityVersion(100), ) + connectedNodesMutex.Lock() connectedNodes[nodeID] = true @@ -2287,6 +2296,7 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 1: Connect all nodes initially t.Logf("Phase 1: Connecting all nodes...") + for i, node := range allNodes { err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) if err != nil { @@ -2303,6 +2313,7 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) t.Logf("Phase 2: Rapid disconnect all nodes...") + for i, node := range allNodes { removed := batcher.RemoveNode(node.n.ID, node.ch) t.Logf("Node %d RemoveNode result: %t", i, removed) @@ -2310,9 +2321,11 @@ func TestBatcherRapidReconnection(t *testing.T) { // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) t.Logf("Phase 3: Rapid reconnect with new channels...") + newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) for i, node := range allNodes { newChannels[i] = make(chan *tailcfg.MapResponse, 10) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) if err != nil { t.Errorf("Failed to reconnect node %d: %v", i, err) @@ -2343,11 +2356,13 @@ func TestBatcherRapidReconnection(t *testing.T) { if infoMap, ok := info.(map[string]any); ok { if connected, ok := infoMap["connected"].(bool); ok && !connected { disconnectedCount++ + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) } } } else { disconnectedCount++ + t.Logf("Node %d missing from debug info entirely", i) } @@ -2382,6 +2397,7 @@ func TestBatcherRapidReconnection(t *testing.T) { case update := <-newChannels[i]: if update != nil { receivedCount++ + t.Logf("Node %d received update successfully", i) } case <-timeout: @@ -2414,6 +2430,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 1: Connect first node with initial connection t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add node1: %v", err) @@ -2433,7 +2450,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 2: Add second connection for node1 (multi-connection scenario) t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add second connection for node1: %v", err) @@ -2444,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 3: Add third connection for node1 t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) if err != nil { t.Fatalf("Failed to add third connection for node1: %v", err) @@ -2455,6 +2476,7 @@ func TestBatcherMultiConnection(t *testing.T) { // Phase 4: Verify debug status shows correct connection count t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { Debug() map[types.NodeID]any }); ok { @@ -2462,6 +2484,7 @@ func TestBatcherMultiConnection(t *testing.T) { if info, exists := debugInfo[node1.n.ID]; exists { t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { if activeConnections, ok := infoMap["active_connections"].(int); ok { if activeConnections != 3 { @@ -2470,6 +2493,7 @@ func TestBatcherMultiConnection(t *testing.T) { t.Logf("SUCCESS: Node1 correctly shows 3 active connections") } } + if connected, ok := infoMap["connected"].(bool); ok && !connected { t.Errorf("Node1 should show as connected with 3 active connections") } diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go index b6f0b534..df0693e3 100644 --- a/hscontrol/mapper/builder.go +++ b/hscontrol/mapper/builder.go @@ -37,6 +37,7 @@ const ( // NewMapResponseBuilder creates a new builder with basic fields set. func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { now := time.Now() + return &MapResponseBuilder{ resp: &tailcfg.MapResponse{ KeepAlive: false, @@ -124,6 +125,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { b.resp.Debug = &tailcfg.Debug{ DisableLogTail: !b.mapper.cfg.LogTail.Enabled, } + return b } @@ -281,16 +283,18 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe for _, id := range removedIDs { tailscaleIDs = append(tailscaleIDs, id.NodeID()) } + b.resp.PeersRemoved = tailscaleIDs return b } -// Build finalizes the response and returns marshaled bytes +// Build finalizes the response and returns marshaled bytes. func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { if len(b.errs) > 0 { return nil, multierr.New(b.errs...) } + if debugDumpMapResponsePath != "" { writeDebugMapResponse(b.resp, b.debugType, b.nodeID) } diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 616d470f..843729c7 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -60,7 +60,6 @@ func newMapper( state *state.State, ) *mapper { // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) - return &mapper{ state: state, cfg: cfg, @@ -80,6 +79,7 @@ func generateUserProfiles( userID := user.Model().ID userMap[userID] = &user ids = append(ids, userID) + for _, peer := range peers.All() { peerUser := peer.Owner() peerUserID := peerUser.Model().ID @@ -90,6 +90,7 @@ func generateUserProfiles( slices.Sort(ids) ids = slices.Compact(ids) var profiles []tailcfg.UserProfile + for _, id := range ids { if userMap[id] != nil { profiles = append(profiles, userMap[id].TailscaleUserProfile()) @@ -306,6 +307,7 @@ func writeDebugMapResponse( perms := fs.FileMode(debugMapResponsePerm) mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) + err = os.MkdirAll(mPath, perms) if err != nil { panic(err) @@ -319,6 +321,7 @@ func writeDebugMapResponse( ) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) if err != nil { panic(err) @@ -375,6 +378,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe } var resp tailcfg.MapResponse + err = json.Unmarshal(body, &resp) if err != nil { log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name()) diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index a503c08c..4852ce04 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -98,6 +98,7 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { if m.polMan == nil { return tailcfg.FilterAllowAll, nil } + return m.polMan.Filter() } @@ -105,6 +106,7 @@ func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { if m.polMan == nil { return nil, nil } + return m.polMan.SSHPolicy(node) } @@ -112,6 +114,7 @@ func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { if m.polMan == nil { return false } + return m.polMan.NodeCanHaveTag(node, tag) } @@ -119,6 +122,7 @@ func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { if m.primary == nil { return nil } + return m.primary.PrimaryRoutes(nodeID) } @@ -126,6 +130,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ if len(peerIDs) > 0 { // Filter peers by the provided IDs var filtered types.Nodes + for _, peer := range m.peers { if slices.Contains(peerIDs, peer.ID) { filtered = append(filtered, peer) @@ -136,6 +141,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ } // Return all peers except the node itself var filtered types.Nodes + for _, peer := range m.peers { if peer.ID != nodeID { filtered = append(filtered, peer) @@ -149,6 +155,7 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { if len(nodeIDs) > 0 { // Filter nodes by the provided IDs var filtered types.Nodes + for _, node := range m.nodes { if slices.Contains(nodeIDs, node.ID) { filtered = append(filtered, node) diff --git a/hscontrol/noise.go b/hscontrol/noise.go index f0e2fefa..869fe3f3 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -243,10 +243,12 @@ func (ns *noiseServer) NoiseRegistrationHandler( registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { var resp *tailcfg.RegisterResponse + body, err := io.ReadAll(req.Body) if err != nil { return &tailcfg.RegisterRequest{}, regErr(err) } + var regReq tailcfg.RegisterRequest if err := json.Unmarshal(body, ®Req); err != nil { return ®Req, regErr(err) @@ -260,6 +262,7 @@ func (ns *noiseServer) NoiseRegistrationHandler( resp = &tailcfg.RegisterResponse{ Error: httpErr.Msg, } + return ®Req, resp } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 7013b8ed..836e8763 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -163,6 +163,7 @@ func (a *AuthProviderOIDC) RegisterHandler( for k, v := range a.cfg.ExtraParams { extras = append(extras, oauth2.SetAuthURLParam(k, v)) } + extras = append(extras, oidc.Nonce(nonce)) // Cache the registration info @@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( } stateCookieName := getCookieName("state", state) + cookieState, err := req.Cookie(stateCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) @@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( httpError(writer, err) return } + if idToken.Nonce == "" { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) return } nonceCookieName := getCookieName("nonce", idToken.Nonce) + nonce, err := req.Cookie(nonceCookieName) if err != nil { httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) return } + if idToken.Nonce != nonce.Value { httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) return @@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Fetch user information (email, groups, name, etc) from the userinfo endpoint // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo var userinfo *oidc.UserInfo + userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token)) if err != nil { util.LogErr(err, "could not get userinfo; only using claims from id token") @@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified) claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username) claims.Name = cmp.Or(userinfo2.Name, claims.Name) + claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL) if userinfo2.Groups != nil { claims.Groups = userinfo2.Groups @@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( Msgf("could not create or update user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) + _, werr := writer.Write([]byte("Could not create or update user")) if werr != nil { log.Error(). @@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( // Register the node if it does not exist. if registrationId != nil { verb := "Reauthenticated" + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) if err != nil { if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { @@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( return } + httpError(writer, err) + return } @@ -324,6 +335,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) + if _, err := writer.Write(content.Bytes()); err != nil { util.LogErr(err, "Failed to write HTTP response") } @@ -370,6 +382,7 @@ func (a *AuthProviderOIDC) getOauth2Token( if !ok { return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) } + if regInfo.Verifier != nil { exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} } @@ -516,6 +529,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( newUser bool c change.Change ) + user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) if err != nil && !errors.Is(err, db.ErrUserNotFound) { return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index afc3cf68..0c84bae0 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -21,10 +21,13 @@ func (m Match) DebugString() string { sb.WriteString("Match:\n") sb.WriteString(" Sources:\n") + for _, prefix := range m.srcs.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } + sb.WriteString(" Destinations:\n") + for _, prefix := range m.dests.Prefixes() { sb.WriteString(" " + prefix.String() + "\n") } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index f4db88a4..ee112609 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -38,8 +38,11 @@ type PolicyManager interface { // NewPolicyManager returns a new policy manager. func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { - var polMan PolicyManager - var err error + var ( + polMan PolicyManager + err error + ) + polMan, err = policyv2.NewPolicyManager(pol, users, nodes) if err != nil { return nil, err @@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ if err != nil { return nil, err } + polMans = append(polMans, pm) } diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go index 24d2865e..42942f61 100644 --- a/hscontrol/policy/policy.go +++ b/hscontrol/policy/policy.go @@ -125,6 +125,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove if !slices.Equal(sortedCurrent, newApproved) { // Log what changed var added, kept []netip.Prefix + for _, route := range newApproved { if !slices.Contains(sortedCurrent, route) { added = append(added, route) diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go index b7a758e6..21c2a66e 100644 --- a/hscontrol/policy/policy_autoapprove_test.go +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -312,8 +312,11 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { nodes := types.Nodes{&node} // Create policy manager or use nil if specified - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + if tt.name != "nil_policy_manager" { pm, err = pmf(users, nodes.ViewSlice()) assert.NoError(t, err) diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go index eb3d85b6..ee4818aa 100644 --- a/hscontrol/policy/policy_test.go +++ b/hscontrol/policy/policy_test.go @@ -32,6 +32,7 @@ func TestReduceNodes(t *testing.T) { rules []tailcfg.FilterRule node *types.Node } + tests := []struct { name string args args @@ -782,9 +783,11 @@ func TestReduceNodes(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -1031,8 +1034,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(nil, tt.nodes.ViewSlice()) require.NoError(t, err) @@ -1050,9 +1056,11 @@ func TestReduceNodesFromPolicy(t *testing.T) { for _, v := range gotViews.All() { got = append(got, v.AsStruct()) } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) t.Log("Matchers: ") + for _, m := range matchers { t.Log("\t+", m.DebugString()) } @@ -1405,13 +1413,17 @@ func TestSSHPolicyRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm PolicyManager - var err error + var ( + pm PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) if tt.expectErr { require.Error(t, err) require.Contains(t, err.Error(), tt.errorMessage) + return } @@ -1434,6 +1446,7 @@ func TestReduceRoutes(t *testing.T) { routes []netip.Prefix rules []tailcfg.FilterRule } + tests := []struct { name string args args @@ -2055,6 +2068,7 @@ func TestReduceRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( tt.args.node.View(), tt.args.routes, diff --git a/hscontrol/policy/policyutil/reduce.go b/hscontrol/policy/policyutil/reduce.go index e4549c10..6d95a297 100644 --- a/hscontrol/policy/policyutil/reduce.go +++ b/hscontrol/policy/policyutil/reduce.go @@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf for _, rule := range rules { // record if the rule is actually relevant for the given node. var dests []tailcfg.NetPortRange + DEST_LOOP: for _, dest := range rule.DstPorts { expanded, err := util.ParseIPSet(dest.IP, nil) diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go index bd975d23..0b674981 100644 --- a/hscontrol/policy/policyutil/reduce_test.go +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -823,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) { for _, tt := range tests { for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) { t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { - var pm policy.PolicyManager - var err error + var ( + pm policy.PolicyManager + err error + ) + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) require.NoError(t, err) + got, _ := pm.Filter() t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) got = policyutil.ReduceFilterRules(tt.node.View(), got) diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go index 5aa5e28c..3d070a25 100644 --- a/hscontrol/policy/route_approval_test.go +++ b/hscontrol/policy/route_approval_test.go @@ -829,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if tt.name == "empty policy" { // We expect this one to have a valid but empty policy require.NoError(t, err) + if err != nil { return } @@ -843,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) { if diff := cmp.Diff(tt.canApprove, result); diff != "" { t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) } + assert.Equal(t, tt.canApprove, result, "Unexpected route approval result") }) } diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go index 78c6ebc5..3f72cdda 100644 --- a/hscontrol/policy/v2/filter.go +++ b/hscontrol/policy/v2/filter.go @@ -45,6 +45,7 @@ func (pol *Policy) compileFilterRules( protocols, _ := acl.Protocol.parseProtocol() var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { ips, err := dest.Resolve(pol, users, nodes) if err != nil { @@ -127,8 +128,10 @@ func (pol *Policy) compileACLWithAutogroupSelf( node types.NodeView, nodes views.Slice[types.NodeView], ) ([]*tailcfg.FilterRule, error) { - var autogroupSelfDests []AliasWithPorts - var otherDests []AliasWithPorts + var ( + autogroupSelfDests []AliasWithPorts + otherDests []AliasWithPorts + ) for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -139,13 +142,14 @@ func (pol *Policy) compileACLWithAutogroupSelf( } protocols, _ := acl.Protocol.parseProtocol() + var rules []*tailcfg.FilterRule var resolvedSrcIPs []*netipx.IPSet for _, src := range acl.Sources { if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { - return nil, fmt.Errorf("autogroup:self cannot be used in sources") + return nil, errors.New("autogroup:self cannot be used in sources") } ips, err := src.Resolve(pol, users, nodes) @@ -167,6 +171,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(autogroupSelfDests) > 0 { // Pre-filter to same-user untagged devices once - reuse for both sources and destinations sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { sameUserNodes = append(sameUserNodes, n) @@ -176,6 +181,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if len(sameUserNodes) > 0 { // Filter sources to only same-user untagged devices var srcIPs netipx.IPSetBuilder + for _, ips := range resolvedSrcIPs { for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set @@ -192,6 +198,7 @@ func (pol *Policy) compileACLWithAutogroupSelf( if srcSet != nil && len(srcSet.Prefixes()) > 0 { var destPorts []tailcfg.NetPortRange + for _, dest := range autogroupSelfDests { for _, n := range sameUserNodes { for _, port := range dest.Ports { @@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy( // Separate destinations into autogroup:self and others // This is needed because autogroup:self requires filtering sources to same-user only, // while other destinations should use all resolved sources - var autogroupSelfDests []Alias - var otherDests []Alias + var ( + autogroupSelfDests []Alias + otherDests []Alias + ) for _, dst := range rule.Destinations { if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { @@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy( } var action tailcfg.SSHAction + switch rule.Action { case SSHActionAccept: action = sshAction(true, 0) @@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy( // by default, we do not allow root unless explicitly stated userMap["root"] = "" } + if rule.Users.ContainsRoot() { userMap["root"] = "root" } + for _, u := range rule.Users.NormalUsers() { userMap[u.String()] = u.String() } @@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy( if len(autogroupSelfDests) > 0 && !node.IsTagged() { // Build destination set for autogroup:self (same-user untagged devices only) var dest netipx.IPSetBuilder + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { n.AppendToIPSet(&dest) @@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy( // Filter sources to only same-user untagged devices // Pre-filter to same-user untagged devices for efficiency sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { if n.User().ID() == node.User().ID() && !n.IsTagged() { sameUserNodes = append(sameUserNodes, n) @@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy( } var filteredSrcIPs netipx.IPSetBuilder + for _, n := range sameUserNodes { // Check if any of this node's IPs are in the source set if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { @@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy( if len(otherDests) > 0 { // Build destination set for other destinations var dest netipx.IPSetBuilder + for _, dst := range otherDests { ips, err := dst.Resolve(pol, users, nodes) if err != nil { log.Trace().Caller().Err(err).Msgf("resolving destination ips") continue } + if ips != nil { dest.AddSet(ips) } diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go index d798b5f7..663e3d6b 100644 --- a/hscontrol/policy/v2/filter_test.go +++ b/hscontrol/policy/v2/filter_test.go @@ -589,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) { if sshPolicy == nil { return // Expected empty result } + assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match") + return } @@ -670,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) { } // TestSSHIntegrationReproduction reproduces the exact scenario from the integration test -// TestSSHOneUserToAll that was failing with empty sshUsers +// TestSSHOneUserToAll that was failing with empty sshUsers. func TestSSHIntegrationReproduction(t *testing.T) { // Create users matching the integration test users := types.Users{ @@ -735,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) { } // TestSSHJSONSerialization verifies that the SSH policy can be properly serialized -// to JSON and that the sshUsers field is not empty +// to JSON and that the sshUsers field is not empty. func TestSSHJSONSerialization(t *testing.T) { users := types.Users{ {Name: "user1", Model: gorm.Model{ID: 1}}, @@ -775,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) { // Parse back to verify structure var parsed tailcfg.SSHPolicy + err = json.Unmarshal(jsonData, &parsed) require.NoError(t, err) @@ -859,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { if err != nil { t.Fatalf("unexpected error: %v", err) } + if len(rules) != 1 { t.Fatalf("expected 1 rule, got %d", len(rules)) } @@ -875,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { found := false addr := netip.MustParseAddr(expectedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -892,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"} for _, excludedIP := range excludedSourceIPs { addr := netip.MustParseAddr(excludedIP) + for _, prefix := range rule.SrcIPs { pref := netip.MustParsePrefix(prefix) if pref.Contains(addr) { @@ -1325,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) { assert.Empty(t, rules3, "user3 should have no rules") } -// Helper function to create IP addresses for testing +// Helper function to create IP addresses for testing. func createAddr(ip string) *netip.Addr { addr, _ := netip.ParseAddr(ip) return &addr } // TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly -// with autogroup:self in destinations +// with autogroup:self in destinations. func TestSSHWithAutogroupSelfInDestination(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1380,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // Test for user2's first node @@ -1398,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { for i, p := range rule2.Principals { principalIPs2[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2) // Test for tagged node (should have no SSH rules) node5 := nodes[4].View() sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy3 != nil { assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self") } @@ -1411,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) { // TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user // is in the source and autogroup:self in destination, only that user's devices -// can SSH (and only if they match the target user) +// can SSH (and only if they match the target user). func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1453,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user2's node: should have no rules (user1's devices can't match user2's self) node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1") } } -// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations +// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations. func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1511,19 +1522,21 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) // For user3's node: should have no rules (not in group:admins) node5 := nodes[4].View() sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)") } } // TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices -// are excluded from both sources and destinations when autogroup:self is used +// are excluded from both sources and destinations when autogroup:self is used. func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { users := types.Users{ {Model: gorm.Model{ID: 1}, Name: "user1"}, @@ -1568,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { for i, p := range rule.Principals { principalIPs[i] = p.NodeIP } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs, "should only include untagged devices") @@ -1575,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { node3 := nodes[2].View() sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) require.NoError(t, err) + if sshPolicy2 != nil { assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self") } @@ -1623,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { // Verify autogroup:self rule has filtered sources (only same-user devices) selfRule := sshPolicy1.Rules[0] require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices") + selfPrincipals := make([]string, len(selfRule.Principals)) for i, p := range selfRule.Principals { selfPrincipals[i] = p.NodeIP } + require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals, "autogroup:self rule should only include same-user untagged devices") @@ -1638,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") routerRule := sshPolicyRouter.Rules[0] + routerPrincipals := make([]string, len(routerRule.Principals)) for i, p := range routerRule.Principals { routerPrincipals[i] = p.NodeIP } + require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)") require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)") diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index 042c2723..8c07e6cc 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Filter: filter, Policy: pm.pol, }) + filterChanged := filterHash != pm.filterHash if filterChanged { log.Debug(). @@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("filter.rules.new", len(filter)). Msg("Policy filter hash changed") } + pm.filter = filter + pm.filterHash = filterHash if filterChanged { pm.matchers = matcher.MatchesFromFilterRules(pm.filter) @@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } tagOwnerMapHash := deephash.Hash(&tagMap) + tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash if tagOwnerChanged { log.Debug(). @@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("tagOwners.new", len(tagMap)). Msg("Tag owner hash changed") } + pm.tagOwnerMap = tagMap pm.tagOwnerMapHash = tagOwnerMapHash @@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { } autoApproveMapHash := deephash.Hash(&autoMap) + autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash if autoApproveChanged { log.Debug(). @@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Int("autoApprovers.new", len(autoMap)). Msg("Auto-approvers hash changed") } + pm.autoApproveMap = autoMap pm.autoApproveMapHash = autoApproveMapHash exitSetHash := deephash.Hash(&exitSet) + exitSetChanged := exitSetHash != pm.exitSetHash if exitSetChanged { log.Debug(). @@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { Str("exitSet.hash.new", exitSetHash.String()[:8]). Msg("Exit node set hash changed") } + pm.exitSet = exitSet pm.exitSetHash = exitSetHash @@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) { if !needsUpdate { log.Trace(). Msg("Policy evaluation detected no changes - all hashes match") + return false, nil } @@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err if err != nil { return nil, fmt.Errorf("compiling SSH policy: %w", err) } + pm.sshPolicyMap[node.ID()] = sshPol return sshPol, nil @@ -318,6 +329,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ if err != nil || len(filter) == 0 { continue } + nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter) } @@ -398,6 +410,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil reducedFilter := policyutil.ReduceFilterRules(node, pm.filter) pm.filterRulesMap[node.ID()] = reducedFilter + return reducedFilter, nil } @@ -442,7 +455,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul // This is different from FilterForNode which returns REDUCED rules for packet filtering. // // For global policies: returns the global matchers (same for all nodes) -// For autogroup:self: returns node-specific matchers from unreduced compiled rules +// For autogroup:self: returns node-specific matchers from unreduced compiled rules. func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { if pm == nil { return nil, nil @@ -474,6 +487,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() defer pm.mu.Unlock() + pm.users = users // Clear SSH policy map when users change to force SSH policy recomputation @@ -685,6 +699,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr if pm.exitSet == nil { return false } + if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) { return true } @@ -748,8 +763,10 @@ func (pm *PolicyManager) DebugString() string { } fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap)) + for prefix, approveAddrs := range pm.autoApproveMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range approveAddrs.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } @@ -758,14 +775,17 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap)) + for prefix, tagOwners := range pm.tagOwnerMap { fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range tagOwners.Ranges() { fmt.Fprintf(&sb, "\t\t%s\n", iprange) } } sb.WriteString("\n\n") + if pm.filter != nil { filter, err := json.MarshalIndent(pm.filter, "", " ") if err == nil { @@ -778,6 +798,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Matchers:\n") sb.WriteString("an internal structure used to filter nodes and routes\n") + for _, match := range pm.matchers { sb.WriteString(match.DebugString()) sb.WriteString("\n") @@ -785,6 +806,7 @@ func (pm *PolicyManager) DebugString() string { sb.WriteString("\n\n") sb.WriteString("Nodes:\n") + for _, node := range pm.nodes.All() { sb.WriteString(node.String()) sb.WriteString("\n") @@ -841,6 +863,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S // Check if IPs changed (simple check - could be more sophisticated) oldIPs := oldNode.IPs() + newIPs := newNode.IPs() if len(oldIPs) != len(newIPs) { affectedUsers[newNode.User().ID()] = struct{}{} @@ -862,6 +885,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S for nodeID := range pm.filterRulesMap { // Find the user for this cached node var nodeUserID uint + found := false // Check in new nodes first @@ -869,6 +893,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if node.ID() == nodeID { nodeUserID = node.User().ID() found = true + break } } @@ -879,6 +904,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S if node.ID() == nodeID { nodeUserID = node.User().ID() found = true + break } } diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go index 80c08eed..4477e8b1 100644 --- a/hscontrol/policy/v2/policy_test.go +++ b/hscontrol/policy/v2/policy_test.go @@ -56,6 +56,7 @@ func TestPolicyManager(t *testing.T) { if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff( tt.wantMatchers, matchers, @@ -176,13 +177,16 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) { t.Run(tt.name, func(t *testing.T) { for i, n := range tt.newNodes { found := false + for _, origNode := range initialNodes { if n.Hostname == origNode.Hostname { n.ID = origNode.ID found = true + break } } + if !found { n.ID = types.NodeID(len(initialNodes) + i + 1) } @@ -369,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) { // TestAutogroupSelfReducedVsUnreducedRules verifies that: // 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships -// 2. FilterForNode returns reduced compiled rules for packet filters +// 2. FilterForNode returns reduced compiled rules for packet filters. func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"} user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"} @@ -409,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations // For node1, destinations should only be node1's IPs node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"} + for _, rule := range filterNode1 { for _, dst := range rule.DstPorts { require.Contains(t, node1IPs, dst.IP, @@ -418,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { // For node2, destinations should only be node2's IPs node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"} + for _, rule := range filterNode2 { for _, dst := range rule.DstPorts { require.Contains(t, node2IPs, dst.IP, diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index fbce8a2b..3fe5a0d4 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -21,7 +21,7 @@ import ( "tailscale.com/util/slicesx" ) -// Global JSON options for consistent parsing across all struct unmarshaling +// Global JSON options for consistent parsing across all struct unmarshaling. var policyJSONOpts = []json.Options{ json.DefaultOptionsV2(), json.MatchCaseInsensitiveNames(true), @@ -58,6 +58,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { } var alias string + switch v := a.Alias.(type) { case *Username: alias = string(*v) @@ -89,6 +90,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) { // Otherwise, format as "alias:ports" var ports []string + for _, port := range a.Ports { if port.First == port.Last { ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) @@ -123,6 +125,7 @@ func (u Username) Validate() error { if isUser(string(u)) { return nil } + return fmt.Errorf("Username has to contain @, got: %q", u) } @@ -194,8 +197,10 @@ func (u Username) resolveUser(users types.Users) (types.User, error) { } func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) user, err := u.resolveUser(users) if err != nil { @@ -228,6 +233,7 @@ func (g Group) Validate() error { if isGroup(string(g)) { return nil } + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) } @@ -268,8 +274,10 @@ func (g Group) MarshalJSON() ([]byte, error) { } func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, user := range p.Groups[g] { uips, err := user.Resolve(nil, users, nodes) @@ -290,6 +298,7 @@ func (t Tag) Validate() error { if isTag(string(t)) { return nil } + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) } @@ -339,6 +348,7 @@ func (h Host) Validate() error { if isHost(string(h)) { return nil } + return fmt.Errorf("Hostname %q is invalid", h) } @@ -352,13 +362,16 @@ func (h *Host) UnmarshalJSON(b []byte) error { } func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) pref, ok := p.Hosts[h] if !ok { return nil, fmt.Errorf("unable to resolve host: %q", h) } + err := pref.Validate() if err != nil { errs = append(errs, err) @@ -376,6 +389,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView if err != nil { errs = append(errs, err) } + for _, node := range nodes.All() { if node.InIPSet(ipsTemp) { node.AppendToIPSet(&ips) @@ -391,6 +405,7 @@ func (p Prefix) Validate() error { if netip.Prefix(p).IsValid() { return nil } + return fmt.Errorf("Prefix %q is invalid", p) } @@ -404,6 +419,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + addrPref, err := addr.Prefix(addr.BitLen()) if err != nil { return err @@ -418,6 +434,7 @@ func (p *Prefix) parseString(addr string) error { if err != nil { return err } + *p = Prefix(pref) return nil @@ -428,6 +445,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { if err != nil { return err } + if err := p.Validate(); err != nil { return err } @@ -441,8 +459,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error { // // See [Policy], [types.Users], and [types.Nodes] for more details. func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) ips.AddPrefix(netip.Prefix(p)) // If the IP is a single host, look for a node to ensure we add all the IPs of @@ -587,8 +607,10 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { switch vs := v.(type) { case string: - var portsPart string - var err error + var ( + portsPart string + err error + ) if strings.Contains(vs, ":") { vs, portsPart, err = splitDestinationAndPort(vs) @@ -600,6 +622,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Ports = ports } else { return errors.New(`hostport must contain a colon (":")`) @@ -609,6 +632,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { if err != nil { return err } + if err := ve.Validate(); err != nil { return err } @@ -646,6 +670,7 @@ func isHost(str string) bool { func parseAlias(vs string) (Alias, error) { var pref Prefix + err := pref.parseString(vs) if err == nil { return &pref, nil @@ -690,6 +715,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Alias = ptr return nil @@ -699,6 +725,7 @@ type Aliases []Alias func (a *Aliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -744,8 +771,10 @@ func (a Aliases) MarshalJSON() ([]byte, error) { } func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, alias := range a { aips, err := alias.Resolve(p, users, nodes) @@ -770,6 +799,7 @@ func unmarshalPointer[T any]( parseFunc func(string) (T, error), ) (T, error) { var s string + err := json.Unmarshal(b, &s) if err != nil { var t T @@ -789,6 +819,7 @@ type AutoApprovers []AutoApprover func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { var autoApprovers []AutoApproverEnc + err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...) if err != nil { return err @@ -854,6 +885,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.AutoApprover = ptr return nil @@ -876,6 +908,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { if err != nil { return err } + ve.Owner = ptr return nil @@ -885,6 +918,7 @@ type Owners []Owner func (o *Owners) UnmarshalJSON(b []byte) error { var owners []OwnerEnc + err := json.Unmarshal(b, &owners, policyJSONOpts...) if err != nil { return err @@ -979,11 +1013,13 @@ func (g *Groups) UnmarshalJSON(b []byte) error { // Then validate each field can be converted to []string rawGroups := make(map[string][]string) + for key, value := range rawMap { switch v := value.(type) { case []any: // Convert []interface{} to []string var stringSlice []string + for _, item := range v { if str, ok := item.(string); ok { stringSlice = append(stringSlice, str) @@ -991,6 +1027,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) } } + rawGroups[key] = stringSlice case string: return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) @@ -1000,6 +1037,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { } *g = make(Groups) + for key, value := range rawGroups { group := Group(key) // Group name already validated above @@ -1014,6 +1052,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error { return err } + usernames = append(usernames, username) } @@ -1033,6 +1072,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error { } *h = make(Hosts) + for key, value := range rawHosts { host := Host(key) if err := host.Validate(); err != nil { @@ -1076,6 +1116,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) { } rawTagOwners := make(map[string][]string) + for tag, owners := range to { tagStr := string(tag) ownerStrs := make([]string, len(owners)) @@ -1152,6 +1193,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if p == nil { return nil, nil, nil } + var err error routes := make(map[netip.Prefix]*netipx.IPSetBuilder) @@ -1160,6 +1202,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. if _, ok := routes[prefix]; !ok { routes[prefix] = new(netipx.IPSetBuilder) } + for _, autoApprover := range autoApprovers { aa, ok := autoApprover.(Alias) if !ok { @@ -1173,6 +1216,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } var exitNodeSetBuilder netipx.IPSetBuilder + if len(p.AutoApprovers.ExitNode) > 0 { for _, autoApprover := range p.AutoApprovers.ExitNode { aa, ok := autoApprover.(Alias) @@ -1187,11 +1231,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types. } ret := make(map[netip.Prefix]*netipx.IPSet) + for prefix, builder := range routes { ipSet, err := builder.IPSet() if err != nil { return nil, nil, err } + ret[prefix] = ipSet } @@ -1235,6 +1281,7 @@ func (a *Action) UnmarshalJSON(b []byte) error { default: return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) } + return nil } @@ -1259,6 +1306,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error { default: return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) } + return nil } @@ -1399,7 +1447,7 @@ func (p Protocol) validate() error { return nil case ProtocolWildcard: // Wildcard "*" is not allowed - Tailscale rejects it - return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") default: // Try to parse as a numeric protocol number str := string(p) @@ -1427,7 +1475,7 @@ func (p Protocol) MarshalJSON() ([]byte, error) { return json.Marshal(string(p)) } -// Protocol constants matching the IANA numbers +// Protocol constants matching the IANA numbers. const ( protocolICMP = 1 // Internet Control Message protocolIGMP = 2 // Internet Group Management @@ -1464,6 +1512,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Remove any fields that start with '#' filtered := make(map[string]any) + for key, value := range raw { if !strings.HasPrefix(key, "#") { filtered[key] = value @@ -1478,6 +1527,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Create a type alias to avoid infinite recursion type aclAlias ACL + var temp aclAlias // Unmarshal into the temporary struct using the v2 JSON options @@ -1487,6 +1537,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error { // Copy the result back to the original struct *a = ACL(temp) + return nil } @@ -1733,6 +1784,7 @@ func (p *Policy) validate() error { } } } + for _, dst := range ssh.Destinations { switch dst := dst.(type) { case *AutoGroup: @@ -1846,6 +1898,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { for i, username := range usernames { users[i] = string(username) } + raw[string(group)] = users } @@ -1854,6 +1907,7 @@ func (g Groups) MarshalJSON() ([]byte, error) { func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1877,6 +1931,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) if err != nil { return err @@ -1960,8 +2015,10 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { } func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { - var ips netipx.IPSetBuilder - var errs []error + var ( + ips netipx.IPSetBuilder + errs []error + ) for _, alias := range a { aips, err := alias.Resolve(p, users, nodes) @@ -2015,18 +2072,22 @@ func unmarshalPolicy(b []byte) (*Policy, error) { } var policy Policy + ast, err := hujson.Parse(b) if err != nil { return nil, fmt.Errorf("parsing HuJSON: %w", err) } ast.Standardize() + if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName { ptr := serr.JSONPointer name := ptr.LastToken() + return nil, fmt.Errorf("unknown field %q", name) } + return nil, fmt.Errorf("parsing policy from bytes: %w", err) } @@ -2073,6 +2134,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range acl.Destinations { if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true @@ -2087,6 +2149,7 @@ func (p *Policy) usesAutogroupSelf() bool { return true } } + for _, dest := range ssh.Destinations { if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { return true diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go index 8f4f7a85..79d005a3 100644 --- a/hscontrol/policy/v2/types_test.go +++ b/hscontrol/policy/v2/types_test.go @@ -81,6 +81,7 @@ func TestMarshalJSON(t *testing.T) { // Unmarshal back to verify round trip var roundTripped Policy + err = json.Unmarshal(marshalled, &roundTripped) require.NoError(t, err) @@ -2020,6 +2021,7 @@ func TestResolvePolicy(t *testing.T) { } var prefs []netip.Prefix + if ips != nil { if p := ips.Prefixes(); len(p) > 0 { prefs = p @@ -2191,9 +2193,11 @@ func TestResolveAutoApprovers(t *testing.T) { t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff) } + if tt.wantAllIPRoutes != nil { if gotAllIPRoutes == nil { t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil") @@ -2340,6 +2344,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet { for _, p := range prefixes { builder.AddPrefix(mp(p)) } + ipSet, _ := builder.IPSet() return ipSet @@ -2349,6 +2354,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool { if x == nil || y == nil { return x == y } + return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...) } @@ -2577,6 +2583,7 @@ func TestResolveTagOwners(t *testing.T) { t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) return } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff) } @@ -2852,6 +2859,7 @@ func TestNodeCanHaveTag(t *testing.T) { require.ErrorContains(t, err, tt.wantErr) return } + require.NoError(t, err) got := pm.NodeCanHaveTag(tt.node.View(), tt.tag) @@ -3112,6 +3120,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) if tt.wantErr { @@ -3163,6 +3172,7 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { // Unmarshal back var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) require.NoError(t, err) @@ -3241,12 +3251,13 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { assert.Contains(t, err.Error(), `invalid action "deny"`) } -// Helper function to parse aliases for testing +// Helper function to parse aliases for testing. func mustParseAlias(s string) Alias { alias, err := parseAlias(s) if err != nil { panic(err) } + return alias } diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go index a4367775..80de52bc 100644 --- a/hscontrol/policy/v2/utils.go +++ b/hscontrol/policy/v2/utils.go @@ -18,9 +18,11 @@ func splitDestinationAndPort(input string) (string, string, error) { if lastColonIndex == -1 { return "", "", errors.New("input must contain a colon character separating destination and port") } + if lastColonIndex == 0 { return "", "", errors.New("input cannot start with a colon character") } + if lastColonIndex == len(input)-1 { return "", "", errors.New("input cannot end with a colon character") } @@ -45,6 +47,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { for part := range parts { if strings.Contains(part, "-") { rangeParts := strings.Split(part, "-") + rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { return e == "" }) diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go index 2084b22f..a845e7a9 100644 --- a/hscontrol/policy/v2/utils_test.go +++ b/hscontrol/policy/v2/utils_test.go @@ -58,9 +58,11 @@ func TestParsePort(t *testing.T) { if err != nil && err.Error() != test.err { t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) } + if err == nil && test.err != "" { t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) } + if result != test.expected { t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) } @@ -92,9 +94,11 @@ func TestParsePortRange(t *testing.T) { if err != nil && err.Error() != test.err { t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) } + if err == nil && test.err != "" { t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) } + if diff := cmp.Diff(result, test.expected); diff != "" { t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) } diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 02275751..d3c9f1ef 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() { // This is not my favourite solution, but it kind of works in our eventually consistent world. ticker := time.NewTicker(time.Second) defer ticker.Stop() + disconnected := true // Wait up to 10 seconds for the node to reconnect. // 10 seconds was arbitrary chosen as a reasonable time to reconnect. @@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() { disconnected = false break } + <-ticker.C } @@ -215,8 +217,10 @@ func (m *mapSession) serveLongPoll() { if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { m.errf(err, "failed to add node to batcher") log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session") + return } + log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher") m.h.Change(mapReqChange) diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go index 72eb2a5b..e3708a13 100644 --- a/hscontrol/routes/primary.go +++ b/hscontrol/routes/primary.go @@ -107,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Msg("Current primary no longer available") } } + if len(nodes) >= 1 { pr.primaries[prefix] = nodes[0] changed = true + log.Debug(). Caller(). Str("prefix", prefix.String()). @@ -126,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool { Str("prefix", prefix.String()). Msg("Cleaning up primary route that no longer has available nodes") delete(pr.primaries, prefix) + changed = true } } @@ -161,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) // If no routes are being set, remove the node from the routes map. if len(prefixes) == 0 { wasPresent := false + if _, ok := pr.routes[node]; ok { delete(pr.routes, node) + wasPresent = true + log.Debug(). Caller(). Uint64("node.id", node.Uint64()). Msg("Removed node from primary routes (no prefixes)") } + changed := pr.updatePrimaryLocked() log.Debug(). Caller(). @@ -254,12 +261,14 @@ func (pr *PrimaryRoutes) stringLocked() string { ids := types.NodeIDs(xmaps.Keys(pr.routes)) slices.Sort(ids) + for _, id := range ids { prefixes := pr.routes[id] fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) } fmt.Fprintln(&sb, "\n\nCurrent primary routes:") + for route, nodeID := range pr.primaries { fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID) } diff --git a/hscontrol/routes/primary_test.go b/hscontrol/routes/primary_test.go index 7a9767b2..b03c8f81 100644 --- a/hscontrol/routes/primary_test.go +++ b/hscontrol/routes/primary_test.go @@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) pr.SetRoutes(2, mp("192.168.2.0/24")) pr.SetRoutes(1) // Deregister by setting no routes + return pr.SetRoutes(1, mp("192.168.3.0/24")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) { { name: "multiple-nodes-register-same-route", operations: func(pr *PrimaryRoutes) bool { - pr.SetRoutes(1, mp("192.168.1.0/24")) // false - pr.SetRoutes(2, mp("192.168.1.0/24")) // true + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true + return pr.SetRoutes(3, mp("192.168.1.0/24")) // false }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) { pr.SetRoutes(1, mp("192.168.1.0/24")) // false pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary - return pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(1) // true, 2 primary }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ 2: { @@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) + return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) }, expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ @@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) { operations: func(pr *PrimaryRoutes) bool { var wg sync.WaitGroup wg.Add(2) + var change1, change2 bool + go func() { defer wg.Done() + change1 = pr.SetRoutes(1, mp("192.168.1.0/24")) }() go func() { defer wg.Done() + change2 = pr.SetRoutes(2, mp("192.168.2.0/24")) }() + wg.Wait() return change1 || change2 @@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { pr := New() + change := tt.operations(pr) if change != tt.expectedChange { t.Errorf("change = %v, want %v", change, tt.expectedChange) } + comps := append(util.Comparers, cmpopts.EquateEmpty()) if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { t.Errorf("routes mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { t.Errorf("primaries mismatch (-want +got):\n%s", diff) } + if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 3ed1d79f..9cad1c04 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -77,6 +77,7 @@ func (s *State) DebugOverview() string { ephemeralCount := 0 now := time.Now() + for _, node := range allNodes.All() { if node.Valid() { userName := node.Owner().Name() @@ -103,17 +104,21 @@ func (s *State) DebugOverview() string { // User statistics sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users))) + for userName, nodeCount := range userNodeCounts { sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount)) } + sb.WriteString("\n") // Policy information sb.WriteString("Policy:\n") sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode)) + if s.cfg.Policy.Mode == types.PolicyModeFile { sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path)) } + sb.WriteString("\n") // DERP information @@ -123,6 +128,7 @@ func (s *State) DebugOverview() string { } else { sb.WriteString("DERP: not configured\n") } + sb.WriteString("\n") // Route information @@ -130,6 +136,7 @@ func (s *State) DebugOverview() string { if s.primaryRoutes.String() == "" { routeCount = 0 } + sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount)) sb.WriteString("\n") @@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string { for _, node := range region.Nodes { sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n", node.Name, node.HostName, node.DERPPort)) + if node.STUNPort != 0 { sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort)) } } + sb.WriteString("\n") } @@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo { if s.primaryRoutes.String() == "" { routeCount = 0 } + info.PrimaryRoutes = routeCount return info diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go index 9f713b3d..5c755687 100644 --- a/hscontrol/state/ephemeral_test.go +++ b/hscontrol/state/ephemeral_test.go @@ -20,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // Create NodeStore store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -43,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { // 6. If DELETE came after UPDATE, the returned node should be invalid done := make(chan bool, 2) - var updatedNode types.NodeView - var updateOk bool + + var ( + updatedNode types.NodeView + updateOk bool + ) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { n.LastSeen = new(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(node.ID) + done <- true }() @@ -90,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -147,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { node := createTestNode(3, 1, "test-user", "test-node-3") store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -203,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -213,8 +223,11 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 1. UpdateNode (from UpdateNodeFromMapRequest during polling) // 2. DeleteNode (from handleLogout when client sends logout request) - var updatedNode types.NodeView - var updateOk bool + var ( + updatedNode types.NodeView + updateOk bool + ) + done := make(chan bool, 2) // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) @@ -222,12 +235,14 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { n.LastSeen = new(time.Now()) }) + done <- true }() // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) go func() { store.DeleteNode(ephemeralNode.ID) + done <- true }() @@ -266,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { // 5. UpdateNode and DeleteNode batch together // 6. UpdateNode returns a valid node (from before delete in batch) // 7. persistNodeToDB is called with the stale valid node -// 8. Node gets re-inserted into database instead of staying deleted +// 8. Node gets re-inserted into database instead of staying deleted. func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -278,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -348,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() defer store.Stop() @@ -398,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { // 3. UpdateNode and DeleteNode batch together // 4. UpdateNode returns a valid node (from before delete in batch) // 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node -// 6. persistNodeToDB must detect the node is deleted and refuse to persist +// 6. persistNodeToDB must detect the node is deleted and refuse to persist. func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7") ephemeralNode.AuthKey = &types.PreAuthKey{ @@ -408,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { } store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go index e7dfc11c..d8cddaa1 100644 --- a/hscontrol/state/maprequest.go +++ b/hscontrol/state/maprequest.go @@ -29,6 +29,7 @@ func netInfoFromMapRequest( Uint64("node.id", nodeID.Uint64()). Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP). Msg("using NetInfo from previous Hostinfo in MapRequest") + return currentHostinfo.NetInfo } diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go index 0fa81318..a7d50a07 100644 --- a/hscontrol/state/maprequest_test.go +++ b/hscontrol/state/maprequest_test.go @@ -136,7 +136,7 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { }) } -// Simple helper function for tests +// Simple helper function for tests. func createTestNodeSimple(id types.NodeID) *types.Node { user := types.User{ Name: "test-user", diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go index 6327b46b..5d8d6e85 100644 --- a/hscontrol/state/node_store.go +++ b/hscontrol/state/node_store.go @@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc for _, n := range allNodes { nodes[n.ID] = *n } + snap := snapshotFromNodes(nodes, peersFunc) store := &NodeStore{ @@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("put").Inc() return resultNode @@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node) } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("update").Inc() // Return the node and whether it exists (is valid) @@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) { } nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result nodeStoreQueueDepth.Dec() @@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() { if len(batch) != 0 { s.applyBatch(batch) } + return } + batch = append(batch, w) if len(batch) >= s.batchSize { s.applyBatch(batch) @@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) { w.updateFn(&n) nodes[w.nodeID] = n } + if w.nodeResult != nil { nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) } @@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) { nodeView := node.View() for _, w := range workItems { w.nodeResult <- nodeView + close(w.nodeResult) } } else { // Node was deleted or doesn't exist for _, w := range workItems { w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) } } @@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S peersByNode: func() map[types.NodeID][]types.NodeView { peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) defer peersTimer.ObserveDuration() + return peersFunc(allNodes) }(), nodesByUser: make(map[types.UserID][]types.NodeView), @@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S if newSnap.nodesByMachineKey[n.MachineKey] == nil { newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView) } + newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView } @@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string { // User distribution (shows internal UserID tracking, not display owner) sb.WriteString("Nodes by Internal User ID:\n") + for userID, nodes := range snapshot.nodesByUser { if len(nodes) > 0 { userName := "unknown" taggedCount := 0 + if len(nodes) > 0 && nodes[0].Valid() { userName = nodes[0].User().Name() // Count tagged nodes (which have UserID set but are owned by "tagged-devices") @@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string { } } } + sb.WriteString("\n") // Peer relationships summary sb.WriteString("Peer Relationships:\n") + totalPeers := 0 + for nodeID, peers := range snapshot.peersByNode { peerCount := len(peers) + totalPeers += peerCount if node, exists := snapshot.nodesByID[nodeID]; exists { sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n", nodeID, node.Hostname, peerCount)) } } + if len(snapshot.peersByNode) > 0 { avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode)) sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers)) } + sb.WriteString("\n") // Node key index @@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() { } s.writeQueue <- w + <-result } diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go index 745850cc..23068b97 100644 --- a/hscontrol/state/node_store_test.go +++ b/hscontrol/state/node_store_test.go @@ -44,6 +44,7 @@ func TestSnapshotFromNodes(t *testing.T) { nodes := map[types.NodeID]types.Node{ 1: createTestNode(1, 1, "user1", "node1"), } + return nodes, allowAllPeersFunc }, validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { @@ -192,11 +193,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + for _, n := range nodes { if n.ID() != node.ID() { peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -207,6 +210,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView ret := make(map[types.NodeID][]types.NodeView, len(nodes)) for _, node := range nodes { var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 for _, n := range nodes { @@ -221,6 +225,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView peers = append(peers, n) } } + ret[node.ID()] = peers } @@ -454,10 +459,13 @@ func TestNodeStoreOperations(t *testing.T) { // Add nodes in sequence n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) assert.True(t, n4.Valid()) @@ -525,16 +533,20 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2 types.NodeView - var newNode3 types.NodeView - var ok1, ok2 bool + var ( + resultNode1, resultNode2 types.NodeView + newNode3 types.NodeView + ok1, ok2 bool + ) // These should all be processed in the same batch + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "batch-updated-node1" n.GivenName = "batch-given-1" }) + close(done1) }() @@ -543,12 +555,14 @@ func TestNodeStoreOperations(t *testing.T) { n.Hostname = "batch-updated-node2" n.GivenName = "batch-given-2" }) + close(done2) }() go func() { node3 := createTestNode(3, 1, "user1", "node3") newNode3 = store.PutNode(node3) + close(done3) }() @@ -601,20 +615,23 @@ func TestNodeStoreOperations(t *testing.T) { // This test verifies that when multiple updates to the same node // are batched together, each returned node reflects ALL changes // in the batch, not just the individual update's changes. - done1 := make(chan struct{}) done2 := make(chan struct{}) done3 := make(chan struct{}) - var resultNode1, resultNode2, resultNode3 types.NodeView - var ok1, ok2, ok3 bool + var ( + resultNode1, resultNode2, resultNode3 types.NodeView + ok1, ok2, ok3 bool + ) // These updates all modify node 1 and should be batched together // The final state should have all three modifications applied + go func() { resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "multi-update-hostname" }) + close(done1) }() @@ -622,6 +639,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "multi-update-givenname" }) + close(done2) }() @@ -629,6 +647,7 @@ func TestNodeStoreOperations(t *testing.T) { resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"tag1", "tag2"} }) + close(done3) }() @@ -722,14 +741,18 @@ func TestNodeStoreOperations(t *testing.T) { done2 := make(chan struct{}) done3 := make(chan struct{}) - var result1, result2, result3 types.NodeView - var ok1, ok2, ok3 bool + var ( + result1, result2, result3 types.NodeView + ok1, ok2, ok3 bool + ) // Start concurrent updates + go func() { result1, ok1 = store.UpdateNode(1, func(n *types.Node) { n.Hostname = "concurrent-db-hostname" }) + close(done1) }() @@ -737,6 +760,7 @@ func TestNodeStoreOperations(t *testing.T) { result2, ok2 = store.UpdateNode(1, func(n *types.Node) { n.GivenName = "concurrent-db-given" }) + close(done2) }() @@ -744,6 +768,7 @@ func TestNodeStoreOperations(t *testing.T) { result3, ok3 = store.UpdateNode(1, func(n *types.Node) { n.Tags = []string{"concurrent-tag"} }) + close(done3) }() @@ -827,6 +852,7 @@ func TestNodeStoreOperations(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { store := tt.setupFunc(t) + store.Start() defer store.Stop() @@ -846,10 +872,11 @@ type testStep struct { // --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- -// Helper for concurrent test nodes +// Helper for concurrent test nodes. func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { machineKey := key.NewMachine() nodeKey := key.NewNode() + return types.Node{ ID: id, Hostname: hostname, @@ -862,72 +889,90 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { } } -// --- Concurrency: concurrent PutNode operations --- +// --- Concurrency: concurrent PutNode operations ---. func TestNodeStoreConcurrentPutNode(t *testing.T) { const concurrentOps = 20 store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, concurrentOps) for i := range concurrentOps { wg.Add(1) + go func(nodeID int) { defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") } -// --- Batching: concurrent ops fit in one batch --- +// --- Batching: concurrent ops fit in one batch ---. func TestNodeStoreBatchingEfficiency(t *testing.T) { const batchSize = 10 + const ops = 15 // more than batchSize store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() var wg sync.WaitGroup + results := make(chan bool, ops) for i := range ops { wg.Add(1) + go func(nodeID int) { defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + resultNode := store.PutNode(node) results <- resultNode.Valid() }(i + 1) } + wg.Wait() close(results) successCount := 0 + for success := range results { if success { successCount++ } } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") } -// --- Race conditions: many goroutines on same node --- +// --- Race conditions: many goroutines on same node ---. func TestNodeStoreRaceConditions(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -936,13 +981,18 @@ func TestNodeStoreRaceConditions(t *testing.T) { resultNode := store.PutNode(node) require.True(t, resultNode.Valid()) - const numGoroutines = 30 - const opsPerGoroutine = 10 + const ( + numGoroutines = 30 + opsPerGoroutine = 10 + ) + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) for i := range numGoroutines { wg.Add(1) + go func(gid int) { defer wg.Done() @@ -962,6 +1012,7 @@ func TestNodeStoreRaceConditions(t *testing.T) { } case 2: newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) if !resultNode.Valid() { errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) @@ -970,23 +1021,28 @@ func TestNodeStoreRaceConditions(t *testing.T) { } }(i) } + wg.Wait() close(errors) errorCount := 0 + for err := range errors { t.Error(err) + errorCount++ } + if errorCount > 0 { t.Fatalf("Race condition test failed with %d errors", errorCount) } } -// --- Resource cleanup: goroutine leak detection --- +// --- Resource cleanup: goroutine leak detection ---. func TestNodeStoreResourceCleanup(t *testing.T) { // initialGoroutines := runtime.NumGoroutine() store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1009,10 +1065,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }) retrieved, found := store.GetNode(nodeID) assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { store.DeleteNode(nodeID) } } + runtime.GC() // Wait for goroutines to settle and check for leaks @@ -1023,9 +1081,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) { }, time.Second, 10*time.Millisecond, "goroutines should not leak") } -// --- Timeout/deadlock: operations complete within reasonable time --- +// --- Timeout/deadlock: operations complete within reasonable time ---. func TestNodeStoreOperationTimeout(t *testing.T) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1033,36 +1092,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) { defer cancel() const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) updateResults := make([]error, ops) // Launch all PutNode operations concurrently for i := 1; i <= ops; i++ { nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) node := createConcurrentTestNode(id, "timeout-node") resultNode := store.PutNode(node) endPut := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) } }(i, nodeID) } + wg.Wait() // Launch all UpdateNode operations concurrently wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { defer wg.Done() + startUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) resultNode, ok := store.UpdateNode(id, func(n *types.Node) { @@ -1070,31 +1140,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) { }) endUpdate := time.Now() fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) } }(i, nodeID) } + done := make(chan struct{}) + go func() { wg.Wait() close(done) }() + select { case <-done: errorCount := 0 + for _, err := range putResults { if err != nil { t.Error(err) + errorCount++ } } + for _, err := range updateResults { if err != nil { t.Error(err) + errorCount++ } } + if errorCount == 0 { t.Log("All concurrent operations completed successfully within timeout") } else { @@ -1106,13 +1185,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) { } } -// --- Edge case: update non-existent node --- +// --- Edge case: update non-existent node ---. func TestNodeStoreUpdateNonExistentNode(t *testing.T) { for i := range 10 { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) store.Start() + nonExistentID := types.NodeID(999 + i) updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { updateCallCount++ @@ -1126,9 +1207,10 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) { } } -// --- Allocation benchmark --- +// --- Allocation benchmark ---. func BenchmarkNodeStoreAllocations(b *testing.B) { store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() defer store.Stop() @@ -1140,6 +1222,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) { n.Hostname = "bench-updated" }) store.GetNode(nodeID) + if i%10 == 9 { store.DeleteNode(nodeID) } diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go index 1a949173..efce647d 100644 --- a/hscontrol/tailsql.go +++ b/hscontrol/tailsql.go @@ -93,6 +93,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s mux := tsql.NewMux() tsweb.Debugger(mux) go http.Serve(lst, mux) + logf("TailSQL started") <-ctx.Done() logf("TailSQL shutting down...") diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index f4814519..be3756a0 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -177,6 +177,7 @@ func RegistrationIDFromString(str string) (RegistrationID, error) { if len(str) != RegistrationIDLength { return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) } + return RegistrationID(str), nil } diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4068d72e..fffe166d 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -301,6 +301,7 @@ func validatePKCEMethod(method string) error { if method != PKCEMethodPlain && method != PKCEMethodS256 { return errInvalidPKCEMethod } + return nil } @@ -1082,6 +1083,7 @@ func LoadServerConfig() (*Config, error) { if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 { return workers } + return DefaultBatcherWorkers() }(), RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), @@ -1117,6 +1119,7 @@ func isSafeServerURL(serverURL, baseDomain string) error { } s := len(serverDomainParts) + b := len(baseDomainParts) for i := range baseDomainParts { if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go index 6b9fc2ef..13a3a418 100644 --- a/hscontrol/types/config_test.go +++ b/hscontrol/types/config_test.go @@ -363,6 +363,7 @@ noise: // Populate a custom config file configFilePath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) @@ -398,10 +399,12 @@ server_url: http://127.0.0.1:8080 tls_letsencrypt_hostname: example.com tls_letsencrypt_challenge_type: TLS-ALPN-01 `) + err = os.WriteFile(configFilePath, configYaml, 0o600) if err != nil { t.Fatalf("Couldn't write file %s", configFilePath) } + err = LoadConfig(tmpDir, false) require.NoError(t, err) } @@ -463,6 +466,7 @@ func TestSafeServerURL(t *testing.T) { return } + assert.NoError(t, err) }) } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 1a66341d..5140bc44 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -156,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool { // Strip invalid DNS characters for givenName comparison normalised := strings.ToLower(node.Hostname) normalised = invalidDNSRegex.ReplaceAllString(normalised, "") + return node.GivenName == normalised } @@ -464,7 +465,7 @@ func (node *Node) IsSubnetRouter() bool { return len(node.SubnetRoutes()) > 0 } -// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes +// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes. func (node *Node) AllApprovedRoutes() []netip.Prefix { return append(node.SubnetRoutes(), node.ExitRoutes()...) } @@ -579,6 +580,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { Str("rejected_hostname", hostInfo.Hostname). Err(err). Msg("Rejecting invalid hostname update from hostinfo") + return } @@ -670,6 +672,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node { func (nodes Nodes) DebugString() string { var sb strings.Builder sb.WriteString("Nodes:\n") + for _, node := range nodes { sb.WriteString(node.DebugString()) sb.WriteString("\n") diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 2ce02f02..3b3e59e2 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error { if pak.Expiration != nil { return *pak.Expiration } + return time.Time{} }()). Time("now", time.Now()). diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 27aff519..dbcf4f44 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -40,9 +40,11 @@ var TaggedDevices = User{ func (u Users) String() string { var sb strings.Builder sb.WriteString("[ ") + for _, user := range u { fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) } + sb.WriteString(" ]") return sb.String() @@ -89,6 +91,7 @@ func (u *User) StringID() string { if u == nil { return "" } + return strconv.FormatUint(uint64(u.ID), 10) } @@ -203,6 +206,7 @@ type FlexibleBoolean bool func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { var val any + err := json.Unmarshal(data, &val) if err != nil { return fmt.Errorf("could not unmarshal data: %w", err) @@ -216,6 +220,7 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { if err != nil { return fmt.Errorf("could not parse %s as boolean: %w", v, err) } + *bit = FlexibleBoolean(pv) default: @@ -253,9 +258,11 @@ func (c *OIDCClaims) Identifier() string { if c.Iss == "" && c.Sub == "" { return "" } + if c.Iss == "" { return CleanIdentifier(c.Sub) } + if c.Sub == "" { return CleanIdentifier(c.Iss) } @@ -340,6 +347,7 @@ func CleanIdentifier(identifier string) string { cleanParts = append(cleanParts, trimmed) } } + if len(cleanParts) == 0 { return "" } @@ -382,6 +390,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) { if claims.Iss == "" && !strings.HasPrefix(identifier, "/") { identifier = "/" + identifier } + u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true} u.DisplayName = claims.Name u.ProfilePicURL = claims.ProfilePictureURL diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go index 15386553..acd88434 100644 --- a/hscontrol/types/users_test.go +++ b/hscontrol/types/users_test.go @@ -70,6 +70,7 @@ func TestUnmarshallOIDCClaims(t *testing.T) { t.Errorf("UnmarshallOIDCClaims() error = %v", err) return } + if diff := cmp.Diff(got, tt.want); diff != "" { t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff) } @@ -190,6 +191,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) { } result := claims.Identifier() assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("Identifier() mismatch (-want +got):\n%s", diff) } @@ -282,6 +284,7 @@ func TestCleanIdentifier(t *testing.T) { t.Run(tt.name, func(t *testing.T) { result := CleanIdentifier(tt.identifier) assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff) } @@ -487,6 +490,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) { var user User user.FromClaim(&got, tt.emailVerifiedRequired) + if diff := cmp.Diff(user, tt.want); diff != "" { t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff) } diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index b492e4d6..4f9a338f 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) { t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if !tt.wantErr && got != tt.want { t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want) } @@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) { t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr) return } + if tt.wantErr && tt.errorContains != "" { if err == nil || !strings.Contains(err.Error(), tt.errorContains) { t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains) diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go index 098f1979..7d9cdbdf 100644 --- a/hscontrol/util/prompt.go +++ b/hscontrol/util/prompt.go @@ -15,10 +15,12 @@ func YesNo(msg string) bool { var resp string fmt.Scanln(&resp) + resp = strings.ToLower(resp) switch resp { case "y", "yes", "sure": return true } + return false } diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go index d726ec60..fbed2ff8 100644 --- a/hscontrol/util/prompt_test.go +++ b/hscontrol/util/prompt_test.go @@ -86,6 +86,7 @@ func TestYesNo(t *testing.T) { // Write test input go func() { defer w.Close() + w.WriteString(tt.input) }() @@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check the result @@ -108,6 +110,7 @@ func TestYesNo(t *testing.T) { stderrR.Close() expectedPrompt := "Test question [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -130,6 +133,7 @@ func TestYesNoPromptMessage(t *testing.T) { // Write test input go func() { defer w.Close() + w.WriteString("n\n") }() @@ -140,6 +144,7 @@ func TestYesNoPromptMessage(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Check that the custom message was included in the prompt @@ -148,6 +153,7 @@ func TestYesNoPromptMessage(t *testing.T) { stderrR.Close() expectedPrompt := customMessage + " [y/n] " + actualPrompt := stderrBuf.String() if actualPrompt != expectedPrompt { t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) @@ -186,6 +192,7 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Write test input go func() { defer w.Close() + w.WriteString(tc.input) }() @@ -195,6 +202,7 @@ func TestYesNoCaseInsensitive(t *testing.T) { // Restore stdin and stderr os.Stdin = oldStdin os.Stderr = oldStderr + stderrW.Close() // Drain stderr diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index d1d7ece7..0a37ec87 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -33,6 +33,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) { b, err := GenerateRandomBytes(n) uenc := base64.RawURLEncoding.EncodeToString(b) + return uenc[:n], err } @@ -99,6 +100,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string { DstIPs: %v } `, rule.SrcIPs, rule.DstPorts)) + if index < len(rules)-1 { sb.WriteString(", ") } diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go index 4d828d02..53189656 100644 --- a/hscontrol/util/util.go +++ b/hscontrol/util/util.go @@ -30,6 +30,7 @@ func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { // It returns an error if not exactly one URL is found. func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { lines := strings.Split(output, "\n") + var urlStr string for _, line := range lines { @@ -38,6 +39,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { if urlStr != "" { return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) } + urlStr = line } } @@ -94,6 +96,7 @@ func ParseTraceroute(output string) (Traceroute, error) { // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`) + headerMatches := headerRegex.FindStringSubmatch(lines[0]) if len(headerMatches) < 2 { return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) @@ -105,6 +108,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if ipStr == "" { ipStr = headerMatches[3] } + ip, err := netip.ParseAddr(ipStr) if err != nil { return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) @@ -144,13 +148,17 @@ func ParseTraceroute(output string) (Traceroute, error) { } remainder := strings.TrimSpace(matches[2]) - var hopHostname string - var hopIP netip.Addr - var latencies []time.Duration + + var ( + hopHostname string + hopIP netip.Addr + latencies []time.Duration + ) // Check for Windows tracert format which has latencies before hostname // Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]" latencyFirst := false + if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") { // Check if latencies appear before any hostname/IP firstSpace := strings.Index(remainder, " ") @@ -171,12 +179,14 @@ func ParseTraceroute(output string) (Traceroute, error) { } // Extract and remove the latency from the beginning latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues duration := time.Duration(ms * float64(time.Millisecond)) latencies = append(latencies, duration.Round(time.Microsecond)) } + remainder = strings.TrimSpace(remainder[latMatch[1]:]) } } @@ -205,6 +215,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if ip, err := netip.ParseAddr(parts[0]); err == nil { hopIP = ip } + remainder = strings.TrimSpace(strings.Join(parts[1:], " ")) } } @@ -216,6 +227,7 @@ func ParseTraceroute(output string) (Traceroute, error) { if len(match) > 1 { // Remove '<' prefix if present (e.g., "<1 ms") latStr := strings.TrimPrefix(match[1], "<") + ms, err := strconv.ParseFloat(latStr, 64) if err == nil { // Round to nearest microsecond to avoid floating point precision issues @@ -280,11 +292,13 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri if key == "" { return "unknown-node" } + keyPrefix := key if len(key) > 8 { keyPrefix = key[:8] } - return fmt.Sprintf("node-%s", keyPrefix) + + return "node-" + keyPrefix } lowercased := strings.ToLower(hostinfo.Hostname) diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go index 33f27b7a..a064a852 100644 --- a/hscontrol/util/util_test.go +++ b/hscontrol/util/util_test.go @@ -180,6 +180,7 @@ Success.`, if err != nil { t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr) } + if gotURL.String() != tt.wantURL { t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL) } @@ -1066,6 +1067,7 @@ func TestEnsureHostname(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.want, "invalid-") { @@ -1103,9 +1105,11 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "test-node" { t.Errorf("hostname = %v, want test-node", hi.Hostname) } + if hi.OS != "linux" { t.Errorf("OS = %v, want linux", hi.OS) } @@ -1147,6 +1151,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "node-nkey1234" { t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname) } @@ -1162,6 +1167,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1179,6 +1185,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "unknown-node" { t.Errorf("hostname = %v, want unknown-node", hi.Hostname) } @@ -1200,18 +1207,23 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if hi.Hostname != "test" { t.Errorf("hostname = %v, want test", hi.Hostname) } + if hi.OS != "windows" { t.Errorf("OS = %v, want windows", hi.OS) } + if hi.OSVersion != "10.0.19044" { t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion) } + if hi.DeviceModel != "test-device" { t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel) } + if hi.BackendLogID != "log123" { t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID) } @@ -1229,6 +1241,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { if hi == nil { t.Error("hostinfo should not be nil") } + if len(hi.Hostname) != 63 { t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) } @@ -1239,6 +1252,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() + gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) // For invalid hostnames, we just check the prefix since the random part varies if strings.HasPrefix(tt.wantHostname, "invalid-") { @@ -1265,6 +1279,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) { for i, hostname := range testCases { t.Run(cmp.Diff("", ""), func(t *testing.T) { hostinfo := &tailcfg.Hostinfo{Hostname: hostname} + result := EnsureHostname(hostinfo, "mkey", "nkey") if len(result) > 63 { t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go index 223e4c8b..825f3d17 100644 --- a/integration/api_auth_test.go +++ b/integration/api_auth_test.go @@ -35,6 +35,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -46,6 +47,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Create an API key using the CLI var validAPIKey string + assert.EventuallyWithT(t, func(ct *assert.CollectT) { apiKeyOutput, err := headscale.Execute( []string{ @@ -63,7 +65,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Get the API endpoint endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" // Create HTTP client client := &http.Client{ @@ -81,6 +83,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -99,6 +102,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should NOT contain user data after "Unauthorized" // This is the security bypass - if users array is present, auth was bypassed var jsonCheck map[string]any + jsonErr := json.Unmarshal(body, &jsonCheck) // If we can unmarshal JSON and it contains "users", that's the bypass @@ -132,6 +136,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -165,6 +170,7 @@ func TestAPIAuthenticationBypass(t *testing.T) { resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -193,10 +199,11 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Expected: Should return 200 with user data (this is the authorized case) req, err := http.NewRequest("GET", apiURL, nil) require.NoError(t, err) - req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey)) + req.Header.Set("Authorization", "Bearer "+validAPIKey) resp, err := client.Do(req) require.NoError(t, err) + defer resp.Body.Close() body, err := io.ReadAll(resp.Body) @@ -208,16 +215,19 @@ func TestAPIAuthenticationBypass(t *testing.T) { // Should be able to parse as protobuf JSON var response v1.ListUsersResponse + err = protojson.Unmarshal(body, &response) assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key") // Should contain our test users users := response.GetUsers() assert.Len(t, users, 3, "Should have 3 users") + userNames := make([]string, len(users)) for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "user1") assert.Contains(t, userNames, "user2") assert.Contains(t, userNames, "user3") @@ -234,6 +244,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -254,10 +265,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) endpoint := headscale.GetEndpoint() - apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + apiURL := endpoint + "/api/v1/user" t.Run("Curl_NoAuth", func(t *testing.T) { // Execute curl from inside the headscale container without auth @@ -274,16 +286,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Parse the output lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + var ( + httpCode string + responseBody string + ) + + var responseBodySb295 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb295.WriteString(line) } } + responseBody += responseBodySb295.String() // Should return 401 assert.Equal(t, "401", httpCode, @@ -320,16 +337,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody string + var ( + httpCode string + responseBody string + ) + + var responseBodySb344 strings.Builder for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { httpCode = after } else { - responseBody += line + responseBodySb344.WriteString(line) } } + responseBody += responseBodySb344.String() assert.Equal(t, "401", httpCode) assert.Contains(t, responseBody, "Unauthorized") @@ -346,7 +368,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { "curl", "-s", "-H", - fmt.Sprintf("Authorization: Bearer %s", validAPIKey), + "Authorization: Bearer " + validAPIKey, "-w", "\nHTTP_CODE:%{http_code}", apiURL, @@ -355,8 +377,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { require.NoError(t, err) lines := strings.Split(curlOutput, "\n") - var httpCode string - var responseBody strings.Builder + + var ( + httpCode string + responseBody strings.Builder + ) for _, line := range lines { if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { @@ -372,8 +397,10 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) { // Should contain user data var response v1.ListUsersResponse + err = protojson.Unmarshal([]byte(responseBody.String()), &response) assert.NoError(t, err, "Response should be valid protobuf JSON") + users := response.GetUsers() assert.Len(t, users, 2, "Should have 2 users") }) @@ -391,6 +418,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -420,11 +448,12 @@ func TestGRPCAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) // Get the gRPC endpoint // For gRPC, we need to use the hostname and port 50443 - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" t.Run("gRPC_NoAPIKey", func(t *testing.T) { // Test 1: Try to use CLI without API key (should fail) @@ -487,6 +516,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) assert.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") @@ -495,6 +525,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) { for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "grpcuser1") assert.Contains(t, userNames, "grpcuser2") }) @@ -513,6 +544,7 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -540,9 +572,10 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) { }, ) require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) - grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + grpcAddress := headscale.GetHostname() + ":50443" // Create a config file for testing configWithoutKey := fmt.Sprintf(` @@ -643,6 +676,7 @@ cli: // CLI outputs the users array directly, not wrapped in ListUsersResponse // Parse as JSON array (CLI uses json.Marshal, not protojson) var users []*v1.User + err = json.Unmarshal([]byte(output), &users) assert.NoError(t, err, "Response should be valid JSON array") assert.Len(t, users, 2, "Should have 2 users") @@ -651,6 +685,7 @@ cli: for i, u := range users { userNames[i] = u.GetName() } + assert.Contains(t, userNames, "cliuser1") assert.Contains(t, userNames, "cliuser2") }) diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go index 9cf352bb..47c55a37 100644 --- a/integration/auth_key_test.go +++ b/integration/auth_key_test.go @@ -30,6 +30,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -68,18 +69,24 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -110,6 +117,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -147,6 +155,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after relogin") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -200,6 +209,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, nodeCountBeforeLogout) @@ -254,10 +264,14 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) @@ -300,9 +314,11 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { } var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -322,15 +338,18 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created // for the new user. The original nodes remain with the original user. var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() @@ -351,6 +370,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -376,11 +396,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { // assertClientsState(t, allClients) clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { ips, err := client.IPs() if err != nil { t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) } + clientIPs[client] = ips } @@ -394,10 +416,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) - var listNodes []*v1.Node - var nodeCountBeforeLogout int + var ( + listNodes []*v1.Node + nodeCountBeforeLogout int + ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, listNodes, len(allClients)) diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index c1d066f8..18c5c3a9 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -149,6 +149,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -176,6 +177,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { syncCompleteTime := time.Now() err = scenario.WaitForTailscaleSync() requireNoErrSync(t, err) + loginDuration := time.Since(syncCompleteTime) t.Logf("Login and sync completed in %v", loginDuration) @@ -207,6 +209,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { assert.EventuallyWithT(t, func(ct *assert.CollectT) { // Check each client's status individually to provide better diagnostics expiredCount := 0 + for _, client := range allClients { status, err := client.Status() if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) { @@ -356,6 +359,7 @@ func TestOIDC024UserCreation(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -413,6 +417,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -470,6 +475,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { oidcMockUser("user1", true), }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -508,6 +514,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -528,9 +535,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) @@ -538,14 +548,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -583,6 +598,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users after user2 login") assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -638,10 +654,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user2's node should be active after user switch var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { if node.GetUser().GetId() == 2 { // user2 activeUser2NodeID = types.NodeID(node.GetId()) t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -655,6 +673,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user2 node is online if node, exists := nodeStore[activeUser2NodeID]; exists { assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User2 node should be online after login") } @@ -747,6 +766,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -816,10 +836,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Security validation: Only user1's node should be active after relogin var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { if node.GetUser().GetId() == 1 { // user1 activeUser1NodeID = types.NodeID(node.GetId()) t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break } } @@ -833,6 +855,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after relogin") } @@ -907,6 +930,7 @@ func TestOIDCFollowUpUrl(t *testing.T) { time.Sleep(2 * time.Minute) var newUrl *url.URL + assert.EventuallyWithT(t, func(c *assert.CollectT) { st, err := ts.Status() assert.NoError(c, err) @@ -1103,6 +1127,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { oidcMockUser("user1", true), // Relogin with same user }, }) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -1142,6 +1167,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during initial validation") assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1162,9 +1188,12 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var initialNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + initialNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes during initial validation") assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes)) @@ -1172,14 +1201,19 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Collect expected node IDs for validation after user1 initial login expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { status := ts.MustStatus() assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) assert.NoError(ct, err, "Failed to parse node ID from status") }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) // Validate initial connection state for user1 @@ -1236,6 +1270,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { listUsers, err := headscale.ListUsers() assert.NoError(ct, err, "Failed to list users during final validation") assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ { Id: 1, @@ -1256,6 +1291,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") var finalNodes []*v1.Node + t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { finalNodes, err = headscale.ListNodes() @@ -1279,6 +1315,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Security validation: user1's node should be active after relogin activeUser1NodeID := types.NodeID(finalNodes[0].GetId()) + t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat)) require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() @@ -1287,6 +1324,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) { // Check user1 node is online if node, exists := nodeStore[activeUser1NodeID]; exists { assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin") + if node.IsOnline != nil { assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin") } @@ -1356,6 +1394,7 @@ func TestOIDCExpiryAfterRestart(t *testing.T) { // Verify initial expiry is set var initialExpiry time.Time + assert.EventuallyWithT(t, func(ct *assert.CollectT) { nodes, err := headscale.ListNodes() assert.NoError(ct, err) diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 5dd546f3..a102b493 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -67,6 +67,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -106,13 +107,16 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -152,6 +156,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after web flow logout") assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) @@ -226,6 +231,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -256,13 +262,16 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { validateInitialConnection(t, headscale, expectedNodes) var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + listNodes, err = headscale.ListNodes() assert.NoError(ct, err, "Failed to list nodes after initial web authentication") assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) }, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication") + nodeCountBeforeLogout := len(listNodes) t.Logf("node count before logout: %d", nodeCountBeforeLogout) @@ -313,9 +322,11 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { t.Logf("all clients logged back in as user1") var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user1Nodes, err = headscale.ListNodes("user1") assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin") assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes)) @@ -333,15 +344,18 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { // Validate that user2's old nodes still exist in database (but are expired/offline) // When CLI registration creates new nodes for user1, user2's old nodes remain var user2Nodes []*v1.Node + t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat)) assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + user2Nodes, err = headscale.ListNodes("user2") assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1") assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes)) }, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1") t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { assert.EventuallyWithT(t, func(ct *assert.CollectT) { status, err := client.Status() diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go index 60260bb1..d2aec30f 100644 --- a/integration/derp_verify_endpoint_test.go +++ b/integration/derp_verify_endpoint_test.go @@ -25,6 +25,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { // Generate random hostname for the headscale instance hash, err := util.GenerateRandomStringDNSSafe(6) require.NoError(t, err) + testName := "derpverify" hostname := fmt.Sprintf("hs-%s-%s", testName, hash) @@ -40,6 +41,7 @@ func TestDERPVerifyEndpoint(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -107,6 +109,7 @@ func DERPVerify( if err := c.Connect(t.Context()); err != nil { result = fmt.Errorf("client Connect: %w", err) } + if m, err := c.Recv(); err != nil { result = fmt.Errorf("client first Recv: %w", err) } else if v, ok := m.(derp.ServerInfoMessage); !ok { diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index c0c57a3e..88b2712c 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -34,6 +34,7 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { if opts.Labels == nil { opts.Labels = make(map[string]string) } + opts.Labels["hi.run-id"] = runID opts.Labels["hi.test-type"] = testType } diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index b09e0d40..4a172471 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -41,6 +41,7 @@ type buffer struct { func (b *buffer) Write(p []byte) (n int, err error) { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.Write(p) } @@ -49,6 +50,7 @@ func (b *buffer) Write(p []byte) (n int, err error) { func (b *buffer) String() string { b.mutex.Lock() defer b.mutex.Unlock() + return b.store.String() } diff --git a/integration/dockertestutil/logs.go b/integration/dockertestutil/logs.go index 7d104e43..d5911ca7 100644 --- a/integration/dockertestutil/logs.go +++ b/integration/dockertestutil/logs.go @@ -47,6 +47,7 @@ func SaveLog( } var stdout, stderr bytes.Buffer + err = WriteLog(pool, resource, &stdout, &stderr) if err != nil { return "", "", err diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 42483247..d07841f1 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -18,6 +18,7 @@ func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Ne if err != nil { return nil, fmt.Errorf("looking up network names: %w", err) } + if len(networks) == 0 { if _, err := pool.CreateNetwork(name); err == nil { // Create does not give us an updated version of the resource, so we need to @@ -90,6 +91,7 @@ func RandomFreeHostPort() (int, error) { // CleanUnreferencedNetworks removes networks that are not referenced by any containers. func CleanUnreferencedNetworks(pool *dockertest.Pool) error { filter := "name=hs-" + networks, err := pool.NetworksByName(filter) if err != nil { return fmt.Errorf("getting networks by filter %q: %w", filter, err) @@ -122,6 +124,7 @@ func CleanImagesInCI(pool *dockertest.Pool) error { } removedCount := 0 + for _, image := range images { // Only remove dangling (untagged) images to avoid forcing rebuilds // Dangling images have no RepoTags or only have ":" diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go index d8a77575..344d93f7 100644 --- a/integration/dsic/dsic.go +++ b/integration/dsic/dsic.go @@ -159,10 +159,12 @@ func New( } else { hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) } + tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname) if err != nil { return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err) } + dsic := &DERPServerInContainer{ version: version, hostname: hostname, @@ -185,6 +187,7 @@ func New( fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort) fmt.Fprintf(&cmdArgs, " --stun=true") fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort) + if dsic.withVerifyClientURL != "" { fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL) } @@ -214,11 +217,13 @@ func New( } var container *dockertest.Resource + buildOptions := &dockertest.BuildOptions{ Dockerfile: "Dockerfile.derper", ContextDir: dockerContextPath, BuildArgs: []docker.BuildArg{}, } + switch version { case "head": buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ @@ -249,6 +254,7 @@ func New( err, ) } + log.Printf("Created %s container\n", hostname) dsic.container = container @@ -259,12 +265,14 @@ func New( return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } } + if len(dsic.tlsCert) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert) if err != nil { return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } } + if len(dsic.tlsKey) != 0 { err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey) if err != nil { diff --git a/integration/helpers.go b/integration/helpers.go index 5acf4729..4a00342c 100644 --- a/integration/helpers.go +++ b/integration/helpers.go @@ -3,6 +3,7 @@ package integration import ( "bufio" "bytes" + "errors" "fmt" "io" "net/netip" @@ -47,7 +48,7 @@ const ( TimestampFormatRunID = "20060102-150405" ) -// NodeSystemStatus represents the status of a node across different systems +// NodeSystemStatus represents the status of a node across different systems. type NodeSystemStatus struct { Batcher bool BatcherConnCount int @@ -104,7 +105,7 @@ func requireNoErrLogout(t *testing.T, err error) { require.NoError(t, err, "failed to log out tailscale nodes") } -// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes. func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { t.Helper() @@ -113,8 +114,10 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod status := client.MustStatus() nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) } + return expectedNodes } @@ -148,15 +151,17 @@ func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNode } // requireAllClientsOnline validates that all nodes are online/offline across all headscale systems -// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems. func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() startTime := time.Now() + stateStr := "offline" if expectedOnline { stateStr = "online" } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) if expectedOnline { @@ -171,15 +176,17 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) } -// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state. func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { t.Helper() var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { // Get batcher state debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } @@ -187,6 +194,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get map responses mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -194,6 +202,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -264,6 +273,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if id == nodeID { continue // Skip self-references } + expectedPeerMaps++ if online, exists := peerMap[nodeID]; exists && online { @@ -278,6 +288,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer } } } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") // Update status with map response data @@ -301,10 +312,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer // Verify all systems show nodes in expected state and report failures allMatch := true + var failureReport strings.Builder ids := types.NodeIDs(maps.Keys(nodeStatus)) slices.Sort(ids) + for _, nodeID := range ids { status := nodeStatus[nodeID] systemsMatch := (status.Batcher == expectedOnline) && @@ -313,10 +326,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if !systemsMatch { allMatch = false + stateStr := "offline" if expectedOnline { stateStr = "online" } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) @@ -331,6 +346,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer t.Logf("Previous report:\n%s", prevReport) t.Logf("Current report:\n%s", failureReport.String()) t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() } @@ -344,11 +360,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer if expectedOnline { stateStr = "online" } + assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) }, timeout, 2*time.Second, message) } -// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components. func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { t.Helper() @@ -357,18 +374,22 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { debugInfo, err := headscale.DebugBatcher() assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { return } allBatcherOffline := true + for _, nodeID := range expectedNodes { nodeIDStr := fmt.Sprintf("%d", nodeID) if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) } } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") }, 15*time.Second, 1*time.Second, "batcher disconnection validation") @@ -377,20 +398,24 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } allNodeStoreOffline := true + for _, nodeID := range expectedNodes { if node, exists := nodeStore[nodeID]; exists { isOnline := node.IsOnline != nil && *node.IsOnline if isOnline { allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) } } } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") }, 20*time.Second, 1*time.Second, "nodestore offline validation") @@ -399,6 +424,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec require.EventuallyWithT(t, func(c *assert.CollectT) { mapResponses, err := headscale.GetAllMapReponses() assert.NoError(c, err, "Failed to get map responses") + if err != nil { return } @@ -411,6 +437,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec for nodeID := range onlineMap { if slices.Contains(expectedNodes, nodeID) { allMapResponsesOffline = false + assert.False(c, true, "Node %d should not appear in map responses", nodeID) } } @@ -421,13 +448,16 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec if id == nodeID { continue // Skip self-references } + if online, exists := peerMap[nodeID]; exists && online { allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) } } } } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") }, 60*time.Second, 2*time.Second, "map response propagation validation") @@ -447,6 +477,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Get nodestore state nodeStore, err := headscale.DebugNodeStore() assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { return } @@ -461,12 +492,14 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe for _, nodeID := range expectedNodes { node, exists := nodeStore[nodeID] assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { continue } // Validate that the node has Hostinfo assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -474,6 +507,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe // Validate that the node has NetInfo assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) continue @@ -524,6 +558,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { // Returns the total number of successful ping operations. func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { t.Helper() + success := 0 for _, client := range clients { @@ -545,6 +580,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts // for validating NAT traversal and relay functionality. Returns success count. func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { t.Helper() + success := 0 for _, client := range clients { @@ -602,9 +638,12 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) { for _, client := range clients { wg.Add(1) + c := client // Avoid loop pointer + go func() { defer wg.Done() + assertValidStatus(t, c) assertValidNetcheck(t, c) assertValidNetmap(t, c) @@ -635,6 +674,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { assert.NoError(c, err, "getting netmap for %q", client.Hostname()) assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) } @@ -653,6 +693,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) @@ -681,6 +722,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) { // and network map presence. This test is not suitable for ACL/partial connection tests. func assertValidStatus(t *testing.T, client TailscaleClient) { t.Helper() + status, err := client.Status(true) if err != nil { t.Fatalf("getting status for %q: %s", client.Hostname(), err) @@ -738,6 +780,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) { // which is essential for NAT traversal and connectivity in restricted networks. func assertValidNetcheck(t *testing.T, client TailscaleClient) { t.Helper() + report, err := client.Netcheck() if err != nil { t.Fatalf("getting status for %q: %s", client.Hostname(), err) @@ -792,6 +835,7 @@ func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { t.Helper() buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) if err != nil { t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) @@ -815,6 +859,7 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) scanner := bufio.NewScanner(in) { const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) scanner.Buffer(buff, len(buff)) scanner.Split(bufio.ScanLines) @@ -941,17 +986,20 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) { for _, client := range updated { isOriginal := false + for _, origClient := range original { if client.Hostname() == origClient.Hostname() { isOriginal = true break } } + if !isOriginal { return client, nil } } - return nil, fmt.Errorf("no new client found") + + return nil, errors.New("no new client found") } // AddAndLoginClient adds a new tailscale client to a user and logs it in. @@ -959,7 +1007,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) // 1. Creating a new node // 2. Finding the new node in the client list // 3. Getting the user to create a preauth key -// 4. Logging in the new node +// 4. Logging in the new node. func (s *Scenario) AddAndLoginClient( t *testing.T, username string, @@ -1037,5 +1085,6 @@ func (s *Scenario) MustAddAndLoginClient( client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...) require.NoError(t, err) + return client } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 202f2014..a08ee7af 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -725,12 +725,14 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { // Find the top-level directory to strip var topLevelDir string + firstPass := tar.NewReader(bytes.NewReader(tarData)) for { header, err := firstPass.Next() if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -747,6 +749,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -794,6 +797,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error { outFile.Close() return fmt.Errorf("failed to copy file contents: %w", err) } + outFile.Close() // Set file permissions @@ -844,10 +848,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Check if the database file exists and has a schema dbPath := "/tmp/integration_test_db.sqlite3" + fileInfo, err := t.Execute([]string{"ls", "-la", dbPath}) if err != nil { return fmt.Errorf("database file does not exist at %s: %w", dbPath, err) } + log.Printf("Database file info: %s", fileInfo) // Check if the database has any tables (schema) @@ -872,6 +878,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { if err == io.EOF { break } + if err != nil { return fmt.Errorf("failed to read tar header: %w", err) } @@ -886,6 +893,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { // Extract the first regular file we find if header.Typeflag == tar.TypeReg { dbPath := path.Join(savePath, t.hostname+".db") + outFile, err := os.Create(dbPath) if err != nil { return fmt.Errorf("failed to create database file: %w", err) @@ -893,6 +901,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { written, err := io.Copy(outFile, tarReader) outFile.Close() + if err != nil { return fmt.Errorf("failed to copy database file: %w", err) } @@ -1059,6 +1068,7 @@ func (t *HeadscaleInContainer) CreateUser( } var u v1.User + err = json.Unmarshal([]byte(result), &u) if err != nil { return nil, fmt.Errorf("failed to unmarshal user: %w", err) @@ -1195,6 +1205,7 @@ func (t *HeadscaleInContainer) ListNodes( users ...string, ) ([]*v1.Node, error) { var ret []*v1.Node + execUnmarshal := func(command []string) error { result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -1206,6 +1217,7 @@ func (t *HeadscaleInContainer) ListNodes( } var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) if err != nil { return fmt.Errorf("failed to unmarshal nodes: %w", err) @@ -1245,7 +1257,7 @@ func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error { "nodes", "delete", "--identifier", - fmt.Sprintf("%d", nodeID), + strconv.FormatUint(nodeID, 10), "--output", "json", "--force", @@ -1309,6 +1321,7 @@ func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { } var users []*v1.User + err = json.Unmarshal([]byte(result), &users) if err != nil { return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) @@ -1439,6 +1452,7 @@ func (h *HeadscaleInContainer) PID() (int, error) { if pidInt == 1 { continue } + pids = append(pids, pidInt) } @@ -1494,6 +1508,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) ( } var node *v1.Node + err = json.Unmarshal([]byte(result), &node) if err != nil { return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err) diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 4ddc7ae9..5604af32 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -28,6 +28,7 @@ func PeerSyncTimeout() time.Duration { if util.IsCI() { return 120 * time.Second } + return 60 * time.Second } @@ -205,6 +206,7 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type res := make(map[types.NodeID]map[types.NodeID]bool) for nid, mrs := range all { res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { for _, peer := range mr.Peers { if peer.Online != nil { @@ -225,5 +227,6 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type } } } + return res } diff --git a/integration/route_test.go b/integration/route_test.go index 6d0a1be2..b6fc8d85 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -48,6 +48,7 @@ func TestEnablingRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -90,6 +91,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -126,6 +128,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route approvals to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) @@ -148,9 +151,11 @@ func TestEnablingRoutes(t *testing.T) { assert.NotNil(c, peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) } } @@ -171,6 +176,7 @@ func TestEnablingRoutes(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -270,6 +276,7 @@ func TestHASubnetRouterFailover(t *testing.T) { prefp, err := scenario.SubnetOfNetwork("usernet1") require.NoError(t, err) + pref := *prefp t.Logf("usernet1 prefix: %s", pref.String()) @@ -289,6 +296,7 @@ func TestHASubnetRouterFailover(t *testing.T) { slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { statusA := a.MustStatus() statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) @@ -308,6 +316,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String()) t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String()) t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time") + for _, client := range allClients[:3] { command := []string{ "tailscale", @@ -323,6 +332,7 @@ func TestHASubnetRouterFailover(t *testing.T) { // Wait for route configuration changes after advertising routes var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { nodes, err = headscale.ListNodes() assert.NoError(c, err) @@ -362,10 +372,12 @@ func TestHASubnetRouterFailover(t *testing.T) { checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) { if t.Failed() { t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat)) + status, err := client.Status() if err == nil { printCurrentRouteMap(t, xmaps.Values(status.Peer)...) } + t.FailNow() } } @@ -384,6 +396,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String()) t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes") t.Logf(" Expected: Client can access webservice through router 1 only") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -454,10 +467,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1") @@ -481,6 +496,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)") t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter2.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -492,6 +508,7 @@ func TestHASubnetRouterFailover(t *testing.T) { nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 6) + if len(nodes) >= 3 { requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) @@ -567,10 +584,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode") @@ -596,6 +615,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)") t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY") t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers") + _, err = headscale.ApproveRoutes( MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{pref}, @@ -670,12 +690,14 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") var expectedIP netip.Addr + for _, ip := range ips { if ip.Is4() { expectedIP = ip break } } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") assertTracerouteViaIPWithCollect(c, tr, expectedIP) @@ -752,10 +774,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover") @@ -823,10 +847,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover") @@ -851,6 +877,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)") t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)") t.Logf(" Expected: HA is restored with 2 routers available") + err = subRouter1.Up() require.NoError(t, err) @@ -900,10 +927,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery") @@ -930,6 +959,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname()) t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname()) t.Logf(" Expected: Full HA restored with all 3 routers online") + err = subRouter2.Up() require.NoError(t, err) @@ -980,10 +1010,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter3.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery") @@ -1065,10 +1097,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable") @@ -1151,10 +1185,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable") @@ -1180,6 +1216,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname()) t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname()) t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY") + r1Node := MustFindNode(subRouter1.Hostname(), nodes) _, err = headscale.ApproveRoutes( r1Node.GetId(), @@ -1235,10 +1272,12 @@ func TestHASubnetRouterFailover(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := subRouter2.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable") @@ -1264,6 +1303,7 @@ func TestHASubnetRouterFailover(t *testing.T) { t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname()) t.Logf(" Expected: Routers 1 & 3 are both STANDBY") t.Logf(" Expected: Full HA restored with all 3 routers available") + r3Node := MustFindNode(subRouter3.Hostname(), nodes) _, err = headscale.ApproveRoutes( r3Node.GetId(), @@ -1313,6 +1353,7 @@ func TestSubnetRouteACL(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1360,6 +1401,7 @@ func TestSubnetRouteACL(t *testing.T) { slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { statusA := a.MustStatus() statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) @@ -1389,15 +1431,20 @@ func TestSubnetRouteACL(t *testing.T) { // Wait for route advertisements to propagate to the server var nodes []*v1.Node + require.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) // Find the node that should have the route by checking node IDs - var routeNode *v1.Node - var otherNode *v1.Node + var ( + routeNode *v1.Node + otherNode *v1.Node + ) + for _, node := range nodes { nodeIDStr := strconv.FormatUint(node.GetId(), 10) if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { @@ -1460,6 +1507,7 @@ func TestSubnetRouteACL(t *testing.T) { srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + if srs1PeerStatus == nil { return } @@ -1570,6 +1618,7 @@ func TestEnablingExitRoutes(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario") defer scenario.ShutdownAssertNoPanics(t) @@ -1591,8 +1640,10 @@ func TestEnablingExitRoutes(t *testing.T) { requireNoErrSync(t, err) var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1650,6 +1701,7 @@ func TestEnablingExitRoutes(t *testing.T) { peerStatus := status.Peer[peerKey] assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) @@ -1680,6 +1732,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1710,10 +1763,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1730,6 +1785,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1760,6 +1816,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { // Wait for route state changes to propagate to nodes assert.EventuallyWithT(t, func(c *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -1777,6 +1834,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) } }, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") @@ -1803,10 +1861,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := user2c.Traceroute(webip) assert.NoError(c, err) + ip, err := user1c.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for user1c") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router") } @@ -1827,6 +1887,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -1854,10 +1915,12 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { if s.User[s.Self.UserID].LoginName == "user1@test.no" { user1c = c } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { user2c = c } } + require.NotNil(t, user1c) require.NotNil(t, user2c) @@ -1874,6 +1937,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { // Wait for route advertisements to propagate to NodeStore assert.EventuallyWithT(t, func(ct *assert.CollectT) { var err error + nodes, err = headscale.ListNodes() assert.NoError(ct, err) assert.Len(ct, nodes, 2) @@ -1956,6 +2020,7 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { return node } } + panic("node not found") } @@ -2239,10 +2304,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } scenario, err := NewScenario(tt.spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) var nodes []*v1.Node + opts := []hsic.Option{ hsic.WithTestName("autoapprovemulti"), hsic.WithEmbeddedDERPServerOnly(), @@ -2298,6 +2365,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the Docker network route to the auto-approvers // Keep existing auto-approvers (like bigRoute) in place var approvers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): approvers = append(approvers, tagApprover(tt.approver)) @@ -2366,6 +2434,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { } else { pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) } + require.NoError(t, err) err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) @@ -2404,6 +2473,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { slices.SortStableFunc(allClients, func(a, b TailscaleClient) int { statusA := a.MustStatus() statusB := b.MustStatus() + return cmp.Compare(statusA.Self.ID, statusB.Self.ID) }) @@ -2456,11 +2526,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) routerPeerFound := false + for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] if peerStatus.ID == routerUsernet1ID.StableID() { routerPeerFound = true + t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", peerStatus.HostName, peerStatus.ID, @@ -2468,9 +2540,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { peerStatus.PrimaryRoutes) assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2507,10 +2581,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router") @@ -2547,9 +2623,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2569,10 +2647,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change") @@ -2606,6 +2686,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { // Add the route back to the auto approver in the policy, the route should // now become available again. var newApprovers policyv2.AutoApprovers + switch { case strings.HasPrefix(tt.approver, "tag:"): newApprovers = append(newApprovers, tagApprover(tt.approver)) @@ -2639,9 +2720,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2661,10 +2744,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := client.Traceroute(webip) assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval") @@ -2700,11 +2785,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "2" { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2742,9 +2829,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.ID == routerUsernet1ID.StableID() { assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else { requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) @@ -2782,6 +2871,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) { if peerStatus.PrimaryRoutes != nil { assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) } else if peerStatus.ID == "3" { requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) @@ -2816,10 +2906,12 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int { func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) { t.Logf("== Current routing map ==") slices.SortFunc(routers, SortPeerStatus) + for _, router := range routers { got := filterNonRoutes(router) t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID) t.Logf(" AllowedIPs: %v", got) + if router.PrimaryRoutes != nil { t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice()) } @@ -2832,6 +2924,7 @@ func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix { if tsaddr.IsExitRoute(p) { return true } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) }) } @@ -2883,6 +2976,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { } scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) defer scenario.ShutdownAssertNoPanics(t) @@ -3023,6 +3117,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { // List nodes and verify the router has 3 available routes var err error + nodes, err := headscale.NodesByUser() assert.NoError(c, err) assert.Len(c, nodes, 2) @@ -3058,10 +3153,12 @@ func TestSubnetRouteACLFiltering(t *testing.T) { assert.EventuallyWithT(t, func(c *assert.CollectT) { tr, err := nodeClient.Traceroute(webip) assert.NoError(c, err) + ip, err := routerClient.IPv4() if !assert.NoError(c, err, "failed to get IPv4 for routerClient") { return } + assertTracerouteViaIPWithCollect(c, tr, ip) }, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router") } diff --git a/integration/scenario.go b/integration/scenario.go index 35fee73e..0108a1de 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -191,9 +191,11 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { } var userToNetwork map[string]*dockertest.Network + if spec.Networks != nil || len(spec.Networks) != 0 { for name, users := range s.spec.Networks { networkName := testHashPrefix + "-" + name + network, err := s.AddNetwork(networkName) if err != nil { return nil, err @@ -203,6 +205,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if n2, ok := userToNetwork[user]; ok { return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) } + mak.Set(&userToNetwork, user, network) } } @@ -219,6 +222,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if err != nil { return nil, err } + mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc)) } } @@ -230,6 +234,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) { if spec.OIDCAccessTTL != 0 { ttl = spec.OIDCAccessTTL } + err = s.runMockOIDC(ttl, spec.OIDCUsers) if err != nil { return nil, err @@ -268,6 +273,7 @@ func (s *Scenario) Networks() []*dockertest.Network { if len(s.networks) == 0 { panic("Scenario.Networks called with empty network list") } + return xmaps.Values(s.networks) } @@ -337,6 +343,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { for userName, user := range s.users { for _, client := range user.Clients { log.Printf("removing client %s in user %s", client.Hostname(), userName) + stdoutPath, stderrPath, err := client.Shutdown() if err != nil { log.Printf("failed to tear down client: %s", err) @@ -353,6 +360,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { } } } + s.mu.Unlock() for _, derp := range s.derpServers { @@ -373,6 +381,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { if s.mockOIDC.r != nil { s.mockOIDC.r.Close() + if err := s.mockOIDC.r.Close(); err != nil { log.Printf("failed to tear down oidc server: %s", err) } @@ -552,6 +561,7 @@ func (s *Scenario) CreateTailscaleNode( s.mu.Lock() defer s.mu.Unlock() + opts = append(opts, tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), @@ -591,6 +601,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) error { if user, ok := s.users[userStr]; ok { var versions []string + for i := range count { version := requestedVersion if requestedVersion == "all" { @@ -749,10 +760,12 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat for _, client := range user.Clients { c := client expectedCount := expectedPeers + user.syncWaitGroup.Go(func() error { return c.WaitForPeers(expectedCount, timeout, retryInterval) }) } + if err := user.syncWaitGroup.Wait(); err != nil { allErrors = append(allErrors, err) } @@ -871,6 +884,7 @@ func (s *Scenario) createHeadscaleEnvWithTags( } else { key, err = s.CreatePreAuthKey(u.GetId(), true, false) } + if err != nil { return err } @@ -887,9 +901,11 @@ func (s *Scenario) createHeadscaleEnvWithTags( func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { for _, client := range user.Clients { tsc := client + user.joinWaitGroup.Go(func() error { loginURL, err := tsc.LoginWithURL(loginServer) if err != nil { @@ -945,6 +961,7 @@ func newDebugJar() (*debugJar, error) { if err != nil { return nil, err } + return &debugJar{ inner: jar, store: make(map[string]map[string]map[string]*http.Cookie), @@ -961,20 +978,25 @@ func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) { if c == nil || c.Name == "" { continue } + domain := c.Domain if domain == "" { domain = u.Hostname() } + path := c.Path if path == "" { path = "/" } + if _, ok := j.store[domain]; !ok { j.store[domain] = make(map[string]map[string]*http.Cookie) } + if _, ok := j.store[domain][path]; !ok { j.store[domain][path] = make(map[string]*http.Cookie) } + j.store[domain][path][c.Name] = copyCookie(c) } } @@ -989,8 +1011,10 @@ func (j *debugJar) Dump(w io.Writer) { for domain, paths := range j.store { fmt.Fprintf(w, "Domain: %s\n", domain) + for path, byName := range paths { fmt.Fprintf(w, " Path: %s\n", path) + for _, c := range byName { fmt.Fprintf( w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n", @@ -1054,7 +1078,9 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f } log.Printf("%s logging in with url: %s", hostname, loginURL.String()) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) if err != nil { return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err) @@ -1066,6 +1092,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f return http.ErrUseLastResponse } } + defer func() { hc.CheckRedirect = originalRedirect }() @@ -1080,6 +1107,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f if err != nil { return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err) } + body := string(bodyBytes) var redirectURL *url.URL @@ -1126,6 +1154,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { if len(keySep) != 2 { return errParseAuthPage } + key := keySep[1] key = strings.SplitN(key, " ", 2)[0] log.Printf("registering node %s", key) @@ -1154,6 +1183,7 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error noTls := &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint } + resp, err := noTls.RoundTrip(req) if err != nil { return nil, err @@ -1361,6 +1391,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse if err != nil { log.Fatalf("could not find an open port: %s", err) } + portNotation := fmt.Sprintf("%d/tcp", port) hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) @@ -1421,6 +1452,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse ipAddr := s.mockOIDC.r.GetIPInNetwork(network) log.Println("Waiting for headscale mock oidc to be ready for tests") + hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) if err := s.pool.Retry(func() error { @@ -1468,7 +1500,6 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { // log.Fatalf("could not find an open port: %s", err) // } // portNotation := fmt.Sprintf("%d/tcp", port) - hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) hostname := "hs-webservice-" + hash diff --git a/integration/scenario_test.go b/integration/scenario_test.go index 1e2a151a..71998fca 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -35,6 +35,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t) @@ -83,6 +84,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) defer scenario.ShutdownAssertNoPanics(t)