all: apply golangci-lint auto-fixes

Apply auto-fixes from golangci-lint for the following linters:
- wsl_v5: whitespace formatting and blank line adjustments
- godot: add periods to comment sentences
- nlreturn: add newlines before return statements
- perfsprint: optimize fmt.Sprintf to more efficient alternatives

Also add missing imports (errors, encoding/hex) where auto-fix
added new code patterns that require them.
This commit is contained in:
Kristoffer Dalby 2026-01-20 14:37:24 +00:00
parent 3675b65504
commit ad7669a2d4
93 changed files with 1262 additions and 155 deletions

View file

@ -73,6 +73,7 @@ func mockOIDC() error {
}
var users []mockoidc.MockUser
err := json.Unmarshal([]byte(userStr), &users)
if err != nil {
return fmt.Errorf("unmarshalling users: %w", err)
@ -137,6 +138,7 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Info().Msgf("Request: %+v", r)
h.ServeHTTP(w, r)
if r.Response != nil {
log.Info().Msgf("Response: %+v", r.Response)
}

View file

@ -29,13 +29,16 @@ func init() {
if err := setPolicy.MarkFlagRequired("file"); err != nil {
log.Fatal().Err(err).Msg("")
}
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
policyCmd.AddCommand(setPolicy)
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
log.Fatal().Err(err).Msg("")
}
policyCmd.AddCommand(checkPolicy)
}

View file

@ -80,6 +80,7 @@ func initConfig() {
Repository: "headscale",
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
}
res, err := latest.Check(githubTag, versionInfo.Version)
if err == nil && res.Outdated {
//nolint
@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool {
return true
}
}
return false
}

View file

@ -23,6 +23,7 @@ func usernameAndIDFlag(cmd *cobra.Command) {
// If both are empty, it will exit the program with an error.
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
username, _ := cmd.Flags().GetString("name")
identifier, _ := cmd.Flags().GetInt64("identifier")
if username == "" && identifier < 0 {
err := errors.New("--name or --identifier flag is required")

View file

@ -69,8 +69,10 @@ func killTestContainers(ctx context.Context) error {
}
removed := 0
for _, cont := range containers {
shouldRemove := false
for _, name := range cont.Names {
if strings.Contains(name, "headscale-test-suite") ||
strings.Contains(name, "hs-") ||
@ -259,8 +261,10 @@ func cleanOldImages(ctx context.Context) error {
}
removed := 0
for _, img := range images {
shouldRemove := false
for _, tag := range img.RepoTags {
if strings.Contains(tag, "hs-") ||
strings.Contains(tag, "headscale-integration") ||
@ -302,6 +306,7 @@ func cleanCacheVolume(ctx context.Context) error {
defer cli.Close()
volumeName := "hs-integration-go-cache"
err = cli.VolumeRemove(ctx, volumeName, true)
if err != nil {
if errdefs.IsNotFound(err) {

View file

@ -60,6 +60,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
if config.Verbose {
log.Printf("Running pre-test cleanup...")
}
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
log.Printf("Warning: pre-test cleanup failed: %v", err)
}
@ -95,13 +96,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Start stats collection for container resource monitoring (if enabled)
var statsCollector *StatsCollector
if config.Stats {
var err error
statsCollector, err = NewStatsCollector()
if err != nil {
if config.Verbose {
log.Printf("Warning: failed to create stats collector: %v", err)
}
statsCollector = nil
}
@ -140,6 +144,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
if len(violations) > 0 {
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
log.Printf("=================================")
for _, violation := range violations {
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
@ -347,6 +352,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
maxWaitTime := 10 * time.Second
checkInterval := 500 * time.Millisecond
timeout := time.After(maxWaitTime)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
@ -356,6 +362,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose {
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
}
return nil
case <-ticker.C:
allFinalized := true
@ -366,12 +373,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose {
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
}
continue
}
// Check if container is in a final state
if !isContainerFinalized(inspect.State) {
allFinalized = false
if verbose {
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
}
@ -384,6 +393,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose {
log.Printf("All test containers finalized, ready for artifact extraction")
}
return nil
}
}
@ -403,10 +413,12 @@ func findProjectRoot(startPath string) string {
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
return current
}
parent := filepath.Dir(current)
if parent == current {
return startPath
}
current = parent
}
}
@ -416,6 +428,7 @@ func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
@ -435,6 +448,7 @@ func createDockerClient() (*client.Client, error) {
}
var clientOpts []client.Opt
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
if contextInfo != nil {
@ -444,6 +458,7 @@ func createDockerClient() (*client.Client, error) {
if runConfig.Verbose {
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
}
clientOpts = append(clientOpts, client.WithHost(host))
}
}
@ -460,6 +475,7 @@ func createDockerClient() (*client.Client, error) {
// getCurrentDockerContext retrieves the current Docker context information.
func getCurrentDockerContext() (*DockerContext, error) {
cmd := exec.Command("docker", "context", "inspect")
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("failed to get docker context: %w", err)
@ -491,6 +507,7 @@ func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageNa
if client.IsErrNotFound(err) {
return false, nil
}
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
}
@ -509,6 +526,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
if verbose {
log.Printf("Image %s is available locally", imageName)
}
return nil
}
@ -533,6 +551,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
if err != nil {
return fmt.Errorf("failed to read pull output: %w", err)
}
log.Printf("Image %s pulled successfully", imageName)
}
@ -547,9 +566,11 @@ func listControlFiles(logsDir string) {
return
}
var logFiles []string
var dataFiles []string
var dataDirs []string
var (
logFiles []string
dataFiles []string
dataDirs []string
)
for _, entry := range entries {
name := entry.Name()
@ -578,6 +599,7 @@ func listControlFiles(logsDir string) {
if len(logFiles) > 0 {
log.Printf("Headscale logs:")
for _, file := range logFiles {
log.Printf(" %s", file)
}
@ -585,9 +607,11 @@ func listControlFiles(logsDir string) {
if len(dataFiles) > 0 || len(dataDirs) > 0 {
log.Printf("Headscale data:")
for _, file := range dataFiles {
log.Printf(" %s", file)
}
for _, dir := range dataDirs {
log.Printf(" %s/", dir)
}
@ -612,6 +636,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
extractedCount := 0
for _, cont := range currentTestContainers {
// Extract container logs and tar files
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
@ -622,6 +647,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
if verbose {
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
}
extractedCount++
}
}
@ -645,11 +671,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
// Find the test container to get its run ID label
var runID string
for _, cont := range containers {
if cont.ID == testContainerID {
if cont.Labels != nil {
runID = cont.Labels["hi.run-id"]
}
break
}
}

View file

@ -266,6 +266,7 @@ func checkGoInstallation() DoctorResult {
}
cmd := exec.Command("go", "version")
output, err := cmd.Output()
if err != nil {
return DoctorResult{
@ -287,6 +288,7 @@ func checkGoInstallation() DoctorResult {
// checkGitRepository verifies we're in a git repository.
func checkGitRepository() DoctorResult {
cmd := exec.Command("git", "rev-parse", "--git-dir")
err := cmd.Run()
if err != nil {
return DoctorResult{
@ -316,6 +318,7 @@ func checkRequiredFiles() DoctorResult {
}
var missingFiles []string
for _, file := range requiredFiles {
cmd := exec.Command("test", "-e", file)
if err := cmd.Run(); err != nil {
@ -350,6 +353,7 @@ func displayDoctorResults(results []DoctorResult) {
for _, result := range results {
var icon string
switch result.Status {
case "PASS":
icon = "✅"

View file

@ -82,9 +82,11 @@ func cleanAll(ctx context.Context) error {
if err := killTestContainers(ctx); err != nil {
return err
}
if err := pruneDockerNetworks(ctx); err != nil {
return err
}
if err := cleanOldImages(ctx); err != nil {
return err
}

View file

@ -48,6 +48,7 @@ func runIntegrationTest(env *command.Env) error {
if runConfig.Verbose {
log.Printf("Running pre-flight system checks...")
}
if err := runDoctorCheck(env.Context()); err != nil {
return fmt.Errorf("pre-flight checks failed: %w", err)
}
@ -94,8 +95,10 @@ func detectGoVersion() string {
// splitLines splits a string into lines without using strings.Split.
func splitLines(s string) []string {
var lines []string
var current string
var (
lines []string
current string
)
for _, char := range s {
if char == '\n' {

View file

@ -71,10 +71,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
// Start monitoring existing containers
sc.wg.Add(1)
go sc.monitorExistingContainers(ctx, runID, verbose)
// Start Docker events monitoring for new containers
sc.wg.Add(1)
go sc.monitorDockerEvents(ctx, runID, verbose)
if verbose {
@ -88,10 +90,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
func (sc *StatsCollector) StopCollection() {
// Check if already stopped without holding lock
sc.mutex.RLock()
if !sc.collectionStarted {
sc.mutex.RUnlock()
return
}
sc.mutex.RUnlock()
// Signal stop to all goroutines
@ -115,6 +119,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
if verbose {
log.Printf("Failed to list existing containers: %v", err)
}
return
}
@ -168,6 +173,7 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
if verbose {
log.Printf("Error in Docker events stream: %v", err)
}
return
}
}
@ -214,6 +220,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
}
sc.wg.Add(1)
go sc.collectStatsForContainer(ctx, containerID, verbose)
}
@ -227,11 +234,13 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
if verbose {
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
}
return
}
defer statsResponse.Body.Close()
decoder := json.NewDecoder(statsResponse.Body)
var prevStats *container.Stats
for {
@ -247,6 +256,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
if err.Error() != "EOF" && verbose {
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
}
return
}
@ -262,8 +272,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
// Store the sample (skip first sample since CPU calculation needs previous stats)
if prevStats != nil {
// Get container stats reference without holding the main mutex
var containerStats *ContainerStats
var exists bool
var (
containerStats *ContainerStats
exists bool
)
sc.mutex.RLock()
containerStats, exists = sc.containers[containerID]
@ -332,10 +344,12 @@ type StatsSummary struct {
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
// Take snapshot of container references without holding main lock long
sc.mutex.RLock()
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
for _, containerStats := range sc.containers {
containerRefs = append(containerRefs, containerStats)
}
sc.mutex.RUnlock()
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
@ -393,9 +407,11 @@ func calculateStatsSummary(values []float64) StatsSummary {
if value < min {
min = value
}
if value > max {
max = value
}
sum += value
}
@ -435,6 +451,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
}
summaries := sc.GetSummary()
var violations []MemoryViolation
for _, summary := range summaries {

View file

@ -2,6 +2,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"os"
@ -40,7 +41,7 @@ func main() {
// runIntegrationTest executes the integration test workflow.
func runOnline(env *command.Env) error {
if mapConfig.Directory == "" {
return fmt.Errorf("directory is required")
return errors.New("directory is required")
}
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
@ -57,5 +58,6 @@ func runOnline(env *command.Env) error {
os.Stderr.Write(out)
os.Stderr.Write([]byte("\n"))
return nil
}

View file

@ -142,6 +142,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if !ok {
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
return
}
@ -157,10 +158,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app.ephemeralGC = ephemeralGC
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
oidcProvider, err := NewAuthProviderOIDC(
ctx,
&app,
@ -177,6 +180,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
authProvider = oidcProvider
}
}
app.authProvider = authProvider
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
@ -251,9 +255,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
lastExpiryCheck := time.Unix(0, 0)
derpTickerChan := make(<-chan time.Time)
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
defer derpTicker.Stop()
derpTickerChan = derpTicker.C
}
@ -271,8 +277,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
return
case <-expireTicker.C:
var expiredNodeChanges []change.Change
var changed bool
var (
expiredNodeChanges []change.Change
changed bool
)
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
@ -287,11 +295,13 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
case <-derpTickerChan:
log.Info().Msg("Fetching DERPMap updates")
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
if err != nil {
return nil, err
}
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion()
derpMap.Regions[region.RegionID] = &region
@ -303,6 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
continue
}
h.state.SetDERPMap(derpMap)
h.Change(change.DERPMap())
@ -311,6 +322,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
if !ok {
continue
}
h.cfg.TailcfgDNSConfig.ExtraRecords = records
h.Change(change.ExtraRecords())
@ -390,6 +402,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writeUnauthorized := func(statusCode int) {
writer.WriteHeader(statusCode)
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
log.Error().Err(err).Msg("writing HTTP response failed")
}
@ -486,6 +499,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
var err error
capver.CanOldCodeBeCleanedUp()
if profilingEnabled {
@ -512,6 +526,7 @@ func (h *Headscale) Serve() error {
Msg("Clients with a lower minimum version will be rejected")
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
h.mapBatcher.Start()
defer h.mapBatcher.Close()
@ -545,6 +560,7 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will
// be cancelled.
go h.ephemeralGC.Start()
ephmNodes := h.state.ListEphemeralNodes()
for _, node := range ephmNodes.All() {
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
@ -555,7 +571,9 @@ func (h *Headscale) Serve() error {
if err != nil {
return fmt.Errorf("setting up extrarecord manager: %w", err)
}
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
go h.extraRecordMan.Run()
defer h.extraRecordMan.Close()
}
@ -564,6 +582,7 @@ func (h *Headscale) Serve() error {
// records updates
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
defer scheduleCancel()
go h.scheduledTasks(scheduleCtx)
if zl.GlobalLevel() == zl.TraceLevel {
@ -751,7 +770,6 @@ func (h *Headscale) Serve() error {
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
}
var tailsqlContext context.Context
if tailsqlEnabled {
if h.cfg.Database.Type != types.DatabaseSqlite {
@ -863,6 +881,7 @@ func (h *Headscale) Serve() error {
// Close state connections
info("closing state and database")
err = h.state.Close()
if err != nil {
log.Error().Err(err).Msg("failed to close state")

View file

@ -51,6 +51,7 @@ func (h *Headscale) handleRegister(
if err != nil {
return nil, fmt.Errorf("handling logout: %w", err)
}
if resp != nil {
return resp, nil
}
@ -131,7 +132,7 @@ func (h *Headscale) handleRegister(
}
// handleLogout checks if the [tailcfg.RegisterRequest] is a
// logout attempt from a node. If the node is not attempting to
// logout attempt from a node. If the node is not attempting to.
func (h *Headscale) handleLogout(
node types.NodeView,
req tailcfg.RegisterRequest,
@ -158,6 +159,7 @@ func (h *Headscale) handleLogout(
Interface("reg.req", req).
Bool("unexpected", true).
Msg("Node key expired, forcing re-authentication")
return &tailcfg.RegisterResponse{
NodeKeyExpired: true,
MachineAuthorized: false,
@ -277,6 +279,7 @@ func (h *Headscale) waitForFollowup(
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node.View()), nil
}
}
@ -342,6 +345,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
if perr, ok := errors.AsType[types.PAKError](err); ok {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
}
@ -432,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive(
Str("generated.hostname", hostname).
Msg("Received registration request with empty hostname, generated default")
}
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(

View file

@ -2,6 +2,7 @@ package hscontrol
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
@ -16,14 +17,14 @@ import (
"tailscale.com/types/key"
)
// Interactive step type constants
// Interactive step type constants.
const (
stepTypeInitialRequest = "initial_request"
stepTypeAuthCompletion = "auth_completion"
stepTypeFollowupRequest = "followup_request"
)
// interactiveStep defines a step in the interactive authentication workflow
// interactiveStep defines a step in the interactive authentication workflow.
type interactiveStep struct {
stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest
expectAuthURL bool
@ -75,6 +76,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -129,6 +131,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
if err != nil {
return "", err
@ -163,6 +166,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify both nodes exist
node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
assert.True(t, found1)
assert.True(t, found2)
assert.Equal(t, "reusable-node-1", node1.Hostname())
@ -196,6 +200,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
if err != nil {
return "", err
@ -227,6 +232,7 @@ func TestAuthenticationFlows(t *testing.T) {
// First node should exist, second should not
_, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
_, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
assert.True(t, found1)
assert.False(t, found2)
},
@ -272,6 +278,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -391,6 +398,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -400,8 +408,10 @@ func TestAuthenticationFlows(t *testing.T) {
// Wait for node to be available in NodeStore with debug info
var attemptCount int
require.EventuallyWithT(t, func(c *assert.CollectT) {
attemptCount++
_, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
if assert.True(c, found, "node should be available in NodeStore") {
t.Logf("Node found in NodeStore after %d attempts", attemptCount)
@ -451,6 +461,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -500,6 +511,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -549,25 +561,31 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
}
// Wait for node to be available in NodeStore
var node types.NodeView
var found bool
var (
node types.NodeView
found bool
)
require.EventuallyWithT(t, func(c *assert.CollectT) {
node, found = app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(c, found, "node should be available in NodeStore")
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
if !found {
return "", fmt.Errorf("node not found after setup")
return "", errors.New("node not found after setup")
}
// Expire the node
expiredTime := time.Now().Add(-1 * time.Hour)
_, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime)
return "", err
},
request: func(_ string) tailcfg.RegisterRequest {
@ -610,6 +628,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -673,6 +692,7 @@ func TestAuthenticationFlows(t *testing.T) {
// and handleRegister will receive the value when it starts waiting
go func() {
user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node")
registered <- node
}()
@ -782,6 +802,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -821,6 +842,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -865,6 +887,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -898,6 +921,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -922,6 +946,7 @@ func TestAuthenticationFlows(t *testing.T) {
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found)
assert.Equal(t, "tagged-pak-node", node.Hostname())
if node.AuthKey().Valid() {
assert.NotEmpty(t, node.AuthKey().Tags())
}
@ -1031,6 +1056,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1047,6 +1073,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak2.Key, nil
},
request: func(newAuthKey string) tailcfg.RegisterRequest {
@ -1099,6 +1126,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1161,6 +1189,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1177,6 +1206,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pakRotation.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1226,6 +1256,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1265,6 +1296,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1429,6 +1461,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify custom hostinfo was preserved through interactive workflow
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be found after interactive registration")
if found {
assert.Equal(t, "custom-interactive-node", node.Hostname())
assert.Equal(t, "linux", node.Hostinfo().OS())
@ -1455,6 +1488,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1520,6 +1554,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify registration ID was properly generated and used
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be registered after interactive workflow")
if found {
assert.Equal(t, "registration-id-test-node", node.Hostname())
assert.Equal(t, "test-os", node.Hostinfo().OS())
@ -1535,6 +1570,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1577,6 +1613,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1632,6 +1669,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1648,6 +1686,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak2.Key, nil
},
request: func(user2AuthKey string) tailcfg.RegisterRequest {
@ -1712,6 +1751,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
if err != nil {
return "", err
@ -1838,6 +1878,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1932,6 +1973,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
if err != nil {
return "", err
@ -2097,6 +2139,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Collect results - at least one should succeed
successCount := 0
for range numConcurrent {
select {
case err := <-results:
@ -2217,6 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Should handle nil hostinfo gracefully
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be registered despite nil hostinfo")
if found {
// Should have some default hostname or handle nil gracefully
hostname := node.Hostname()
@ -2315,12 +2359,14 @@ func TestAuthenticationFlows(t *testing.T) {
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
require.NoError(t, err)
authURL2 := resp2.AuthURL
assert.Contains(t, authURL2, "/register/")
// Both should have different registration IDs
regID1, err1 := extractRegistrationIDFromAuthURL(authURL1)
regID2, err2 := extractRegistrationIDFromAuthURL(authURL2)
require.NoError(t, err1)
require.NoError(t, err2)
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
@ -2328,6 +2374,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Both cache entries should exist simultaneously
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
assert.True(t, found1, "first registration cache entry should exist")
assert.True(t, found2, "second registration cache entry should exist")
@ -2371,6 +2418,7 @@ func TestAuthenticationFlows(t *testing.T) {
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
require.NoError(t, err)
authURL2 := resp2.AuthURL
regID2, err := extractRegistrationIDFromAuthURL(authURL2)
require.NoError(t, err)
@ -2378,6 +2426,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify both exist
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
assert.True(t, found1, "first cache entry should exist")
assert.True(t, found2, "second cache entry should exist")
@ -2403,6 +2452,7 @@ func TestAuthenticationFlows(t *testing.T) {
errorChan <- err
return
}
responseChan <- resp
}()
@ -2430,6 +2480,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify the node was created with the second registration's data
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be registered")
if found {
assert.Equal(t, "pending-node-2", node.Hostname())
assert.Equal(t, "second-registration-user", node.User().Name())
@ -2463,8 +2514,10 @@ func TestAuthenticationFlows(t *testing.T) {
// Set up context with timeout for followup tests
ctx := context.Background()
if req.Followup != "" {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
}
@ -2516,7 +2569,7 @@ func TestAuthenticationFlows(t *testing.T) {
}
}
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow.
func runInteractiveWorkflowTest(t *testing.T, tt struct {
name string
setupFunc func(*testing.T, *Headscale) (string, error)
@ -2597,6 +2650,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
errorChan <- err
return
}
responseChan <- resp
}()
@ -2650,24 +2704,27 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
if responseToValidate == nil {
responseToValidate = initialResp
}
tt.validate(t, responseToValidate, app)
}
}
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
// AuthURL format: "http://localhost/register/abc123"
const registerPrefix = "/register/"
idx := strings.LastIndex(authURL, registerPrefix)
if idx == -1 {
return "", fmt.Errorf("invalid AuthURL format: %s", authURL)
}
idStr := authURL[idx+len(registerPrefix):]
return types.RegistrationIDFromString(idStr)
}
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) {
// Basic response validation
require.NotNil(t, resp, "response should not be nil")
@ -2681,7 +2738,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe
// Additional validation can be added here as needed
}
// Simple test to validate basic node creation and lookup
// Simple test to validate basic node creation and lookup.
func TestNodeStoreLookup(t *testing.T) {
app := createTestApp(t)
@ -2713,8 +2770,10 @@ func TestNodeStoreLookup(t *testing.T) {
// Wait for node to be available in NodeStore
var node types.NodeView
require.EventuallyWithT(t, func(c *assert.CollectT) {
var found bool
node, found = app.state.GetNodeByNodeKey(nodeKey.Public())
assert.True(c, found, "Node should be found in NodeStore")
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore")
@ -2783,8 +2842,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Get the node ID
var registeredNode types.NodeView
require.EventuallyWithT(t, func(c *assert.CollectT) {
var found bool
registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public())
assert.True(c, found, "Node should be found in NodeStore")
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available")
@ -2796,6 +2857,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Verify initial state: user1 has 2 nodes, user2 has 2 nodes
user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID))
user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID))
require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially")
require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially")
@ -2876,6 +2938,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Verify new nodes were created for user1 with the same machine keys
t.Logf("Verifying new nodes created for user1 from user2's machine keys...")
for i := 2; i < 4; i++ {
node := nodes[i]
// Should be able to find a node with user1 and this machine key (the new one)
@ -2899,7 +2962,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Expected behavior:
// - User1's original node should STILL EXIST (expired)
// - User2 should get a NEW node created (NOT transfer)
// - Both nodes share the same machine key (same physical device)
// - Both nodes share the same machine key (same physical device).
func TestWebFlowReauthDifferentUser(t *testing.T) {
machineKey := key.NewMachine()
nodeKey1 := key.NewNode()
@ -3043,6 +3106,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
// Count nodes per user
user1Nodes := 0
user2Nodes := 0
for i := 0; i < allNodesSlice.Len(); i++ {
n := allNodesSlice.At(i)
if n.UserID().Get() == user1.ID {
@ -3060,7 +3124,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
})
}
// Helper function to create test app
// Helper function to create test app.
func createTestApp(t *testing.T) *Headscale {
t.Helper()
@ -3147,6 +3211,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
}
t.Log("Step 1: Initial registration with pre-auth key")
initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public())
require.NoError(t, err, "initial registration should succeed")
require.NotNil(t, initialResp)
@ -3172,6 +3237,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// - System reboots
// The Tailscale client persists the pre-auth key in its state and sends it on every registration
t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key")
restartReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pakNew.Key, // Same key, now marked as Used=true
@ -3189,9 +3255,11 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// This is the assertion that currently FAILS in v0.27.0
assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed")
if err != nil {
t.Logf("Error received (this is the bug): %v", err)
t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially")
return // Stop here to show the bug clearly
}

View file

@ -155,6 +155,7 @@ AND auth_key_id NOT IN (
nodeRoutes := map[uint64][]netip.Prefix{}
var routes []types.Route
err = tx.Find(&routes).Error
if err != nil {
return fmt.Errorf("fetching routes: %w", err)
@ -255,9 +256,11 @@ AND auth_key_id NOT IN (
// Check if routes table exists and drop it (should have been migrated already)
var routesExists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
if err == nil && routesExists {
log.Info().Msg("Dropping leftover routes table")
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
return fmt.Errorf("dropping routes table: %w", err)
}
@ -280,6 +283,7 @@ AND auth_key_id NOT IN (
for _, table := range tablesToRename {
// Check if table exists before renaming
var exists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
if err != nil {
return fmt.Errorf("checking if table %s exists: %w", table, err)
@ -761,6 +765,7 @@ AND auth_key_id NOT IN (
// or else it blocks...
sqlConn.SetMaxIdleConns(maxIdleConns)
sqlConn.SetMaxOpenConns(maxOpenConns)
defer sqlConn.SetMaxIdleConns(1)
defer sqlConn.SetMaxOpenConns(1)

View file

@ -44,6 +44,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
// Verify api_keys data preservation
var apiKeyCount int
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
require.NoError(t, err)
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
@ -186,6 +187,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
func requireConstraintFailed(t *testing.T, err error) {
t.Helper()
require.Error(t, err)
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
}
@ -401,6 +403,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
// skip already-applied migrations and only run new ones.
func TestSQLiteAllTestdataMigrations(t *testing.T) {
t.Parallel()
schemas, err := os.ReadDir("testdata/sqlite")
require.NoError(t, err)

View file

@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Basic deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var deletionWg sync.WaitGroup
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
deletionWg sync.WaitGroup
)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionWg.Done()
}
@ -43,10 +47,13 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
go gc.Start()
// Schedule several nodes for deletion with short expiry
const expiry = fifty
const numNodes = 100
const (
expiry = fifty
numNodes = 100
)
// Set up wait group for expected deletions
deletionWg.Add(numNodes)
for i := 1; i <= numNodes; i++ {
@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Start GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
defer gc.Close()
const shortExpiry = fifty
const longExpiry = 1 * time.Hour
const (
shortExpiry = fifty
longExpiry = 1 * time.Hour
)
nodeID := types.NodeID(1)
@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
}
// Start the GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
defer gc.Close()
nodeID := types.NodeID(1)
const expiry = fifty
// Schedule node for deletion
@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
// Deletion tracking
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
nodeDeleted := make(chan struct{})
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
close(nodeDeleted) // Signal that deletion happened
}
@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Use a WaitGroup to ensure the GC has started
var startWg sync.WaitGroup
startWg.Add(1)
go func() {
startWg.Done() // Signal that the goroutine has started
gc.Start()
}()
startWg.Wait() // Wait for the GC to start
// Close GC right away
@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Check no node was deleted
deleteMutex.Lock()
nodesDeleted := len(deletedIDs)
deleteMutex.Unlock()
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
}
@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
go gc.Start()
// Number of concurrent scheduling goroutines
const numSchedulers = 10
const nodesPerScheduler = 50
const (
numSchedulers = 10
nodesPerScheduler = 50
)
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler

View file

@ -483,6 +483,7 @@ func TestBackfillIPAddresses(t *testing.T) {
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
defer db.Close()
alloc, err := NewIPAllocator(

View file

@ -206,6 +206,7 @@ func SetTags(
slices.Sort(tags)
tags = slices.Compact(tags)
b, err := json.Marshal(tags)
if err != nil {
return err
@ -378,6 +379,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
if ipv4 == nil {
ipv4 = oldNode.IPv4
}
if ipv6 == nil {
ipv6 = oldNode.IPv6
}
@ -406,6 +408,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
node.IPv6 = ipv6
var err error
node.Hostname, err = util.NormaliseHostname(node.Hostname)
if err != nil {
newHostname := util.InvalidString()
@ -693,9 +696,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
}
var registeredNode *types.Node
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
var err error
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
return err
})
if err != nil {

View file

@ -497,6 +497,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes1) == 0 {
expectedRoutes1 = nil
}
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
@ -508,6 +509,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes2) == 0 {
expectedRoutes2 = nil
}
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
@ -745,12 +747,15 @@ func TestNodeNaming(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
return err
})
require.NoError(t, err)
@ -999,6 +1004,7 @@ func TestListPeers(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
return err
@ -1084,6 +1090,7 @@ func TestListNodes(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
return err

View file

@ -372,18 +372,23 @@ func (c *Config) ToURL() (string, error) {
if c.BusyTimeout > 0 {
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
}
if c.JournalMode != "" {
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
}
if c.AutoVacuum != "" {
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
}
if c.WALAutocheckpoint >= 0 {
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
}
if c.Synchronous != "" {
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
}
if c.ForeignKeys {
pragmas = append(pragmas, "foreign_keys=ON")
}

View file

@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
t.Errorf("Config.ToURL() error = %v", err)
return
}
if got != tt.want {
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
}
@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
Path: "",
BusyTimeout: -1,
}
_, err := config.ToURL()
if err == nil {
t.Error("Config.ToURL() with invalid config should return error")

View file

@ -109,7 +109,9 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
for pragma, expectedValue := range tt.expected {
t.Run("pragma_"+pragma, func(t *testing.T) {
var actualValue any
query := "PRAGMA " + pragma
err := db.QueryRow(query).Scan(&actualValue)
if err != nil {
t.Fatalf("Failed to query %s: %v", query, err)
@ -249,6 +251,7 @@ func TestJournalModeValidation(t *testing.T) {
defer db.Close()
var actualMode string
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
if err != nil {
t.Fatalf("Failed to query journal_mode: %v", err)

View file

@ -42,6 +42,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
@ -55,6 +56,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args)
if !ret[0].IsNil() {
return decodingError(field.Name, ret[0].Interface().(error))
@ -89,6 +91,7 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) {
return nil, nil
}
b, err := v.MarshalText()
if err != nil {
return nil, err

View file

@ -88,10 +88,12 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
// not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
var err error
oldUser, err := GetUserByID(tx, uid)
if err != nil {
return err
}
if err = util.ValidateHostname(newName); err != nil {
return err
}

View file

@ -25,17 +25,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
overview := h.state.DebugOverviewJSON()
overviewJSON, err := json.MarshalIndent(overview, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(overviewJSON)
} else {
// Default to text/plain for backward compatibility
overview := h.state.DebugOverview()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(overview))
@ -45,11 +48,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Configuration endpoint
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config := h.state.DebugConfig()
configJSON, err := json.MarshalIndent(config, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(configJSON)
@ -70,6 +75,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
} else {
w.Header().Set("Content-Type", "text/plain")
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(policy))
}))
@ -81,11 +87,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
httpError(w, err)
return
}
filterJSON, err := json.MarshalIndent(filter, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(filterJSON)
@ -94,11 +102,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// SSH policies endpoint
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sshPolicies := h.state.DebugSSHPolicies()
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(sshJSON)
@ -112,17 +122,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
derpInfo := h.state.DebugDERPJSON()
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(derpJSON)
} else {
// Default to text/plain for backward compatibility
derpInfo := h.state.DebugDERPMap()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(derpInfo))
@ -137,17 +150,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
nodeStoreNodes := h.state.DebugNodeStoreJSON()
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(nodeStoreJSON)
} else {
// Default to text/plain for backward compatibility
nodeStoreInfo := h.state.DebugNodeStore()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(nodeStoreInfo))
@ -157,11 +173,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
// Registration cache endpoint
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cacheInfo := h.state.DebugRegistrationCache()
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(cacheJSON)
@ -175,17 +193,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
routes := h.state.DebugRoutes()
routesJSON, err := json.MarshalIndent(routes, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(routesJSON)
} else {
// Default to text/plain for backward compatibility
routes := h.state.DebugRoutesString()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(routes))
@ -200,17 +221,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
policyManagerInfo := h.state.DebugPolicyManagerJSON()
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(policyManagerJSON)
} else {
// Default to text/plain for backward compatibility
policyManagerInfo := h.state.DebugPolicyManager()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(policyManagerInfo))
@ -227,6 +251,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if res == nil {
w.WriteHeader(http.StatusOK)
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
return
}
@ -235,6 +260,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(resJSON)
@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
activeConnections: info.ActiveConnections,
})
totalNodes++
if info.Connected {
connectedCount++
}
@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
activeConnections: 0,
})
totalNodes++
if connected {
connectedCount++
}
return true
})
}
@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
ActiveConnections: 0,
}
info.TotalNodes++
return true
})
}

View file

@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
for id := range dm.Regions {
ids = append(ids, id)
}
slices.Sort(ids)
for _, id := range ids {
@ -164,12 +165,14 @@ func derpRandom() *rand.Rand {
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
derpRandomInst = rnd
})
return derpRandomInst
}
func resetDerpRandomForTesting() {
derpRandomMu.Lock()
defer derpRandomMu.Unlock()
derpRandomOnce = sync.Once{}
derpRandomInst = nil
}

View file

@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
viper.Set("dns.base_domain", tt.baseDomain)
defer viper.Reset()
resetDerpRandomForTesting()
testMap := tt.derpMap.View().AsStruct()

View file

@ -74,9 +74,11 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
if err != nil {
return tailcfg.DERPRegion{}, err
}
var host string
var port int
var portStr string
var (
host string
port int
portStr string
)
// Extract hostname and port from URL
host, portStr, err = net.SplitHostPort(serverURL.Host)
@ -205,6 +207,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
return
}
defer websocketConn.Close(websocket.StatusInternalError, "closing")
if websocketConn.Subprotocol() != "derp" {
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
@ -309,6 +312,7 @@ func DERPBootstrapDNSHandler(
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
defer cancel()
var resolver net.Resolver
for _, region := range derpMap.Regions().All() {
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
@ -320,6 +324,7 @@ func DERPBootstrapDNSHandler(
continue
}
dnsEntries[node.HostName()] = addrs
}
}

View file

@ -85,12 +85,15 @@ func (e *ExtraRecordsMan) Run() {
log.Error().Caller().Msgf("file watcher event channel closing")
return
}
switch event.Op {
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
if event.Name != e.path {
continue
}
e.updateRecords()
// If a file is removed or renamed, fsnotify will loose track of it
@ -123,6 +126,7 @@ func (e *ExtraRecordsMan) Run() {
log.Error().Caller().Msgf("file watcher error channel closing")
return
}
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
}
}
@ -165,6 +169,7 @@ func (e *ExtraRecordsMan) updateRecords() {
e.hashes[e.path] = newHash
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
e.updateCh <- e.records.Slice()
}
@ -183,6 +188,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
}
var records []tailcfg.DNSRecord
err = json.Unmarshal(b, &records)
if err != nil {
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)

View file

@ -181,6 +181,7 @@ func (h *Headscale) HealthHandler(
json.NewEncoder(writer).Encode(res)
}
err := h.state.PingDB(req.Context())
if err != nil {
respond(err)
@ -217,6 +218,7 @@ func (h *Headscale) VersionHandler(
writer.WriteHeader(http.StatusOK)
versionInfo := types.GetVersionInfo()
err := json.NewEncoder(writer).Encode(versionInfo)
if err != nil {
log.Error().

View file

@ -2,6 +2,7 @@ package mapper
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sync"
@ -77,6 +78,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
@ -86,10 +88,11 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
case c <- initialMap:
// Success
case <-time.After(5 * time.Second):
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
log.Error().Uint64("node.id", id.Uint64()).Err(errors.New("timeout")).Msg("Initial map send timeout")
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
Msg("Initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
}
@ -129,6 +132,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
log.Debug().Caller().Uint64("node.id", id.Uint64()).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection removed but keeping online because other connections remain")
return true // Node still has active connections
}
@ -211,10 +215,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
@ -397,6 +403,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
}
}
}
return true
})
@ -449,6 +456,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
if nodeConn.hasActiveConnections() {
ret.Store(id, true)
}
return true
})
@ -464,6 +472,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret.Store(id, false)
}
}
return true
})
@ -518,7 +527,8 @@ type multiChannelNodeConn struct {
func generateConnectionID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
return hex.EncodeToString(bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
@ -545,11 +555,14 @@ func (mc *multiChannelNodeConn) close() {
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
@ -571,9 +584,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("Successfully removed connection")
return true
}
}
return false
}
@ -607,6 +622,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// This is not an error - the node will receive a full map when it reconnects
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error
}
@ -615,7 +631,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Msg("send: broadcasting to all connections")
var lastErr error
successCount := 0
var failedConnections []int // Track failed connections for removal
// Send to all connections
@ -626,6 +644,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
if err := conn.send(data); err != nil {
lastErr = err
failedConnections = append(failedConnections, i)
log.Warn().Err(err).
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
@ -633,6 +652,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Msg("send: connection send failed")
} else {
successCount++
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: successfully sent to connection")
@ -797,6 +817,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
Connected: connected,
ActiveConnections: activeConnCount,
}
return true
})
@ -811,6 +832,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
ActiveConnections: 0,
}
}
return true
})

View file

@ -677,6 +677,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
connectedCount := 0
for i := range allNodes {
node := &allNodes[i]
@ -694,6 +695,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
t.Logf("✅ All nodes achieved full connectivity!")
totalTime := time.Since(startTime)
// Disconnect all nodes
@ -1309,6 +1311,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
for range i % 3 {
runtime.Gosched() // Introduce timing variability
}
batcher.RemoveNode(testNode.n.ID, ch)
// Yield to allow workers to process and close channels
@ -1392,6 +1395,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully
return
}
if valid, reason := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@ -1449,7 +1453,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
@ -1463,6 +1469,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully
return
}
if valid, _ := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@ -1495,6 +1502,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for range i % 5 {
runtime.Gosched() // Introduce timing variability
}
churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID]
@ -1879,6 +1887,7 @@ func XTestBatcherScalability(t *testing.T) {
channel,
tailcfg.CapabilityVersion(100),
)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = true
@ -2287,6 +2296,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...")
for i, node := range allNodes {
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
if err != nil {
@ -2303,6 +2313,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
for i, node := range allNodes {
removed := batcher.RemoveNode(node.n.ID, node.ch)
t.Logf("Node %d RemoveNode result: %t", i, removed)
@ -2310,9 +2321,11 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i, node := range allNodes {
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err)
@ -2343,11 +2356,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
if infoMap, ok := info.(map[string]any); ok {
if connected, ok := infoMap["connected"].(bool); ok && !connected {
disconnectedCount++
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
}
}
} else {
disconnectedCount++
t.Logf("Node %d missing from debug info entirely", i)
}
@ -2382,6 +2397,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
case update := <-newChannels[i]:
if update != nil {
receivedCount++
t.Logf("Node %d received update successfully", i)
}
case <-timeout:
@ -2414,6 +2430,7 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node1: %v", err)
@ -2433,7 +2450,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err)
@ -2444,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err)
@ -2455,6 +2476,7 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
@ -2462,6 +2484,7 @@ func TestBatcherMultiConnection(t *testing.T) {
if info, exists := debugInfo[node1.n.ID]; exists {
t.Logf("Node1 debug info: %+v", info)
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 3 {
@ -2470,6 +2493,7 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
}
}
if connected, ok := infoMap["connected"].(bool); ok && !connected {
t.Errorf("Node1 should show as connected with 3 active connections")
}

View file

@ -37,6 +37,7 @@ const (
// NewMapResponseBuilder creates a new builder with basic fields set.
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
resp: &tailcfg.MapResponse{
KeepAlive: false,
@ -124,6 +125,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
}
return b
}
@ -281,16 +283,18 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
// Build finalizes the response and returns marshaled bytes.
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)
}
if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
}

View file

@ -60,7 +60,6 @@ func newMapper(
state *state.State,
) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &mapper{
state: state,
cfg: cfg,
@ -80,6 +79,7 @@ func generateUserProfiles(
userID := user.Model().ID
userMap[userID] = &user
ids = append(ids, userID)
for _, peer := range peers.All() {
peerUser := peer.Owner()
peerUserID := peerUser.Model().ID
@ -90,6 +90,7 @@ func generateUserProfiles(
slices.Sort(ids)
ids = slices.Compact(ids)
var profiles []tailcfg.UserProfile
for _, id := range ids {
if userMap[id] != nil {
profiles = append(profiles, userMap[id].TailscaleUserProfile())
@ -306,6 +307,7 @@ func writeDebugMapResponse(
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
@ -319,6 +321,7 @@ func writeDebugMapResponse(
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
@ -375,6 +378,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
}
var resp tailcfg.MapResponse
err = json.Unmarshal(body, &resp)
if err != nil {
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())

View file

@ -98,6 +98,7 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
if m.polMan == nil {
return tailcfg.FilterAllowAll, nil
}
return m.polMan.Filter()
}
@ -105,6 +106,7 @@ func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
if m.polMan == nil {
return nil, nil
}
return m.polMan.SSHPolicy(node)
}
@ -112,6 +114,7 @@ func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
if m.polMan == nil {
return false
}
return m.polMan.NodeCanHaveTag(node, tag)
}
@ -119,6 +122,7 @@ func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
if m.primary == nil {
return nil
}
return m.primary.PrimaryRoutes(nodeID)
}
@ -126,6 +130,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
if len(peerIDs) > 0 {
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
@ -136,6 +141,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
}
// Return all peers except the node itself
var filtered types.Nodes
for _, peer := range m.peers {
if peer.ID != nodeID {
filtered = append(filtered, peer)
@ -149,6 +155,7 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) > 0 {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)

View file

@ -243,10 +243,12 @@ func (ns *noiseServer) NoiseRegistrationHandler(
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
var resp *tailcfg.RegisterResponse
body, err := io.ReadAll(req.Body)
if err != nil {
return &tailcfg.RegisterRequest{}, regErr(err)
}
var regReq tailcfg.RegisterRequest
if err := json.Unmarshal(body, &regReq); err != nil {
return &regReq, regErr(err)
@ -260,6 +262,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
resp = &tailcfg.RegisterResponse{
Error: httpErr.Msg,
}
return &regReq, resp
}

View file

@ -163,6 +163,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info
@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
stateCookieName := getCookieName("state", state)
cookieState, err := req.Cookie(stateCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
httpError(writer, err)
return
}
if idToken.Nonce == "" {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
return
}
nonceCookieName := getCookieName("nonce", idToken.Nonce)
nonce, err := req.Cookie(nonceCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return
}
if idToken.Nonce != nonce.Value {
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
return
@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
var userinfo *oidc.UserInfo
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
if err != nil {
util.LogErr(err, "could not get userinfo; only using claims from id token")
@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
if userinfo2.Groups != nil {
claims.Groups = userinfo2.Groups
@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
Msgf("could not create or update user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not create or update user"))
if werr != nil {
log.Error().
@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
httpError(writer, err)
return
}
@ -324,6 +335,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write HTTP response")
}
@ -370,6 +382,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
}
@ -516,6 +529,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
newUser bool
c change.Change
)
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)

View file

@ -21,10 +21,13 @@ func (m Match) DebugString() string {
sb.WriteString("Match:\n")
sb.WriteString(" Sources:\n")
for _, prefix := range m.srcs.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}
sb.WriteString(" Destinations:\n")
for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}

View file

@ -38,8 +38,11 @@ type PolicyManager interface {
// NewPolicyManager returns a new policy manager.
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
var polMan PolicyManager
var err error
var (
polMan PolicyManager
err error
)
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
if err != nil {
return nil, err
@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
if err != nil {
return nil, err
}
polMans = append(polMans, pm)
}

View file

@ -125,6 +125,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
if !slices.Equal(sortedCurrent, newApproved) {
// Log what changed
var added, kept []netip.Prefix
for _, route := range newApproved {
if !slices.Contains(sortedCurrent, route) {
added = append(added, route)

View file

@ -312,8 +312,11 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
nodes := types.Nodes{&node}
// Create policy manager or use nil if specified
var pm PolicyManager
var err error
var (
pm PolicyManager
err error
)
if tt.name != "nil_policy_manager" {
pm, err = pmf(users, nodes.ViewSlice())
assert.NoError(t, err)

View file

@ -32,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
rules []tailcfg.FilterRule
node *types.Node
}
tests := []struct {
name string
args args
@ -782,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
for _, v := range gotViews.All() {
got = append(got, v.AsStruct())
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
t.Log("Matchers: ")
for _, m := range matchers {
t.Log("\t+", m.DebugString())
}
@ -1031,8 +1034,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager
var err error
var (
pm PolicyManager
err error
)
pm, err = pmf(nil, tt.nodes.ViewSlice())
require.NoError(t, err)
@ -1050,9 +1056,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
for _, v := range gotViews.All() {
got = append(got, v.AsStruct())
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
t.Log("Matchers: ")
for _, m := range matchers {
t.Log("\t+", m.DebugString())
}
@ -1405,13 +1413,17 @@ func TestSSHPolicyRules(t *testing.T) {
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager
var err error
var (
pm PolicyManager
err error
)
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
if tt.expectErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMessage)
return
}
@ -1434,6 +1446,7 @@ func TestReduceRoutes(t *testing.T) {
routes []netip.Prefix
rules []tailcfg.FilterRule
}
tests := []struct {
name string
args args
@ -2055,6 +2068,7 @@ func TestReduceRoutes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
got := ReduceRoutes(
tt.args.node.View(),
tt.args.routes,

View file

@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)

View file

@ -823,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests {
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm policy.PolicyManager
var err error
var (
pm policy.PolicyManager
err error
)
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err)
got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = policyutil.ReduceFilterRules(tt.node.View(), got)

View file

@ -829,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
if tt.name == "empty policy" {
// We expect this one to have a valid but empty policy
require.NoError(t, err)
if err != nil {
return
}
@ -843,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
}
assert.Equal(t, tt.canApprove, result, "Unexpected route approval result")
})
}

View file

@ -45,6 +45,7 @@ func (pol *Policy) compileFilterRules(
protocols, _ := acl.Protocol.parseProtocol()
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
@ -127,8 +128,10 @@ func (pol *Policy) compileACLWithAutogroupSelf(
node types.NodeView,
nodes views.Slice[types.NodeView],
) ([]*tailcfg.FilterRule, error) {
var autogroupSelfDests []AliasWithPorts
var otherDests []AliasWithPorts
var (
autogroupSelfDests []AliasWithPorts
otherDests []AliasWithPorts
)
for _, dest := range acl.Destinations {
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
@ -139,13 +142,14 @@ func (pol *Policy) compileACLWithAutogroupSelf(
}
protocols, _ := acl.Protocol.parseProtocol()
var rules []*tailcfg.FilterRule
var resolvedSrcIPs []*netipx.IPSet
for _, src := range acl.Sources {
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
return nil, fmt.Errorf("autogroup:self cannot be used in sources")
return nil, errors.New("autogroup:self cannot be used in sources")
}
ips, err := src.Resolve(pol, users, nodes)
@ -167,6 +171,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
if len(autogroupSelfDests) > 0 {
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
@ -176,6 +181,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
if len(sameUserNodes) > 0 {
// Filter sources to only same-user untagged devices
var srcIPs netipx.IPSetBuilder
for _, ips := range resolvedSrcIPs {
for _, n := range sameUserNodes {
// Check if any of this node's IPs are in the source set
@ -192,6 +198,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
if srcSet != nil && len(srcSet.Prefixes()) > 0 {
var destPorts []tailcfg.NetPortRange
for _, dest := range autogroupSelfDests {
for _, n := range sameUserNodes {
for _, port := range dest.Ports {
@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy(
// Separate destinations into autogroup:self and others
// This is needed because autogroup:self requires filtering sources to same-user only,
// while other destinations should use all resolved sources
var autogroupSelfDests []Alias
var otherDests []Alias
var (
autogroupSelfDests []Alias
otherDests []Alias
)
for _, dst := range rule.Destinations {
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy(
}
var action tailcfg.SSHAction
switch rule.Action {
case SSHActionAccept:
action = sshAction(true, 0)
@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy(
// by default, we do not allow root unless explicitly stated
userMap["root"] = ""
}
if rule.Users.ContainsRoot() {
userMap["root"] = "root"
}
for _, u := range rule.Users.NormalUsers() {
userMap[u.String()] = u.String()
}
@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy(
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
// Build destination set for autogroup:self (same-user untagged devices only)
var dest netipx.IPSetBuilder
for _, n := range nodes.All() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
n.AppendToIPSet(&dest)
@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy(
// Filter sources to only same-user untagged devices
// Pre-filter to same-user untagged devices for efficiency
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy(
}
var filteredSrcIPs netipx.IPSetBuilder
for _, n := range sameUserNodes {
// Check if any of this node's IPs are in the source set
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy(
if len(otherDests) > 0 {
// Build destination set for other destinations
var dest netipx.IPSetBuilder
for _, dst := range otherDests {
ips, err := dst.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
continue
}
if ips != nil {
dest.AddSet(ips)
}

View file

@ -589,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
if sshPolicy == nil {
return // Expected empty result
}
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
return
}
@ -670,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
}
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
// TestSSHOneUserToAll that was failing with empty sshUsers
// TestSSHOneUserToAll that was failing with empty sshUsers.
func TestSSHIntegrationReproduction(t *testing.T) {
// Create users matching the integration test
users := types.Users{
@ -735,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
}
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
// to JSON and that the sshUsers field is not empty
// to JSON and that the sshUsers field is not empty.
func TestSSHJSONSerialization(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},
@ -775,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) {
// Parse back to verify structure
var parsed tailcfg.SSHPolicy
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
@ -859,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(rules) != 1 {
t.Fatalf("expected 1 rule, got %d", len(rules))
}
@ -875,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
found := false
addr := netip.MustParseAddr(expectedIP)
for _, prefix := range rule.SrcIPs {
pref := netip.MustParsePrefix(prefix)
if pref.Contains(addr) {
@ -892,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
for _, excludedIP := range excludedSourceIPs {
addr := netip.MustParseAddr(excludedIP)
for _, prefix := range rule.SrcIPs {
pref := netip.MustParsePrefix(prefix)
if pref.Contains(addr) {
@ -1325,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
assert.Empty(t, rules3, "user3 should have no rules")
}
// Helper function to create IP addresses for testing
// Helper function to create IP addresses for testing.
func createAddr(ip string) *netip.Addr {
addr, _ := netip.ParseAddr(ip)
return &addr
}
// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly
// with autogroup:self in destinations
// with autogroup:self in destinations.
func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1380,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
// Test for user2's first node
@ -1398,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
for i, p := range rule2.Principals {
principalIPs2[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2)
// Test for tagged node (should have no SSH rules)
node5 := nodes[4].View()
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy3 != nil {
assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self")
}
@ -1411,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user
// is in the source and autogroup:self in destination, only that user's devices
// can SSH (and only if they match the target user)
// can SSH (and only if they match the target user).
func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1453,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
// For user2's node: should have no rules (user1's devices can't match user2's self)
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1")
}
}
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations.
func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1511,19 +1522,21 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
// For user3's node: should have no rules (not in group:admins)
node5 := nodes[4].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)")
}
}
// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices
// are excluded from both sources and destinations when autogroup:self is used
// are excluded from both sources and destinations when autogroup:self is used.
func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1568,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs,
"should only include untagged devices")
@ -1575,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self")
}
@ -1623,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Verify autogroup:self rule has filtered sources (only same-user devices)
selfRule := sshPolicy1.Rules[0]
require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices")
selfPrincipals := make([]string, len(selfRule.Principals))
for i, p := range selfRule.Principals {
selfPrincipals[i] = p.NodeIP
}
require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals,
"autogroup:self rule should only include same-user untagged devices")
@ -1638,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
routerRule := sshPolicyRouter.Rules[0]
routerPrincipals := make([]string, len(routerRule.Principals))
for i, p := range routerRule.Principals {
routerPrincipals[i] = p.NodeIP
}
require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)")
require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)")
require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)")

View file

@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Filter: filter,
Policy: pm.pol,
})
filterChanged := filterHash != pm.filterHash
if filterChanged {
log.Debug().
@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Int("filter.rules.new", len(filter)).
Msg("Policy filter hash changed")
}
pm.filter = filter
pm.filterHash = filterHash
if filterChanged {
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
}
tagOwnerMapHash := deephash.Hash(&tagMap)
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
if tagOwnerChanged {
log.Debug().
@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Int("tagOwners.new", len(tagMap)).
Msg("Tag owner hash changed")
}
pm.tagOwnerMap = tagMap
pm.tagOwnerMapHash = tagOwnerMapHash
@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
}
autoApproveMapHash := deephash.Hash(&autoMap)
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
if autoApproveChanged {
log.Debug().
@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Int("autoApprovers.new", len(autoMap)).
Msg("Auto-approvers hash changed")
}
pm.autoApproveMap = autoMap
pm.autoApproveMapHash = autoApproveMapHash
exitSetHash := deephash.Hash(&exitSet)
exitSetChanged := exitSetHash != pm.exitSetHash
if exitSetChanged {
log.Debug().
@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Str("exitSet.hash.new", exitSetHash.String()[:8]).
Msg("Exit node set hash changed")
}
pm.exitSet = exitSet
pm.exitSetHash = exitSetHash
@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
if !needsUpdate {
log.Trace().
Msg("Policy evaluation detected no changes - all hashes match")
return false, nil
}
@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}
pm.sshPolicyMap[node.ID()] = sshPol
return sshPol, nil
@ -318,6 +329,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ
if err != nil || len(filter) == 0 {
continue
}
nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter)
}
@ -398,6 +410,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
pm.filterRulesMap[node.ID()] = reducedFilter
return reducedFilter, nil
}
@ -442,7 +455,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
//
// For global policies: returns the global matchers (same for all nodes)
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
// For autogroup:self: returns node-specific matchers from unreduced compiled rules.
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
if pm == nil {
return nil, nil
@ -474,6 +487,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
// Clear SSH policy map when users change to force SSH policy recomputation
@ -685,6 +699,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
if pm.exitSet == nil {
return false
}
if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) {
return true
}
@ -748,8 +763,10 @@ func (pm *PolicyManager) DebugString() string {
}
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
for prefix, approveAddrs := range pm.autoApproveMap {
fmt.Fprintf(&sb, "\t%s:\n", prefix)
for _, iprange := range approveAddrs.Ranges() {
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
}
@ -758,14 +775,17 @@ func (pm *PolicyManager) DebugString() string {
sb.WriteString("\n\n")
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
for prefix, tagOwners := range pm.tagOwnerMap {
fmt.Fprintf(&sb, "\t%s:\n", prefix)
for _, iprange := range tagOwners.Ranges() {
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
}
}
sb.WriteString("\n\n")
if pm.filter != nil {
filter, err := json.MarshalIndent(pm.filter, "", " ")
if err == nil {
@ -778,6 +798,7 @@ func (pm *PolicyManager) DebugString() string {
sb.WriteString("\n\n")
sb.WriteString("Matchers:\n")
sb.WriteString("an internal structure used to filter nodes and routes\n")
for _, match := range pm.matchers {
sb.WriteString(match.DebugString())
sb.WriteString("\n")
@ -785,6 +806,7 @@ func (pm *PolicyManager) DebugString() string {
sb.WriteString("\n\n")
sb.WriteString("Nodes:\n")
for _, node := range pm.nodes.All() {
sb.WriteString(node.String())
sb.WriteString("\n")
@ -841,6 +863,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// Check if IPs changed (simple check - could be more sophisticated)
oldIPs := oldNode.IPs()
newIPs := newNode.IPs()
if len(oldIPs) != len(newIPs) {
affectedUsers[newNode.User().ID()] = struct{}{}
@ -862,6 +885,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
for nodeID := range pm.filterRulesMap {
// Find the user for this cached node
var nodeUserID uint
found := false
// Check in new nodes first
@ -869,6 +893,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
if node.ID() == nodeID {
nodeUserID = node.User().ID()
found = true
break
}
}
@ -879,6 +904,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
if node.ID() == nodeID {
nodeUserID = node.User().ID()
found = true
break
}
}

View file

@ -56,6 +56,7 @@ func TestPolicyManager(t *testing.T) {
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(
tt.wantMatchers,
matchers,
@ -176,13 +177,16 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
for i, n := range tt.newNodes {
found := false
for _, origNode := range initialNodes {
if n.Hostname == origNode.Hostname {
n.ID = origNode.ID
found = true
break
}
}
if !found {
n.ID = types.NodeID(len(initialNodes) + i + 1)
}
@ -369,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) {
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
// 2. FilterForNode returns reduced compiled rules for packet filters
// 2. FilterForNode returns reduced compiled rules for packet filters.
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
@ -409,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
// For node1, destinations should only be node1's IPs
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
for _, rule := range filterNode1 {
for _, dst := range rule.DstPorts {
require.Contains(t, node1IPs, dst.IP,
@ -418,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
// For node2, destinations should only be node2's IPs
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
for _, rule := range filterNode2 {
for _, dst := range rule.DstPorts {
require.Contains(t, node2IPs, dst.IP,

View file

@ -21,7 +21,7 @@ import (
"tailscale.com/util/slicesx"
)
// Global JSON options for consistent parsing across all struct unmarshaling
// Global JSON options for consistent parsing across all struct unmarshaling.
var policyJSONOpts = []json.Options{
json.DefaultOptionsV2(),
json.MatchCaseInsensitiveNames(true),
@ -58,6 +58,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
}
var alias string
switch v := a.Alias.(type) {
case *Username:
alias = string(*v)
@ -89,6 +90,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
// Otherwise, format as "alias:ports"
var ports []string
for _, port := range a.Ports {
if port.First == port.Last {
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
@ -123,6 +125,7 @@ func (u Username) Validate() error {
if isUser(string(u)) {
return nil
}
return fmt.Errorf("Username has to contain @, got: %q", u)
}
@ -194,8 +197,10 @@ func (u Username) resolveUser(users types.Users) (types.User, error) {
}
func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
var (
ips netipx.IPSetBuilder
errs []error
)
user, err := u.resolveUser(users)
if err != nil {
@ -228,6 +233,7 @@ func (g Group) Validate() error {
if isGroup(string(g)) {
return nil
}
return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
}
@ -268,8 +274,10 @@ func (g Group) MarshalJSON() ([]byte, error) {
}
func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
var (
ips netipx.IPSetBuilder
errs []error
)
for _, user := range p.Groups[g] {
uips, err := user.Resolve(nil, users, nodes)
@ -290,6 +298,7 @@ func (t Tag) Validate() error {
if isTag(string(t)) {
return nil
}
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
}
@ -339,6 +348,7 @@ func (h Host) Validate() error {
if isHost(string(h)) {
return nil
}
return fmt.Errorf("Hostname %q is invalid", h)
}
@ -352,13 +362,16 @@ func (h *Host) UnmarshalJSON(b []byte) error {
}
func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
var (
ips netipx.IPSetBuilder
errs []error
)
pref, ok := p.Hosts[h]
if !ok {
return nil, fmt.Errorf("unable to resolve host: %q", h)
}
err := pref.Validate()
if err != nil {
errs = append(errs, err)
@ -376,6 +389,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView
if err != nil {
errs = append(errs, err)
}
for _, node := range nodes.All() {
if node.InIPSet(ipsTemp) {
node.AppendToIPSet(&ips)
@ -391,6 +405,7 @@ func (p Prefix) Validate() error {
if netip.Prefix(p).IsValid() {
return nil
}
return fmt.Errorf("Prefix %q is invalid", p)
}
@ -404,6 +419,7 @@ func (p *Prefix) parseString(addr string) error {
if err != nil {
return err
}
addrPref, err := addr.Prefix(addr.BitLen())
if err != nil {
return err
@ -418,6 +434,7 @@ func (p *Prefix) parseString(addr string) error {
if err != nil {
return err
}
*p = Prefix(pref)
return nil
@ -428,6 +445,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
if err := p.Validate(); err != nil {
return err
}
@ -441,8 +459,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
//
// See [Policy], [types.Users], and [types.Nodes] for more details.
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
var (
ips netipx.IPSetBuilder
errs []error
)
ips.AddPrefix(netip.Prefix(p))
// If the IP is a single host, look for a node to ensure we add all the IPs of
@ -587,8 +607,10 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
switch vs := v.(type) {
case string:
var portsPart string
var err error
var (
portsPart string
err error
)
if strings.Contains(vs, ":") {
vs, portsPart, err = splitDestinationAndPort(vs)
@ -600,6 +622,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
ve.Ports = ports
} else {
return errors.New(`hostport must contain a colon (":")`)
@ -609,6 +632,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
if err := ve.Validate(); err != nil {
return err
}
@ -646,6 +670,7 @@ func isHost(str string) bool {
func parseAlias(vs string) (Alias, error) {
var pref Prefix
err := pref.parseString(vs)
if err == nil {
return &pref, nil
@ -690,6 +715,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
ve.Alias = ptr
return nil
@ -699,6 +725,7 @@ type Aliases []Alias
func (a *Aliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
@ -744,8 +771,10 @@ func (a Aliases) MarshalJSON() ([]byte, error) {
}
func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
var (
ips netipx.IPSetBuilder
errs []error
)
for _, alias := range a {
aips, err := alias.Resolve(p, users, nodes)
@ -770,6 +799,7 @@ func unmarshalPointer[T any](
parseFunc func(string) (T, error),
) (T, error) {
var s string
err := json.Unmarshal(b, &s)
if err != nil {
var t T
@ -789,6 +819,7 @@ type AutoApprovers []AutoApprover
func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
var autoApprovers []AutoApproverEnc
err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...)
if err != nil {
return err
@ -854,6 +885,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
ve.AutoApprover = ptr
return nil
@ -876,6 +908,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
if err != nil {
return err
}
ve.Owner = ptr
return nil
@ -885,6 +918,7 @@ type Owners []Owner
func (o *Owners) UnmarshalJSON(b []byte) error {
var owners []OwnerEnc
err := json.Unmarshal(b, &owners, policyJSONOpts...)
if err != nil {
return err
@ -979,11 +1013,13 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
// Then validate each field can be converted to []string
rawGroups := make(map[string][]string)
for key, value := range rawMap {
switch v := value.(type) {
case []any:
// Convert []interface{} to []string
var stringSlice []string
for _, item := range v {
if str, ok := item.(string); ok {
stringSlice = append(stringSlice, str)
@ -991,6 +1027,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item)
}
}
rawGroups[key] = stringSlice
case string:
return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v)
@ -1000,6 +1037,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
}
*g = make(Groups)
for key, value := range rawGroups {
group := Group(key)
// Group name already validated above
@ -1014,6 +1052,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
return err
}
usernames = append(usernames, username)
}
@ -1033,6 +1072,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error {
}
*h = make(Hosts)
for key, value := range rawHosts {
host := Host(key)
if err := host.Validate(); err != nil {
@ -1076,6 +1116,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) {
}
rawTagOwners := make(map[string][]string)
for tag, owners := range to {
tagStr := string(tag)
ownerStrs := make([]string, len(owners))
@ -1152,6 +1193,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
if p == nil {
return nil, nil, nil
}
var err error
routes := make(map[netip.Prefix]*netipx.IPSetBuilder)
@ -1160,6 +1202,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
if _, ok := routes[prefix]; !ok {
routes[prefix] = new(netipx.IPSetBuilder)
}
for _, autoApprover := range autoApprovers {
aa, ok := autoApprover.(Alias)
if !ok {
@ -1173,6 +1216,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
}
var exitNodeSetBuilder netipx.IPSetBuilder
if len(p.AutoApprovers.ExitNode) > 0 {
for _, autoApprover := range p.AutoApprovers.ExitNode {
aa, ok := autoApprover.(Alias)
@ -1187,11 +1231,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
}
ret := make(map[netip.Prefix]*netipx.IPSet)
for prefix, builder := range routes {
ipSet, err := builder.IPSet()
if err != nil {
return nil, nil, err
}
ret[prefix] = ipSet
}
@ -1235,6 +1281,7 @@ func (a *Action) UnmarshalJSON(b []byte) error {
default:
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
}
return nil
}
@ -1259,6 +1306,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error {
default:
return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str)
}
return nil
}
@ -1399,7 +1447,7 @@ func (p Protocol) validate() error {
return nil
case ProtocolWildcard:
// Wildcard "*" is not allowed - Tailscale rejects it
return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
default:
// Try to parse as a numeric protocol number
str := string(p)
@ -1427,7 +1475,7 @@ func (p Protocol) MarshalJSON() ([]byte, error) {
return json.Marshal(string(p))
}
// Protocol constants matching the IANA numbers
// Protocol constants matching the IANA numbers.
const (
protocolICMP = 1 // Internet Control Message
protocolIGMP = 2 // Internet Group Management
@ -1464,6 +1512,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error {
// Remove any fields that start with '#'
filtered := make(map[string]any)
for key, value := range raw {
if !strings.HasPrefix(key, "#") {
filtered[key] = value
@ -1478,6 +1527,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error {
// Create a type alias to avoid infinite recursion
type aclAlias ACL
var temp aclAlias
// Unmarshal into the temporary struct using the v2 JSON options
@ -1487,6 +1537,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error {
// Copy the result back to the original struct
*a = ACL(temp)
return nil
}
@ -1733,6 +1784,7 @@ func (p *Policy) validate() error {
}
}
}
for _, dst := range ssh.Destinations {
switch dst := dst.(type) {
case *AutoGroup:
@ -1846,6 +1898,7 @@ func (g Groups) MarshalJSON() ([]byte, error) {
for i, username := range usernames {
users[i] = string(username)
}
raw[string(group)] = users
}
@ -1854,6 +1907,7 @@ func (g Groups) MarshalJSON() ([]byte, error) {
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
@ -1877,6 +1931,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
var aliases []AliasEnc
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
if err != nil {
return err
@ -1960,8 +2015,10 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
}
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
var errs []error
var (
ips netipx.IPSetBuilder
errs []error
)
for _, alias := range a {
aips, err := alias.Resolve(p, users, nodes)
@ -2015,18 +2072,22 @@ func unmarshalPolicy(b []byte) (*Policy, error) {
}
var policy Policy
ast, err := hujson.Parse(b)
if err != nil {
return nil, fmt.Errorf("parsing HuJSON: %w", err)
}
ast.Standardize()
if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil {
if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName {
ptr := serr.JSONPointer
name := ptr.LastToken()
return nil, fmt.Errorf("unknown field %q", name)
}
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
}
@ -2073,6 +2134,7 @@ func (p *Policy) usesAutogroupSelf() bool {
return true
}
}
for _, dest := range acl.Destinations {
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
return true
@ -2087,6 +2149,7 @@ func (p *Policy) usesAutogroupSelf() bool {
return true
}
}
for _, dest := range ssh.Destinations {
if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
return true

View file

@ -81,6 +81,7 @@ func TestMarshalJSON(t *testing.T) {
// Unmarshal back to verify round trip
var roundTripped Policy
err = json.Unmarshal(marshalled, &roundTripped)
require.NoError(t, err)
@ -2020,6 +2021,7 @@ func TestResolvePolicy(t *testing.T) {
}
var prefs []netip.Prefix
if ips != nil {
if p := ips.Prefixes(); len(p) > 0 {
prefs = p
@ -2191,9 +2193,11 @@ func TestResolveAutoApprovers(t *testing.T) {
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got, cmps...); diff != "" {
t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff)
}
if tt.wantAllIPRoutes != nil {
if gotAllIPRoutes == nil {
t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil")
@ -2340,6 +2344,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
for _, p := range prefixes {
builder.AddPrefix(mp(p))
}
ipSet, _ := builder.IPSet()
return ipSet
@ -2349,6 +2354,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool {
if x == nil || y == nil {
return x == y
}
return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...)
}
@ -2577,6 +2583,7 @@ func TestResolveTagOwners(t *testing.T) {
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got, cmps...); diff != "" {
t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff)
}
@ -2852,6 +2859,7 @@ func TestNodeCanHaveTag(t *testing.T) {
require.ErrorContains(t, err, tt.wantErr)
return
}
require.NoError(t, err)
got := pm.NodeCanHaveTag(tt.node.View(), tt.tag)
@ -3112,6 +3120,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var acl ACL
err := json.Unmarshal([]byte(tt.input), &acl)
if tt.wantErr {
@ -3163,6 +3172,7 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) {
// Unmarshal back
var unmarshaled ACL
err = json.Unmarshal(jsonBytes, &unmarshaled)
require.NoError(t, err)
@ -3241,12 +3251,13 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
assert.Contains(t, err.Error(), `invalid action "deny"`)
}
// Helper function to parse aliases for testing
// Helper function to parse aliases for testing.
func mustParseAlias(s string) Alias {
alias, err := parseAlias(s)
if err != nil {
panic(err)
}
return alias
}

View file

@ -18,9 +18,11 @@ func splitDestinationAndPort(input string) (string, string, error) {
if lastColonIndex == -1 {
return "", "", errors.New("input must contain a colon character separating destination and port")
}
if lastColonIndex == 0 {
return "", "", errors.New("input cannot start with a colon character")
}
if lastColonIndex == len(input)-1 {
return "", "", errors.New("input cannot end with a colon character")
}
@ -45,6 +47,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
for part := range parts {
if strings.Contains(part, "-") {
rangeParts := strings.Split(part, "-")
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
return e == ""
})

View file

@ -58,9 +58,11 @@ func TestParsePort(t *testing.T) {
if err != nil && err.Error() != test.err {
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
}
if err == nil && test.err != "" {
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
}
if result != test.expected {
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
}
@ -92,9 +94,11 @@ func TestParsePortRange(t *testing.T) {
if err != nil && err.Error() != test.err {
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
}
if err == nil && test.err != "" {
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
}
if diff := cmp.Diff(result, test.expected); diff != "" {
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
}

View file

@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() {
// This is not my favourite solution, but it kind of works in our eventually consistent world.
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
disconnected := true
// Wait up to 10 seconds for the node to reconnect.
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() {
disconnected = false
break
}
<-ticker.C
}
@ -215,8 +217,10 @@ func (m *mapSession) serveLongPoll() {
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
return
}
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
m.h.Change(mapReqChange)

View file

@ -107,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
Msg("Current primary no longer available")
}
}
if len(nodes) >= 1 {
pr.primaries[prefix] = nodes[0]
changed = true
log.Debug().
Caller().
Str("prefix", prefix.String()).
@ -126,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
Str("prefix", prefix.String()).
Msg("Cleaning up primary route that no longer has available nodes")
delete(pr.primaries, prefix)
changed = true
}
}
@ -161,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix)
// If no routes are being set, remove the node from the routes map.
if len(prefixes) == 0 {
wasPresent := false
if _, ok := pr.routes[node]; ok {
delete(pr.routes, node)
wasPresent = true
log.Debug().
Caller().
Uint64("node.id", node.Uint64()).
Msg("Removed node from primary routes (no prefixes)")
}
changed := pr.updatePrimaryLocked()
log.Debug().
Caller().
@ -254,12 +261,14 @@ func (pr *PrimaryRoutes) stringLocked() string {
ids := types.NodeIDs(xmaps.Keys(pr.routes))
slices.Sort(ids)
for _, id := range ids {
prefixes := pr.routes[id]
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
}
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
for route, nodeID := range pr.primaries {
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
}

View file

@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24"))
pr.SetRoutes(2, mp("192.168.2.0/24"))
pr.SetRoutes(1) // Deregister by setting no routes
return pr.SetRoutes(1, mp("192.168.3.0/24"))
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) {
{
name: "multiple-nodes-register-same-route",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(1) // true, 2 primary
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
2: {
@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) {
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) {
operations: func(pr *PrimaryRoutes) bool {
var wg sync.WaitGroup
wg.Add(2)
var change1, change2 bool
go func() {
defer wg.Done()
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
}()
go func() {
defer wg.Done()
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
}()
wg.Wait()
return change1 || change2
@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pr := New()
change := tt.operations(pr)
if change != tt.expectedChange {
t.Errorf("change = %v, want %v", change, tt.expectedChange)
}
comps := append(util.Comparers, cmpopts.EquateEmpty())
if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
t.Errorf("routes mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
}

View file

@ -77,6 +77,7 @@ func (s *State) DebugOverview() string {
ephemeralCount := 0
now := time.Now()
for _, node := range allNodes.All() {
if node.Valid() {
userName := node.Owner().Name()
@ -103,17 +104,21 @@ func (s *State) DebugOverview() string {
// User statistics
sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users)))
for userName, nodeCount := range userNodeCounts {
sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount))
}
sb.WriteString("\n")
// Policy information
sb.WriteString("Policy:\n")
sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode))
if s.cfg.Policy.Mode == types.PolicyModeFile {
sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path))
}
sb.WriteString("\n")
// DERP information
@ -123,6 +128,7 @@ func (s *State) DebugOverview() string {
} else {
sb.WriteString("DERP: not configured\n")
}
sb.WriteString("\n")
// Route information
@ -130,6 +136,7 @@ func (s *State) DebugOverview() string {
if s.primaryRoutes.String() == "" {
routeCount = 0
}
sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount))
sb.WriteString("\n")
@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string {
for _, node := range region.Nodes {
sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n",
node.Name, node.HostName, node.DERPPort))
if node.STUNPort != 0 {
sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort))
}
}
sb.WriteString("\n")
}
@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
if s.primaryRoutes.String() == "" {
routeCount = 0
}
info.PrimaryRoutes = routeCount
return info

View file

@ -20,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
// Create NodeStore
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -43,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
// 6. If DELETE came after UPDATE, the returned node should be invalid
done := make(chan bool, 2)
var updatedNode types.NodeView
var updateOk bool
var (
updatedNode types.NodeView
updateOk bool
)
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
go func() {
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = new(time.Now())
})
done <- true
}()
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
go func() {
store.DeleteNode(node.ID)
done <- true
}()
@ -90,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -147,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
node := createTestNode(3, 1, "test-user", "test-node-3")
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -203,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -213,8 +223,11 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
// 2. DeleteNode (from handleLogout when client sends logout request)
var updatedNode types.NodeView
var updateOk bool
var (
updatedNode types.NodeView
updateOk bool
)
done := make(chan bool, 2)
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
@ -222,12 +235,14 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = new(time.Now())
})
done <- true
}()
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
go func() {
store.DeleteNode(ephemeralNode.ID)
done <- true
}()
@ -266,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
// 5. UpdateNode and DeleteNode batch together
// 6. UpdateNode returns a valid node (from before delete in batch)
// 7. persistNodeToDB is called with the stale valid node
// 8. Node gets re-inserted into database instead of staying deleted
// 8. Node gets re-inserted into database instead of staying deleted.
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
ephemeralNode.AuthKey = &types.PreAuthKey{
@ -278,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -348,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -398,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
// 3. UpdateNode and DeleteNode batch together
// 4. UpdateNode returns a valid node (from before delete in batch)
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
// 6. persistNodeToDB must detect the node is deleted and refuse to persist.
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
ephemeralNode.AuthKey = &types.PreAuthKey{
@ -408,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
}
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()

View file

@ -29,6 +29,7 @@ func netInfoFromMapRequest(
Uint64("node.id", nodeID.Uint64()).
Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP).
Msg("using NetInfo from previous Hostinfo in MapRequest")
return currentHostinfo.NetInfo
}

View file

@ -136,7 +136,7 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
})
}
// Simple helper function for tests
// Simple helper function for tests.
func createTestNodeSimple(id types.NodeID) *types.Node {
user := types.User{
Name: "test-user",

View file

@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc
for _, n := range allNodes {
nodes[n.ID] = *n
}
snap := snapshotFromNodes(nodes, peersFunc)
store := &NodeStore{
@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView {
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
resultNode := <-work.nodeResult
nodeStoreOperations.WithLabelValues("put").Inc()
return resultNode
@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
resultNode := <-work.nodeResult
nodeStoreOperations.WithLabelValues("update").Inc()
// Return the node and whether it exists (is valid)
@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) {
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() {
if len(batch) != 0 {
s.applyBatch(batch)
}
return
}
batch = append(batch, w)
if len(batch) >= s.batchSize {
s.applyBatch(batch)
@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) {
w.updateFn(&n)
nodes[w.nodeID] = n
}
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) {
nodeView := node.View()
for _, w := range workItems {
w.nodeResult <- nodeView
close(w.nodeResult)
}
} else {
// Node was deleted or doesn't exist
for _, w := range workItems {
w.nodeResult <- types.NodeView{} // Send invalid view
close(w.nodeResult)
}
}
@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
peersByNode: func() map[types.NodeID][]types.NodeView {
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
defer peersTimer.ObserveDuration()
return peersFunc(allNodes)
}(),
nodesByUser: make(map[types.UserID][]types.NodeView),
@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
}
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
}
@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string {
// User distribution (shows internal UserID tracking, not display owner)
sb.WriteString("Nodes by Internal User ID:\n")
for userID, nodes := range snapshot.nodesByUser {
if len(nodes) > 0 {
userName := "unknown"
taggedCount := 0
if len(nodes) > 0 && nodes[0].Valid() {
userName = nodes[0].User().Name()
// Count tagged nodes (which have UserID set but are owned by "tagged-devices")
@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string {
}
}
}
sb.WriteString("\n")
// Peer relationships summary
sb.WriteString("Peer Relationships:\n")
totalPeers := 0
for nodeID, peers := range snapshot.peersByNode {
peerCount := len(peers)
totalPeers += peerCount
if node, exists := snapshot.nodesByID[nodeID]; exists {
sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n",
nodeID, node.Hostname, peerCount))
}
}
if len(snapshot.peersByNode) > 0 {
avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode))
sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers))
}
sb.WriteString("\n")
// Node key index
@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() {
}
s.writeQueue <- w
<-result
}

View file

@ -44,6 +44,7 @@ func TestSnapshotFromNodes(t *testing.T) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
@ -192,11 +193,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
@ -207,6 +210,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
nodeIsOdd := node.ID()%2 == 1
for _, n := range nodes {
@ -221,6 +225,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
@ -454,10 +459,13 @@ func TestNodeStoreOperations(t *testing.T) {
// Add nodes in sequence
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
assert.True(t, n1.Valid())
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
assert.True(t, n2.Valid())
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
assert.True(t, n3.Valid())
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
assert.True(t, n4.Valid())
@ -525,16 +533,20 @@ func TestNodeStoreOperations(t *testing.T) {
done2 := make(chan struct{})
done3 := make(chan struct{})
var resultNode1, resultNode2 types.NodeView
var newNode3 types.NodeView
var ok1, ok2 bool
var (
resultNode1, resultNode2 types.NodeView
newNode3 types.NodeView
ok1, ok2 bool
)
// These should all be processed in the same batch
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "batch-updated-node1"
n.GivenName = "batch-given-1"
})
close(done1)
}()
@ -543,12 +555,14 @@ func TestNodeStoreOperations(t *testing.T) {
n.Hostname = "batch-updated-node2"
n.GivenName = "batch-given-2"
})
close(done2)
}()
go func() {
node3 := createTestNode(3, 1, "user1", "node3")
newNode3 = store.PutNode(node3)
close(done3)
}()
@ -601,20 +615,23 @@ func TestNodeStoreOperations(t *testing.T) {
// This test verifies that when multiple updates to the same node
// are batched together, each returned node reflects ALL changes
// in the batch, not just the individual update's changes.
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var resultNode1, resultNode2, resultNode3 types.NodeView
var ok1, ok2, ok3 bool
var (
resultNode1, resultNode2, resultNode3 types.NodeView
ok1, ok2, ok3 bool
)
// These updates all modify node 1 and should be batched together
// The final state should have all three modifications applied
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "multi-update-hostname"
})
close(done1)
}()
@ -622,6 +639,7 @@ func TestNodeStoreOperations(t *testing.T) {
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "multi-update-givenname"
})
close(done2)
}()
@ -629,6 +647,7 @@ func TestNodeStoreOperations(t *testing.T) {
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.Tags = []string{"tag1", "tag2"}
})
close(done3)
}()
@ -722,14 +741,18 @@ func TestNodeStoreOperations(t *testing.T) {
done2 := make(chan struct{})
done3 := make(chan struct{})
var result1, result2, result3 types.NodeView
var ok1, ok2, ok3 bool
var (
result1, result2, result3 types.NodeView
ok1, ok2, ok3 bool
)
// Start concurrent updates
go func() {
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "concurrent-db-hostname"
})
close(done1)
}()
@ -737,6 +760,7 @@ func TestNodeStoreOperations(t *testing.T) {
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "concurrent-db-given"
})
close(done2)
}()
@ -744,6 +768,7 @@ func TestNodeStoreOperations(t *testing.T) {
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.Tags = []string{"concurrent-tag"}
})
close(done3)
}()
@ -827,6 +852,7 @@ func TestNodeStoreOperations(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := tt.setupFunc(t)
store.Start()
defer store.Stop()
@ -846,10 +872,11 @@ type testStep struct {
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
// Helper for concurrent test nodes
// Helper for concurrent test nodes.
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
machineKey := key.NewMachine()
nodeKey := key.NewNode()
return types.Node{
ID: id,
Hostname: hostname,
@ -862,72 +889,90 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
}
}
// --- Concurrency: concurrent PutNode operations ---
// --- Concurrency: concurrent PutNode operations ---.
func TestNodeStoreConcurrentPutNode(t *testing.T) {
const concurrentOps = 20
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, concurrentOps)
for i := range concurrentOps {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
}
// --- Batching: concurrent ops fit in one batch ---
// --- Batching: concurrent ops fit in one batch ---.
func TestNodeStoreBatchingEfficiency(t *testing.T) {
const batchSize = 10
const ops = 15 // more than batchSize
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, ops)
for i := range ops {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
}
// --- Race conditions: many goroutines on same node ---
// --- Race conditions: many goroutines on same node ---.
func TestNodeStoreRaceConditions(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -936,13 +981,18 @@ func TestNodeStoreRaceConditions(t *testing.T) {
resultNode := store.PutNode(node)
require.True(t, resultNode.Valid())
const numGoroutines = 30
const opsPerGoroutine = 10
const (
numGoroutines = 30
opsPerGoroutine = 10
)
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*opsPerGoroutine)
for i := range numGoroutines {
wg.Add(1)
go func(gid int) {
defer wg.Done()
@ -962,6 +1012,7 @@ func TestNodeStoreRaceConditions(t *testing.T) {
}
case 2:
newNode := createConcurrentTestNode(nodeID, "race-put")
resultNode := store.PutNode(newNode)
if !resultNode.Valid() {
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
@ -970,23 +1021,28 @@ func TestNodeStoreRaceConditions(t *testing.T) {
}
}(i)
}
wg.Wait()
close(errors)
errorCount := 0
for err := range errors {
t.Error(err)
errorCount++
}
if errorCount > 0 {
t.Fatalf("Race condition test failed with %d errors", errorCount)
}
}
// --- Resource cleanup: goroutine leak detection ---
// --- Resource cleanup: goroutine leak detection ---.
func TestNodeStoreResourceCleanup(t *testing.T) {
// initialGoroutines := runtime.NumGoroutine()
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -1009,10 +1065,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
})
retrieved, found := store.GetNode(nodeID)
assert.True(t, found && retrieved.Valid())
if i%10 == 9 {
store.DeleteNode(nodeID)
}
}
runtime.GC()
// Wait for goroutines to settle and check for leaks
@ -1023,9 +1081,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
}
// --- Timeout/deadlock: operations complete within reasonable time ---
// --- Timeout/deadlock: operations complete within reasonable time ---.
func TestNodeStoreOperationTimeout(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -1033,36 +1092,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
defer cancel()
const ops = 30
var wg sync.WaitGroup
putResults := make([]error, ops)
updateResults := make([]error, ops)
// Launch all PutNode operations concurrently
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i)
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
node := createConcurrentTestNode(id, "timeout-node")
resultNode := store.PutNode(node)
endPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
if !resultNode.Valid() {
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
}
}(i, nodeID)
}
wg.Wait()
// Launch all UpdateNode operations concurrently
wg = sync.WaitGroup{}
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i)
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
@ -1070,31 +1140,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
})
endUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
if !ok || !resultNode.Valid() {
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
}
}(i, nodeID)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
errorCount := 0
for _, err := range putResults {
if err != nil {
t.Error(err)
errorCount++
}
}
for _, err := range updateResults {
if err != nil {
t.Error(err)
errorCount++
}
}
if errorCount == 0 {
t.Log("All concurrent operations completed successfully within timeout")
} else {
@ -1106,13 +1185,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
}
}
// --- Edge case: update non-existent node ---
// --- Edge case: update non-existent node ---.
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
for i := range 10 {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
nonExistentID := types.NodeID(999 + i)
updateCallCount := 0
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
updateCallCount++
@ -1126,9 +1207,10 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
}
}
// --- Allocation benchmark ---
// --- Allocation benchmark ---.
func BenchmarkNodeStoreAllocations(b *testing.B) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -1140,6 +1222,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) {
n.Hostname = "bench-updated"
})
store.GetNode(nodeID)
if i%10 == 9 {
store.DeleteNode(nodeID)
}

View file

@ -93,6 +93,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
mux := tsql.NewMux()
tsweb.Debugger(mux)
go http.Serve(lst, mux)
logf("TailSQL started")
<-ctx.Done()
logf("TailSQL shutting down...")

View file

@ -177,6 +177,7 @@ func RegistrationIDFromString(str string) (RegistrationID, error) {
if len(str) != RegistrationIDLength {
return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength)
}
return RegistrationID(str), nil
}

View file

@ -301,6 +301,7 @@ func validatePKCEMethod(method string) error {
if method != PKCEMethodPlain && method != PKCEMethodS256 {
return errInvalidPKCEMethod
}
return nil
}
@ -1082,6 +1083,7 @@ func LoadServerConfig() (*Config, error) {
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
return workers
}
return DefaultBatcherWorkers()
}(),
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
@ -1117,6 +1119,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
}
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range baseDomainParts {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {

View file

@ -363,6 +363,7 @@ noise:
// Populate a custom config file
configFilePath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
@ -398,10 +399,12 @@ server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
err = LoadConfig(tmpDir, false)
require.NoError(t, err)
}
@ -463,6 +466,7 @@ func TestSafeServerURL(t *testing.T) {
return
}
assert.NoError(t, err)
})
}

View file

@ -156,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool {
// Strip invalid DNS characters for givenName comparison
normalised := strings.ToLower(node.Hostname)
normalised = invalidDNSRegex.ReplaceAllString(normalised, "")
return node.GivenName == normalised
}
@ -464,7 +465,7 @@ func (node *Node) IsSubnetRouter() bool {
return len(node.SubnetRoutes()) > 0
}
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes.
func (node *Node) AllApprovedRoutes() []netip.Prefix {
return append(node.SubnetRoutes(), node.ExitRoutes()...)
}
@ -579,6 +580,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
Str("rejected_hostname", hostInfo.Hostname).
Err(err).
Msg("Rejecting invalid hostname update from hostinfo")
return
}
@ -670,6 +672,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
func (nodes Nodes) DebugString() string {
var sb strings.Builder
sb.WriteString("Nodes:\n")
for _, node := range nodes {
sb.WriteString(node.DebugString())
sb.WriteString("\n")

View file

@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error {
if pak.Expiration != nil {
return *pak.Expiration
}
return time.Time{}
}()).
Time("now", time.Now()).

View file

@ -40,9 +40,11 @@ var TaggedDevices = User{
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
for _, user := range u {
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
}
sb.WriteString(" ]")
return sb.String()
@ -89,6 +91,7 @@ func (u *User) StringID() string {
if u == nil {
return ""
}
return strconv.FormatUint(uint64(u.ID), 10)
}
@ -203,6 +206,7 @@ type FlexibleBoolean bool
func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
var val any
err := json.Unmarshal(data, &val)
if err != nil {
return fmt.Errorf("could not unmarshal data: %w", err)
@ -216,6 +220,7 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
if err != nil {
return fmt.Errorf("could not parse %s as boolean: %w", v, err)
}
*bit = FlexibleBoolean(pv)
default:
@ -253,9 +258,11 @@ func (c *OIDCClaims) Identifier() string {
if c.Iss == "" && c.Sub == "" {
return ""
}
if c.Iss == "" {
return CleanIdentifier(c.Sub)
}
if c.Sub == "" {
return CleanIdentifier(c.Iss)
}
@ -340,6 +347,7 @@ func CleanIdentifier(identifier string) string {
cleanParts = append(cleanParts, trimmed)
}
}
if len(cleanParts) == 0 {
return ""
}
@ -382,6 +390,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) {
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
identifier = "/" + identifier
}
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
u.DisplayName = claims.Name
u.ProfilePicURL = claims.ProfilePictureURL

View file

@ -70,6 +70,7 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
t.Errorf("UnmarshallOIDCClaims() error = %v", err)
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff)
}
@ -190,6 +191,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) {
}
result := claims.Identifier()
assert.Equal(t, tt.expected, result)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
}
@ -282,6 +284,7 @@ func TestCleanIdentifier(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result := CleanIdentifier(tt.identifier)
assert.Equal(t, tt.expected, result)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
}
@ -487,6 +490,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
var user User
user.FromClaim(&got, tt.emailVerifiedRequired)
if diff := cmp.Diff(user, tt.want); diff != "" {
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
}

View file

@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) {
t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && got != tt.want {
t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want)
}
@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) {
t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errorContains != "" {
if err == nil || !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains)

View file

@ -15,10 +15,12 @@ func YesNo(msg string) bool {
var resp string
fmt.Scanln(&resp)
resp = strings.ToLower(resp)
switch resp {
case "y", "yes", "sure":
return true
}
return false
}

View file

@ -86,6 +86,7 @@ func TestYesNo(t *testing.T) {
// Write test input
go func() {
defer w.Close()
w.WriteString(tt.input)
}()
@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) {
// Restore stdin and stderr
os.Stdin = oldStdin
os.Stderr = oldStderr
stderrW.Close()
// Check the result
@ -108,6 +110,7 @@ func TestYesNo(t *testing.T) {
stderrR.Close()
expectedPrompt := "Test question [y/n] "
actualPrompt := stderrBuf.String()
if actualPrompt != expectedPrompt {
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
@ -130,6 +133,7 @@ func TestYesNoPromptMessage(t *testing.T) {
// Write test input
go func() {
defer w.Close()
w.WriteString("n\n")
}()
@ -140,6 +144,7 @@ func TestYesNoPromptMessage(t *testing.T) {
// Restore stdin and stderr
os.Stdin = oldStdin
os.Stderr = oldStderr
stderrW.Close()
// Check that the custom message was included in the prompt
@ -148,6 +153,7 @@ func TestYesNoPromptMessage(t *testing.T) {
stderrR.Close()
expectedPrompt := customMessage + " [y/n] "
actualPrompt := stderrBuf.String()
if actualPrompt != expectedPrompt {
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
@ -186,6 +192,7 @@ func TestYesNoCaseInsensitive(t *testing.T) {
// Write test input
go func() {
defer w.Close()
w.WriteString(tc.input)
}()
@ -195,6 +202,7 @@ func TestYesNoCaseInsensitive(t *testing.T) {
// Restore stdin and stderr
os.Stdin = oldStdin
os.Stderr = oldStderr
stderrW.Close()
// Drain stderr

View file

@ -33,6 +33,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) {
b, err := GenerateRandomBytes(n)
uenc := base64.RawURLEncoding.EncodeToString(b)
return uenc[:n], err
}
@ -99,6 +100,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string {
DstIPs: %v
}
`, rule.SrcIPs, rule.DstPorts))
if index < len(rules)-1 {
sb.WriteString(", ")
}

View file

@ -30,6 +30,7 @@ func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool {
// It returns an error if not exactly one URL is found.
func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
lines := strings.Split(output, "\n")
var urlStr string
for _, line := range lines {
@ -38,6 +39,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
if urlStr != "" {
return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line)
}
urlStr = line
}
}
@ -94,6 +96,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
// Parse the header line - handle both 'traceroute' and 'tracert' (Windows)
headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`)
headerMatches := headerRegex.FindStringSubmatch(lines[0])
if len(headerMatches) < 2 {
return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0])
@ -105,6 +108,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
if ipStr == "" {
ipStr = headerMatches[3]
}
ip, err := netip.ParseAddr(ipStr)
if err != nil {
return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err)
@ -144,13 +148,17 @@ func ParseTraceroute(output string) (Traceroute, error) {
}
remainder := strings.TrimSpace(matches[2])
var hopHostname string
var hopIP netip.Addr
var latencies []time.Duration
var (
hopHostname string
hopIP netip.Addr
latencies []time.Duration
)
// Check for Windows tracert format which has latencies before hostname
// Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]"
latencyFirst := false
if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") {
// Check if latencies appear before any hostname/IP
firstSpace := strings.Index(remainder, " ")
@ -171,12 +179,14 @@ func ParseTraceroute(output string) (Traceroute, error) {
}
// Extract and remove the latency from the beginning
latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<")
ms, err := strconv.ParseFloat(latStr, 64)
if err == nil {
// Round to nearest microsecond to avoid floating point precision issues
duration := time.Duration(ms * float64(time.Millisecond))
latencies = append(latencies, duration.Round(time.Microsecond))
}
remainder = strings.TrimSpace(remainder[latMatch[1]:])
}
}
@ -205,6 +215,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
if ip, err := netip.ParseAddr(parts[0]); err == nil {
hopIP = ip
}
remainder = strings.TrimSpace(strings.Join(parts[1:], " "))
}
}
@ -216,6 +227,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
if len(match) > 1 {
// Remove '<' prefix if present (e.g., "<1 ms")
latStr := strings.TrimPrefix(match[1], "<")
ms, err := strconv.ParseFloat(latStr, 64)
if err == nil {
// Round to nearest microsecond to avoid floating point precision issues
@ -280,11 +292,13 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
if key == "" {
return "unknown-node"
}
keyPrefix := key
if len(key) > 8 {
keyPrefix = key[:8]
}
return fmt.Sprintf("node-%s", keyPrefix)
return "node-" + keyPrefix
}
lowercased := strings.ToLower(hostinfo.Hostname)

View file

@ -180,6 +180,7 @@ Success.`,
if err != nil {
t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr)
}
if gotURL.String() != tt.wantURL {
t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL)
}
@ -1066,6 +1067,7 @@ func TestEnsureHostname(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.want, "invalid-") {
@ -1103,9 +1105,11 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "test-node" {
t.Errorf("hostname = %v, want test-node", hi.Hostname)
}
if hi.OS != "linux" {
t.Errorf("OS = %v, want linux", hi.OS)
}
@ -1147,6 +1151,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "node-nkey1234" {
t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname)
}
@ -1162,6 +1167,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "unknown-node" {
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
}
@ -1179,6 +1185,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "unknown-node" {
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
}
@ -1200,18 +1207,23 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if hi.Hostname != "test" {
t.Errorf("hostname = %v, want test", hi.Hostname)
}
if hi.OS != "windows" {
t.Errorf("OS = %v, want windows", hi.OS)
}
if hi.OSVersion != "10.0.19044" {
t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion)
}
if hi.DeviceModel != "test-device" {
t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel)
}
if hi.BackendLogID != "log123" {
t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID)
}
@ -1229,6 +1241,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
if hi == nil {
t.Error("hostinfo should not be nil")
}
if len(hi.Hostname) != 63 {
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
}
@ -1239,6 +1252,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
// For invalid hostnames, we just check the prefix since the random part varies
if strings.HasPrefix(tt.wantHostname, "invalid-") {
@ -1265,6 +1279,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
for i, hostname := range testCases {
t.Run(cmp.Diff("", ""), func(t *testing.T) {
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
result := EnsureHostname(hostinfo, "mkey", "nkey")
if len(result) > 63 {
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))

View file

@ -35,6 +35,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -46,6 +47,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
// Create an API key using the CLI
var validAPIKey string
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
apiKeyOutput, err := headscale.Execute(
[]string{
@ -63,7 +65,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
// Get the API endpoint
endpoint := headscale.GetEndpoint()
apiURL := fmt.Sprintf("%s/api/v1/user", endpoint)
apiURL := endpoint + "/api/v1/user"
// Create HTTP client
client := &http.Client{
@ -81,6 +83,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
@ -99,6 +102,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
// Should NOT contain user data after "Unauthorized"
// This is the security bypass - if users array is present, auth was bypassed
var jsonCheck map[string]any
jsonErr := json.Unmarshal(body, &jsonCheck)
// If we can unmarshal JSON and it contains "users", that's the bypass
@ -132,6 +136,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
@ -165,6 +170,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
@ -193,10 +199,11 @@ func TestAPIAuthenticationBypass(t *testing.T) {
// Expected: Should return 200 with user data (this is the authorized case)
req, err := http.NewRequest("GET", apiURL, nil)
require.NoError(t, err)
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey))
req.Header.Set("Authorization", "Bearer "+validAPIKey)
resp, err := client.Do(req)
require.NoError(t, err)
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
@ -208,16 +215,19 @@ func TestAPIAuthenticationBypass(t *testing.T) {
// Should be able to parse as protobuf JSON
var response v1.ListUsersResponse
err = protojson.Unmarshal(body, &response)
assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key")
// Should contain our test users
users := response.GetUsers()
assert.Len(t, users, 3, "Should have 3 users")
userNames := make([]string, len(users))
for i, u := range users {
userNames[i] = u.GetName()
}
assert.Contains(t, userNames, "user1")
assert.Contains(t, userNames, "user2")
assert.Contains(t, userNames, "user3")
@ -234,6 +244,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -254,10 +265,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
},
)
require.NoError(t, err)
validAPIKey := strings.TrimSpace(apiKeyOutput)
endpoint := headscale.GetEndpoint()
apiURL := fmt.Sprintf("%s/api/v1/user", endpoint)
apiURL := endpoint + "/api/v1/user"
t.Run("Curl_NoAuth", func(t *testing.T) {
// Execute curl from inside the headscale container without auth
@ -274,16 +286,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
// Parse the output
lines := strings.Split(curlOutput, "\n")
var httpCode string
var responseBody string
var (
httpCode string
responseBody string
)
var responseBodySb295 strings.Builder
for _, line := range lines {
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
httpCode = after
} else {
responseBody += line
responseBodySb295.WriteString(line)
}
}
responseBody += responseBodySb295.String()
// Should return 401
assert.Equal(t, "401", httpCode,
@ -320,16 +337,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
require.NoError(t, err)
lines := strings.Split(curlOutput, "\n")
var httpCode string
var responseBody string
var (
httpCode string
responseBody string
)
var responseBodySb344 strings.Builder
for _, line := range lines {
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
httpCode = after
} else {
responseBody += line
responseBodySb344.WriteString(line)
}
}
responseBody += responseBodySb344.String()
assert.Equal(t, "401", httpCode)
assert.Contains(t, responseBody, "Unauthorized")
@ -346,7 +368,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
"curl",
"-s",
"-H",
fmt.Sprintf("Authorization: Bearer %s", validAPIKey),
"Authorization: Bearer " + validAPIKey,
"-w",
"\nHTTP_CODE:%{http_code}",
apiURL,
@ -355,8 +377,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
require.NoError(t, err)
lines := strings.Split(curlOutput, "\n")
var httpCode string
var responseBody strings.Builder
var (
httpCode string
responseBody strings.Builder
)
for _, line := range lines {
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
@ -372,8 +397,10 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
// Should contain user data
var response v1.ListUsersResponse
err = protojson.Unmarshal([]byte(responseBody.String()), &response)
assert.NoError(t, err, "Response should be valid protobuf JSON")
users := response.GetUsers()
assert.Len(t, users, 2, "Should have 2 users")
})
@ -391,6 +418,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -420,11 +448,12 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
},
)
require.NoError(t, err)
validAPIKey := strings.TrimSpace(apiKeyOutput)
// Get the gRPC endpoint
// For gRPC, we need to use the hostname and port 50443
grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname())
grpcAddress := headscale.GetHostname() + ":50443"
t.Run("gRPC_NoAPIKey", func(t *testing.T) {
// Test 1: Try to use CLI without API key (should fail)
@ -487,6 +516,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
// CLI outputs the users array directly, not wrapped in ListUsersResponse
// Parse as JSON array (CLI uses json.Marshal, not protojson)
var users []*v1.User
err = json.Unmarshal([]byte(output), &users)
assert.NoError(t, err, "Response should be valid JSON array")
assert.Len(t, users, 2, "Should have 2 users")
@ -495,6 +525,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
for i, u := range users {
userNames[i] = u.GetName()
}
assert.Contains(t, userNames, "grpcuser1")
assert.Contains(t, userNames, "grpcuser2")
})
@ -513,6 +544,7 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -540,9 +572,10 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) {
},
)
require.NoError(t, err)
validAPIKey := strings.TrimSpace(apiKeyOutput)
grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname())
grpcAddress := headscale.GetHostname() + ":50443"
// Create a config file for testing
configWithoutKey := fmt.Sprintf(`
@ -643,6 +676,7 @@ cli:
// CLI outputs the users array directly, not wrapped in ListUsersResponse
// Parse as JSON array (CLI uses json.Marshal, not protojson)
var users []*v1.User
err = json.Unmarshal([]byte(output), &users)
assert.NoError(t, err, "Response should be valid JSON array")
assert.Len(t, users, 2, "Should have 2 users")
@ -651,6 +685,7 @@ cli:
for i, u := range users {
userNames[i] = u.GetName()
}
assert.Contains(t, userNames, "cliuser1")
assert.Contains(t, userNames, "cliuser2")
})

View file

@ -30,6 +30,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -68,18 +69,24 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
// assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr)
for _, client := range allClients {
ips, err := client.IPs()
if err != nil {
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
}
clientIPs[client] = ips
}
var listNodes []*v1.Node
var nodeCountBeforeLogout int
var (
listNodes []*v1.Node
nodeCountBeforeLogout int
)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
@ -110,6 +117,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after logout")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
@ -147,6 +155,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after relogin")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
@ -200,6 +209,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, nodeCountBeforeLogout)
@ -254,10 +264,14 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
var (
listNodes []*v1.Node
nodeCountBeforeLogout int
)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))
@ -300,9 +314,11 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
}
var user1Nodes []*v1.Node
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user1Nodes, err = headscale.ListNodes("user1")
assert.NoError(ct, err, "Failed to list nodes for user1 after relogin")
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes))
@ -322,15 +338,18 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
// When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created
// for the new user. The original nodes remain with the original user.
var user2Nodes []*v1.Node
t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user2Nodes, err = headscale.ListNodes("user2")
assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin")
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes))
}, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)")
t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat))
for _, client := range allClients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()
@ -351,6 +370,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -376,11 +396,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
// assertClientsState(t, allClients)
clientIPs := make(map[TailscaleClient][]netip.Addr)
for _, client := range allClients {
ips, err := client.IPs()
if err != nil {
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
}
clientIPs[client] = ips
}
@ -394,10 +416,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
var listNodes []*v1.Node
var nodeCountBeforeLogout int
var (
listNodes []*v1.Node
nodeCountBeforeLogout int
)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, listNodes, len(allClients))

View file

@ -149,6 +149,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -176,6 +177,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
syncCompleteTime := time.Now()
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
loginDuration := time.Since(syncCompleteTime)
t.Logf("Login and sync completed in %v", loginDuration)
@ -207,6 +209,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
// Check each client's status individually to provide better diagnostics
expiredCount := 0
for _, client := range allClients {
status, err := client.Status()
if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) {
@ -356,6 +359,7 @@ func TestOIDC024UserCreation(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -413,6 +417,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -470,6 +475,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
oidcMockUser("user1", true),
},
})
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -508,6 +514,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users during initial validation")
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@ -528,9 +535,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
var listNodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during initial validation")
assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes))
@ -538,14 +548,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Collect expected node IDs for validation after user1 initial login
expectedNodes := make([]types.NodeID, 0, 1)
var nodeID uint64
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status := ts.MustStatus()
assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status")
var err error
nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64)
assert.NoError(ct, err, "Failed to parse node ID from status")
}, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login")
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
// Validate initial connection state for user1
@ -583,6 +598,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users after user2 login")
assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@ -638,10 +654,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Security validation: Only user2's node should be active after user switch
var activeUser2NodeID types.NodeID
for _, node := range listNodesAfterNewUserLogin {
if node.GetUser().GetId() == 2 { // user2
activeUser2NodeID = types.NodeID(node.GetId())
t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName())
break
}
}
@ -655,6 +673,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Check user2 node is online
if node, exists := nodeStore[activeUser2NodeID]; exists {
assert.NotNil(c, node.IsOnline, "User2 node should have online status")
if node.IsOnline != nil {
assert.True(c, *node.IsOnline, "User2 node should be online after login")
}
@ -747,6 +766,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users during final validation")
assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@ -816,10 +836,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Security validation: Only user1's node should be active after relogin
var activeUser1NodeID types.NodeID
for _, node := range listNodesAfterLoggingBackIn {
if node.GetUser().GetId() == 1 { // user1
activeUser1NodeID = types.NodeID(node.GetId())
t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName())
break
}
}
@ -833,6 +855,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
// Check user1 node is online
if node, exists := nodeStore[activeUser1NodeID]; exists {
assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin")
if node.IsOnline != nil {
assert.True(c, *node.IsOnline, "User1 node should be online after relogin")
}
@ -907,6 +930,7 @@ func TestOIDCFollowUpUrl(t *testing.T) {
time.Sleep(2 * time.Minute)
var newUrl *url.URL
assert.EventuallyWithT(t, func(c *assert.CollectT) {
st, err := ts.Status()
assert.NoError(c, err)
@ -1103,6 +1127,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
oidcMockUser("user1", true), // Relogin with same user
},
})
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -1142,6 +1167,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users during initial validation")
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@ -1162,9 +1188,12 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
var initialNodes []*v1.Node
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
initialNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes during initial validation")
assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes))
@ -1172,14 +1201,19 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
// Collect expected node IDs for validation after user1 initial login
expectedNodes := make([]types.NodeID, 0, 1)
var nodeID uint64
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status := ts.MustStatus()
assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status")
var err error
nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64)
assert.NoError(ct, err, "Failed to parse node ID from status")
}, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login")
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
// Validate initial connection state for user1
@ -1236,6 +1270,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
listUsers, err := headscale.ListUsers()
assert.NoError(ct, err, "Failed to list users during final validation")
assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers))
wantUsers := []*v1.User{
{
Id: 1,
@ -1256,6 +1291,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
}, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle")
var finalNodes []*v1.Node
t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
finalNodes, err = headscale.ListNodes()
@ -1279,6 +1315,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
// Security validation: user1's node should be active after relogin
activeUser1NodeID := types.NodeID(finalNodes[0].GetId())
t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat))
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodeStore, err := headscale.DebugNodeStore()
@ -1287,6 +1324,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
// Check user1 node is online
if node, exists := nodeStore[activeUser1NodeID]; exists {
assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin")
if node.IsOnline != nil {
assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin")
}
@ -1356,6 +1394,7 @@ func TestOIDCExpiryAfterRestart(t *testing.T) {
// Verify initial expiry is set
var initialExpiry time.Time
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(ct, err)

View file

@ -67,6 +67,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -106,13 +107,16 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
validateInitialConnection(t, headscale, expectedNodes)
var listNodes []*v1.Node
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after web authentication")
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
}, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication")
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -152,6 +156,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after web flow logout")
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
@ -226,6 +231,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -256,13 +262,16 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
validateInitialConnection(t, headscale, expectedNodes)
var listNodes []*v1.Node
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
listNodes, err = headscale.ListNodes()
assert.NoError(ct, err, "Failed to list nodes after initial web authentication")
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
}, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication")
nodeCountBeforeLogout := len(listNodes)
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
@ -313,9 +322,11 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
t.Logf("all clients logged back in as user1")
var user1Nodes []*v1.Node
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user1Nodes, err = headscale.ListNodes("user1")
assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin")
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes))
@ -333,15 +344,18 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
// Validate that user2's old nodes still exist in database (but are expired/offline)
// When CLI registration creates new nodes for user1, user2's old nodes remain
var user2Nodes []*v1.Node
t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat))
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
user2Nodes, err = headscale.ListNodes("user2")
assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1")
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes))
}, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1")
t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat))
for _, client := range allClients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
status, err := client.Status()

View file

@ -25,6 +25,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
// Generate random hostname for the headscale instance
hash, err := util.GenerateRandomStringDNSSafe(6)
require.NoError(t, err)
testName := "derpverify"
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
@ -40,6 +41,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -107,6 +109,7 @@ func DERPVerify(
if err := c.Connect(t.Context()); err != nil {
result = fmt.Errorf("client Connect: %w", err)
}
if m, err := c.Recv(); err != nil {
result = fmt.Errorf("client first Recv: %w", err)
} else if v, ok := m.(derp.ServerInfoMessage); !ok {

View file

@ -34,6 +34,7 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
if opts.Labels == nil {
opts.Labels = make(map[string]string)
}
opts.Labels["hi.run-id"] = runID
opts.Labels["hi.test-type"] = testType
}

View file

@ -41,6 +41,7 @@ type buffer struct {
func (b *buffer) Write(p []byte) (n int, err error) {
b.mutex.Lock()
defer b.mutex.Unlock()
return b.store.Write(p)
}
@ -49,6 +50,7 @@ func (b *buffer) Write(p []byte) (n int, err error) {
func (b *buffer) String() string {
b.mutex.Lock()
defer b.mutex.Unlock()
return b.store.String()
}

View file

@ -47,6 +47,7 @@ func SaveLog(
}
var stdout, stderr bytes.Buffer
err = WriteLog(pool, resource, &stdout, &stderr)
if err != nil {
return "", "", err

View file

@ -18,6 +18,7 @@ func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Ne
if err != nil {
return nil, fmt.Errorf("looking up network names: %w", err)
}
if len(networks) == 0 {
if _, err := pool.CreateNetwork(name); err == nil {
// Create does not give us an updated version of the resource, so we need to
@ -90,6 +91,7 @@ func RandomFreeHostPort() (int, error) {
// CleanUnreferencedNetworks removes networks that are not referenced by any containers.
func CleanUnreferencedNetworks(pool *dockertest.Pool) error {
filter := "name=hs-"
networks, err := pool.NetworksByName(filter)
if err != nil {
return fmt.Errorf("getting networks by filter %q: %w", filter, err)
@ -122,6 +124,7 @@ func CleanImagesInCI(pool *dockertest.Pool) error {
}
removedCount := 0
for _, image := range images {
// Only remove dangling (untagged) images to avoid forcing rebuilds
// Dangling images have no RepoTags or only have "<none>:<none>"

View file

@ -159,10 +159,12 @@ func New(
} else {
hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
}
tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname)
if err != nil {
return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err)
}
dsic := &DERPServerInContainer{
version: version,
hostname: hostname,
@ -185,6 +187,7 @@ func New(
fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort)
fmt.Fprintf(&cmdArgs, " --stun=true")
fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort)
if dsic.withVerifyClientURL != "" {
fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL)
}
@ -214,11 +217,13 @@ func New(
}
var container *dockertest.Resource
buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.derper",
ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{},
}
switch version {
case "head":
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
@ -249,6 +254,7 @@ func New(
err,
)
}
log.Printf("Created %s container\n", hostname)
dsic.container = container
@ -259,12 +265,14 @@ func New(
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
if len(dsic.tlsCert) != 0 {
err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert)
if err != nil {
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
}
}
if len(dsic.tlsKey) != 0 {
err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey)
if err != nil {

View file

@ -3,6 +3,7 @@ package integration
import (
"bufio"
"bytes"
"errors"
"fmt"
"io"
"net/netip"
@ -47,7 +48,7 @@ const (
TimestampFormatRunID = "20060102-150405"
)
// NodeSystemStatus represents the status of a node across different systems
// NodeSystemStatus represents the status of a node across different systems.
type NodeSystemStatus struct {
Batcher bool
BatcherConnCount int
@ -104,7 +105,7 @@ func requireNoErrLogout(t *testing.T, err error) {
require.NoError(t, err, "failed to log out tailscale nodes")
}
// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes
// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes.
func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID {
t.Helper()
@ -113,8 +114,10 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod
status := client.MustStatus()
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
require.NoError(t, err)
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
}
return expectedNodes
}
@ -148,15 +151,17 @@ func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNode
}
// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems
// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems
// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems.
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
startTime := time.Now()
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message)
if expectedOnline {
@ -171,15 +176,17 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode
t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message)
}
// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state
// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state.
func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
t.Helper()
var prevReport string
require.EventuallyWithT(t, func(c *assert.CollectT) {
// Get batcher state
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
@ -187,6 +194,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
// Get map responses
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
@ -194,6 +202,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
@ -264,6 +273,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
if id == nodeID {
continue // Skip self-references
}
expectedPeerMaps++
if online, exists := peerMap[nodeID]; exists && online {
@ -278,6 +288,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
}
}
}
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
// Update status with map response data
@ -301,10 +312,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
// Verify all systems show nodes in expected state and report failures
allMatch := true
var failureReport strings.Builder
ids := types.NodeIDs(maps.Keys(nodeStatus))
slices.Sort(ids)
for _, nodeID := range ids {
status := nodeStatus[nodeID]
systemsMatch := (status.Batcher == expectedOnline) &&
@ -313,10 +326,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
if !systemsMatch {
allMatch = false
stateStr := "offline"
if expectedOnline {
stateStr = "online"
}
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat)))
failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline))
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
@ -331,6 +346,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
t.Logf("Previous report:\n%s", prevReport)
t.Logf("Current report:\n%s", failureReport.String())
t.Logf("Report diff:\n%s", diff)
prevReport = failureReport.String()
}
@ -344,11 +360,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
if expectedOnline {
stateStr = "online"
}
assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr))
}, timeout, 2*time.Second, message)
}
// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components
// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components.
func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) {
t.Helper()
@ -357,18 +374,22 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
require.EventuallyWithT(t, func(c *assert.CollectT) {
debugInfo, err := headscale.DebugBatcher()
assert.NoError(c, err, "Failed to get batcher debug info")
if err != nil {
return
}
allBatcherOffline := true
for _, nodeID := range expectedNodes {
nodeIDStr := fmt.Sprintf("%d", nodeID)
if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected {
allBatcherOffline = false
assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID)
}
}
assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher")
}, 15*time.Second, 1*time.Second, "batcher disconnection validation")
@ -377,20 +398,24 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
require.EventuallyWithT(t, func(c *assert.CollectT) {
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
allNodeStoreOffline := true
for _, nodeID := range expectedNodes {
if node, exists := nodeStore[nodeID]; exists {
isOnline := node.IsOnline != nil && *node.IsOnline
if isOnline {
allNodeStoreOffline = false
assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID)
}
}
}
assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore")
}, 20*time.Second, 1*time.Second, "nodestore offline validation")
@ -399,6 +424,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
require.EventuallyWithT(t, func(c *assert.CollectT) {
mapResponses, err := headscale.GetAllMapReponses()
assert.NoError(c, err, "Failed to get map responses")
if err != nil {
return
}
@ -411,6 +437,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
for nodeID := range onlineMap {
if slices.Contains(expectedNodes, nodeID) {
allMapResponsesOffline = false
assert.False(c, true, "Node %d should not appear in map responses", nodeID)
}
}
@ -421,13 +448,16 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
if id == nodeID {
continue // Skip self-references
}
if online, exists := peerMap[nodeID]; exists && online {
allMapResponsesOffline = false
assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id)
}
}
}
}
assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses")
}, 60*time.Second, 2*time.Second, "map response propagation validation")
@ -447,6 +477,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe
// Get nodestore state
nodeStore, err := headscale.DebugNodeStore()
assert.NoError(c, err, "Failed to get nodestore debug info")
if err != nil {
return
}
@ -461,12 +492,14 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe
for _, nodeID := range expectedNodes {
node, exists := nodeStore[nodeID]
assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID)
if !exists {
continue
}
// Validate that the node has Hostinfo
assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname)
if node.Hostinfo == nil {
t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
continue
@ -474,6 +507,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe
// Validate that the node has NetInfo
assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname)
if node.Hostinfo.NetInfo == nil {
t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
continue
@ -524,6 +558,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
// Returns the total number of successful ping operations.
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
t.Helper()
success := 0
for _, client := range clients {
@ -545,6 +580,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts
// for validating NAT traversal and relay functionality. Returns success count.
func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
t.Helper()
success := 0
for _, client := range clients {
@ -602,9 +638,12 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) {
for _, client := range clients {
wg.Add(1)
c := client // Avoid loop pointer
go func() {
defer wg.Done()
assertValidStatus(t, c)
assertValidNetcheck(t, c)
assertValidNetmap(t, c)
@ -635,6 +674,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
assert.NoError(c, err, "getting netmap for %q", client.Hostname())
assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
}
@ -653,6 +693,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
if hi := peer.Hostinfo(); hi.Valid() {
assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
@ -681,6 +722,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
// and network map presence. This test is not suitable for ACL/partial connection tests.
func assertValidStatus(t *testing.T, client TailscaleClient) {
t.Helper()
status, err := client.Status(true)
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
@ -738,6 +780,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) {
// which is essential for NAT traversal and connectivity in restricted networks.
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
t.Helper()
report, err := client.Netcheck()
if err != nil {
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
@ -792,6 +835,7 @@ func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool {
t.Helper()
buf := &bytes.Buffer{}
err := client.WriteLogs(buf, buf)
if err != nil {
t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err)
@ -815,6 +859,7 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error)
scanner := bufio.NewScanner(in)
{
const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB
buff := make([]byte, logBufferInitialSize)
scanner.Buffer(buff, len(buff))
scanner.Split(bufio.ScanLines)
@ -941,17 +986,20 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) {
func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) {
for _, client := range updated {
isOriginal := false
for _, origClient := range original {
if client.Hostname() == origClient.Hostname() {
isOriginal = true
break
}
}
if !isOriginal {
return client, nil
}
}
return nil, fmt.Errorf("no new client found")
return nil, errors.New("no new client found")
}
// AddAndLoginClient adds a new tailscale client to a user and logs it in.
@ -959,7 +1007,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error)
// 1. Creating a new node
// 2. Finding the new node in the client list
// 3. Getting the user to create a preauth key
// 4. Logging in the new node
// 4. Logging in the new node.
func (s *Scenario) AddAndLoginClient(
t *testing.T,
username string,
@ -1037,5 +1085,6 @@ func (s *Scenario) MustAddAndLoginClient(
client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...)
require.NoError(t, err)
return client
}

View file

@ -725,12 +725,14 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
// Find the top-level directory to strip
var topLevelDir string
firstPass := tar.NewReader(bytes.NewReader(tarData))
for {
header, err := firstPass.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
@ -747,6 +749,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
@ -794,6 +797,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
outFile.Close()
return fmt.Errorf("failed to copy file contents: %w", err)
}
outFile.Close()
// Set file permissions
@ -844,10 +848,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
// Check if the database file exists and has a schema
dbPath := "/tmp/integration_test_db.sqlite3"
fileInfo, err := t.Execute([]string{"ls", "-la", dbPath})
if err != nil {
return fmt.Errorf("database file does not exist at %s: %w", dbPath, err)
}
log.Printf("Database file info: %s", fileInfo)
// Check if the database has any tables (schema)
@ -872,6 +878,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
@ -886,6 +893,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
// Extract the first regular file we find
if header.Typeflag == tar.TypeReg {
dbPath := path.Join(savePath, t.hostname+".db")
outFile, err := os.Create(dbPath)
if err != nil {
return fmt.Errorf("failed to create database file: %w", err)
@ -893,6 +901,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
written, err := io.Copy(outFile, tarReader)
outFile.Close()
if err != nil {
return fmt.Errorf("failed to copy database file: %w", err)
}
@ -1059,6 +1068,7 @@ func (t *HeadscaleInContainer) CreateUser(
}
var u v1.User
err = json.Unmarshal([]byte(result), &u)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
@ -1195,6 +1205,7 @@ func (t *HeadscaleInContainer) ListNodes(
users ...string,
) ([]*v1.Node, error) {
var ret []*v1.Node
execUnmarshal := func(command []string) error {
result, _, err := dockertestutil.ExecuteCommand(
t.container,
@ -1206,6 +1217,7 @@ func (t *HeadscaleInContainer) ListNodes(
}
var nodes []*v1.Node
err = json.Unmarshal([]byte(result), &nodes)
if err != nil {
return fmt.Errorf("failed to unmarshal nodes: %w", err)
@ -1245,7 +1257,7 @@ func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error {
"nodes",
"delete",
"--identifier",
fmt.Sprintf("%d", nodeID),
strconv.FormatUint(nodeID, 10),
"--output",
"json",
"--force",
@ -1309,6 +1321,7 @@ func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) {
}
var users []*v1.User
err = json.Unmarshal([]byte(result), &users)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
@ -1439,6 +1452,7 @@ func (h *HeadscaleInContainer) PID() (int, error) {
if pidInt == 1 {
continue
}
pids = append(pids, pidInt)
}
@ -1494,6 +1508,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
}
var node *v1.Node
err = json.Unmarshal([]byte(result), &node)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err)

View file

@ -28,6 +28,7 @@ func PeerSyncTimeout() time.Duration {
if util.IsCI() {
return 120 * time.Second
}
return 60 * time.Second
}
@ -205,6 +206,7 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type
res := make(map[types.NodeID]map[types.NodeID]bool)
for nid, mrs := range all {
res[nid] = make(map[types.NodeID]bool)
for _, mr := range mrs {
for _, peer := range mr.Peers {
if peer.Online != nil {
@ -225,5 +227,6 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type
}
}
}
return res
}

View file

@ -48,6 +48,7 @@ func TestEnablingRoutes(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
@ -90,6 +91,7 @@ func TestEnablingRoutes(t *testing.T) {
// Wait for route advertisements to propagate to NodeStore
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(ct, err)
@ -126,6 +128,7 @@ func TestEnablingRoutes(t *testing.T) {
// Wait for route approvals to propagate to NodeStore
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(ct, err)
@ -148,9 +151,11 @@ func TestEnablingRoutes(t *testing.T) {
assert.NotNil(c, peerStatus.PrimaryRoutes)
assert.NotNil(c, peerStatus.AllowedIPs)
if peerStatus.AllowedIPs != nil {
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
}
}
@ -171,6 +176,7 @@ func TestEnablingRoutes(t *testing.T) {
// Wait for route state changes to propagate to nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
@ -270,6 +276,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
prefp, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)
pref := *prefp
t.Logf("usernet1 prefix: %s", pref.String())
@ -289,6 +296,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
slices.SortStableFunc(allClients, func(a, b TailscaleClient) int {
statusA := a.MustStatus()
statusB := b.MustStatus()
return cmp.Compare(statusA.Self.ID, statusB.Self.ID)
})
@ -308,6 +316,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String())
t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String())
t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time")
for _, client := range allClients[:3] {
command := []string{
"tailscale",
@ -323,6 +332,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
// Wait for route configuration changes after advertising routes
var nodes []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
@ -362,10 +372,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) {
if t.Failed() {
t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat))
status, err := client.Status()
if err == nil {
printCurrentRouteMap(t, xmaps.Values(status.Peer)...)
}
t.FailNow()
}
}
@ -384,6 +396,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String())
t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes")
t.Logf(" Expected: Client can access webservice through router 1 only")
_, err = headscale.ApproveRoutes(
MustFindNode(subRouter1.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
@ -454,10 +467,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter1.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1")
@ -481,6 +496,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)")
t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)")
t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over")
_, err = headscale.ApproveRoutes(
MustFindNode(subRouter2.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
@ -492,6 +508,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 6)
if len(nodes) >= 3 {
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
@ -567,10 +584,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter1.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode")
@ -596,6 +615,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)")
t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY")
t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers")
_, err = headscale.ApproveRoutes(
MustFindNode(subRouter3.Hostname(), nodes).GetId(),
[]netip.Prefix{pref},
@ -670,12 +690,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
var expectedIP netip.Addr
for _, ip := range ips {
if ip.Is4() {
expectedIP = ip
break
}
}
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
@ -752,10 +774,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter2.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover")
@ -823,10 +847,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter3.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover")
@ -851,6 +877,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)")
t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)")
t.Logf(" Expected: HA is restored with 2 routers available")
err = subRouter1.Up()
require.NoError(t, err)
@ -900,10 +927,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter3.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery")
@ -930,6 +959,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname())
t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname())
t.Logf(" Expected: Full HA restored with all 3 routers online")
err = subRouter2.Up()
require.NoError(t, err)
@ -980,10 +1010,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter3.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery")
@ -1065,10 +1097,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter1.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable")
@ -1151,10 +1185,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter2.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable")
@ -1180,6 +1216,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname())
t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname())
t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY")
r1Node := MustFindNode(subRouter1.Hostname(), nodes)
_, err = headscale.ApproveRoutes(
r1Node.GetId(),
@ -1235,10 +1272,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := subRouter2.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable")
@ -1264,6 +1303,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname())
t.Logf(" Expected: Routers 1 & 3 are both STANDBY")
t.Logf(" Expected: Full HA restored with all 3 routers available")
r3Node := MustFindNode(subRouter3.Hostname(), nodes)
_, err = headscale.ApproveRoutes(
r3Node.GetId(),
@ -1313,6 +1353,7 @@ func TestSubnetRouteACL(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
@ -1360,6 +1401,7 @@ func TestSubnetRouteACL(t *testing.T) {
slices.SortStableFunc(allClients, func(a, b TailscaleClient) int {
statusA := a.MustStatus()
statusB := b.MustStatus()
return cmp.Compare(statusA.Self.ID, statusB.Self.ID)
})
@ -1389,15 +1431,20 @@ func TestSubnetRouteACL(t *testing.T) {
// Wait for route advertisements to propagate to the server
var nodes []*v1.Node
require.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
// Find the node that should have the route by checking node IDs
var routeNode *v1.Node
var otherNode *v1.Node
var (
routeNode *v1.Node
otherNode *v1.Node
)
for _, node := range nodes {
nodeIDStr := strconv.FormatUint(node.GetId(), 10)
if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute {
@ -1460,6 +1507,7 @@ func TestSubnetRouteACL(t *testing.T) {
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist")
if srs1PeerStatus == nil {
return
}
@ -1570,6 +1618,7 @@ func TestEnablingExitRoutes(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario")
defer scenario.ShutdownAssertNoPanics(t)
@ -1591,8 +1640,10 @@ func TestEnablingExitRoutes(t *testing.T) {
requireNoErrSync(t, err)
var nodes []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
@ -1650,6 +1701,7 @@ func TestEnablingExitRoutes(t *testing.T) {
peerStatus := status.Peer[peerKey]
assert.NotNil(c, peerStatus.AllowedIPs)
if peerStatus.AllowedIPs != nil {
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
@ -1680,6 +1732,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
@ -1710,10 +1763,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
if s.User[s.Self.UserID].LoginName == "user1@test.no" {
user1c = c
}
if s.User[s.Self.UserID].LoginName == "user2@test.no" {
user2c = c
}
}
require.NotNil(t, user1c)
require.NotNil(t, user2c)
@ -1730,6 +1785,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
// Wait for route advertisements to propagate to NodeStore
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, nodes, 2)
@ -1760,6 +1816,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
// Wait for route state changes to propagate to nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
@ -1777,6 +1834,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
}
}, 10*time.Second, 500*time.Millisecond, "routes should be visible to client")
@ -1803,10 +1861,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := user2c.Traceroute(webip)
assert.NoError(c, err)
ip, err := user1c.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for user1c") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router")
}
@ -1827,6 +1887,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
@ -1854,10 +1915,12 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
if s.User[s.Self.UserID].LoginName == "user1@test.no" {
user1c = c
}
if s.User[s.Self.UserID].LoginName == "user2@test.no" {
user2c = c
}
}
require.NotNil(t, user1c)
require.NotNil(t, user2c)
@ -1874,6 +1937,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
// Wait for route advertisements to propagate to NodeStore
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
var err error
nodes, err = headscale.ListNodes()
assert.NoError(ct, err)
assert.Len(ct, nodes, 2)
@ -1956,6 +2020,7 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node {
return node
}
}
panic("node not found")
}
@ -2239,10 +2304,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
}
scenario, err := NewScenario(tt.spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
var nodes []*v1.Node
opts := []hsic.Option{
hsic.WithTestName("autoapprovemulti"),
hsic.WithEmbeddedDERPServerOnly(),
@ -2298,6 +2365,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Add the Docker network route to the auto-approvers
// Keep existing auto-approvers (like bigRoute) in place
var approvers policyv2.AutoApprovers
switch {
case strings.HasPrefix(tt.approver, "tag:"):
approvers = append(approvers, tagApprover(tt.approver))
@ -2366,6 +2434,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
} else {
pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
}
require.NoError(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
@ -2404,6 +2473,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
slices.SortStableFunc(allClients, func(a, b TailscaleClient) int {
statusA := a.MustStatus()
statusB := b.MustStatus()
return cmp.Compare(statusA.Self.ID, statusB.Self.ID)
})
@ -2456,11 +2526,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers()))
routerPeerFound := false
for _, peerKey := range status.Peers() {
peerStatus := status.Peer[peerKey]
if peerStatus.ID == routerUsernet1ID.StableID() {
routerPeerFound = true
t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v",
peerStatus.HostName,
peerStatus.ID,
@ -2468,9 +2540,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
peerStatus.PrimaryRoutes)
assert.NotNil(c, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
} else {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
@ -2507,10 +2581,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := routerUsernet1.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router")
@ -2547,9 +2623,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
if peerStatus.ID == routerUsernet1ID.StableID() {
assert.NotNil(c, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
} else {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
@ -2569,10 +2647,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := routerUsernet1.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change")
@ -2606,6 +2686,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
// Add the route back to the auto approver in the policy, the route should
// now become available again.
var newApprovers policyv2.AutoApprovers
switch {
case strings.HasPrefix(tt.approver, "tag:"):
newApprovers = append(newApprovers, tagApprover(tt.approver))
@ -2639,9 +2720,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
if peerStatus.ID == routerUsernet1ID.StableID() {
assert.NotNil(c, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
} else {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
@ -2661,10 +2744,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
assert.NoError(c, err)
ip, err := routerUsernet1.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval")
@ -2700,11 +2785,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
} else if peerStatus.ID == "2" {
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute})
} else {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
@ -2742,9 +2829,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
if peerStatus.ID == routerUsernet1ID.StableID() {
assert.NotNil(c, peerStatus.PrimaryRoutes)
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
} else {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
@ -2782,6 +2871,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
if peerStatus.PrimaryRoutes != nil {
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
}
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
} else if peerStatus.ID == "3" {
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
@ -2816,10 +2906,12 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int {
func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) {
t.Logf("== Current routing map ==")
slices.SortFunc(routers, SortPeerStatus)
for _, router := range routers {
got := filterNonRoutes(router)
t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID)
t.Logf(" AllowedIPs: %v", got)
if router.PrimaryRoutes != nil {
t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice())
}
@ -2832,6 +2924,7 @@ func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix {
if tsaddr.IsExitRoute(p) {
return true
}
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
})
}
@ -2883,6 +2976,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
@ -3023,6 +3117,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// List nodes and verify the router has 3 available routes
var err error
nodes, err := headscale.NodesByUser()
assert.NoError(c, err)
assert.Len(c, nodes, 2)
@ -3058,10 +3153,12 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := nodeClient.Traceroute(webip)
assert.NoError(c, err)
ip, err := routerClient.IPv4()
if !assert.NoError(c, err, "failed to get IPv4 for routerClient") {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router")
}

View file

@ -191,9 +191,11 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
}
var userToNetwork map[string]*dockertest.Network
if spec.Networks != nil || len(spec.Networks) != 0 {
for name, users := range s.spec.Networks {
networkName := testHashPrefix + "-" + name
network, err := s.AddNetwork(networkName)
if err != nil {
return nil, err
@ -203,6 +205,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
if n2, ok := userToNetwork[user]; ok {
return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name)
}
mak.Set(&userToNetwork, user, network)
}
}
@ -219,6 +222,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
if err != nil {
return nil, err
}
mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc))
}
}
@ -230,6 +234,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
if spec.OIDCAccessTTL != 0 {
ttl = spec.OIDCAccessTTL
}
err = s.runMockOIDC(ttl, spec.OIDCUsers)
if err != nil {
return nil, err
@ -268,6 +273,7 @@ func (s *Scenario) Networks() []*dockertest.Network {
if len(s.networks) == 0 {
panic("Scenario.Networks called with empty network list")
}
return xmaps.Values(s.networks)
}
@ -337,6 +343,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
for userName, user := range s.users {
for _, client := range user.Clients {
log.Printf("removing client %s in user %s", client.Hostname(), userName)
stdoutPath, stderrPath, err := client.Shutdown()
if err != nil {
log.Printf("failed to tear down client: %s", err)
@ -353,6 +360,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
}
}
}
s.mu.Unlock()
for _, derp := range s.derpServers {
@ -373,6 +381,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
if s.mockOIDC.r != nil {
s.mockOIDC.r.Close()
if err := s.mockOIDC.r.Close(); err != nil {
log.Printf("failed to tear down oidc server: %s", err)
}
@ -552,6 +561,7 @@ func (s *Scenario) CreateTailscaleNode(
s.mu.Lock()
defer s.mu.Unlock()
opts = append(opts,
tsic.WithCACert(cert),
tsic.WithHeadscaleName(hostname),
@ -591,6 +601,7 @@ func (s *Scenario) CreateTailscaleNodesInUser(
) error {
if user, ok := s.users[userStr]; ok {
var versions []string
for i := range count {
version := requestedVersion
if requestedVersion == "all" {
@ -749,10 +760,12 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat
for _, client := range user.Clients {
c := client
expectedCount := expectedPeers
user.syncWaitGroup.Go(func() error {
return c.WaitForPeers(expectedCount, timeout, retryInterval)
})
}
if err := user.syncWaitGroup.Wait(); err != nil {
allErrors = append(allErrors, err)
}
@ -871,6 +884,7 @@ func (s *Scenario) createHeadscaleEnvWithTags(
} else {
key, err = s.CreatePreAuthKey(u.GetId(), true, false)
}
if err != nil {
return err
}
@ -887,9 +901,11 @@ func (s *Scenario) createHeadscaleEnvWithTags(
func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
log.Printf("running tailscale up for user %s", userStr)
if user, ok := s.users[userStr]; ok {
for _, client := range user.Clients {
tsc := client
user.joinWaitGroup.Go(func() error {
loginURL, err := tsc.LoginWithURL(loginServer)
if err != nil {
@ -945,6 +961,7 @@ func newDebugJar() (*debugJar, error) {
if err != nil {
return nil, err
}
return &debugJar{
inner: jar,
store: make(map[string]map[string]map[string]*http.Cookie),
@ -961,20 +978,25 @@ func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
if c == nil || c.Name == "" {
continue
}
domain := c.Domain
if domain == "" {
domain = u.Hostname()
}
path := c.Path
if path == "" {
path = "/"
}
if _, ok := j.store[domain]; !ok {
j.store[domain] = make(map[string]map[string]*http.Cookie)
}
if _, ok := j.store[domain][path]; !ok {
j.store[domain][path] = make(map[string]*http.Cookie)
}
j.store[domain][path][c.Name] = copyCookie(c)
}
}
@ -989,8 +1011,10 @@ func (j *debugJar) Dump(w io.Writer) {
for domain, paths := range j.store {
fmt.Fprintf(w, "Domain: %s\n", domain)
for path, byName := range paths {
fmt.Fprintf(w, " Path: %s\n", path)
for _, c := range byName {
fmt.Fprintf(
w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n",
@ -1054,7 +1078,9 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f
}
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
ctx := context.Background()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
if err != nil {
return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err)
@ -1066,6 +1092,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f
return http.ErrUseLastResponse
}
}
defer func() {
hc.CheckRedirect = originalRedirect
}()
@ -1080,6 +1107,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f
if err != nil {
return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err)
}
body := string(bodyBytes)
var redirectURL *url.URL
@ -1126,6 +1154,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
if len(keySep) != 2 {
return errParseAuthPage
}
key := keySep[1]
key = strings.SplitN(key, " ", 2)[0]
log.Printf("registering node %s", key)
@ -1154,6 +1183,7 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
noTls := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
}
resp, err := noTls.RoundTrip(req)
if err != nil {
return nil, err
@ -1361,6 +1391,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
if err != nil {
log.Fatalf("could not find an open port: %s", err)
}
portNotation := fmt.Sprintf("%d/tcp", port)
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
@ -1421,6 +1452,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
ipAddr := s.mockOIDC.r.GetIPInNetwork(network)
log.Println("Waiting for headscale mock oidc to be ready for tests")
hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port))
if err := s.pool.Retry(func() error {
@ -1468,7 +1500,6 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
// log.Fatalf("could not find an open port: %s", err)
// }
// portNotation := fmt.Sprintf("%d/tcp", port)
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
hostname := "hs-webservice-" + hash

View file

@ -35,6 +35,7 @@ func TestHeadscale(t *testing.T) {
user := "test-space"
scenario, err := NewScenario(ScenarioSpec{})
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -83,6 +84,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
count := 1
scenario, err := NewScenario(ScenarioSpec{})
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)