This commit is contained in:
Kristoffer Dalby 2026-01-21 16:21:54 +00:00 committed by GitHub
commit 487716f645
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
126 changed files with 3264 additions and 1938 deletions

View file

@ -25,6 +25,7 @@ linters:
- revive
- tagliatelle
- testpackage
- thelper
- varnamelen
- wrapcheck
- wsl

View file

@ -2,7 +2,6 @@ package cli
import (
"encoding/json"
"errors"
"fmt"
"net"
"net/http"
@ -19,6 +18,7 @@ const (
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined")
refreshTTL = 60 * time.Minute
)
@ -69,10 +69,11 @@ func mockOIDC() error {
userStr := os.Getenv("MOCKOIDC_USERS")
if userStr == "" {
return errors.New("MOCKOIDC_USERS not defined")
return errMockOidcUsersNotDefined
}
var users []mockoidc.MockUser
err := json.Unmarshal([]byte(userStr), &users)
if err != nil {
return fmt.Errorf("unmarshalling users: %w", err)
@ -133,10 +134,11 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
ErrorQueue: &mockoidc.ErrorQueue{},
}
mock.AddMiddleware(func(h http.Handler) http.Handler {
_ = mock.AddMiddleware(func(h http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
log.Info().Msgf("Request: %+v", r)
h.ServeHTTP(w, r)
if r.Response != nil {
log.Info().Msgf("Response: %+v", r.Response)
}

View file

@ -72,12 +72,12 @@ func init() {
nodeCmd.AddCommand(deleteNodeCmd)
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
tagCmd.MarkFlagRequired("identifier")
_ = tagCmd.MarkFlagRequired("identifier")
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
nodeCmd.AddCommand(tagCmd)
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
approveRoutesCmd.MarkFlagRequired("identifier")
_ = approveRoutesCmd.MarkFlagRequired("identifier")
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
nodeCmd.AddCommand(approveRoutesCmd)
@ -233,10 +233,7 @@ var listNodeRoutesCmd = &cobra.Command{
return
}
tableData, err := nodeRoutesToPtables(nodes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
}
tableData := nodeRoutesToPtables(nodes)
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
if err != nil {
@ -601,9 +598,7 @@ func nodesToPtables(
return tableData, nil
}
func nodeRoutesToPtables(
nodes []*v1.Node,
) (pterm.TableData, error) {
func nodeRoutesToPtables(nodes []*v1.Node) pterm.TableData {
tableHeader := []string{
"ID",
"Hostname",
@ -627,7 +622,7 @@ func nodeRoutesToPtables(
)
}
return tableData, nil
return tableData
}
var tagCmd = &cobra.Command{

View file

@ -16,6 +16,7 @@ import (
)
const (
//nolint:gosec
bypassFlag = "bypass-grpc-and-access-database-directly"
)
@ -29,13 +30,17 @@ func init() {
if err := setPolicy.MarkFlagRequired("file"); err != nil {
log.Fatal().Err(err).Msg("")
}
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
policyCmd.AddCommand(setPolicy)
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
err := checkPolicy.MarkFlagRequired("file")
if err != nil {
log.Fatal().Err(err).Msg("")
}
policyCmd.AddCommand(checkPolicy)
}
@ -173,6 +178,7 @@ var setPolicy = &cobra.Command{
defer cancel()
defer conn.Close()
//nolint:noinlineerr
if _, err := client.SetPolicy(ctx, request); err != nil {
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
}

View file

@ -80,6 +80,7 @@ func initConfig() {
Repository: "headscale",
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
}
res, err := latest.Check(githubTag, versionInfo.Version)
if err == nil && res.Outdated {
//nolint
@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool {
return true
}
}
return false
}

View file

@ -23,8 +23,7 @@ var serveCmd = &cobra.Command{
Run: func(cmd *cobra.Command, args []string) {
app, err := newHeadscaleServerWithConfig()
if err != nil {
var squibbleErr squibble.ValidationError
if errors.As(err, &squibbleErr) {
if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok {
fmt.Printf("SQLite schema failed to validate:\n")
fmt.Println(squibbleErr.Diff)
}

View file

@ -14,6 +14,12 @@ import (
"google.golang.org/grpc/status"
)
// Sentinel errors for CLI commands.
var (
ErrNameOrIDRequired = errors.New("--name or --identifier flag is required")
ErrMultipleUsersFoundUseID = errors.New("unable to determine user, query returned multiple users, use ID")
)
func usernameAndIDFlag(cmd *cobra.Command) {
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
cmd.Flags().StringP("name", "n", "", "Username")
@ -23,12 +29,12 @@ func usernameAndIDFlag(cmd *cobra.Command) {
// If both are empty, it will exit the program with an error.
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
username, _ := cmd.Flags().GetString("name")
identifier, _ := cmd.Flags().GetInt64("identifier")
if username == "" && identifier < 0 {
err := errors.New("--name or --identifier flag is required")
ErrorOutput(
err,
"Cannot rename user: "+status.Convert(err).Message(),
ErrNameOrIDRequired,
"Cannot rename user: "+status.Convert(ErrNameOrIDRequired).Message(),
"",
)
}
@ -50,7 +56,7 @@ func init() {
userCmd.AddCommand(renameUserCmd)
usernameAndIDFlag(renameUserCmd)
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
renameNodeCmd.MarkFlagRequired("new-name")
_ = renameUserCmd.MarkFlagRequired("new-name")
}
var errMissingParameter = errors.New("missing parameters")
@ -94,6 +100,7 @@ var createUserCmd = &cobra.Command{
}
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
//nolint:noinlineerr
if _, err := url.Parse(pictureURL); err != nil {
ErrorOutput(
err,
@ -148,7 +155,7 @@ var destroyUserCmd = &cobra.Command{
}
if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
err := ErrMultipleUsersFoundUseID
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),
@ -276,7 +283,7 @@ var renameUserCmd = &cobra.Command{
}
if len(users.GetUsers()) != 1 {
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
err := ErrMultipleUsersFoundUseID
ErrorOutput(
err,
"Error: "+status.Convert(err).Message(),

View file

@ -10,11 +10,11 @@ import (
"time"
"github.com/cenkalti/backoff/v5"
cerrdefs "github.com/containerd/errdefs"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/client"
"github.com/docker/docker/errdefs"
)
// cleanupBeforeTest performs cleanup operations before running tests.
@ -25,6 +25,7 @@ func cleanupBeforeTest(ctx context.Context) error {
return fmt.Errorf("failed to clean stale test containers: %w", err)
}
//nolint:noinlineerr
if err := pruneDockerNetworks(ctx); err != nil {
return fmt.Errorf("failed to prune networks: %w", err)
}
@ -55,6 +56,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
// killTestContainers terminates and removes all test containers.
func killTestContainers(ctx context.Context) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -69,8 +71,10 @@ func killTestContainers(ctx context.Context) error {
}
removed := 0
for _, cont := range containers {
shouldRemove := false
for _, name := range cont.Names {
if strings.Contains(name, "headscale-test-suite") ||
strings.Contains(name, "hs-") ||
@ -107,6 +111,7 @@ func killTestContainers(ctx context.Context) error {
// This function filters containers by the hi.run-id label to only affect containers
// belonging to the specified test run, leaving other concurrent test runs untouched.
func killTestContainersByRunID(ctx context.Context, runID string) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -149,6 +154,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
// without interfering with currently running concurrent tests.
func cleanupStaleTestContainers(ctx context.Context) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -223,6 +229,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container
// pruneDockerNetworks removes unused Docker networks.
func pruneDockerNetworks(ctx context.Context) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -245,6 +252,7 @@ func pruneDockerNetworks(ctx context.Context) error {
// cleanOldImages removes test-related and old dangling Docker images.
func cleanOldImages(ctx context.Context) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -259,8 +267,10 @@ func cleanOldImages(ctx context.Context) error {
}
removed := 0
for _, img := range images {
shouldRemove := false
for _, tag := range img.RepoTags {
if strings.Contains(tag, "hs-") ||
strings.Contains(tag, "headscale-integration") ||
@ -295,6 +305,7 @@ func cleanOldImages(ctx context.Context) error {
// cleanCacheVolume removes the Docker volume used for Go module cache.
func cleanCacheVolume(ctx context.Context) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -302,11 +313,12 @@ func cleanCacheVolume(ctx context.Context) error {
defer cli.Close()
volumeName := "hs-integration-go-cache"
err = cli.VolumeRemove(ctx, volumeName, true)
if err != nil {
if errdefs.IsNotFound(err) {
if cerrdefs.IsNotFound(err) {
fmt.Printf("Go module cache volume not found: %s\n", volumeName)
} else if errdefs.IsConflict(err) {
} else if cerrdefs.IsConflict(err) {
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
} else {
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)

View file

@ -14,6 +14,7 @@ import (
"strings"
"time"
cerrdefs "github.com/containerd/errdefs"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/image"
"github.com/docker/docker/api/types/mount"
@ -26,10 +27,21 @@ var (
ErrTestFailed = errors.New("test failed")
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
ErrNoDockerContext = errors.New("no docker context found")
ErrMemoryLimitExceeded = errors.New("container exceeded memory limits")
)
// Docker container constants.
const (
containerFinalStateWait = 10 * time.Second
containerStateCheckInterval = 500 * time.Millisecond
dirPermissions = 0o755
)
// runTestContainer executes integration tests in a Docker container.
//
//nolint:gocyclo
func runTestContainer(ctx context.Context, config *RunConfig) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -52,6 +64,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
}
const dirPerm = 0o755
//nolint:noinlineerr
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
}
@ -60,7 +73,9 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
if config.Verbose {
log.Printf("Running pre-test cleanup...")
}
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
err := cleanupBeforeTest(ctx)
if err != nil && config.Verbose {
log.Printf("Warning: pre-test cleanup failed: %v", err)
}
}
@ -71,6 +86,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
}
imageName := "golang:" + config.GoVersion
//nolint:noinlineerr
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil {
return fmt.Errorf("failed to ensure image availability: %w", err)
}
@ -84,6 +100,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
log.Printf("Created container: %s", resp.ID)
}
//nolint:noinlineerr
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
return fmt.Errorf("failed to start container: %w", err)
}
@ -95,13 +112,17 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Start stats collection for container resource monitoring (if enabled)
var statsCollector *StatsCollector
if config.Stats {
var err error
//nolint:contextcheck
statsCollector, err = NewStatsCollector()
if err != nil {
if config.Verbose {
log.Printf("Warning: failed to create stats collector: %v", err)
}
statsCollector = nil
}
@ -110,7 +131,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
// Start stats collection immediately - no need for complex retry logic
// The new implementation monitors Docker events and will catch containers as they start
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
err := statsCollector.StartCollection(ctx, runID, config.Verbose)
if err != nil {
if config.Verbose {
log.Printf("Warning: failed to start stats collection: %v", err)
}
@ -122,11 +144,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
exitCode, err := streamAndWait(ctx, cli, resp.ID)
// Ensure all containers have finished and logs are flushed before extracting artifacts
if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose {
waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose)
if waitErr != nil && config.Verbose {
log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
}
// Extract artifacts from test containers before cleanup
//nolint:noinlineerr
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose {
log.Printf("Warning: failed to extract artifacts from containers: %v", err)
}
@ -140,12 +164,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
if len(violations) > 0 {
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
log.Printf("=================================")
for _, violation := range violations {
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
}
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
return fmt.Errorf("%w: %d container(s)", ErrMemoryLimitExceeded, len(violations))
}
}
@ -344,9 +369,10 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
// Wait for all test containers to reach a final state
maxWaitTime := 10 * time.Second
checkInterval := 500 * time.Millisecond
maxWaitTime := containerFinalStateWait
checkInterval := containerStateCheckInterval
timeout := time.After(maxWaitTime)
ticker := time.NewTicker(checkInterval)
defer ticker.Stop()
@ -356,6 +382,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose {
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
}
return nil
case <-ticker.C:
allFinalized := true
@ -366,12 +393,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose {
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
}
continue
}
// Check if container is in a final state
if !isContainerFinalized(inspect.State) {
allFinalized = false
if verbose {
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
}
@ -384,6 +413,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
if verbose {
log.Printf("All test containers finalized, ready for artifact extraction")
}
return nil
}
}
@ -400,13 +430,16 @@ func isContainerFinalized(state *container.State) bool {
func findProjectRoot(startPath string) string {
current := startPath
for {
//nolint:noinlineerr
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
return current
}
parent := filepath.Dir(current)
if parent == current {
return startPath
}
current = parent
}
}
@ -416,6 +449,7 @@ func boolToInt(b bool) int {
if b {
return 1
}
return 0
}
@ -435,6 +469,7 @@ func createDockerClient() (*client.Client, error) {
}
var clientOpts []client.Opt
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
if contextInfo != nil {
@ -444,6 +479,7 @@ func createDockerClient() (*client.Client, error) {
if runConfig.Verbose {
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
}
clientOpts = append(clientOpts, client.WithHost(host))
}
}
@ -459,13 +495,15 @@ func createDockerClient() (*client.Client, error) {
// getCurrentDockerContext retrieves the current Docker context information.
func getCurrentDockerContext() (*DockerContext, error) {
cmd := exec.Command("docker", "context", "inspect")
cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect")
output, err := cmd.Output()
if err != nil {
return nil, fmt.Errorf("failed to get docker context: %w", err)
}
var contexts []DockerContext
//nolint:noinlineerr
if err := json.Unmarshal(output, &contexts); err != nil {
return nil, fmt.Errorf("failed to parse docker context: %w", err)
}
@ -486,11 +524,12 @@ func getDockerSocketPath() string {
// checkImageAvailableLocally checks if the specified Docker image is available locally.
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
_, err := cli.ImageInspect(ctx, imageName)
if err != nil {
if client.IsErrNotFound(err) {
if cerrdefs.IsNotFound(err) {
return false, nil
}
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
}
@ -509,6 +548,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
if verbose {
log.Printf("Image %s is available locally", imageName)
}
return nil
}
@ -533,6 +573,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
if err != nil {
return fmt.Errorf("failed to read pull output: %w", err)
}
log.Printf("Image %s pulled successfully", imageName)
}
@ -547,9 +588,11 @@ func listControlFiles(logsDir string) {
return
}
var logFiles []string
var dataFiles []string
var dataDirs []string
var (
logFiles []string
dataFiles []string
dataDirs []string
)
for _, entry := range entries {
name := entry.Name()
@ -578,6 +621,7 @@ func listControlFiles(logsDir string) {
if len(logFiles) > 0 {
log.Printf("Headscale logs:")
for _, file := range logFiles {
log.Printf(" %s", file)
}
@ -585,9 +629,11 @@ func listControlFiles(logsDir string) {
if len(dataFiles) > 0 || len(dataDirs) > 0 {
log.Printf("Headscale data:")
for _, file := range dataFiles {
log.Printf(" %s", file)
}
for _, dir := range dataDirs {
log.Printf(" %s/", dir)
}
@ -596,6 +642,7 @@ func listControlFiles(logsDir string) {
// extractArtifactsFromContainers collects container logs and files from the specific test run.
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return fmt.Errorf("failed to create Docker client: %w", err)
@ -612,9 +659,11 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
extractedCount := 0
for _, cont := range currentTestContainers {
// Extract container logs and tar files
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose)
if err != nil {
if verbose {
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
}
@ -622,6 +671,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
if verbose {
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
}
extractedCount++
}
}
@ -645,11 +695,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
// Find the test container to get its run ID label
var runID string
for _, cont := range containers {
if cont.ID == testContainerID {
if cont.Labels != nil {
runID = cont.Labels["hi.run-id"]
}
break
}
}
@ -690,18 +742,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
// extractContainerArtifacts saves logs and tar files from a container.
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
// Ensure the logs directory exists
if err := os.MkdirAll(logsDir, 0o755); err != nil {
err := os.MkdirAll(logsDir, dirPermissions)
if err != nil {
return fmt.Errorf("failed to create logs directory: %w", err)
}
// Extract container logs
//nolint:noinlineerr
if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
return fmt.Errorf("failed to extract logs: %w", err)
}
// Extract tar files for headscale containers only
if strings.HasPrefix(containerName, "hs-") {
if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose)
if err != nil {
if verbose {
log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
}
@ -741,11 +796,13 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
}
// Write stdout logs
//nolint:gosec,mnd,noinlineerr
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stdout log: %w", err)
}
// Write stderr logs
//nolint:gosec,mnd,noinlineerr
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
return fmt.Errorf("failed to write stderr log: %w", err)
}

View file

@ -38,12 +38,15 @@ func runDoctorCheck(ctx context.Context) error {
}
// Check 3: Go installation
//nolint:contextcheck
results = append(results, checkGoInstallation())
// Check 4: Git repository
//nolint:contextcheck
results = append(results, checkGitRepository())
// Check 5: Required files
//nolint:contextcheck
results = append(results, checkRequiredFiles())
// Display results
@ -86,6 +89,7 @@ func checkDockerBinary() DoctorResult {
// checkDockerDaemon verifies Docker daemon is running and accessible.
func checkDockerDaemon(ctx context.Context) DoctorResult {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return DoctorResult{
@ -125,6 +129,7 @@ func checkDockerDaemon(ctx context.Context) DoctorResult {
// checkDockerContext verifies Docker context configuration.
func checkDockerContext(_ context.Context) DoctorResult {
//nolint:contextcheck
contextInfo, err := getCurrentDockerContext()
if err != nil {
return DoctorResult{
@ -155,6 +160,7 @@ func checkDockerContext(_ context.Context) DoctorResult {
// checkDockerSocket verifies Docker socket accessibility.
func checkDockerSocket(ctx context.Context) DoctorResult {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return DoctorResult{
@ -192,6 +198,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
func checkGolangImage(ctx context.Context) DoctorResult {
//nolint:contextcheck
cli, err := createDockerClient()
if err != nil {
return DoctorResult{
@ -265,7 +272,8 @@ func checkGoInstallation() DoctorResult {
}
}
cmd := exec.Command("go", "version")
cmd := exec.CommandContext(context.Background(), "go", "version")
output, err := cmd.Output()
if err != nil {
return DoctorResult{
@ -286,7 +294,8 @@ func checkGoInstallation() DoctorResult {
// checkGitRepository verifies we're in a git repository.
func checkGitRepository() DoctorResult {
cmd := exec.Command("git", "rev-parse", "--git-dir")
cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir")
err := cmd.Run()
if err != nil {
return DoctorResult{
@ -316,9 +325,12 @@ func checkRequiredFiles() DoctorResult {
}
var missingFiles []string
for _, file := range requiredFiles {
cmd := exec.Command("test", "-e", file)
if err := cmd.Run(); err != nil {
cmd := exec.CommandContext(context.Background(), "test", "-e", file)
err := cmd.Run()
if err != nil {
missingFiles = append(missingFiles, file)
}
}
@ -350,6 +362,7 @@ func displayDoctorResults(results []DoctorResult) {
for _, result := range results {
var icon string
switch result.Status {
case "PASS":
icon = "✅"

View file

@ -79,13 +79,18 @@ func main() {
}
func cleanAll(ctx context.Context) error {
if err := killTestContainers(ctx); err != nil {
err := killTestContainers(ctx)
if err != nil {
return err
}
if err := pruneDockerNetworks(ctx); err != nil {
err = pruneDockerNetworks(ctx)
if err != nil {
return err
}
if err := cleanOldImages(ctx); err != nil {
err = cleanOldImages(ctx)
if err != nil {
return err
}

View file

@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error {
if runConfig.Verbose {
log.Printf("Running pre-flight system checks...")
}
if err := runDoctorCheck(env.Context()); err != nil {
err := runDoctorCheck(env.Context())
if err != nil {
return fmt.Errorf("pre-flight checks failed: %w", err)
}
@ -66,8 +68,10 @@ func runIntegrationTest(env *command.Env) error {
func detectGoVersion() string {
goModPath := filepath.Join("..", "..", "go.mod")
//nolint:noinlineerr
if _, err := os.Stat("go.mod"); err == nil {
goModPath = "go.mod"
//nolint:noinlineerr
} else if _, err := os.Stat("../../go.mod"); err == nil {
goModPath = "../../go.mod"
}
@ -94,8 +98,10 @@ func detectGoVersion() string {
// splitLines splits a string into lines without using strings.Split.
func splitLines(s string) []string {
var lines []string
var current string
var (
lines []string
current string
)
for _, char := range s {
if char == '\n' {

View file

@ -1,23 +1,32 @@
package main
import (
"cmp"
"context"
"encoding/json"
"errors"
"fmt"
"log"
"sort"
"slices"
"strings"
"sync"
"time"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/events"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/client"
)
// ErrStatsCollectionAlreadyStarted is returned when stats collection is already running.
var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started")
// Stats calculation constants.
const (
bytesPerKB = 1024
percentageMultiplier = 100.0
)
// ContainerStats represents statistics for a single container.
type ContainerStats struct {
ContainerID string
@ -63,17 +72,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
defer sc.mutex.Unlock()
if sc.collectionStarted {
return errors.New("stats collection already started")
return ErrStatsCollectionAlreadyStarted
}
sc.collectionStarted = true
// Start monitoring existing containers
sc.wg.Add(1)
go sc.monitorExistingContainers(ctx, runID, verbose)
// Start Docker events monitoring for new containers
sc.wg.Add(1)
go sc.monitorDockerEvents(ctx, runID, verbose)
if verbose {
@ -87,10 +98,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
func (sc *StatsCollector) StopCollection() {
// Check if already stopped without holding lock
sc.mutex.RLock()
if !sc.collectionStarted {
sc.mutex.RUnlock()
return
}
sc.mutex.RUnlock()
// Signal stop to all goroutines
@ -114,6 +127,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
if verbose {
log.Printf("Failed to list existing containers: %v", err)
}
return
}
@ -147,13 +161,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
case event := <-events:
if event.Type == "container" && event.Action == "start" {
// Get container details
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
containerInfo, err := sc.client.ContainerInspect(ctx, event.Actor.ID)
if err != nil {
continue
}
// Convert to types.Container format for consistency
cont := types.Container{
// Convert to container.Summary format for consistency
cont := container.Summary{
ID: containerInfo.ID,
Names: []string{containerInfo.Name},
Labels: containerInfo.Config.Labels,
@ -167,13 +181,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
if verbose {
log.Printf("Error in Docker events stream: %v", err)
}
return
}
}
}
// shouldMonitorContainer determines if a container should be monitored.
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
func (sc *StatsCollector) shouldMonitorContainer(cont container.Summary, runID string) bool {
// Check if it has the correct run ID label
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
return false
@ -213,6 +228,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
}
sc.wg.Add(1)
go sc.collectStatsForContainer(ctx, containerID, verbose)
}
@ -226,12 +242,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
if verbose {
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
}
return
}
defer statsResponse.Body.Close()
decoder := json.NewDecoder(statsResponse.Body)
var prevStats *container.Stats
var prevStats *container.StatsResponse
for {
select {
@ -240,12 +258,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
case <-ctx.Done():
return
default:
var stats container.Stats
if err := decoder.Decode(&stats); err != nil {
var stats container.StatsResponse
err := decoder.Decode(&stats)
if err != nil {
// EOF is expected when container stops or stream ends
if err.Error() != "EOF" && verbose {
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
}
return
}
@ -256,13 +277,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
}
// Calculate memory usage in MB
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
memoryMB := float64(stats.MemoryStats.Usage) / (bytesPerKB * bytesPerKB)
// Store the sample (skip first sample since CPU calculation needs previous stats)
if prevStats != nil {
// Get container stats reference without holding the main mutex
var containerStats *ContainerStats
var exists bool
var (
containerStats *ContainerStats
exists bool
)
sc.mutex.RLock()
containerStats, exists = sc.containers[containerID]
@ -286,7 +309,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
}
// calculateCPUPercent calculates CPU usage percentage from Docker stats.
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
func calculateCPUPercent(prevStats, stats *container.StatsResponse) float64 {
// CPU calculation based on Docker's implementation
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
@ -299,7 +322,7 @@ func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
numCPUs = 1.0
}
return (cpuDelta / systemDelta) * numCPUs * 100.0
return (cpuDelta / systemDelta) * numCPUs * percentageMultiplier
}
return 0.0
@ -331,10 +354,12 @@ type StatsSummary struct {
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
// Take snapshot of container references without holding main lock long
sc.mutex.RLock()
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
for _, containerStats := range sc.containers {
containerRefs = append(containerRefs, containerStats)
}
sc.mutex.RUnlock()
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
@ -371,8 +396,8 @@ func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
}
// Sort by container name for consistent output
sort.Slice(summaries, func(i, j int) bool {
return summaries[i].ContainerName < summaries[j].ContainerName
slices.SortFunc(summaries, func(a, b ContainerStatsSummary) int {
return cmp.Compare(a.ContainerName, b.ContainerName)
})
return summaries
@ -384,23 +409,25 @@ func calculateStatsSummary(values []float64) StatsSummary {
return StatsSummary{}
}
min := values[0]
max := values[0]
minVal := values[0]
maxVal := values[0]
sum := 0.0
for _, value := range values {
if value < min {
min = value
if value < minVal {
minVal = value
}
if value > max {
max = value
if value > maxVal {
maxVal = value
}
sum += value
}
return StatsSummary{
Min: min,
Max: max,
Min: minVal,
Max: maxVal,
Average: sum / float64(len(values)),
}
}
@ -434,6 +461,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
}
summaries := sc.GetSummary()
var violations []MemoryViolation
for _, summary := range summaries {

View file

@ -2,6 +2,7 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"os"
@ -11,6 +12,8 @@ import (
"github.com/juanfont/headscale/integration/integrationutil"
)
var errDirectoryRequired = errors.New("directory is required")
type MapConfig struct {
Directory string `flag:"directory,Directory to read map responses from"`
}
@ -40,7 +43,7 @@ func main() {
// runIntegrationTest executes the integration test workflow.
func runOnline(env *command.Env) error {
if mapConfig.Directory == "" {
return fmt.Errorf("directory is required")
return errDirectoryRequired
}
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
@ -57,5 +60,6 @@ func runOnline(env *command.Env) error {
os.Stderr.Write(out)
os.Stderr.Write([]byte("\n"))
return nil
}

8
flake.lock generated
View file

@ -20,16 +20,16 @@
},
"nixpkgs": {
"locked": {
"lastModified": 1766840161,
"narHash": "sha256-Ss/LHpJJsng8vz1Pe33RSGIWUOcqM1fjrehjUkdrWio=",
"lastModified": 1769011238,
"narHash": "sha256-WPiOcgZv7GQ/AVd9giOrlZjzXHwBNM4yQ+JzLrgI3Xk=",
"owner": "NixOS",
"repo": "nixpkgs",
"rev": "3edc4a30ed3903fdf6f90c837f961fa6b49582d1",
"rev": "a895ec2c048eba3bceab06d5dfee5026a6b1c875",
"type": "github"
},
"original": {
"owner": "NixOS",
"ref": "nixpkgs-unstable",
"ref": "master",
"repo": "nixpkgs",
"type": "github"
}

View file

@ -2,7 +2,8 @@
description = "headscale - Open Source Tailscale Control server";
inputs = {
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
# TODO: Move back to nixpkgs-unstable once Go 1.26 is available there
nixpkgs.url = "github:NixOS/nixpkgs/master";
flake-utils.url = "github:numtide/flake-utils";
};
@ -23,11 +24,11 @@
default = headscale;
};
overlay = _: prev:
overlays.default = _: prev:
let
pkgs = nixpkgs.legacyPackages.${prev.system};
buildGo = pkgs.buildGo125Module;
vendorHash = "sha256-escboufgbk+lEitw48eWEIltXbaCPdysb/g4YR+extg=";
buildGo = pkgs.buildGo126Module;
vendorHash = "sha256-hL9vHunaxodGt3g/CIVirXy4OjZKTI3XwbVPPRb34OY=";
in
{
headscale = buildGo {
@ -129,10 +130,10 @@
(system:
let
pkgs = import nixpkgs {
overlays = [ self.overlay ];
overlays = [ self.overlays.default ];
inherit system;
};
buildDeps = with pkgs; [ git go_1_25 gnumake ];
buildDeps = with pkgs; [ git go_1_26 gnumake ];
devDeps = with pkgs;
buildDeps
++ [
@ -167,7 +168,7 @@
clang-tools # clang-format
protobuf-language-server
]
++ lib.optional pkgs.stdenv.isLinux [ traceroute ];
++ lib.optional pkgs.stdenv.hostPlatform.isLinux [ traceroute ];
# Add entry to build a docker image with headscale
# caveat: only works on Linux
@ -184,7 +185,7 @@
in
rec {
# `nix develop`
devShell = pkgs.mkShell {
devShells.default = pkgs.mkShell {
buildInputs =
devDeps
++ [
@ -219,8 +220,8 @@
packages = with pkgs; {
inherit headscale;
inherit headscale-docker;
default = headscale;
};
defaultPackage = pkgs.headscale;
# `nix run`
apps.headscale = flake-utils.lib.mkApp {

2
go.mod
View file

@ -1,6 +1,6 @@
module github.com/juanfont/headscale
go 1.25
go 1.26rc2
require (
github.com/arl/statsviz v0.7.2

View file

@ -67,6 +67,9 @@ var (
)
)
// oidcProviderInitTimeout is the timeout for initializing the OIDC provider.
const oidcProviderInitTimeout = 30 * time.Second
var (
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
@ -142,6 +145,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
if !ok {
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
return
}
@ -157,10 +161,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
app.ephemeralGC = ephemeralGC
var authProvider AuthProvider
authProvider = NewAuthProviderWeb(cfg.ServerURL)
if cfg.OIDC.Issuer != "" {
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), oidcProviderInitTimeout)
defer cancel()
oidcProvider, err := NewAuthProviderOIDC(
ctx,
&app,
@ -177,6 +183,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
authProvider = oidcProvider
}
}
app.authProvider = authProvider
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
lastExpiryCheck := time.Unix(0, 0)
derpTickerChan := make(<-chan time.Time)
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
defer derpTicker.Stop()
derpTickerChan = derpTicker.C
}
@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
return
case <-expireTicker.C:
var expiredNodeChanges []change.Change
var changed bool
var (
expiredNodeChanges []change.Change
changed bool
)
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
@ -287,11 +298,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
case <-derpTickerChan:
log.Info().Msg("Fetching DERPMap updates")
//nolint:contextcheck
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
if err != nil {
return nil, err
}
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
region, _ := h.DERPServer.GenerateRegion()
derpMap.Regions[region.RegionID] = &region
@ -303,6 +317,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
continue
}
h.state.SetDERPMap(derpMap)
h.Change(change.DERPMap())
@ -311,6 +326,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
if !ok {
continue
}
h.cfg.TailcfgDNSConfig.ExtraRecords = records
h.Change(change.ExtraRecords())
@ -390,6 +406,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
writeUnauthorized := func(statusCode int) {
writer.WriteHeader(statusCode)
//nolint:noinlineerr
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
log.Error().Err(err).Msg("writing HTTP response failed")
}
@ -486,6 +504,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
// Serve launches the HTTP and gRPC server service Headscale and the API.
func (h *Headscale) Serve() error {
var err error
capver.CanOldCodeBeCleanedUp()
if profilingEnabled {
@ -512,6 +531,7 @@ func (h *Headscale) Serve() error {
Msg("Clients with a lower minimum version will be rejected")
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
h.mapBatcher.Start()
defer h.mapBatcher.Close()
@ -545,6 +565,7 @@ func (h *Headscale) Serve() error {
// around between restarts, they will reconnect and the GC will
// be cancelled.
go h.ephemeralGC.Start()
ephmNodes := h.state.ListEphemeralNodes()
for _, node := range ephmNodes.All() {
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
@ -555,7 +576,9 @@ func (h *Headscale) Serve() error {
if err != nil {
return fmt.Errorf("setting up extrarecord manager: %w", err)
}
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
go h.extraRecordMan.Run()
defer h.extraRecordMan.Close()
}
@ -564,6 +587,7 @@ func (h *Headscale) Serve() error {
// records updates
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
defer scheduleCancel()
go h.scheduledTasks(scheduleCtx)
if zl.GlobalLevel() == zl.TraceLevel {
@ -751,7 +775,6 @@ func (h *Headscale) Serve() error {
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
}
var tailsqlContext context.Context
if tailsqlEnabled {
if h.cfg.Database.Type != types.DatabaseSqlite {
@ -863,6 +886,8 @@ func (h *Headscale) Serve() error {
// Close state connections
info("closing state and database")
//nolint:contextcheck
err = h.state.Close()
if err != nil {
log.Error().Err(err).Msg("failed to close state")

View file

@ -16,12 +16,11 @@ import (
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
type AuthProvider interface {
RegisterHandler(http.ResponseWriter, *http.Request)
AuthURL(types.RegistrationID) string
RegisterHandler(w http.ResponseWriter, r *http.Request)
AuthURL(regID types.RegistrationID) string
}
func (h *Headscale) handleRegister(
@ -52,6 +51,7 @@ func (h *Headscale) handleRegister(
if err != nil {
return nil, fmt.Errorf("handling logout: %w", err)
}
if resp != nil {
return resp, nil
}
@ -113,8 +113,7 @@ func (h *Headscale) handleRegister(
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
if err != nil {
// Preserve HTTPError types so they can be handled properly by the HTTP layer
var httpErr HTTPError
if errors.As(err, &httpErr) {
if httpErr, ok := errors.AsType[HTTPError](err); ok {
return nil, httpErr
}
@ -133,7 +132,7 @@ func (h *Headscale) handleRegister(
}
// handleLogout checks if the [tailcfg.RegisterRequest] is a
// logout attempt from a node. If the node is not attempting to
// logout attempt from a node. If the node is not attempting to.
func (h *Headscale) handleLogout(
node types.NodeView,
req tailcfg.RegisterRequest,
@ -160,6 +159,7 @@ func (h *Headscale) handleLogout(
Interface("reg.req", req).
Bool("unexpected", true).
Msg("Node key expired, forcing re-authentication")
return &tailcfg.RegisterResponse{
NodeKeyExpired: true,
MachineAuthorized: false,
@ -279,6 +279,7 @@ func (h *Headscale) waitForFollowup(
// registration is expired in the cache, instruct the client to try a new registration
return h.reqToNewRegisterResponse(req, machineKey)
}
return nodeToRegisterResponse(node.View()), nil
}
}
@ -316,7 +317,7 @@ func (h *Headscale) reqToNewRegisterResponse(
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: ptr.To(time.Now()),
LastSeen: new(time.Now()),
},
)
@ -344,8 +345,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
}
var perr types.PAKError
if errors.As(err, &perr) {
if perr, ok := errors.AsType[types.PAKError](err); ok {
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
}
@ -355,7 +356,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
// If node is not valid, it means an ephemeral node was deleted during logout
if !node.Valid() {
h.Change(changed)
return nil, nil
return nil, nil //nolint:nilnil
}
// This is a bit of a back and forth, but we have a bit of a chicken and egg
@ -435,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive(
Str("generated.hostname", hostname).
Msg("Received registration request with empty hostname, generated default")
}
hostinfo.Hostname = hostname
nodeToRegister := types.NewRegisterNode(
@ -443,7 +445,7 @@ func (h *Headscale) handleRegisterInteractive(
MachineKey: machineKey,
NodeKey: req.NodeKey,
Hostinfo: hostinfo,
LastSeen: ptr.To(time.Now()),
LastSeen: new(time.Now()),
},
)

View file

@ -2,6 +2,7 @@ package hscontrol
import (
"context"
"errors"
"fmt"
"net/url"
"strings"
@ -16,14 +17,20 @@ import (
"tailscale.com/types/key"
)
// Interactive step type constants
// Test sentinel errors.
var (
errNodeNotFoundAfterSetup = errors.New("node not found after setup")
errInvalidAuthURLFormat = errors.New("invalid AuthURL format")
)
// Interactive step type constants.
const (
stepTypeInitialRequest = "initial_request"
stepTypeAuthCompletion = "auth_completion"
stepTypeFollowupRequest = "followup_request"
)
// interactiveStep defines a step in the interactive authentication workflow
// interactiveStep defines a step in the interactive authentication workflow.
type interactiveStep struct {
stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest
expectAuthURL bool
@ -31,6 +38,7 @@ type interactiveStep struct {
callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked
}
//nolint:gocyclo
func TestAuthenticationFlows(t *testing.T) {
// Shared test keys for consistent behavior across test cases
machineKey1 := key.NewMachine()
@ -69,12 +77,15 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "preauth_key_valid_new_node",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
t.Helper()
user := app.state.CreateUserForTest("preauth-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -89,7 +100,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -111,6 +122,8 @@ func TestAuthenticationFlows(t *testing.T) {
{
name: "preauth_key_reusable_multiple_nodes",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
t.Helper()
user := app.state.CreateUserForTest("reusable-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
@ -129,6 +142,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
if err != nil {
return "", err
@ -154,7 +168,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey2.Public() },
machineKey: machineKey2.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -163,6 +177,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify both nodes exist
node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
assert.True(t, found1)
assert.True(t, found2)
assert.Equal(t, "reusable-node-1", node1.Hostname())
@ -196,6 +211,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
if err != nil {
return "", err
@ -221,12 +237,13 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey2.Public() },
machineKey: machineKey2.Public,
wantError: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
// First node should exist, second should not
_, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
_, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
assert.True(t, found1)
assert.False(t, found2)
},
@ -254,7 +271,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
@ -272,6 +289,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -286,7 +304,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -323,7 +341,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -352,7 +370,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -391,6 +409,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -400,8 +419,10 @@ func TestAuthenticationFlows(t *testing.T) {
// Wait for node to be available in NodeStore with debug info
var attemptCount int
require.EventuallyWithT(t, func(c *assert.CollectT) {
attemptCount++
_, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
if assert.True(c, found, "node should be available in NodeStore") {
t.Logf("Node found in NodeStore after %d attempts", attemptCount)
@ -417,7 +438,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
wantExpired: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
@ -451,6 +472,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -471,7 +493,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(-1 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key
machineKey: machineKey2.Public, // Different machine key
wantError: true,
},
// TEST: Existing node cannot extend expiry without re-auth
@ -500,6 +522,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -520,7 +543,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
// TEST: Expired node must re-authenticate
@ -549,25 +572,31 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
}
// Wait for node to be available in NodeStore
var node types.NodeView
var found bool
var (
node types.NodeView
found bool
)
require.EventuallyWithT(t, func(c *assert.CollectT) {
node, found = app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(c, found, "node should be available in NodeStore")
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
if !found {
return "", fmt.Errorf("node not found after setup")
return "", errNodeNotFoundAfterSetup
}
// Expire the node
expiredTime := time.Now().Add(-1 * time.Hour)
_, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime)
return "", err
},
request: func(_ string) tailcfg.RegisterRequest {
@ -577,7 +606,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour), // Future expiry
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantExpired: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.NodeKeyExpired)
@ -610,6 +639,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -630,7 +660,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(-1 * time.Hour), // Logout
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantExpired: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.NodeKeyExpired)
@ -673,6 +703,7 @@ func TestAuthenticationFlows(t *testing.T) {
// and handleRegister will receive the value when it starts waiting
go func() {
user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node")
registered <- node
}()
@ -685,7 +716,7 @@ func TestAuthenticationFlows(t *testing.T) {
NodeKey: nodeKey1.Public(),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -723,7 +754,7 @@ func TestAuthenticationFlows(t *testing.T) {
NodeKey: nodeKey1.Public(),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
// TEST: Invalid followup URL is rejected
@ -742,7 +773,7 @@ func TestAuthenticationFlows(t *testing.T) {
NodeKey: nodeKey1.Public(),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
// TEST: Non-existent registration ID is rejected
@ -761,7 +792,7 @@ func TestAuthenticationFlows(t *testing.T) {
NodeKey: nodeKey1.Public(),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
@ -782,6 +813,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -796,7 +828,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -821,6 +853,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -833,7 +866,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -865,6 +898,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -879,7 +913,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
@ -898,6 +932,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -912,7 +947,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -922,6 +957,7 @@ func TestAuthenticationFlows(t *testing.T) {
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found)
assert.Equal(t, "tagged-pak-node", node.Hostname())
if node.AuthKey().Valid() {
assert.NotEmpty(t, node.AuthKey().Tags())
}
@ -1031,6 +1067,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1047,6 +1084,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak2.Key, nil
},
request: func(newAuthKey string) tailcfg.RegisterRequest {
@ -1061,7 +1099,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(48 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1099,6 +1137,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1124,7 +1163,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(48 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuthURL: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.Contains(t, resp.AuthURL, "register/")
@ -1161,6 +1200,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1177,6 +1217,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pakRotation.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1191,7 +1232,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1226,6 +1267,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1240,7 +1282,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Time{}, // Zero time
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1265,6 +1307,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1286,7 +1329,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1342,7 +1385,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: false, // Should not be authorized yet - needs to use new AuthURL
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
// Should get a new AuthURL, not an error
@ -1367,7 +1410,7 @@ func TestAuthenticationFlows(t *testing.T) {
NodeKey: nodeKey1.Public(),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
// TEST: Wrong followup path format is rejected
@ -1386,7 +1429,7 @@ func TestAuthenticationFlows(t *testing.T) {
NodeKey: nodeKey1.Public(),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantError: true,
},
@ -1417,7 +1460,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -1429,6 +1472,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify custom hostinfo was preserved through interactive workflow
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be found after interactive registration")
if found {
assert.Equal(t, "custom-interactive-node", node.Hostname())
assert.Equal(t, "linux", node.Hostinfo().OS())
@ -1455,6 +1499,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1469,7 +1514,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1508,7 +1553,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -1520,6 +1565,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify registration ID was properly generated and used
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be registered after interactive workflow")
if found {
assert.Equal(t, "registration-id-test-node", node.Hostname())
assert.Equal(t, "test-os", node.Hostinfo().OS())
@ -1535,6 +1581,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1549,7 +1596,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1577,6 +1624,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
@ -1592,7 +1640,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(12 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1632,6 +1680,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1648,6 +1697,7 @@ func TestAuthenticationFlows(t *testing.T) {
if err != nil {
return "", err
}
return pak2.Key, nil
},
request: func(user2AuthKey string) tailcfg.RegisterRequest {
@ -1662,7 +1712,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
assert.True(t, resp.MachineAuthorized)
@ -1712,6 +1762,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
if err != nil {
return "", err
@ -1735,7 +1786,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key
machineKey: machineKey1.Public, // Same machine key
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -1789,7 +1840,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: false, // Should not be authorized yet - needs to use new AuthURL
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
// Should get a new AuthURL, not an error
@ -1799,13 +1850,13 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify the response contains a valid registration URL
authURL, err := url.Parse(resp.AuthURL)
assert.NoError(t, err, "AuthURL should be a valid URL")
require.NoError(t, err, "AuthURL should be a valid URL")
assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/")
// Extract and validate the new registration ID exists in cache
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
assert.NoError(t, err, "should be able to parse new registration ID")
require.NoError(t, err, "should be able to parse new registration ID")
// Verify new registration entry exists in cache
_, found := app.state.GetRegistrationCacheEntry(newRegID)
@ -1838,6 +1889,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
if err != nil {
return "", err
@ -1858,7 +1910,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now(), // Exactly now (edge case between past and future)
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
wantAuth: true,
wantExpired: true,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
@ -1890,7 +1942,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey2.Public() },
machineKey: machineKey2.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -1932,6 +1984,7 @@ func TestAuthenticationFlows(t *testing.T) {
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
if err != nil {
return "", err
@ -1955,7 +2008,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -2001,7 +2054,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -2055,7 +2108,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
// This test validates concurrent interactive registration attempts
assert.Contains(t, resp.AuthURL, "/register/")
@ -2097,6 +2150,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Collect results - at least one should succeed
successCount := 0
for range numConcurrent {
select {
case err := <-results:
@ -2162,7 +2216,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -2206,7 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
requiresInteractiveFlow: true,
interactiveSteps: []interactiveStep{
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
@ -2217,6 +2271,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Should handle nil hostinfo gracefully
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be registered despite nil hostinfo")
if found {
// Should have some default hostname or handle nil gracefully
hostname := node.Hostname()
@ -2243,7 +2298,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
// Get initial AuthURL and extract registration ID
authURL := resp.AuthURL
@ -2265,7 +2320,7 @@ func TestAuthenticationFlows(t *testing.T) {
nil,
"error-test-method",
)
assert.Error(t, err, "should fail with invalid user ID")
require.Error(t, err, "should fail with invalid user ID")
// Cache entry should still exist after auth error (for retry scenarios)
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
@ -2297,7 +2352,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
// Test multiple interactive registration attempts for the same node can coexist
authURL1 := resp.AuthURL
@ -2315,12 +2370,14 @@ func TestAuthenticationFlows(t *testing.T) {
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
require.NoError(t, err)
authURL2 := resp2.AuthURL
assert.Contains(t, authURL2, "/register/")
// Both should have different registration IDs
regID1, err1 := extractRegistrationIDFromAuthURL(authURL1)
regID2, err2 := extractRegistrationIDFromAuthURL(authURL2)
require.NoError(t, err1)
require.NoError(t, err2)
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
@ -2328,6 +2385,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Both cache entries should exist simultaneously
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
assert.True(t, found1, "first registration cache entry should exist")
assert.True(t, found2, "second registration cache entry should exist")
@ -2354,7 +2412,7 @@ func TestAuthenticationFlows(t *testing.T) {
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: func() key.MachinePublic { return machineKey1.Public() },
machineKey: machineKey1.Public,
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
authURL1 := resp.AuthURL
regID1, err := extractRegistrationIDFromAuthURL(authURL1)
@ -2371,6 +2429,7 @@ func TestAuthenticationFlows(t *testing.T) {
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
require.NoError(t, err)
authURL2 := resp2.AuthURL
regID2, err := extractRegistrationIDFromAuthURL(authURL2)
require.NoError(t, err)
@ -2378,6 +2437,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify both exist
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
assert.True(t, found1, "first cache entry should exist")
assert.True(t, found2, "second cache entry should exist")
@ -2403,6 +2463,7 @@ func TestAuthenticationFlows(t *testing.T) {
errorChan <- err
return
}
responseChan <- resp
}()
@ -2430,6 +2491,7 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify the node was created with the second registration's data
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
assert.True(t, found, "node should be registered")
if found {
assert.Equal(t, "pending-node-2", node.Hostname())
assert.Equal(t, "second-registration-user", node.User().Name())
@ -2463,8 +2525,10 @@ func TestAuthenticationFlows(t *testing.T) {
// Set up context with timeout for followup tests
ctx := context.Background()
if req.Followup != "" {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
}
@ -2516,7 +2580,7 @@ func TestAuthenticationFlows(t *testing.T) {
}
}
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow.
func runInteractiveWorkflowTest(t *testing.T, tt struct {
name string
setupFunc func(*testing.T, *Headscale) (string, error)
@ -2535,6 +2599,8 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
validateCompleteResponse bool
}, app *Headscale, dynamicValue string,
) {
t.Helper()
// Build initial request
req := tt.request(dynamicValue)
machineKey := tt.machineKey()
@ -2597,6 +2663,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
errorChan <- err
return
}
responseChan <- resp
}()
@ -2650,25 +2717,30 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
if responseToValidate == nil {
responseToValidate = initialResp
}
tt.validate(t, responseToValidate, app)
}
}
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
// AuthURL format: "http://localhost/register/abc123"
const registerPrefix = "/register/"
idx := strings.LastIndex(authURL, registerPrefix)
if idx == -1 {
return "", fmt.Errorf("invalid AuthURL format: %s", authURL)
return "", fmt.Errorf("%w: %s", errInvalidAuthURLFormat, authURL)
}
idStr := authURL[idx+len(registerPrefix):]
return types.RegistrationIDFromString(idStr)
}
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response
func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) {
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, _ tailcfg.RegisterRequest) {
t.Helper()
// Basic response validation
require.NotNil(t, resp, "response should not be nil")
require.True(t, resp.MachineAuthorized, "machine should be authorized")
@ -2681,7 +2753,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe
// Additional validation can be added here as needed
}
// Simple test to validate basic node creation and lookup
// Simple test to validate basic node creation and lookup.
func TestNodeStoreLookup(t *testing.T) {
app := createTestApp(t)
@ -2713,8 +2785,10 @@ func TestNodeStoreLookup(t *testing.T) {
// Wait for node to be available in NodeStore
var node types.NodeView
require.EventuallyWithT(t, func(c *assert.CollectT) {
var found bool
node, found = app.state.GetNodeByNodeKey(nodeKey.Public())
assert.True(c, found, "Node should be found in NodeStore")
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore")
@ -2783,8 +2857,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Get the node ID
var registeredNode types.NodeView
require.EventuallyWithT(t, func(c *assert.CollectT) {
var found bool
registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public())
assert.True(c, found, "Node should be found in NodeStore")
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available")
@ -2796,6 +2872,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Verify initial state: user1 has 2 nodes, user2 has 2 nodes
user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID))
user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID))
require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially")
require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially")
@ -2876,6 +2953,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Verify new nodes were created for user1 with the same machine keys
t.Logf("Verifying new nodes created for user1 from user2's machine keys...")
for i := 2; i < 4; i++ {
node := nodes[i]
// Should be able to find a node with user1 and this machine key (the new one)
@ -2899,7 +2977,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Expected behavior:
// - User1's original node should STILL EXIST (expired)
// - User2 should get a NEW node created (NOT transfer)
// - Both nodes share the same machine key (same physical device)
// - Both nodes share the same machine key (same physical device).
func TestWebFlowReauthDifferentUser(t *testing.T) {
machineKey := key.NewMachine()
nodeKey1 := key.NewNode()
@ -3043,7 +3121,8 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
// Count nodes per user
user1Nodes := 0
user2Nodes := 0
for i := 0; i < allNodesSlice.Len(); i++ {
for i := range allNodesSlice.Len() {
n := allNodesSlice.At(i)
if n.UserID().Get() == user1.ID {
user1Nodes++
@ -3060,7 +3139,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
})
}
// Helper function to create test app
// Helper function to create test app.
func createTestApp(t *testing.T) *Headscale {
t.Helper()
@ -3147,6 +3226,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
}
t.Log("Step 1: Initial registration with pre-auth key")
initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public())
require.NoError(t, err, "initial registration should succeed")
require.NotNil(t, initialResp)
@ -3172,6 +3252,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// - System reboots
// The Tailscale client persists the pre-auth key in its state and sends it on every registration
t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key")
restartReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pakNew.Key, // Same key, now marked as Used=true
@ -3188,10 +3269,12 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public())
// This is the assertion that currently FAILS in v0.27.0
assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed")
require.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed")
if err != nil {
t.Logf("Error received (this is the bug): %v", err)
t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially")
return // Stop here to show the bug clearly
}
@ -3289,7 +3372,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) {
}
_, err = app.handleRegister(context.Background(), req, machineKey.Public())
assert.Error(t, err, "expired pre-auth key should be rejected")
require.Error(t, err, "expired pre-auth key should be rejected")
assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration")
}

View file

@ -13,7 +13,7 @@ import (
)
const (
apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential
apiKeyPrefix = "hskey-api-" //nolint:gosec
apiKeyPrefixLength = 12
apiKeyHashLength = 64

View file

@ -24,7 +24,6 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/gorm/schema"
"tailscale.com/net/tsaddr"
"zgo.at/zcache/v2"
)
@ -76,6 +75,7 @@ func NewHeadscaleDatabase(
ID: "202501221827",
Migrate: func(tx *gorm.DB) error {
// Remove any invalid routes associated with a node that does not exist.
//nolint:staticcheck
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
if err != nil {
@ -84,6 +84,7 @@ func NewHeadscaleDatabase(
}
// Remove any invalid routes without a node_id.
//nolint:staticcheck
if tx.Migrator().HasTable(&types.Route{}) {
err := tx.Exec("delete from routes where node_id is null").Error
if err != nil {
@ -91,6 +92,7 @@ func NewHeadscaleDatabase(
}
}
//nolint:staticcheck
err := tx.AutoMigrate(&types.Route{})
if err != nil {
return fmt.Errorf("automigrating types.Route: %w", err)
@ -109,6 +111,7 @@ func NewHeadscaleDatabase(
if err != nil {
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
}
err = tx.AutoMigrate(&types.Node{})
if err != nil {
return fmt.Errorf("automigrating types.Node: %w", err)
@ -155,7 +158,9 @@ AND auth_key_id NOT IN (
nodeRoutes := map[uint64][]netip.Prefix{}
//nolint:staticcheck
var routes []types.Route
err = tx.Find(&routes).Error
if err != nil {
return fmt.Errorf("fetching routes: %w", err)
@ -168,10 +173,13 @@ AND auth_key_id NOT IN (
}
for nodeID, routes := range nodeRoutes {
tsaddr.SortPrefixes(routes)
slices.SortFunc(routes, netip.Prefix.Compare)
routes = slices.Compact(routes)
data, err := json.Marshal(routes)
if err != nil {
return fmt.Errorf("marshaling routes for node %d: %w", nodeID, err)
}
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
if err != nil {
@ -180,6 +188,7 @@ AND auth_key_id NOT IN (
}
// Drop the old table.
//nolint:staticcheck
_ = tx.Migrator().DropTable(&types.Route{})
return nil
@ -256,10 +265,13 @@ AND auth_key_id NOT IN (
// Check if routes table exists and drop it (should have been migrated already)
var routesExists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
if err == nil && routesExists {
log.Info().Msg("Dropping leftover routes table")
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
err := tx.Exec("DROP TABLE routes").Error
if err != nil {
return fmt.Errorf("dropping routes table: %w", err)
}
}
@ -281,6 +293,7 @@ AND auth_key_id NOT IN (
for _, table := range tablesToRename {
// Check if table exists before renaming
var exists bool
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
if err != nil {
return fmt.Errorf("checking if table %s exists: %w", table, err)
@ -291,7 +304,8 @@ AND auth_key_id NOT IN (
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
// Rename current table to _old
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
if err != nil {
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
}
}
@ -365,7 +379,8 @@ AND auth_key_id NOT IN (
}
for _, createSQL := range tableCreationSQL {
if err := tx.Exec(createSQL).Error; err != nil {
err := tx.Exec(createSQL).Error
if err != nil {
return fmt.Errorf("creating new table: %w", err)
}
}
@ -394,7 +409,8 @@ AND auth_key_id NOT IN (
}
for _, copySQL := range dataCopySQL {
if err := tx.Exec(copySQL).Error; err != nil {
err := tx.Exec(copySQL).Error
if err != nil {
return fmt.Errorf("copying data: %w", err)
}
}
@ -417,14 +433,16 @@ AND auth_key_id NOT IN (
}
for _, indexSQL := range indexes {
if err := tx.Exec(indexSQL).Error; err != nil {
err := tx.Exec(indexSQL).Error
if err != nil {
return fmt.Errorf("creating index: %w", err)
}
}
// Drop old tables only after everything succeeds
for _, table := range tablesToRename {
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
if err != nil {
log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded")
}
}
@ -762,6 +780,7 @@ AND auth_key_id NOT IN (
// or else it blocks...
sqlConn.SetMaxIdleConns(maxIdleConns)
sqlConn.SetMaxOpenConns(maxOpenConns)
defer sqlConn.SetMaxIdleConns(1)
defer sqlConn.SetMaxOpenConns(1)
@ -779,6 +798,7 @@ AND auth_key_id NOT IN (
},
}
//nolint:noinlineerr
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
return nil, fmt.Errorf("validating schema: %w", err)
}
@ -913,6 +933,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
// Get the current foreign key status
var fkOriginallyEnabled int
//nolint:noinlineerr
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
return fmt.Errorf("checking foreign key status: %w", err)
}
@ -942,27 +963,32 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
if needsFKDisabled {
// Disable foreign keys for this migration
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
if err != nil {
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
}
} else {
// Ensure foreign keys are enabled for this migration
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
if err != nil {
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
}
}
// Run up to this specific migration (will only run the next pending migration)
if err := migrations.MigrateTo(migrationID); err != nil {
err := migrations.MigrateTo(migrationID)
if err != nil {
return fmt.Errorf("running migration %s: %w", migrationID, err)
}
}
//nolint:noinlineerr
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
return fmt.Errorf("restoring foreign keys: %w", err)
}
// Run the rest of the migrations
//nolint:noinlineerr
if err := migrations.Migrate(); err != nil {
return err
}
@ -1005,7 +1031,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
}
} else {
// PostgreSQL can run all migrations in one block - no foreign key issues
if err := migrations.Migrate(); err != nil {
err := migrations.Migrate()
if err != nil {
return err
}
}
@ -1031,7 +1058,7 @@ func (hsdb *HSDatabase) Close() error {
}
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
db.Exec("VACUUM")
_, _ = db.ExecContext(context.Background(), "VACUUM")
}
return db.Close()

View file

@ -1,6 +1,7 @@
package db
import (
"context"
"database/sql"
"os"
"os/exec"
@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
// Verify api_keys data preservation
var apiKeyCount int
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
require.NoError(t, err)
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
return err
}
_, err = db.Exec(string(schemaContent))
_, err = db.ExecContext(context.Background(), string(schemaContent))
return err
}
@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
func requireConstraintFailed(t *testing.T, err error) {
t.Helper()
require.Error(t, err)
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
}
@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
}
// Construct the pg_restore command
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
// Set the output streams
cmd.Stdout = os.Stdout
@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
// skip already-applied migrations and only run new ones.
func TestSQLiteAllTestdataMigrations(t *testing.T) {
t.Parallel()
schemas, err := os.ReadDir("testdata/sqlite")
require.NoError(t, err)

View file

@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Basic deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var deletionWg sync.WaitGroup
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
deletionWg sync.WaitGroup
)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionWg.Done()
}
@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
go gc.Start()
// Schedule several nodes for deletion with short expiry
const expiry = fifty
const numNodes = 100
const (
expiry = fifty
numNodes = 100
)
// Set up wait group for expected deletions
deletionWg.Add(numNodes)
for i := 1; i <= numNodes; i++ {
gc.Schedule(types.NodeID(i), expiry)
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec
}
// Wait for all scheduled deletions to complete
@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// Schedule and immediately cancel to test that part of the code
for i := numNodes + 1; i <= numNodes*2; i++ {
nodeID := types.NodeID(i)
nodeID := types.NodeID(i) //nolint:gosec
gc.Schedule(nodeID, time.Hour)
gc.Cancel(nodeID)
}
@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Start GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
defer gc.Close()
const shortExpiry = fifty
const longExpiry = 1 * time.Hour
const (
shortExpiry = fifty
longExpiry = 1 * time.Hour
)
nodeID := types.NodeID(1)
@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// and verifies that the node is deleted only once.
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
}
// Start the GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
defer gc.Close()
nodeID := types.NodeID(1)
const expiry = fifty
// Schedule node for deletion
@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
// Deletion tracking
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
nodeDeleted := make(chan struct{})
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
close(nodeDeleted) // Signal that deletion happened
}
@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Use a WaitGroup to ensure the GC has started
var startWg sync.WaitGroup
startWg.Add(1)
go func() {
startWg.Done() // Signal that the goroutine has started
gc.Start()
}()
startWg.Wait() // Wait for the GC to start
// Close GC right away
@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Check no node was deleted
deleteMutex.Lock()
nodesDeleted := len(deletedIDs)
deleteMutex.Unlock()
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
t.Logf("Initial number of goroutines: %d", initialGoroutines)
// Deletion tracking mechanism
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
var (
deletedIDs []types.NodeID
deleteMutex sync.Mutex
)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
}
@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
go gc.Start()
// Number of concurrent scheduling goroutines
const numSchedulers = 10
const nodesPerScheduler = 50
const (
numSchedulers = 10
nodesPerScheduler = 50
)
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
case <-stopScheduling:
return
default:
nodeID := types.NodeID(baseNodeID + j + 1)
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
atomic.AddInt64(&scheduledCount, 1)
// Yield to other goroutines to introduce variability

View file

@ -13,7 +13,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/net/tsaddr"
"tailscale.com/types/ptr"
)
var mpp = func(pref string) *netip.Prefix {
@ -484,12 +483,13 @@ func TestBackfillIPAddresses(t *testing.T) {
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
db, err := newSQLiteTestDB()
require.NoError(t, err)
defer db.Close()
alloc, err := NewIPAllocator(
db,
ptr.To(tsaddr.CGNATRange()),
ptr.To(tsaddr.TailscaleULARange()),
new(tsaddr.CGNATRange()),
new(tsaddr.TailscaleULARange()),
types.IPAllocationStrategySequential,
)
if err != nil {
@ -497,17 +497,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
}
// Validate that we do not give out 100.100.100.100
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange()))
require.NoError(t, err)
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
// Validate that we do not give out fd7a:115c:a1e0::53
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange()))
require.NoError(t, err)
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
// Validate that we do not give out fd7a:115c:a1e0::53
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange()))
t.Logf("chrome: %s", nextChrome.String())
require.NoError(t, err)
assert.Equal(t, na("100.115.94.0"), *nextChrome)

View file

@ -1,13 +1,13 @@
package db
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"net/netip"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"sync"
@ -20,7 +20,6 @@ import (
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
const (
@ -37,6 +36,7 @@ var (
"node not found in registration cache",
)
ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
ErrNameNotUnique = errors.New("name is not unique")
)
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
@ -60,7 +60,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
return types.Nodes{}, err
}
sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID })
slices.SortFunc(nodes, func(a, b *types.Node) int { return cmp.Compare(a.ID, b.ID) })
return nodes, nil
}
@ -207,6 +207,7 @@ func SetTags(
slices.Sort(tags)
tags = slices.Compact(tags)
b, err := json.Marshal(tags)
if err != nil {
return err
@ -220,7 +221,7 @@ func SetTags(
return nil
}
// SetTags takes a Node struct pointer and update the forced tags.
// SetApprovedRoutes updates the approved routes for a node.
func SetApprovedRoutes(
tx *gorm.DB,
nodeID types.NodeID,
@ -228,7 +229,8 @@ func SetApprovedRoutes(
) error {
if len(routes) == 0 {
// if no routes are provided, we remove all
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error
if err != nil {
return fmt.Errorf("removing approved routes: %w", err)
}
@ -251,6 +253,7 @@ func SetApprovedRoutes(
return err
}
//nolint:noinlineerr
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
return fmt.Errorf("updating approved routes: %w", err)
}
@ -277,18 +280,21 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
func RenameNode(tx *gorm.DB,
nodeID types.NodeID, newName string,
) error {
if err := util.ValidateHostname(newName); err != nil {
err := util.ValidateHostname(newName)
if err != nil {
return fmt.Errorf("renaming node: %w", err)
}
// Check if the new name is unique
var count int64
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error
if err != nil {
return fmt.Errorf("failed to check name uniqueness: %w", err)
}
if count > 0 {
return errors.New("name is not unique")
return ErrNameNotUnique
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
@ -379,6 +385,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
if ipv4 == nil {
ipv4 = oldNode.IPv4
}
if ipv6 == nil {
ipv6 = oldNode.IPv6
}
@ -407,6 +414,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
node.IPv6 = ipv6
var err error
node.Hostname, err = util.NormaliseHostname(node.Hostname)
if err != nil {
newHostname := util.InvalidString()
@ -491,8 +499,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
func isUniqueName(tx *gorm.DB, name string) (bool, error) {
nodes := types.Nodes{}
if err := tx.
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
err := tx.
Where("given_name = ?", name).Find(&nodes).Error
if err != nil {
return false, err
}
@ -646,7 +656,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
panic("CreateNodeForTest requires a valid user")
}
nodeName := "testnode"
nodeName := "testnode" //nolint:goconst
if len(hostname) > 0 && hostname[0] != "" {
nodeName = hostname[0]
}
@ -668,7 +678,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
Hostname: nodeName,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
AuthKeyID: &pak.ID,
}
err = hsdb.DB.Save(node).Error
@ -694,9 +704,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
}
var registeredNode *types.Node
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
var err error
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
return err
})
if err != nil {

View file

@ -22,7 +22,6 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestGetNode(t *testing.T) {
@ -99,6 +98,7 @@ func TestExpireNode(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
//nolint:staticcheck
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
@ -115,7 +115,7 @@ func TestExpireNode(t *testing.T) {
Hostname: "testnode",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
AuthKeyID: new(pak.ID),
Expiry: &time.Time{},
}
db.DB.Save(node)
@ -143,6 +143,7 @@ func TestSetTags(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
//nolint:staticcheck
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
@ -159,7 +160,7 @@ func TestSetTags(t *testing.T) {
Hostname: "testnode",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
AuthKeyID: new(pak.ID),
}
trx := db.DB.Save(node)
@ -443,7 +444,7 @@ func TestAutoApproveRoutes(t *testing.T) {
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.routes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
}
err = adb.DB.Save(&node).Error
@ -460,17 +461,17 @@ func TestAutoApproveRoutes(t *testing.T) {
RoutableIPs: tt.routes,
},
Tags: []string{"tag:exit"},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
IPv4: new(netip.MustParseAddr("100.64.0.2")),
}
err = adb.DB.Save(&nodeTagged).Error
require.NoError(t, err)
users, err := adb.ListUsers()
assert.NoError(t, err)
require.NoError(t, err)
nodes, err := adb.ListNodes()
assert.NoError(t, err)
require.NoError(t, err)
pm, err := pmf(users, nodes.ViewSlice())
require.NoError(t, err)
@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes1) == 0 {
expectedRoutes1 = nil
}
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) {
if len(expectedRoutes2) == 0 {
expectedRoutes2 = nil
}
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
}
@ -597,7 +600,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
// Use shorter expiry for faster tests
for i := range want {
go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec // test code, no overflow risk
go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec
}
// Wait for all deletions to complete
@ -636,9 +639,11 @@ func TestListEphemeralNodes(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
//nolint:staticcheck
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
//nolint:staticcheck
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
require.NoError(t, err)
@ -649,7 +654,7 @@ func TestListEphemeralNodes(t *testing.T) {
Hostname: "test",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
AuthKeyID: new(pak.ID),
}
nodeEph := types.Node{
@ -659,7 +664,7 @@ func TestListEphemeralNodes(t *testing.T) {
Hostname: "ephemeral",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pakEph.ID),
AuthKeyID: new(pakEph.ID),
}
err = db.DB.Save(&node).Error
@ -719,6 +724,7 @@ func TestNodeNaming(t *testing.T) {
// break your network, so they should be replaced when registering
// a node.
// https://github.com/juanfont/headscale/issues/2343
//nolint:gosmopolitan
nodeInvalidHostname := types.Node{
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
@ -746,12 +752,19 @@ func TestNodeNaming(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
return err
})
require.NoError(t, err)
@ -810,25 +823,26 @@ func TestNodeNaming(t *testing.T) {
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "test")
})
assert.ErrorContains(t, err, "name is not unique")
require.ErrorContains(t, err, "name is not unique")
// Rename invalid chars
//nolint:gosmopolitan
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[2].ID, "我的电脑")
})
assert.ErrorContains(t, err, "invalid characters")
require.ErrorContains(t, err, "invalid characters")
// Rename too short
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[3].ID, "a")
})
assert.ErrorContains(t, err, "at least 2 characters")
require.ErrorContains(t, err, "at least 2 characters")
// Rename with emoji
err = db.Write(func(tx *gorm.DB) error {
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
})
assert.ErrorContains(t, err, "invalid characters")
require.ErrorContains(t, err, "invalid characters")
// Rename with only emoji
err = db.Write(func(tx *gorm.DB) error {
@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
},
{
name: "chinese_chars_with_dash_rejected",
newName: "server-北京-01",
newName: "server-北京-01", //nolint:gosmopolitan
wantErr: "invalid characters",
},
{
name: "chinese_only_rejected",
newName: "我的电脑",
newName: "我的电脑", //nolint:gosmopolitan
wantErr: "invalid characters",
},
{
@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
},
{
name: "mixed_chinese_emoji_rejected",
newName: "测试💻机器",
newName: "测试💻机器", //nolint:gosmopolitan
wantErr: "invalid characters",
},
{
@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
return err
@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) {
if err != nil {
return err
}
_, err = RegisterNodeForTest(tx, node2, nil, nil)
return err

View file

@ -332,7 +332,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
return nil
}
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
// ExpirePreAuthKey marks a PreAuthKey as expired.
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
now := time.Now()
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error

View file

@ -11,7 +11,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/ptr"
)
func TestCreatePreAuthKey(t *testing.T) {
@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) {
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil)
_, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil)
assert.Error(t, err)
},
},
@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
Hostname: "testest",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(key.ID),
AuthKeyID: new(key.ID),
}
db.DB.Save(&node)

View file

@ -22,6 +22,9 @@ var (
const (
// DefaultBusyTimeout is the default busy timeout in milliseconds.
DefaultBusyTimeout = 10000
// DefaultWALAutocheckpoint is the default WAL autocheckpoint value (number of pages).
// SQLite default is 1000 pages.
DefaultWALAutocheckpoint = 1000
)
// JournalMode represents SQLite journal_mode pragma values.
@ -310,7 +313,7 @@ func Default(path string) *Config {
BusyTimeout: DefaultBusyTimeout,
JournalMode: JournalModeWAL,
AutoVacuum: AutoVacuumIncremental,
WALAutocheckpoint: 1000,
WALAutocheckpoint: DefaultWALAutocheckpoint,
Synchronous: SynchronousNormal,
ForeignKeys: true,
TxLock: TxLockImmediate,
@ -362,7 +365,8 @@ func (c *Config) Validate() error {
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
// compatible with modernc.org/sqlite driver.
func (c *Config) ToURL() (string, error) {
if err := c.Validate(); err != nil {
err := c.Validate()
if err != nil {
return "", fmt.Errorf("invalid config: %w", err)
}
@ -372,18 +376,23 @@ func (c *Config) ToURL() (string, error) {
if c.BusyTimeout > 0 {
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
}
if c.JournalMode != "" {
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
}
if c.AutoVacuum != "" {
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
}
if c.WALAutocheckpoint >= 0 {
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
}
if c.Synchronous != "" {
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
}
if c.ForeignKeys {
pragmas = append(pragmas, "foreign_keys=ON")
}

View file

@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
t.Errorf("Config.ToURL() error = %v", err)
return
}
if got != tt.want {
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
}
@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
Path: "",
BusyTimeout: -1,
}
_, err := config.ToURL()
if err == nil {
t.Error("Config.ToURL() with invalid config should return error")

View file

@ -1,6 +1,7 @@
package sqliteconfig
import (
"context"
"database/sql"
"path/filepath"
"strings"
@ -101,7 +102,8 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
defer db.Close()
// Test connection
if err := db.Ping(); err != nil {
//nolint:noinlineerr
if err := db.PingContext(context.Background()); err != nil {
t.Fatalf("Failed to ping database: %v", err)
}
@ -109,8 +111,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
for pragma, expectedValue := range tt.expected {
t.Run("pragma_"+pragma, func(t *testing.T) {
var actualValue any
query := "PRAGMA " + pragma
err := db.QueryRow(query).Scan(&actualValue)
err := db.QueryRowContext(context.Background(), query).Scan(&actualValue)
if err != nil {
t.Fatalf("Failed to query %s: %v", query, err)
}
@ -178,23 +182,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
);
`
if _, err := db.Exec(schema); err != nil {
//nolint:noinlineerr
if _, err := db.ExecContext(context.Background(), schema); err != nil {
t.Fatalf("Failed to create schema: %v", err)
}
// Insert parent record
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
//nolint:noinlineerr
if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
t.Fatalf("Failed to insert parent: %v", err)
}
// Test 1: Valid foreign key should work
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
_, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
if err != nil {
t.Fatalf("Valid foreign key insert failed: %v", err)
}
// Test 2: Invalid foreign key should fail
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
_, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
if err == nil {
t.Error("Expected foreign key constraint violation, but insert succeeded")
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
@ -204,7 +210,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
}
// Test 3: Deleting referenced parent should fail
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
_, err = db.ExecContext(context.Background(), "DELETE FROM parent WHERE id = 1")
if err == nil {
t.Error("Expected foreign key constraint violation when deleting referenced parent")
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
@ -249,7 +255,8 @@ func TestJournalModeValidation(t *testing.T) {
defer db.Close()
var actualMode string
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
if err != nil {
t.Fatalf("Failed to query journal_mode: %v", err)
}

View file

@ -3,12 +3,20 @@ package db
import (
"context"
"encoding"
"errors"
"fmt"
"reflect"
"gorm.io/gorm/schema"
)
// Sentinel errors for text serialisation.
var (
ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value")
ErrUnsupportedType = errors.New("unsupported type")
ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
)
// Got from https://github.com/xdg-go/strum/blob/main/types.go
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
@ -17,7 +25,7 @@ func isTextUnmarshaler(rv reflect.Value) bool {
}
func maybeInstantiatePtr(rv reflect.Value) {
if rv.Kind() == reflect.Ptr && rv.IsNil() {
if rv.Kind() == reflect.Pointer && rv.IsNil() {
np := reflect.New(rv.Type().Elem())
rv.Set(np)
}
@ -36,27 +44,30 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
// If the field is a pointer, we need to dereference it to get the actual type
// so we do not end with a second pointer.
if fieldValue.Elem().Kind() == reflect.Ptr {
if fieldValue.Elem().Kind() == reflect.Pointer {
fieldValue = fieldValue.Elem()
}
if dbValue != nil {
var bytes []byte
switch v := dbValue.(type) {
case []byte:
bytes = v
case string:
bytes = []byte(v)
default:
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
return fmt.Errorf("%w: %#v", ErrTextUnmarshalFailed, dbValue)
}
if isTextUnmarshaler(fieldValue) {
maybeInstantiatePtr(fieldValue)
f := fieldValue.MethodByName("UnmarshalText")
args := []reflect.Value{reflect.ValueOf(bytes)}
ret := f.Call(args)
if !ret[0].IsNil() {
//nolint:forcetypeassert
return decodingError(field.Name, ret[0].Interface().(error))
}
@ -65,7 +76,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
// If it is not a pointer, we need to assign the value to the
// field.
dstField := field.ReflectValueOf(ctx, dst)
if dstField.Kind() == reflect.Ptr {
if dstField.Kind() == reflect.Pointer {
dstField.Set(fieldValue)
} else {
dstField.Set(fieldValue.Elem())
@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
return nil
} else {
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
return fmt.Errorf("%w: %T", ErrUnsupportedType, fieldValue.Interface())
}
}
@ -86,9 +97,10 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
// If the value is nil, we return nil, however, go nil values are not
// always comparable, particularly when reflection is involved:
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
return nil, nil
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) {
return nil, nil //nolint:nilnil
}
b, err := v.MarshalText()
if err != nil {
return nil, err
@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
return string(b), nil
default:
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
return nil, fmt.Errorf("%w, got %T", ErrTextMarshalerOnly, v)
}
}

View file

@ -15,6 +15,8 @@ var (
ErrUserExists = errors.New("user already exists")
ErrUserNotFound = errors.New("user not found")
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
ErrTooManyWhereArgs = errors.New("expect 0 or 1 where User structs")
ErrMultipleUsers = errors.New("expected exactly one user")
)
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
@ -26,7 +28,8 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
// CreateUser creates a new User. Returns error if could not be created
// or another user already exists.
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
if err := util.ValidateHostname(user.Name); err != nil {
err := util.ValidateHostname(user.Name)
if err != nil {
return nil, err
}
if err := tx.Create(&user).Error; err != nil {
@ -88,10 +91,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
// not exist or if another User exists with the new name.
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
var err error
oldUser, err := GetUserByID(tx, uid)
if err != nil {
return err
}
//nolint:noinlineerr
if err = util.ValidateHostname(newName); err != nil {
return err
}
@ -151,7 +157,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
// ListUsers gets all the existing users.
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
if len(where) > 1 {
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
return nil, fmt.Errorf("%w, got %d", ErrTooManyWhereArgs, len(where))
}
var user *types.User
@ -160,7 +166,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
}
users := []types.User{}
if err := tx.Where(user).Find(&users).Error; err != nil {
err := tx.Where(user).Find(&users).Error
if err != nil {
return nil, err
}
@ -180,7 +188,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
}
if len(users) != 1 {
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
return nil, fmt.Errorf("%w, found %d", ErrMultipleUsers, len(users))
}
return &users[0], nil

View file

@ -8,7 +8,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
func TestCreateAndDestroyUser(t *testing.T) {
@ -71,6 +70,7 @@ func TestDestroyUserErrors(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
//nolint:staticcheck
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
@ -79,7 +79,7 @@ func TestDestroyUserErrors(t *testing.T) {
Hostname: "testnode",
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
AuthKeyID: new(pak.ID),
}
trx := db.DB.Save(&node)
require.NoError(t, trx.Error)

View file

@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
overview := h.state.DebugOverviewJSON()
overviewJSON, err := json.MarshalIndent(overview, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(overviewJSON)
_, _ = w.Write(overviewJSON)
} else {
// Default to text/plain for backward compatibility
overview := h.state.DebugOverview()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(overview))
_, _ = w.Write([]byte(overview))
}
}))
// Configuration endpoint
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
config := h.state.DebugConfig()
configJSON, err := json.MarshalIndent(config, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(configJSON)
_, _ = w.Write(configJSON)
}))
// Policy endpoint
@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server {
} else {
w.Header().Set("Content-Type", "text/plain")
}
w.WriteHeader(http.StatusOK)
w.Write([]byte(policy))
_, _ = w.Write([]byte(policy))
}))
// Filter rules endpoint
@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server {
httpError(w, err)
return
}
filterJSON, err := json.MarshalIndent(filter, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(filterJSON)
_, _ = w.Write(filterJSON)
}))
// SSH policies endpoint
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
sshPolicies := h.state.DebugSSHPolicies()
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(sshJSON)
_, _ = w.Write(sshJSON)
}))
// DERP map endpoint
@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
derpInfo := h.state.DebugDERPJSON()
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(derpJSON)
_, _ = w.Write(derpJSON)
} else {
// Default to text/plain for backward compatibility
derpInfo := h.state.DebugDERPMap()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(derpInfo))
_, _ = w.Write([]byte(derpInfo))
}
}))
@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
nodeStoreNodes := h.state.DebugNodeStoreJSON()
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(nodeStoreJSON)
_, _ = w.Write(nodeStoreJSON)
} else {
// Default to text/plain for backward compatibility
nodeStoreInfo := h.state.DebugNodeStore()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(nodeStoreInfo))
_, _ = w.Write([]byte(nodeStoreInfo))
}
}))
// Registration cache endpoint
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
cacheInfo := h.state.DebugRegistrationCache()
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(cacheJSON)
_, _ = w.Write(cacheJSON)
}))
// Routes endpoint
@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
routes := h.state.DebugRoutes()
routesJSON, err := json.MarshalIndent(routes, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(routesJSON)
_, _ = w.Write(routesJSON)
} else {
// Default to text/plain for backward compatibility
routes := h.state.DebugRoutesString()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(routes))
_, _ = w.Write([]byte(routes))
}
}))
@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if wantsJSON {
policyManagerInfo := h.state.DebugPolicyManagerJSON()
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
if err != nil {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(policyManagerJSON)
_, _ = w.Write(policyManagerJSON)
} else {
// Default to text/plain for backward compatibility
policyManagerInfo := h.state.DebugPolicyManager()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(policyManagerInfo))
_, _ = w.Write([]byte(policyManagerInfo))
}
}))
@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server {
if res == nil {
w.WriteHeader(http.StatusOK)
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
_, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
return
}
@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server {
httpError(w, err)
return
}
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(resJSON)
_, _ = w.Write(resJSON)
}))
// Batcher endpoint
@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
w.Write(batcherJSON)
_, _ = w.Write(batcherJSON)
} else {
// Default to text/plain for backward compatibility
batcherInfo := h.debugBatcher()
w.Header().Set("Content-Type", "text/plain")
w.WriteHeader(http.StatusOK)
w.Write([]byte(batcherInfo))
_, _ = w.Write([]byte(batcherInfo))
}
}))
@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
activeConnections: info.ActiveConnections,
})
totalNodes++
if info.Connected {
connectedCount++
}
@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
activeConnections: 0,
})
totalNodes++
if connected {
connectedCount++
}
return true
})
}
@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
ActiveConnections: 0,
}
info.TotalNodes++
return true
})
}

View file

@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
for id := range dm.Regions {
ids = append(ids, id)
}
slices.Sort(ids)
for _, id := range ids {
@ -160,16 +161,20 @@ func derpRandom() *rand.Rand {
derpRandomOnce.Do(func() {
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
//nolint:gosec
rnd := rand.New(rand.NewSource(0))
//nolint:gosec
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
derpRandomInst = rnd
})
return derpRandomInst
}
func resetDerpRandomForTesting() {
derpRandomMu.Lock()
defer derpRandomMu.Unlock()
derpRandomOnce = sync.Once{}
derpRandomInst = nil
}

View file

@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
viper.Set("dns.base_domain", tt.baseDomain)
defer viper.Reset()
resetDerpRandomForTesting()
testMap := tt.derpMap.View().AsStruct()

View file

@ -74,9 +74,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
if err != nil {
return tailcfg.DERPRegion{}, err
}
var host string
var port int
var portStr string
var (
host string
port int
portStr string
)
// Extract hostname and port from URL
host, portStr, err = net.SplitHostPort(serverURL.Host)
@ -97,12 +100,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
// If debug flag is set, resolve hostname to IP address
if debugUseDERPIP {
ips, err := net.LookupIP(host)
addrs, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
if err != nil {
log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host)
} else if len(ips) > 0 {
} else if len(addrs) > 0 {
// Use the first IP address
ipStr := ips[0].String()
ipStr := addrs[0].IP.String()
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr)
host = ipStr
}
@ -205,6 +208,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
return
}
defer websocketConn.Close(websocket.StatusInternalError, "closing")
if websocketConn.Subprotocol() != "derp" {
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
@ -309,6 +313,8 @@ func DERPBootstrapDNSHandler(
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
defer cancel()
var resolver net.Resolver
//nolint:unqueryvet
for _, region := range derpMap.Regions().All() {
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
@ -320,6 +326,7 @@ func DERPBootstrapDNSHandler(
continue
}
dnsEntries[node.HostName()] = addrs
}
}
@ -411,7 +418,9 @@ type DERPVerifyTransport struct {
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
buf := new(bytes.Buffer)
if err := t.handleVerifyRequest(req, buf); err != nil {
err := t.handleVerifyRequest(req, buf)
if err != nil {
log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ")
return nil, err

View file

@ -4,6 +4,7 @@ import (
"context"
"crypto/sha256"
"encoding/json"
"errors"
"fmt"
"os"
"sync"
@ -15,6 +16,9 @@ import (
"tailscale.com/util/set"
)
// ErrPathIsDirectory is returned when a path is a directory instead of a file.
var ErrPathIsDirectory = errors.New("path is a directory, only file is supported")
type ExtraRecordsMan struct {
mu sync.RWMutex
records set.Set[tailcfg.DNSRecord]
@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) {
}
if fi.IsDir() {
return nil, fmt.Errorf("path is a directory, only file is supported: %s", path)
return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path)
}
records, hash, err := readExtraRecordsFromPath(path)
@ -85,18 +89,22 @@ func (e *ExtraRecordsMan) Run() {
log.Error().Caller().Msgf("file watcher event channel closing")
return
}
switch event.Op {
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
if event.Name != e.path {
continue
}
e.updateRecords()
// If a file is removed or renamed, fsnotify will loose track of it
// and not watch it. We will therefore attempt to re-add it with a backoff.
case fsnotify.Remove, fsnotify.Rename:
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
//nolint:noinlineerr
if _, err := os.Stat(e.path); err != nil {
return struct{}{}, err
}
@ -123,6 +131,7 @@ func (e *ExtraRecordsMan) Run() {
log.Error().Caller().Msgf("file watcher error channel closing")
return
}
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
}
}
@ -165,6 +174,7 @@ func (e *ExtraRecordsMan) updateRecords() {
e.hashes[e.path] = newHash
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
e.updateCh <- e.records.Slice()
}
@ -183,6 +193,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
}
var records []tailcfg.DNSRecord
err = json.Unmarshal(b, &records)
if err != nil {
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)

View file

@ -4,6 +4,7 @@
package hscontrol
import (
"cmp"
"context"
"errors"
"fmt"
@ -11,7 +12,6 @@ import (
"net/netip"
"os"
"slices"
"sort"
"strings"
"time"
@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers(
response[index] = user.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.User) int {
return cmp.Compare(a.Id, b.Id)
})
return &v1.ListUsersResponse{Users: response}, nil
@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
response[index] = key.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.PreAuthKey) int {
return cmp.Compare(a.Id, b.Id)
})
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
newApproved = append(newApproved, prefix)
}
}
tsaddr.SortPrefixes(newApproved)
slices.SortFunc(newApproved, netip.Prefix.Compare)
newApproved = slices.Compact(newApproved)
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N
response[index] = resp
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.Node) int {
return cmp.Compare(a.Id, b.Id)
})
return response
@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys(
response[index] = key.Proto()
}
sort.Slice(response, func(i, j int) bool {
return response[i].Id < response[j].Id
slices.SortFunc(response, func(a, b *v1.ApiKey) int {
return cmp.Compare(a.Id, b.Id)
})
return &v1.ListApiKeysResponse{ApiKeys: response}, nil

View file

@ -36,8 +36,7 @@ const (
// httpError logs an error and sends an HTTP error response with the given.
func httpError(w http.ResponseWriter, err error) {
var herr HTTPError
if errors.As(err, &herr) {
if herr, ok := errors.AsType[HTTPError](err); ok {
http.Error(w, herr.Msg, herr.Code)
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
} else {
@ -56,7 +55,7 @@ type HTTPError struct {
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
func (e HTTPError) Unwrap() error { return e.Err }
// Error returns an HTTPError containing the given information.
// NewHTTPError returns an HTTPError containing the given information.
func NewHTTPError(code int, msg string, err error) HTTPError {
return HTTPError{Code: code, Msg: msg, Err: err}
}
@ -92,6 +91,7 @@ func (h *Headscale) handleVerifyRequest(
}
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
//nolint:noinlineerr
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
}
@ -155,7 +155,11 @@ func (h *Headscale) KeyHandler(
}
writer.Header().Set("Content-Type", "application/json")
json.NewEncoder(writer).Encode(resp)
err := json.NewEncoder(writer).Encode(resp)
if err != nil {
log.Error().Err(err).Msg("failed to encode key response")
}
return
}
@ -180,8 +184,12 @@ func (h *Headscale) HealthHandler(
res.Status = "fail"
}
json.NewEncoder(writer).Encode(res)
//nolint:noinlineerr
if err := json.NewEncoder(writer).Encode(res); err != nil {
log.Error().Err(err).Msg("failed to encode health response")
}
}
err := h.state.PingDB(req.Context())
if err != nil {
respond(err)
@ -218,6 +226,7 @@ func (h *Headscale) VersionHandler(
writer.WriteHeader(http.StatusOK)
versionInfo := types.GetVersionInfo()
err := json.NewEncoder(writer).Encode(versionInfo)
if err != nil {
log.Error().
@ -267,7 +276,7 @@ func (a *AuthProviderWeb) RegisterHandler(
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
_, _ = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
}
func FaviconHandler(writer http.ResponseWriter, req *http.Request) {

View file

@ -15,6 +15,17 @@ import (
"tailscale.com/tailcfg"
)
// Sentinel errors for batcher operations.
var (
ErrInvalidNodeID = errors.New("invalid nodeID")
ErrMapperNil = errors.New("mapper is nil")
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
)
// workChannelMultiplier is the multiplier for work channel capacity based on worker count.
// The size is arbitrary chosen, the sizing should be revisited.
const workChannelMultiplier = 200
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
@ -42,8 +53,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
workers: workers,
tick: time.NewTicker(batchTime),
// The size of this channel is arbitrary chosen, the sizing should be revisited.
workCh: make(chan work, workers*200),
workCh: make(chan work, workers*workChannelMultiplier),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
@ -76,20 +86,20 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
version := nc.version()
if r.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send
return nil, nil //nolint:nilnil
}
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
}
if mapper == nil {
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
}
// Handle self-only responses
if r.IsSelfOnly() && r.TargetNode != nodeID {
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
return nil, nil //nolint:nilnil
}
// Check if this is a self-update (the changed node is the receiving node).
@ -135,7 +145,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return errors.New("nodeConnection is nil")
return ErrNodeConnectionNil
}
nodeID := nc.nodeID()

View file

@ -2,6 +2,7 @@ package mapper
import (
"crypto/rand"
"encoding/hex"
"errors"
"fmt"
"sync"
@ -13,10 +14,35 @@ import (
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var errConnectionClosed = errors.New("connection channel already closed")
// Sentinel errors for lock-free batcher operations.
var (
errConnectionClosed = errors.New("connection channel already closed")
ErrInitialMapTimeout = errors.New("failed to send initial map: timeout")
ErrNodeNotFound = errors.New("node not found")
ErrBatcherShutdown = errors.New("batcher shutting down")
ErrConnectionTimeout = errors.New("connection timeout sending to channel (likely stale connection)")
)
// Batcher configuration constants.
const (
// initialMapSendTimeout is the timeout for sending the initial map response to a new connection.
initialMapSendTimeout = 5 * time.Second
// offlineNodeCleanupThreshold is how long a node must be offline before it's cleaned up.
offlineNodeCleanupThreshold = 15 * time.Minute
// offlineNodeCleanupInterval is the interval between cleanup runs.
offlineNodeCleanupInterval = 5 * time.Minute
// connectionSendTimeout is the timeout for detecting stale connections.
// Kept short to quickly detect Docker containers that are forcefully terminated.
connectionSendTimeout = 50 * time.Millisecond
// connectionIDBytes is the number of random bytes used for connection IDs.
connectionIDBytes = 8
)
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
@ -78,6 +104,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
if err != nil {
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
}
@ -86,12 +113,13 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
select {
case c <- initialMap:
// Success
case <-time.After(5 * time.Second):
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
case <-time.After(initialMapSendTimeout):
log.Error().Uint64("node.id", id.Uint64()).Err(ErrInitialMapTimeout).Msg("Initial map send timeout")
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", initialMapSendTimeout).
Msg("Initial map send timed out because channel was blocked or receiver not ready")
nodeConn.removeConnectionByChannel(c)
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
return fmt.Errorf("%w for node %d", ErrInitialMapTimeout, id)
}
// Update connection status
@ -130,13 +158,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
log.Debug().Caller().Uint64("node.id", id.Uint64()).
Int("active.connections", nodeConn.getActiveConnectionCount()).
Msg("Node connection removed but keeping online because other connections remain")
return true // Node still has active connections
}
// No active connections - keep the node entry alive for rapid reconnections
// The node will get a fresh full map when it reconnects
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
b.connected.Store(id, ptr.To(time.Now()))
b.connected.Store(id, new(time.Now()))
return false
}
@ -177,7 +206,7 @@ func (b *LockFreeBatcher) doWork() {
}
// Create a cleanup ticker for removing truly disconnected nodes
cleanupTicker := time.NewTicker(5 * time.Minute)
cleanupTicker := time.NewTicker(offlineNodeCleanupInterval)
defer cleanupTicker.Stop()
for {
@ -212,10 +241,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
// This is used for synchronous map generation.
if w.resultCh != nil {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
var err error
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
@ -229,7 +260,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
nc.updateSentPeers(result.mapResponse)
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
result.err = fmt.Errorf("%w: %d", ErrNodeNotFound, w.nodeID)
b.workErrors.Add(1)
log.Error().Err(result.err).
@ -383,14 +414,13 @@ func (b *LockFreeBatcher) processBatchedChanges() {
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
// TODO(kradalby): reevaluate if we want to keep this.
func (b *LockFreeBatcher) cleanupOfflineNodes() {
cleanupThreshold := 15 * time.Minute
now := time.Now()
var nodesToCleanup []types.NodeID
// Find nodes that have been offline for too long
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
if disconnectTime != nil && now.Sub(*disconnectTime) > offlineNodeCleanupThreshold {
// Double-check the node doesn't have active connections
if nodeConn, exists := b.nodes.Load(nodeID); exists {
if !nodeConn.hasActiveConnections() {
@ -398,13 +428,14 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
}
}
}
return true
})
// Clean up the identified nodes
for _, nodeID := range nodesToCleanup {
log.Info().Uint64("node.id", nodeID.Uint64()).
Dur("offline_duration", cleanupThreshold).
Dur("offline_duration", offlineNodeCleanupThreshold).
Msg("Cleaning up node that has been offline for too long")
b.nodes.Delete(nodeID)
@ -450,6 +481,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
if nodeConn.hasActiveConnections() {
ret.Store(id, true)
}
return true
})
@ -465,6 +497,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
ret.Store(id, false)
}
}
return true
})
@ -484,7 +517,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.done:
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
return nil, fmt.Errorf("%w: generating map response for node %d", ErrBatcherShutdown, id)
}
}
@ -517,9 +550,10 @@ type multiChannelNodeConn struct {
// generateConnectionID generates a unique connection identifier.
func generateConnectionID() string {
bytes := make([]byte, 8)
rand.Read(bytes)
return fmt.Sprintf("%x", bytes)
bytes := make([]byte, connectionIDBytes)
_, _ = rand.Read(bytes)
return hex.EncodeToString(bytes)
}
// newMultiChannelNodeConn creates a new multi-channel node connection.
@ -546,11 +580,14 @@ func (mc *multiChannelNodeConn) close() {
// addConnection adds a new connection.
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
mutexWaitStart := time.Now()
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
mc.mutex.Lock()
mutexWaitDur := time.Since(mutexWaitStart)
defer mc.mutex.Unlock()
mc.connections = append(mc.connections, entry)
@ -572,9 +609,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
Int("remaining_connections", len(mc.connections)).
Msg("Successfully removed connection")
return true
}
}
return false
}
@ -608,6 +647,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
// This is not an error - the node will receive a full map when it reconnects
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
return nil // Return success instead of error
}
@ -616,7 +656,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Msg("send: broadcasting to all connections")
var lastErr error
successCount := 0
var failedConnections []int // Track failed connections for removal
// Send to all connections
@ -625,8 +667,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: attempting to send to connection")
if err := conn.send(data); err != nil {
err := conn.send(data)
if err != nil {
lastErr = err
failedConnections = append(failedConnections, i)
log.Warn().Err(err).
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
@ -634,6 +678,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
Msg("send: connection send failed")
} else {
successCount++
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
Str("conn.id", conn.id).Int("connection_index", i).
Msg("send: successfully sent to connection")
@ -685,10 +730,10 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
// Update last used timestamp on successful send
entry.lastUsed.Store(time.Now().Unix())
return nil
case <-time.After(50 * time.Millisecond):
case <-time.After(connectionSendTimeout):
// Connection is likely stale - client isn't reading from channel
// This catches the case where Docker containers are killed but channels remain open
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
return fmt.Errorf("%w: connection %s", ErrConnectionTimeout, entry.id)
}
}
@ -798,6 +843,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
Connected: connected,
ActiveConnections: activeConnCount,
}
return true
})
@ -812,6 +858,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
ActiveConnections: 0,
}
}
return true
})

View file

@ -5,6 +5,7 @@ import (
"fmt"
"net/netip"
"runtime"
"slices"
"strings"
"sync"
"sync/atomic"
@ -35,6 +36,7 @@ type batcherTestCase struct {
// that would normally be sent by poll.go in production.
type testBatcherWrapper struct {
Batcher
state *state.State
}
@ -80,12 +82,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
}
// Finally remove from the real batcher
removed := t.Batcher.RemoveNode(id, c)
if !removed {
return false
}
return true
return t.Batcher.RemoveNode(id, c)
}
// wrapBatcherForTest wraps a batcher with test-specific behavior.
@ -129,15 +126,13 @@ const (
SMALL_BUFFER_SIZE = 3
TINY_BUFFER_SIZE = 1 // For maximum contention
LARGE_BUFFER_SIZE = 200
reservedResponseHeaderSize = 4
)
// TestData contains all test entities created for a test scenario.
type TestData struct {
Database *db.HSDatabase
Users []*types.User
Nodes []node
Nodes []*node
State *state.State
Config *types.Config
Batcher Batcher
@ -223,11 +218,11 @@ func setupBatcherWithTestData(
// Create test users and nodes in the database
users := database.CreateUsersForTest(userCount, "testuser")
allNodes := make([]node, 0, userCount*nodesPerUser)
allNodes := make([]*node, 0, userCount*nodesPerUser)
for _, user := range users {
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
for i := range dbNodes {
allNodes = append(allNodes, node{
allNodes = append(allNodes, &node{
n: dbNodes[i],
ch: make(chan *tailcfg.MapResponse, bufferSize),
})
@ -241,8 +236,8 @@ func setupBatcherWithTestData(
}
derpMap, err := derp.GetDERPMap(cfg.DERP)
assert.NoError(t, err)
assert.NotNil(t, derpMap)
require.NoError(t, err)
require.NotNil(t, derpMap)
state.SetDERPMap(derpMap)
@ -318,23 +313,6 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
stats.LastUpdate = time.Now()
}
// getStats returns a copy of the statistics for a node.
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
ut.mu.RLock()
defer ut.mu.RUnlock()
if stats, exists := ut.stats[nodeID]; exists {
// Return a copy to avoid race conditions
return UpdateStats{
TotalUpdates: stats.TotalUpdates,
UpdateSizes: append([]int{}, stats.UpdateSizes...),
LastUpdate: stats.LastUpdate,
}
}
return UpdateStats{}
}
// getAllStats returns a copy of all statistics.
func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats {
ut.mu.RLock()
@ -344,7 +322,7 @@ func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats {
for nodeID, stats := range ut.stats {
result[nodeID] = UpdateStats{
TotalUpdates: stats.TotalUpdates,
UpdateSizes: append([]int{}, stats.UpdateSizes...),
UpdateSizes: slices.Clone(stats.UpdateSizes),
LastUpdate: stats.LastUpdate,
}
}
@ -386,16 +364,14 @@ type UpdateInfo struct {
}
// parseUpdateAndAnalyze parses an update and returns detailed information.
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
info := UpdateInfo{
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
return UpdateInfo{
PeerCount: len(resp.Peers),
PatchCount: len(resp.PeersChangedPatch),
IsFull: len(resp.Peers) > 0,
IsPatch: len(resp.PeersChangedPatch) > 0,
IsDERP: resp.DERPMap != nil,
}
return info, nil
}
// start begins consuming updates from the node's channel and tracking stats.
@ -417,36 +393,36 @@ func (n *node) start() {
atomic.AddInt64(&n.updateCount, 1)
// Parse update and track detailed stats
if info, err := parseUpdateAndAnalyze(data); err == nil {
// Track update types
if info.IsFull {
atomic.AddInt64(&n.fullCount, 1)
n.lastPeerCount.Store(int64(info.PeerCount))
// Update max peers seen using compare-and-swap for thread safety
for {
current := n.maxPeersCount.Load()
if int64(info.PeerCount) <= current {
break
}
info := parseUpdateAndAnalyze(data)
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
break
}
// Track update types
if info.IsFull {
atomic.AddInt64(&n.fullCount, 1)
n.lastPeerCount.Store(int64(info.PeerCount))
// Update max peers seen using compare-and-swap for thread safety
for {
current := n.maxPeersCount.Load()
if int64(info.PeerCount) <= current {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
break
}
}
}
if info.IsPatch {
atomic.AddInt64(&n.patchCount, 1)
// For patches, we track how many patch items using compare-and-swap
for {
current := n.maxPeersCount.Load()
if int64(info.PatchCount) <= current {
break
}
if info.IsPatch {
atomic.AddInt64(&n.patchCount, 1)
// For patches, we track how many patch items using compare-and-swap
for {
current := n.maxPeersCount.Load()
if int64(info.PatchCount) <= current {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
break
}
}
}
@ -540,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
defer cleanup()
batcher := testData.Batcher
testNode := &testData.Nodes[0]
testNode := testData.Nodes[0]
t.Logf("Testing enhanced tracking with node ID %d", testNode.n.ID)
@ -548,7 +524,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
testNode.start()
// Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
// Wait for connection to be established
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -656,8 +632,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
for i := range allNodes {
node := &allNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
node := allNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullUpdate())
@ -676,8 +652,9 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
assert.EventuallyWithT(t, func(c *assert.CollectT) {
connectedCount := 0
for i := range allNodes {
node := &allNodes[i]
node := allNodes[i]
currentMaxPeers := int(node.maxPeersCount.Load())
if currentMaxPeers >= expectedPeers {
@ -693,11 +670,12 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
t.Logf("✅ All nodes achieved full connectivity!")
totalTime := time.Since(startTime)
// Disconnect all nodes
for i := range allNodes {
node := &allNodes[i]
node := allNodes[i]
batcher.RemoveNode(node.n.ID, node.ch)
}
@ -718,7 +696,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
nodeDetails := make([]string, 0, min(10, len(allNodes)))
for i := range allNodes {
node := &allNodes[i]
node := allNodes[i]
stats := node.cleanup()
totalUpdates += stats.TotalUpdates
@ -824,7 +802,7 @@ func TestBatcherBasicOperations(t *testing.T) {
tn2 := testData.Nodes[1]
// Test AddNode with real node ID
batcher.AddNode(tn.n.ID, tn.ch, 100)
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
if !batcher.IsConnected(tn.n.ID) {
t.Error("Node should be connected after AddNode")
@ -845,7 +823,7 @@ func TestBatcherBasicOperations(t *testing.T) {
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
// Add the second node and verify update message
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
assert.True(t, batcher.IsConnected(tn2.n.ID))
// First node should get an update that second node has connected.
@ -911,7 +889,7 @@ func TestBatcherBasicOperations(t *testing.T) {
}
}
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, _ string, timeout time.Duration) {
count := 0
timer := time.NewTimer(timeout)
@ -1050,7 +1028,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
testNodes := testData.Nodes
ch := make(chan *tailcfg.MapResponse, 10)
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
// Track update content for validation
var receivedUpdates []*tailcfg.MapResponse
@ -1130,6 +1108,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
// The test verifies that channels are closed synchronously and deterministically
// even when real node updates are being processed, ensuring no race conditions
// occur during channel replacement with actual workload.
//
func XTestBatcherChannelClosingRace(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
@ -1154,7 +1134,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
ch1 := make(chan *tailcfg.MapResponse, 1)
wg.Go(func() {
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
})
// Add real work during connection chaos
@ -1167,7 +1147,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
wg.Go(func() {
runtime.Gosched() // Yield to introduce timing variability
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
})
// Remove second connection
@ -1258,7 +1239,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, 5)
// Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPMap())
// Consumer goroutine to validate data and detect channel issues
@ -1308,6 +1289,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
for range i % 3 {
runtime.Gosched() // Introduce timing variability
}
batcher.RemoveNode(testNode.n.ID, ch)
// Yield to allow workers to process and close channels
@ -1350,6 +1332,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// real node data. The test validates that stable clients continue to function
// normally and receive proper updates despite the connection churn from other clients,
// ensuring system stability under concurrent load.
//
//nolint:gocyclo
func TestBatcherConcurrentClients(t *testing.T) {
if testing.Short() {
t.Skip("Skipping concurrent client test in short mode")
@ -1380,7 +1364,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for _, node := range stableNodes {
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
stableChannels[node.n.ID] = ch
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
// Monitor updates for each stable client
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
@ -1391,6 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully
return
}
if valid, reason := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@ -1448,10 +1433,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
churningChannelsMutex.Lock()
churningChannels[nodeID] = ch
churningChannelsMutex.Unlock()
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
// Consume updates to prevent blocking
go func() {
@ -1462,6 +1449,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Channel was closed, exit gracefully
return
}
if valid, _ := validateUpdateContent(data); valid {
tracker.recordUpdate(
nodeID,
@ -1494,6 +1482,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
for range i % 5 {
runtime.Gosched() // Introduce timing variability
}
churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID]
@ -1623,6 +1612,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
// It validates that the system remains stable with no deadlocks, panics, or
// missed updates under sustained high load. The test uses real node data to
// generate authentic update scenarios and tracks comprehensive statistics.
//
//nolint:gocyclo,thelper
func XTestBatcherScalability(t *testing.T) {
if testing.Short() {
t.Skip("Skipping scalability test in short mode")
@ -1651,6 +1642,7 @@ func XTestBatcherScalability(t *testing.T) {
description string
}
//nolint:prealloc
var testCases []testCase
// Generate all combinations of the test matrix
@ -1761,8 +1753,9 @@ func XTestBatcherScalability(t *testing.T) {
var connectedNodesMutex sync.RWMutex
for i := range testNodes {
node := &testNodes[i]
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
node := testNodes[i]
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
connectedNodesMutex.Lock()
connectedNodes[node.n.ID] = true
@ -1878,6 +1871,7 @@ func XTestBatcherScalability(t *testing.T) {
channel,
tailcfg.CapabilityVersion(100),
)
connectedNodesMutex.Lock()
connectedNodes[nodeID] = true
@ -1991,7 +1985,7 @@ func XTestBatcherScalability(t *testing.T) {
// Now disconnect all nodes from batcher to stop new updates
for i := range testNodes {
node := &testNodes[i]
node := testNodes[i]
batcher.RemoveNode(node.n.ID, node.ch)
}
@ -2010,7 +2004,7 @@ func XTestBatcherScalability(t *testing.T) {
nodeStatsReport := make([]string, 0, len(testNodes))
for i := range testNodes {
node := &testNodes[i]
node := testNodes[i]
stats := node.cleanup()
totalUpdates += stats.TotalUpdates
totalPatches += stats.PatchUpdates
@ -2139,7 +2133,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Connect nodes one at a time and wait for each to be connected
for i, node := range allNodes {
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Wait for node to be connected
@ -2286,6 +2280,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 1: Connect all nodes initially
t.Logf("Phase 1: Connecting all nodes...")
for i, node := range allNodes {
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
if err != nil {
@ -2302,6 +2297,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
for i, node := range allNodes {
removed := batcher.RemoveNode(node.n.ID, node.ch)
t.Logf("Node %d RemoveNode result: %t", i, removed)
@ -2309,9 +2305,11 @@ func TestBatcherRapidReconnection(t *testing.T) {
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
t.Logf("Phase 3: Rapid reconnect with new channels...")
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
for i, node := range allNodes {
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
if err != nil {
t.Errorf("Failed to reconnect node %d: %v", i, err)
@ -2342,11 +2340,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
if infoMap, ok := info.(map[string]any); ok {
if connected, ok := infoMap["connected"].(bool); ok && !connected {
disconnectedCount++
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
}
}
} else {
disconnectedCount++
t.Logf("Node %d missing from debug info entirely", i)
}
@ -2381,6 +2381,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
case update := <-newChannels[i]:
if update != nil {
receivedCount++
t.Logf("Node %d received update successfully", i)
}
case <-timeout:
@ -2399,6 +2400,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
}
}
//nolint:gocyclo
func TestBatcherMultiConnection(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
@ -2413,6 +2415,7 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 1: Connect first node with initial connection
t.Logf("Phase 1: Connecting node 1 with first connection...")
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add node1: %v", err)
@ -2432,7 +2435,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
secondChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add second connection for node1: %v", err)
@ -2443,7 +2448,9 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
thirdChannel := make(chan *tailcfg.MapResponse, 10)
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
if err != nil {
t.Fatalf("Failed to add third connection for node1: %v", err)
@ -2454,6 +2461,7 @@ func TestBatcherMultiConnection(t *testing.T) {
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
if debugBatcher, ok := batcher.(interface {
Debug() map[types.NodeID]any
}); ok {
@ -2461,6 +2469,7 @@ func TestBatcherMultiConnection(t *testing.T) {
if info, exists := debugInfo[node1.n.ID]; exists {
t.Logf("Node1 debug info: %+v", info)
if infoMap, ok := info.(map[string]any); ok {
if activeConnections, ok := infoMap["active_connections"].(int); ok {
if activeConnections != 3 {
@ -2469,6 +2478,7 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
}
}
if connected, ok := infoMap["connected"].(bool); ok && !connected {
t.Errorf("Node1 should show as connected with 3 active connections")
}
@ -2651,9 +2661,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) {
batcher := testData.Batcher
st := testData.State
node1 := &testData.Nodes[0]
node2 := &testData.Nodes[1]
node3 := &testData.Nodes[2]
node1 := testData.Nodes[0]
node2 := testData.Nodes[1]
node3 := testData.Nodes[2]
t.Logf("Testing issue #2924: Node1=%d, Node2=%d, Node3=%d",
node1.n.ID, node2.n.ID, node3.n.ID)

View file

@ -1,9 +1,9 @@
package mapper
import (
"errors"
"cmp"
"net/netip"
"sort"
"slices"
"time"
"github.com/juanfont/headscale/hscontrol/policy"
@ -36,6 +36,7 @@ const (
// NewMapResponseBuilder creates a new builder with basic fields set.
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
now := time.Now()
return &MapResponseBuilder{
resp: &tailcfg.MapResponse{
KeepAlive: false,
@ -69,7 +70,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFound)
return b
}
@ -123,6 +124,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
b.resp.Debug = &tailcfg.Debug{
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
}
return b
}
@ -130,7 +132,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFound)
return b
}
@ -149,7 +151,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFound)
return b
}
@ -162,7 +164,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFound)
return b
}
@ -175,7 +177,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
b.addError(errors.New("node not found"))
b.addError(ErrNodeNotFound)
return b
}
@ -229,7 +231,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
if !ok {
return nil, errors.New("node not found")
return nil, ErrNodeNotFound
}
// Get unreduced matchers for peer relationship determination.
@ -261,8 +263,8 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
}
// Peers is always returned sorted by Node.ID.
sort.SliceStable(tailPeers, func(x, y int) bool {
return tailPeers[x].ID < tailPeers[y].ID
slices.SortStableFunc(tailPeers, func(a, b *tailcfg.Node) int {
return cmp.Compare(a.ID, b.ID)
})
return tailPeers, nil
@ -276,20 +278,23 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
// WithPeersRemoved adds removed peer IDs.
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
//nolint:prealloc
var tailscaleIDs []tailcfg.NodeID
for _, id := range removedIDs {
tailscaleIDs = append(tailscaleIDs, id.NodeID())
}
b.resp.PeersRemoved = tailscaleIDs
return b
}
// Build finalizes the response and returns marshaled bytes
// Build finalizes the response and returns marshaled bytes.
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
if len(b.errs) > 0 {
return nil, multierr.New(b.errs...)
}
if debugDumpMapResponsePath != "" {
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
}

View file

@ -340,7 +340,7 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
// Build should return a multierr
data, err := result.Build()
assert.Nil(t, data)
assert.Error(t, err)
require.Error(t, err)
// The error should contain information about multiple errors
assert.Contains(t, err.Error(), "multiple errors")

View file

@ -60,7 +60,6 @@ func newMapper(
state *state.State,
) *mapper {
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
return &mapper{
state: state,
cfg: cfg,
@ -80,6 +79,7 @@ func generateUserProfiles(
userID := user.Model().ID
userMap[userID] = &user
ids = append(ids, userID)
for _, peer := range peers.All() {
peerUser := peer.Owner()
peerUserID := peerUser.Model().ID
@ -90,6 +90,7 @@ func generateUserProfiles(
slices.Sort(ids)
ids = slices.Compact(ids)
var profiles []tailcfg.UserProfile
for _, id := range ids {
if userMap[id] != nil {
profiles = append(profiles, userMap[id].TailscaleUserProfile())
@ -138,29 +139,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
}
}
// fullMapResponse returns a MapResponse for the given node.
func (m *mapper) fullMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
peers := m.state.ListPeers(nodeID)
return m.NewMapResponseBuilder(nodeID).
WithDebugType(fullResponseDebug).
WithCapabilityVersion(capVer).
WithSelfNode().
WithDERPMap().
WithDomain().
WithCollectServicesDisabled().
WithDebugConfig().
WithSSHPolicy().
WithDNSConfig().
WithUserProfiles(peers).
WithPacketFilters().
WithPeers(peers).
Build()
}
func (m *mapper) selfMapResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
@ -214,7 +192,7 @@ func (m *mapper) policyChangeResponse(
// Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved
removedIDs := make([]types.NodeID, len(removedPeers))
for i, id := range removedPeers {
removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent
removedIDs[i] = types.NodeID(id) //nolint:gosec
}
builder.WithPeersRemoved(removedIDs...)
@ -237,7 +215,7 @@ func (m *mapper) buildFromChange(
resp *change.Change,
) (*tailcfg.MapResponse, error) {
if resp.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error
return nil, nil //nolint:nilnil
}
// If this is a self-update (the changed node is the receiving node),
@ -306,6 +284,7 @@ func writeDebugMapResponse(
perms := fs.FileMode(debugMapResponsePerm)
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
err = os.MkdirAll(mPath, perms)
if err != nil {
panic(err)
@ -319,6 +298,7 @@ func writeDebugMapResponse(
)
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
err = os.WriteFile(mapResponsePath, body, perms)
if err != nil {
panic(err)
@ -327,7 +307,7 @@ func writeDebugMapResponse(
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
if debugDumpMapResponsePath == "" {
return nil, nil
return nil, nil //nolint:nilnil
}
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
@ -375,6 +355,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
}
var resp tailcfg.MapResponse
err = json.Unmarshal(body, &resp)
if err != nil {
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())

View file

@ -3,18 +3,13 @@ package mapper
import (
"fmt"
"net/netip"
"slices"
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/policy/matcher"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/ptr"
)
var iap = func(ipStr string) *netip.Addr {
@ -51,7 +46,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{
Hostname: hostname,
UserID: ptr.To(userid),
UserID: new(userid),
User: &types.User{
Name: username,
},
@ -82,86 +77,6 @@ func TestDNSConfigMapResponse(t *testing.T) {
}
}
// mockState is a mock implementation that provides the required methods.
type mockState struct {
polMan policy.PolicyManager
derpMap *tailcfg.DERPMap
primary *routes.PrimaryRoutes
nodes types.Nodes
peers types.Nodes
}
func (m *mockState) DERPMap() *tailcfg.DERPMap {
return m.derpMap
}
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
if m.polMan == nil {
return tailcfg.FilterAllowAll, nil
}
return m.polMan.Filter()
}
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
if m.polMan == nil {
return nil, nil
}
return m.polMan.SSHPolicy(node)
}
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
if m.polMan == nil {
return false
}
return m.polMan.NodeCanHaveTag(node, tag)
}
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
if m.primary == nil {
return nil
}
return m.primary.PrimaryRoutes(nodeID)
}
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
if len(peerIDs) > 0 {
// Filter peers by the provided IDs
var filtered types.Nodes
for _, peer := range m.peers {
if slices.Contains(peerIDs, peer.ID) {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
// Return all peers except the node itself
var filtered types.Nodes
for _, peer := range m.peers {
if peer.ID != nodeID {
filtered = append(filtered, peer)
}
}
return filtered, nil
}
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
if len(nodeIDs) > 0 {
// Filter nodes by the provided IDs
var filtered types.Nodes
for _, node := range m.nodes {
if slices.Contains(nodeIDs, node.ID) {
filtered = append(filtered, node)
}
}
return filtered, nil
}
return m.nodes, nil
}
func Test_fullMapResponse(t *testing.T) {
t.Skip("Test needs to be refactored for new state-based architecture")
// TODO: Refactor this test to work with the new state-based mapper

View file

@ -13,7 +13,6 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestTailNode(t *testing.T) {
@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) {
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: ptr.To(uint(0)),
UserID: new(uint(0)),
User: &types.User{
Name: "mini",
},
@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) {
),
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
AllowedIPs: []netip.Prefix{
tsaddr.AllIPv4(),
netip.MustParsePrefix("192.168.0.0/24"),
netip.MustParsePrefix("100.64.0.1/32"),
tsaddr.AllIPv6(),
tsaddr.AllIPv4(), // 0.0.0.0/0
netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4
netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4
tsaddr.AllIPv6(), // ::/0 (IPv6)
},
PrimaryRoutes: []netip.Prefix{
netip.MustParsePrefix("192.168.0.0/24"),

View file

@ -31,6 +31,9 @@ const (
earlyPayloadMagic = "\xff\xff\xffTS"
)
// ErrUnsupportedClientVersion is returned when a client version is not supported.
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
type noiseServer struct {
headscale *Headscale
@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler(
}
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version)
return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
}
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
@ -241,13 +244,17 @@ func (ns *noiseServer) NoiseRegistrationHandler(
return
}
//nolint:contextcheck
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
var resp *tailcfg.RegisterResponse
body, err := io.ReadAll(req.Body)
if err != nil {
return &tailcfg.RegisterRequest{}, regErr(err)
}
var regReq tailcfg.RegisterRequest
//nolint:noinlineerr
if err := json.Unmarshal(body, &regReq); err != nil {
return &regReq, regErr(err)
}
@ -256,11 +263,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
if err != nil {
var httpErr HTTPError
if errors.As(err, &httpErr) {
if httpErr, ok := errors.AsType[HTTPError](err); ok {
resp = &tailcfg.RegisterResponse{
Error: httpErr.Msg,
}
return &regReq, resp
}
@ -278,7 +285,8 @@ func (ns *noiseServer) NoiseRegistrationHandler(
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
writer.WriteHeader(http.StatusOK)
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
err := json.NewEncoder(writer).Encode(registerResponse)
if err != nil {
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
return
}

View file

@ -28,6 +28,7 @@ const (
defaultOAuthOptionsCount = 3
registerCacheExpiration = time.Minute * 15
registerCacheCleanup = time.Minute * 20
csrfTokenLength = 64
)
var (
@ -68,6 +69,7 @@ func NewAuthProviderOIDC(
) (*AuthProviderOIDC, error) {
var err error
// grab oidc config if it hasn't been already
//nolint:contextcheck
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
if err != nil {
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
@ -163,6 +165,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
for k, v := range a.cfg.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
extras = append(extras, oidc.Nonce(nonce))
// Cache the registration info
@ -190,6 +193,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
}
stateCookieName := getCookieName("state", state)
cookieState, err := req.Cookie(stateCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
@ -212,17 +216,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
httpError(writer, err)
return
}
if idToken.Nonce == "" {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
return
}
nonceCookieName := getCookieName("nonce", idToken.Nonce)
nonce, err := req.Cookie(nonceCookieName)
if err != nil {
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
return
}
if idToken.Nonce != nonce.Value {
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
return
@ -231,6 +238,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
var claims types.OIDCClaims
//nolint:noinlineerr
if err := idToken.Claims(&claims); err != nil {
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
return
@ -239,6 +247,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
var userinfo *oidc.UserInfo
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
if err != nil {
util.LogErr(err, "could not get userinfo; only using claims from id token")
@ -255,6 +264,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
if userinfo2.Groups != nil {
claims.Groups = userinfo2.Groups
@ -279,6 +289,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
Msgf("could not create or update user")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusInternalServerError)
_, werr := writer.Write([]byte("Could not create or update user"))
if werr != nil {
log.Error().
@ -299,6 +310,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// Register the node if it does not exist.
if registrationId != nil {
verb := "Reauthenticated"
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
if err != nil {
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
@ -307,7 +319,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
return
}
httpError(writer, err)
return
}
@ -324,6 +338,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
//nolint:noinlineerr
if _, err := writer.Write(content.Bytes()); err != nil {
util.LogErr(err, "Failed to write HTTP response")
}
@ -370,6 +386,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
if !ok {
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
}
if regInfo.Verifier != nil {
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
}
@ -516,6 +533,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
newUser bool
c change.Change
)
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
@ -600,7 +618,7 @@ func getCookieName(baseName, value string) string {
}
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
val, err := util.GenerateRandomStringURLSafe(64)
val, err := util.GenerateRandomStringURLSafe(csrfTokenLength)
if err != nil {
return val, err
}

View file

@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
_, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
}
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage(
) {
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
writer.WriteHeader(http.StatusOK)
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
_, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
}
func (h *Headscale) ApplePlatformConfig(
@ -98,7 +98,7 @@ func (h *Headscale) ApplePlatformConfig(
writer.Header().
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
writer.WriteHeader(http.StatusOK)
writer.Write(content.Bytes())
_, _ = writer.Write(content.Bytes())
}
type AppleMobileConfig struct {

View file

@ -21,10 +21,13 @@ func (m Match) DebugString() string {
sb.WriteString("Match:\n")
sb.WriteString(" Sources:\n")
for _, prefix := range m.srcs.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}
sb.WriteString(" Destinations:\n")
for _, prefix := range m.dests.Prefixes() {
sb.WriteString(" " + prefix.String() + "\n")
}

View file

@ -19,18 +19,18 @@ type PolicyManager interface {
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
// BuildPeerMap constructs peer relationship maps for the given nodes
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy([]byte) (bool, error)
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
SetPolicy(data []byte) (bool, error)
SetUsers(users []types.User) (bool, error)
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
// NodeCanHaveTag reports whether the given node can have the given tag.
NodeCanHaveTag(types.NodeView, string) bool
NodeCanHaveTag(node types.NodeView, tag string) bool
// TagExists reports whether the given tag is defined in the policy.
TagExists(tag string) bool
// NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool
Version() int
DebugString() string
@ -38,8 +38,11 @@ type PolicyManager interface {
// NewPolicyManager returns a new policy manager.
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
var polMan PolicyManager
var err error
var (
polMan PolicyManager
err error
)
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
if err != nil {
return nil, err
@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
if err != nil {
return nil, err
}
polMans = append(polMans, pm)
}
@ -66,6 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
}
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
//nolint:prealloc
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {

View file

@ -9,7 +9,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/types/views"
)
@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
}
// Sort and deduplicate
tsaddr.SortPrefixes(newApproved)
slices.SortFunc(newApproved, netip.Prefix.Compare)
newApproved = slices.Compact(newApproved)
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
return route.IsValid()
@ -120,12 +119,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
// Sort the current approved for comparison
sortedCurrent := make([]netip.Prefix, len(currentApproved))
copy(sortedCurrent, currentApproved)
tsaddr.SortPrefixes(sortedCurrent)
slices.SortFunc(sortedCurrent, netip.Prefix.Compare)
// Only update if the routes actually changed
if !slices.Equal(sortedCurrent, newApproved) {
// Log what changed
var added, kept []netip.Prefix
for _, route := range newApproved {
if !slices.Contains(sortedCurrent, route) {
added = append(added, route)

View file

@ -3,16 +3,16 @@ package policy
import (
"fmt"
"net/netip"
"slices"
"testing"
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
)
@ -32,10 +32,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test-node",
UserID: ptr.To(user1.ID),
User: ptr.To(user1),
UserID: new(user1.ID),
User: new(user1),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
Tags: []string{"tag:test"},
}
@ -44,10 +44,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "other-node",
UserID: ptr.To(user2.ID),
User: ptr.To(user2),
UserID: new(user2.ID),
User: new(user2),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
IPv4: new(netip.MustParseAddr("100.64.0.2")),
}
// Create a policy that auto-approves specific routes
@ -76,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
}`
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
assert.NoError(t, err)
require.NoError(t, err)
tests := []struct {
name string
@ -194,7 +194,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
tsaddr.SortPrefixes(tt.wantApproved)
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
// Verify that all previously approved routes are still present
@ -304,20 +304,23 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
// Create policy manager or use nil if specified
var pm PolicyManager
var err error
var (
pm PolicyManager
err error
)
if tt.name != "nil_policy_manager" {
pm, err = pmf(users, nodes.ViewSlice())
assert.NoError(t, err)
require.NoError(t, err)
} else {
pm = nil
}
@ -330,7 +333,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
if tt.wantApproved == nil {
assert.Nil(t, gotApproved, "expected nil approved routes")
} else {
tsaddr.SortPrefixes(tt.wantApproved)
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
}
})

View file

@ -13,7 +13,6 @@ import (
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
},
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
nodeUser: "test",
// Sorted by netip.Prefix.Compare: by IP address then by prefix length
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("10.0.0.0/24"),
netip.MustParsePrefix("172.16.0.0/16"),
netip.MustParsePrefix("192.168.0.0/24"),
},
wantChanged: false,
@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
},
nodeUser: "test",
nodeTags: []string{"tag:approved"},
// Sorted by netip.Prefix.Compare: by IP address then by prefix length
wantApproved: []netip.Prefix{
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
},
wantChanged: true,
},
@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname,
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
Tags: tt.nodeTags,
}
@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
}
nodes := types.Nodes{&node}
@ -326,6 +327,7 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
}
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
//nolint:staticcheck
user := types.User{
Model: gorm.Model{ID: 1},
Name: "test",
@ -343,13 +345,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: ptr.To(user.ID),
User: ptr.To(user),
UserID: new(user.ID),
User: new(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
IPv4: new(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: currentApproved,
}

View file

@ -14,7 +14,6 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var ap = func(ipStr string) *netip.Addr {
@ -33,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
rules []tailcfg.FilterRule
node *types.Node
}
tests := []struct {
name string
args args
@ -783,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
for _, v := range gotViews.All() {
got = append(got, v.AsStruct())
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
t.Log("Matchers: ")
for _, m := range matchers {
t.Log("\t+", m.DebugString())
}
@ -796,6 +798,7 @@ func TestReduceNodes(t *testing.T) {
func TestReduceNodesFromPolicy(t *testing.T) {
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
//nolint:prealloc
var routes []netip.Prefix
for _, route := range routess {
routes = append(routes, netip.MustParsePrefix(route))
@ -1032,8 +1035,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager
var err error
var (
pm PolicyManager
err error
)
pm, err = pmf(nil, tt.nodes.ViewSlice())
require.NoError(t, err)
@ -1051,9 +1057,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
for _, v := range gotViews.All() {
got = append(got, v.AsStruct())
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
t.Log("Matchers: ")
for _, m := range matchers {
t.Log("\t+", m.DebugString())
}
@ -1074,21 +1082,21 @@ func TestSSHPolicyRules(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: ap("100.64.0.1"),
UserID: ptr.To(uint(1)),
User: ptr.To(users[0]),
UserID: new(uint(1)),
User: new(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: ap("100.64.0.2"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
UserID: new(uint(2)),
User: new(users[1]),
}
taggedClient := types.Node{
Hostname: "tagged-client",
IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
UserID: new(uint(2)),
User: new(users[1]),
Tags: []string{"tag:client"},
}
@ -1207,7 +1215,7 @@ func TestSSHPolicyRules(t *testing.T) {
]
}`,
expectErr: true,
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
},
{
name: "invalid-check-period",
@ -1242,7 +1250,7 @@ func TestSSHPolicyRules(t *testing.T) {
]
}`,
expectErr: true,
errorMessage: "autogroup \"autogroup:invalid\" is not supported",
errorMessage: `autogroup not supported for SSH: "autogroup:invalid" for SSH user`,
},
{
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
@ -1406,13 +1414,17 @@ func TestSSHPolicyRules(t *testing.T) {
for _, tt := range tests {
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm PolicyManager
var err error
var (
pm PolicyManager
err error
)
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
if tt.expectErr {
require.Error(t, err)
require.Contains(t, err.Error(), tt.errorMessage)
return
}
@ -1435,6 +1447,7 @@ func TestReduceRoutes(t *testing.T) {
routes []netip.Prefix
rules []tailcfg.FilterRule
}
tests := []struct {
name string
args args
@ -2056,6 +2069,7 @@ func TestReduceRoutes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
got := ReduceRoutes(
tt.args.node.View(),
tt.args.routes,

View file

@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
for _, rule := range rules {
// record if the rule is actually relevant for the given node.
var dests []tailcfg.NetPortRange
DEST_LOOP:
for _, dest := range rule.DstPorts {
expanded, err := util.ParseIPSet(dest.IP, nil)

View file

@ -16,7 +16,6 @@ import (
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
"tailscale.com/util/must"
)
@ -144,13 +143,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: ptr.To(users[0]),
User: new(users[0]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: ptr.To(users[0]),
User: new(users[0]),
},
},
want: []tailcfg.FilterRule{},
@ -191,7 +190,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
@ -202,7 +201,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
User: new(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -283,19 +282,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[2]),
User: new(users[2]),
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: ptr.To(users[3]),
User: new(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -344,7 +343,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: ptr.To(users[3]),
User: new(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -353,12 +352,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[2]),
User: new(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -453,7 +452,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: ptr.To(users[3]),
User: new(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -462,12 +461,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[2]),
User: new(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -565,7 +564,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: ptr.To(users[3]),
User: new(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
@ -574,12 +573,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[2]),
User: new(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -655,7 +654,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: ptr.To(users[3]),
User: new(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
@ -664,12 +663,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[2]),
User: new(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -737,7 +736,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: ptr.To(users[3]),
User: new(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
@ -747,7 +746,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -804,13 +803,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[3]),
User: new(users[3]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[1]),
User: new(users[1]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},
@ -824,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) {
for _, tt := range tests {
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
var pm policy.PolicyManager
var err error
var (
pm policy.PolicyManager
err error
)
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
require.NoError(t, err)
got, _ := pm.Filter()
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
got = policyutil.ReduceFilterRules(tt.node.View(), got)

View file

@ -10,7 +10,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
func TestNodeCanApproveRoute(t *testing.T) {
@ -25,24 +24,24 @@ func TestNodeCanApproveRoute(t *testing.T) {
ID: 1,
Hostname: "user1-device",
IPv4: ap("100.64.0.1"),
UserID: ptr.To(uint(1)),
User: ptr.To(users[0]),
UserID: new(uint(1)),
User: new(users[0]),
}
exitNode := types.Node{
ID: 2,
Hostname: "user2-device",
IPv4: ap("100.64.0.2"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
UserID: new(uint(2)),
User: new(users[1]),
}
taggedNode := types.Node{
ID: 3,
Hostname: "tagged-server",
IPv4: ap("100.64.0.3"),
UserID: ptr.To(uint(3)),
User: ptr.To(users[2]),
UserID: new(uint(3)),
User: new(users[2]),
Tags: []string{"tag:router"},
}
@ -50,8 +49,8 @@ func TestNodeCanApproveRoute(t *testing.T) {
ID: 4,
Hostname: "multi-tag-node",
IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
UserID: new(uint(2)),
User: new(users[1]),
Tags: []string{"tag:router", "tag:server"},
}
@ -830,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
if tt.name == "empty policy" {
// We expect this one to have a valid but empty policy
require.NoError(t, err)
if err != nil {
return
}
@ -844,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
}
assert.Equal(t, tt.canApprove, result, "Unexpected route approval result")
})
}

View file

@ -1,7 +1,6 @@
package v2
import (
"errors"
"fmt"
"slices"
"time"
@ -14,8 +13,6 @@ import (
"tailscale.com/types/views"
)
var ErrInvalidAction = errors.New("invalid action")
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
// set of Tailscale compatible FilterRules used to allow traffic on clients.
func (pol *Policy) compileFilterRules(
@ -42,9 +39,10 @@ func (pol *Policy) compileFilterRules(
continue
}
protocols, _ := acl.Protocol.parseProtocol()
protocols := acl.Protocol.parseProtocol()
var destPorts []tailcfg.NetPortRange
for _, dest := range acl.Destinations {
ips, err := dest.Resolve(pol, users, nodes)
if err != nil {
@ -121,14 +119,18 @@ func (pol *Policy) compileFilterRulesForNode(
// It returns a slice of filter rules because when an ACL has both autogroup:self
// and other destinations, they need to be split into separate rules with different
// source filtering logic.
//
//nolint:gocyclo
func (pol *Policy) compileACLWithAutogroupSelf(
acl ACL,
users types.Users,
node types.NodeView,
nodes views.Slice[types.NodeView],
) ([]*tailcfg.FilterRule, error) {
var autogroupSelfDests []AliasWithPorts
var otherDests []AliasWithPorts
var (
autogroupSelfDests []AliasWithPorts
otherDests []AliasWithPorts
)
for _, dest := range acl.Destinations {
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
@ -138,14 +140,15 @@ func (pol *Policy) compileACLWithAutogroupSelf(
}
}
protocols, _ := acl.Protocol.parseProtocol()
protocols := acl.Protocol.parseProtocol()
var rules []*tailcfg.FilterRule
var resolvedSrcIPs []*netipx.IPSet
for _, src := range acl.Sources {
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
return nil, fmt.Errorf("autogroup:self cannot be used in sources")
return nil, ErrAutogroupSelfInSource
}
ips, err := src.Resolve(pol, users, nodes)
@ -167,6 +170,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
if len(autogroupSelfDests) > 0 {
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
@ -176,6 +180,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
if len(sameUserNodes) > 0 {
// Filter sources to only same-user untagged devices
var srcIPs netipx.IPSetBuilder
for _, ips := range resolvedSrcIPs {
for _, n := range sameUserNodes {
// Check if any of this node's IPs are in the source set
@ -192,6 +197,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
if srcSet != nil && len(srcSet.Prefixes()) > 0 {
var destPorts []tailcfg.NetPortRange
for _, dest := range autogroupSelfDests {
for _, n := range sameUserNodes {
for _, port := range dest.Ports {
@ -280,13 +286,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
}
}
//nolint:gocyclo
func (pol *Policy) compileSSHPolicy(
users types.Users,
node types.NodeView,
nodes views.Slice[types.NodeView],
) (*tailcfg.SSHPolicy, error) {
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
return nil, nil
return nil, nil //nolint:nilnil
}
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy(
// Separate destinations into autogroup:self and others
// This is needed because autogroup:self requires filtering sources to same-user only,
// while other destinations should use all resolved sources
var autogroupSelfDests []Alias
var otherDests []Alias
var (
autogroupSelfDests []Alias
otherDests []Alias
)
for _, dst := range rule.Destinations {
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy(
}
var action tailcfg.SSHAction
switch rule.Action {
case SSHActionAccept:
action = sshAction(true, 0)
@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy(
// by default, we do not allow root unless explicitly stated
userMap["root"] = ""
}
if rule.Users.ContainsRoot() {
userMap["root"] = "root"
}
for _, u := range rule.Users.NormalUsers() {
userMap[u.String()] = u.String()
}
@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy(
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
// Build destination set for autogroup:self (same-user untagged devices only)
var dest netipx.IPSetBuilder
for _, n := range nodes.All() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
n.AppendToIPSet(&dest)
@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy(
// Filter sources to only same-user untagged devices
// Pre-filter to same-user untagged devices for efficiency
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy(
}
var filteredSrcIPs netipx.IPSetBuilder
for _, n := range sameUserNodes {
// Check if any of this node's IPs are in the source set
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy(
if len(otherDests) > 0 {
// Build destination set for other destinations
var dest netipx.IPSetBuilder
for _, dst := range otherDests {
ips, err := dst.Resolve(pol, users, nodes)
if err != nil {
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
continue
}
if ips != nil {
dest.AddSet(ips)
}

View file

@ -15,7 +15,6 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
@ -410,14 +409,14 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
UserID: new(users[0].ID),
User: new(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodes := types.Nodes{&nodeUser1, &nodeUser2}
@ -590,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
if sshPolicy == nil {
return // Expected empty result
}
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
return
}
@ -622,14 +623,14 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
UserID: new(users[0].ID),
User: new(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodes := types.Nodes{&nodeUser1, &nodeUser2}
@ -671,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
}
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
// TestSSHOneUserToAll that was failing with empty sshUsers
// TestSSHOneUserToAll that was failing with empty sshUsers.
func TestSSHIntegrationReproduction(t *testing.T) {
// Create users matching the integration test
users := types.Users{
@ -683,15 +684,15 @@ func TestSSHIntegrationReproduction(t *testing.T) {
node1 := &types.Node{
Hostname: "user1-node",
IPv4: createAddr("100.64.0.1"),
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
UserID: new(users[0].ID),
User: new(users[0]),
}
node2 := &types.Node{
Hostname: "user2-node",
IPv4: createAddr("100.64.0.2"),
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
UserID: new(users[1].ID),
User: new(users[1]),
}
nodes := types.Nodes{node1, node2}
@ -736,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
}
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
// to JSON and that the sshUsers field is not empty
// to JSON and that the sshUsers field is not empty.
func TestSSHJSONSerialization(t *testing.T) {
users := types.Users{
{Name: "user1", Model: gorm.Model{ID: 1}},
@ -776,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) {
// Parse back to verify structure
var parsed tailcfg.SSHPolicy
err = json.Unmarshal(jsonData, &parsed)
require.NoError(t, err)
@ -806,19 +808,19 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
nodes := types.Nodes{
{
User: ptr.To(users[0]),
User: new(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[0]),
User: new(users[0]),
IPv4: ap("100.64.0.2"),
},
{
User: ptr.To(users[1]),
User: new(users[1]),
IPv4: ap("100.64.0.3"),
},
{
User: ptr.To(users[1]),
User: new(users[1]),
IPv4: ap("100.64.0.4"),
},
// Tagged device for user1
@ -860,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if len(rules) != 1 {
t.Fatalf("expected 1 rule, got %d", len(rules))
}
@ -876,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
found := false
addr := netip.MustParseAddr(expectedIP)
for _, prefix := range rule.SrcIPs {
pref := netip.MustParsePrefix(prefix)
if pref.Contains(addr) {
@ -893,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
for _, excludedIP := range excludedSourceIPs {
addr := netip.MustParseAddr(excludedIP)
for _, prefix := range rule.SrcIPs {
pref := netip.MustParsePrefix(prefix)
if pref.Contains(addr) {
@ -938,11 +943,11 @@ func TestTagUserMutualExclusivity(t *testing.T) {
nodes := types.Nodes{
// User-owned nodes
{
User: ptr.To(users[0]),
User: new(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[1]),
User: new(users[1]),
IPv4: ap("100.64.0.2"),
},
// Tagged nodes
@ -960,8 +965,8 @@ func TestTagUserMutualExclusivity(t *testing.T) {
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
Tag("tag:server"): Owners{new(Username("user1@"))},
Tag("tag:database"): Owners{new(Username("user2@"))},
},
ACLs: []ACL{
// Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes
@ -1056,11 +1061,11 @@ func TestAutogroupTagged(t *testing.T) {
nodes := types.Nodes{
// User-owned nodes (not tagged)
{
User: ptr.To(users[0]),
User: new(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[1]),
User: new(users[1]),
IPv4: ap("100.64.0.2"),
},
// Tagged nodes
@ -1083,10 +1088,10 @@ func TestAutogroupTagged(t *testing.T) {
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
Tag("tag:server"): Owners{new(Username("user1@"))},
Tag("tag:database"): Owners{new(Username("user2@"))},
Tag("tag:web"): Owners{new(Username("user1@"))},
Tag("tag:prod"): Owners{new(Username("user1@"))},
},
ACLs: []ACL{
// Rule: autogroup:tagged can reach user-owned nodes
@ -1206,10 +1211,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) {
}
nodes := types.Nodes{
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: new(users[0]), IPv4: ap("100.64.0.1")},
{User: new(users[0]), IPv4: ap("100.64.0.2")},
{User: new(users[1]), IPv4: ap("100.64.0.3")},
{User: new(users[1]), IPv4: ap("100.64.0.4")},
}
policy := &Policy{
@ -1273,11 +1278,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
}
nodes := types.Nodes{
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
{User: new(users[0]), IPv4: ap("100.64.0.1")},
{User: new(users[0]), IPv4: ap("100.64.0.2")},
{User: new(users[1]), IPv4: ap("100.64.0.3")},
{User: new(users[1]), IPv4: ap("100.64.0.4")},
{User: new(users[2]), IPv4: ap("100.64.0.5")},
}
policy := &Policy{
@ -1326,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
assert.Empty(t, rules3, "user3 should have no rules")
}
// Helper function to create IP addresses for testing
// Helper function to create IP addresses for testing.
func createAddr(ip string) *netip.Addr {
addr, _ := netip.ParseAddr(ip)
return &addr
}
// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly
// with autogroup:self in destinations
// with autogroup:self in destinations.
func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1342,13 +1347,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
nodes := types.Nodes{
// User1's nodes
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
// User2's nodes
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
{User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
{User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
// Tagged node for user1 (should be excluded)
{User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
{User: new(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
}
policy := &Policy{
@ -1381,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
// Test for user2's first node
@ -1399,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
for i, p := range rule2.Principals {
principalIPs2[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2)
// Test for tagged node (should have no SSH rules)
node5 := nodes[4].View()
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy3 != nil {
assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self")
}
@ -1412,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user
// is in the source and autogroup:self in destination, only that user's devices
// can SSH (and only if they match the target user)
// can SSH (and only if they match the target user).
func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1420,10 +1428,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
}
nodes := types.Nodes{
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: new(users[0]), IPv4: ap("100.64.0.1")},
{User: new(users[0]), IPv4: ap("100.64.0.2")},
{User: new(users[1]), IPv4: ap("100.64.0.3")},
{User: new(users[1]), IPv4: ap("100.64.0.4")},
}
policy := &Policy{
@ -1454,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
// For user2's node: should have no rules (user1's devices can't match user2's self)
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1")
}
}
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations.
func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
@ -1474,11 +1484,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
}
nodes := types.Nodes{
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
{User: new(users[0]), IPv4: ap("100.64.0.1")},
{User: new(users[0]), IPv4: ap("100.64.0.2")},
{User: new(users[1]), IPv4: ap("100.64.0.3")},
{User: new(users[1]), IPv4: ap("100.64.0.4")},
{User: new(users[2]), IPv4: ap("100.64.0.5")},
}
policy := &Policy{
@ -1512,29 +1522,31 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
// For user3's node: should have no rules (not in group:admins)
node5 := nodes[4].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)")
}
}
// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices
// are excluded from both sources and destinations when autogroup:self is used
// are excluded from both sources and destinations when autogroup:self is used.
func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
}
nodes := types.Nodes{
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
{User: new(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
{User: new(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
}
policy := &Policy{
@ -1569,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
for i, p := range rule.Principals {
principalIPs[i] = p.NodeIP
}
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs,
"should only include untagged devices")
@ -1576,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
node3 := nodes[2].View()
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
require.NoError(t, err)
if sshPolicy2 != nil {
assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self")
}
@ -1591,10 +1605,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
}
nodes := types.Nodes{
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
{User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
{User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
}
policy := &Policy{
@ -1624,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
// Verify autogroup:self rule has filtered sources (only same-user devices)
selfRule := sshPolicy1.Rules[0]
require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices")
selfPrincipals := make([]string, len(selfRule.Principals))
for i, p := range selfRule.Principals {
selfPrincipals[i] = p.NodeIP
}
require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals,
"autogroup:self rule should only include same-user untagged devices")
@ -1639,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
routerRule := sshPolicyRouter.Rules[0]
routerPrincipals := make([]string, len(routerRule.Principals))
for i, p := range routerRule.Principals {
routerPrincipals[i] = p.NodeIP
}
require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)")
require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)")
require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)")

View file

@ -21,6 +21,9 @@ import (
"tailscale.com/util/deephash"
)
// PolicyVersion is the version number of this policy implementation.
const PolicyVersion = 2
// ErrInvalidTagOwner is returned when a tag owner is not an Alias type.
var ErrInvalidTagOwner = errors.New("tag owner is not an Alias")
@ -111,6 +114,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Filter: filter,
Policy: pm.pol,
})
filterChanged := filterHash != pm.filterHash
if filterChanged {
log.Debug().
@ -120,7 +124,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Int("filter.rules.new", len(filter)).
Msg("Policy filter hash changed")
}
pm.filter = filter
pm.filterHash = filterHash
if filterChanged {
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
@ -135,6 +141,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
}
tagOwnerMapHash := deephash.Hash(&tagMap)
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
if tagOwnerChanged {
log.Debug().
@ -144,6 +151,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Int("tagOwners.new", len(tagMap)).
Msg("Tag owner hash changed")
}
pm.tagOwnerMap = tagMap
pm.tagOwnerMapHash = tagOwnerMapHash
@ -153,6 +161,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
}
autoApproveMapHash := deephash.Hash(&autoMap)
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
if autoApproveChanged {
log.Debug().
@ -162,10 +171,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Int("autoApprovers.new", len(autoMap)).
Msg("Auto-approvers hash changed")
}
pm.autoApproveMap = autoMap
pm.autoApproveMapHash = autoApproveMapHash
exitSetHash := deephash.Hash(&exitSet)
exitSetChanged := exitSetHash != pm.exitSetHash
if exitSetChanged {
log.Debug().
@ -173,6 +184,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
Str("exitSet.hash.new", exitSetHash.String()[:8]).
Msg("Exit node set hash changed")
}
pm.exitSet = exitSet
pm.exitSetHash = exitSetHash
@ -199,6 +211,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
if !needsUpdate {
log.Trace().
Msg("Policy evaluation detected no changes - all hashes match")
return false, nil
}
@ -224,6 +237,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
if err != nil {
return nil, fmt.Errorf("compiling SSH policy: %w", err)
}
pm.sshPolicyMap[node.ID()] = sshPol
return sshPol, nil
@ -318,6 +332,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ
if err != nil || len(filter) == 0 {
continue
}
nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter)
}
@ -398,6 +413,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
pm.filterRulesMap[node.ID()] = reducedFilter
return reducedFilter, nil
}
@ -442,7 +458,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
//
// For global policies: returns the global matchers (same for all nodes)
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
// For autogroup:self: returns node-specific matchers from unreduced compiled rules.
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
if pm == nil {
return nil, nil
@ -474,6 +490,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
pm.mu.Lock()
defer pm.mu.Unlock()
pm.users = users
// Clear SSH policy map when users change to force SSH policy recomputation
@ -685,6 +702,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
if pm.exitSet == nil {
return false
}
if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) {
return true
}
@ -724,7 +742,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
}
func (pm *PolicyManager) Version() int {
return 2
return PolicyVersion
}
func (pm *PolicyManager) DebugString() string {
@ -748,8 +766,10 @@ func (pm *PolicyManager) DebugString() string {
}
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
for prefix, approveAddrs := range pm.autoApproveMap {
fmt.Fprintf(&sb, "\t%s:\n", prefix)
for _, iprange := range approveAddrs.Ranges() {
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
}
@ -758,14 +778,17 @@ func (pm *PolicyManager) DebugString() string {
sb.WriteString("\n\n")
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
for prefix, tagOwners := range pm.tagOwnerMap {
fmt.Fprintf(&sb, "\t%s:\n", prefix)
for _, iprange := range tagOwners.Ranges() {
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
}
}
sb.WriteString("\n\n")
if pm.filter != nil {
filter, err := json.MarshalIndent(pm.filter, "", " ")
if err == nil {
@ -778,6 +801,7 @@ func (pm *PolicyManager) DebugString() string {
sb.WriteString("\n\n")
sb.WriteString("Matchers:\n")
sb.WriteString("an internal structure used to filter nodes and routes\n")
for _, match := range pm.matchers {
sb.WriteString(match.DebugString())
sb.WriteString("\n")
@ -785,6 +809,7 @@ func (pm *PolicyManager) DebugString() string {
sb.WriteString("\n\n")
sb.WriteString("Nodes:\n")
for _, node := range pm.nodes.All() {
sb.WriteString(node.String())
sb.WriteString("\n")
@ -841,6 +866,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// Check if IPs changed (simple check - could be more sophisticated)
oldIPs := oldNode.IPs()
newIPs := newNode.IPs()
if len(oldIPs) != len(newIPs) {
affectedUsers[newNode.User().ID()] = struct{}{}
@ -862,6 +888,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
for nodeID := range pm.filterRulesMap {
// Find the user for this cached node
var nodeUserID uint
found := false
// Check in new nodes first
@ -869,6 +896,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
if node.ID() == nodeID {
nodeUserID = node.User().ID()
found = true
break
}
}
@ -879,6 +907,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
if node.ID() == nodeID {
nodeUserID = node.User().ID()
found = true
break
}
}
@ -956,14 +985,7 @@ func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.
// It will return a Owners list where all the Tag types have been resolved to their underlying Owners.
func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) {
if visiting[tag] {
cycleStart := 0
for i, t := range chain {
if t == tag {
cycleStart = i
break
}
}
cycleStart := slices.Index(chain, tag)
cycleTags := make([]string, len(chain[cycleStart:]))
for i, t := range chain[cycleStart:] {

View file

@ -11,18 +11,16 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
func node(name, ipv4, ipv6 string, user types.User, _ *tailcfg.Hostinfo) *types.Node {
return &types.Node{
ID: 0,
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: ptr.To(user),
UserID: ptr.To(user.ID),
Hostinfo: hostinfo,
User: new(user),
UserID: new(user.ID),
}
}
@ -57,6 +55,7 @@ func TestPolicyManager(t *testing.T) {
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(
tt.wantMatchers,
matchers,
@ -77,6 +76,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
{Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"},
}
//nolint:goconst
policy := `{
"acls": [
{
@ -95,7 +95,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
}
for i, n := range initialNodes {
n.ID = types.NodeID(i + 1)
n.ID = types.NodeID(i + 1) //nolint:gosec
}
pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice())
@ -107,7 +107,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
require.NoError(t, err)
}
require.Equal(t, len(initialNodes), len(pm.filterRulesMap))
require.Len(t, pm.filterRulesMap, len(initialNodes))
tests := []struct {
name string
@ -177,15 +177,18 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
for i, n := range tt.newNodes {
found := false
for _, origNode := range initialNodes {
if n.Hostname == origNode.Hostname {
n.ID = origNode.ID
found = true
break
}
}
if !found {
n.ID = types.NodeID(len(initialNodes) + i + 1)
n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec
}
}
@ -370,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) {
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
// 2. FilterForNode returns reduced compiled rules for packet filters
// 2. FilterForNode returns reduced compiled rules for packet filters.
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
@ -410,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
// For node1, destinations should only be node1's IPs
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
for _, rule := range filterNode1 {
for _, dst := range rule.DstPorts {
require.Contains(t, node1IPs, dst.IP,
@ -419,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
// For node2, destinations should only be node2's IPs
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
for _, rule := range filterNode2 {
for _, dst := range rule.DstPorts {
require.Contains(t, node2IPs, dst.IP,
@ -457,8 +462,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
Hostname: "test-1-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
User: new(users[0]),
UserID: new(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -468,8 +473,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
User: new(users[1]),
UserID: new(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
}
@ -537,8 +542,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
Hostname: "test-1-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
User: new(users[0]),
UserID: new(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -547,8 +552,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
Hostname: "test-2-device",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
User: new(users[1]),
UserID: new(users[1].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -647,8 +652,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
Hostname: "user1-node",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
User: new(users[0]),
UserID: new(users[0].ID),
Tags: []string{"tag:web", "tag:internal"},
}
@ -658,8 +663,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
Hostname: "user2-node",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
User: new(users[1]),
UserID: new(users[1].ID),
}
initialNodes := types.Nodes{user1Node, user2Node}
@ -686,8 +691,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
Hostname: "user1-node",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
User: new(users[0]),
UserID: new(users[0].ID),
Tags: []string{"tag:internal"}, // tag:web removed!
}
@ -749,8 +754,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) {
Hostname: "admin-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
User: new(users[0]),
UserID: new(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -760,8 +765,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) {
Hostname: "user1-server",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
User: new(users[1]),
UserID: new(users[1].ID),
Tags: []string{"tag:server"},
Hostinfo: &tailcfg.Hostinfo{},
}
@ -832,8 +837,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) {
Hostname: "device-a",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
User: new(users[0]),
UserID: new(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -843,8 +848,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) {
Hostname: "device-b",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
User: new(users[1]),
UserID: new(users[1].ID),
Tags: []string{"tag:web"},
Hostinfo: &tailcfg.Hostinfo{},
}

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -9,6 +9,21 @@ import (
"tailscale.com/tailcfg"
)
// portRangeParts is the expected number of parts in a port range (start-end).
const portRangeParts = 2
// Sentinel errors for port and destination parsing.
var (
ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port")
ErrInputStartsWithColon = errors.New("input cannot start with a colon character")
ErrInputEndsWithColon = errors.New("input cannot end with a colon character")
ErrInvalidPortRange = errors.New("invalid port range format")
ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port")
ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard")
ErrInvalidPortNumber = errors.New("invalid port number")
ErrPortOutOfRange = errors.New("port number out of range")
)
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
func splitDestinationAndPort(input string) (string, string, error) {
// Find the last occurrence of the colon character
@ -16,13 +31,15 @@ func splitDestinationAndPort(input string) (string, string, error) {
// Check if the colon character is present and not at the beginning or end of the string
if lastColonIndex == -1 {
return "", "", errors.New("input must contain a colon character separating destination and port")
return "", "", ErrInputMissingColon
}
if lastColonIndex == 0 {
return "", "", errors.New("input cannot start with a colon character")
return "", "", ErrInputStartsWithColon
}
if lastColonIndex == len(input)-1 {
return "", "", errors.New("input cannot end with a colon character")
return "", "", ErrInputEndsWithColon
}
// Split the string into destination and port based on the last colon
@ -45,11 +62,12 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
for part := range parts {
if strings.Contains(part, "-") {
rangeParts := strings.Split(part, "-")
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
return e == ""
})
if len(rangeParts) != 2 {
return nil, errors.New("invalid port range format")
if len(rangeParts) != portRangeParts {
return nil, ErrInvalidPortRange
}
first, err := parsePort(rangeParts[0])
@ -63,7 +81,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
}
if first > last {
return nil, errors.New("invalid port range: first port is greater than last port")
return nil, ErrPortRangeInverted
}
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
@ -74,7 +92,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
}
if port < 1 {
return nil, errors.New("first port must be >0, or use '*' for wildcard")
return nil, ErrPortMustBePositive
}
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
@ -88,11 +106,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
func parsePort(portStr string) (uint16, error) {
port, err := strconv.Atoi(portStr)
if err != nil {
return 0, errors.New("invalid port number")
return 0, ErrInvalidPortNumber
}
if port < 0 || port > 65535 {
return 0, errors.New("port number out of range")
return 0, ErrPortOutOfRange
}
return uint16(port), nil

View file

@ -24,14 +24,14 @@ func TestParseDestinationAndPort(t *testing.T) {
{"tag:api-server:443", "tag:api-server", "443", nil},
{"example-host-1:*", "example-host-1", "*", nil},
{"hostname:80-90", "hostname", "80-90", nil},
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
{":invalid", "", "", errors.New("input cannot start with a colon character")},
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
{"invalidinput", "", "", ErrInputMissingColon},
{":invalid", "", "", ErrInputStartsWithColon},
{"invalid:", "", "", ErrInputEndsWithColon},
}
for _, testCase := range testCases {
dst, port, err := splitDestinationAndPort(testCase.input)
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
if dst != testCase.expectedDst || port != testCase.expectedPort || !errors.Is(err, testCase.expectedErr) {
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
}
@ -42,25 +42,23 @@ func TestParsePort(t *testing.T) {
tests := []struct {
input string
expected uint16
err string
err error
}{
{"80", 80, ""},
{"0", 0, ""},
{"65535", 65535, ""},
{"-1", 0, "port number out of range"},
{"65536", 0, "port number out of range"},
{"abc", 0, "invalid port number"},
{"", 0, "invalid port number"},
{"80", 80, nil},
{"0", 0, nil},
{"65535", 65535, nil},
{"-1", 0, ErrPortOutOfRange},
{"65536", 0, ErrPortOutOfRange},
{"abc", 0, ErrInvalidPortNumber},
{"", 0, ErrInvalidPortNumber},
}
for _, test := range tests {
result, err := parsePort(test.input)
if err != nil && err.Error() != test.err {
if !errors.Is(err, test.err) {
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
}
if err == nil && test.err != "" {
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
}
if result != test.expected {
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
}
@ -71,30 +69,28 @@ func TestParsePortRange(t *testing.T) {
tests := []struct {
input string
expected []tailcfg.PortRange
err string
err error
}{
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
{"80-", nil, "invalid port range format"},
{"-90", nil, "invalid port range format"},
{"80-90,", nil, "invalid port number"},
{"80,90-", nil, "invalid port range format"},
{"80-90,abc", nil, "invalid port number"},
{"80-90,65536", nil, "port number out of range"},
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, nil},
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, nil},
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, nil},
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, nil},
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, nil},
{"80-", nil, ErrInvalidPortRange},
{"-90", nil, ErrInvalidPortRange},
{"80-90,", nil, ErrInvalidPortNumber},
{"80,90-", nil, ErrInvalidPortRange},
{"80-90,abc", nil, ErrInvalidPortNumber},
{"80-90,65536", nil, ErrPortOutOfRange},
{"80-90,90-80", nil, ErrPortRangeInverted},
}
for _, test := range tests {
result, err := parsePortRange(test.input)
if err != nil && err.Error() != test.err {
if !errors.Is(err, test.err) {
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
}
if err == nil && test.err != "" {
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
}
if diff := cmp.Diff(result, test.expected); diff != "" {
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
}

View file

@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() {
// This is not my favourite solution, but it kind of works in our eventually consistent world.
ticker := time.NewTicker(time.Second)
defer ticker.Stop()
disconnected := true
// Wait up to 10 seconds for the node to reconnect.
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() {
disconnected = false
break
}
<-ticker.C
}
@ -212,11 +214,14 @@ func (m *mapSession) serveLongPoll() {
// adding this before connecting it to the state ensure that
// it does not miss any updates that might be sent in the split
// time between the node connecting and the batcher being ready.
//nolint:noinlineerr
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
m.errf(err, "failed to add node to batcher")
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
return
}
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
m.h.Change(mapReqChange)
@ -245,7 +250,8 @@ func (m *mapSession) serveLongPoll() {
return
}
if err := m.writeMap(update); err != nil {
err := m.writeMap(update)
if err != nil {
m.errf(err, "cannot write update to client")
return
}
@ -254,7 +260,8 @@ func (m *mapSession) serveLongPoll() {
m.resetKeepAlive()
case <-m.keepAliveTicker.C:
if err := m.writeMap(&keepAlive); err != nil {
err := m.writeMap(&keepAlive)
if err != nil {
m.errf(err, "cannot write keep alive")
return
}
@ -282,8 +289,9 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
}
//nolint:prealloc
data := make([]byte, reservedResponseHeaderSize)
//nolint:gosec // G115: JSON response size will not exceed uint32 max
//nolint:gosec
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
data = append(data, jsonBody...)
@ -328,13 +336,13 @@ func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event {
Str("node.name", m.node.Hostname)
}
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
//nolint:zerologlint
func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) }
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
//nolint:zerologlint
func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) }
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
//nolint:zerologlint
func (m *mapSession) errf(err error, msg string, a ...any) {
m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...)
}

View file

@ -4,7 +4,6 @@ import (
"fmt"
"net/netip"
"slices"
"sort"
"strings"
"sync"
@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
// this is important so the same node is chosen two times in a row
// as the primary route.
ids := types.NodeIDs(xmaps.Keys(pr.routes))
sort.Sort(ids)
slices.Sort(ids)
// Create a map of prefixes to nodes that serve them so we
// can determine the primary route for each prefix.
@ -108,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
Msg("Current primary no longer available")
}
}
if len(nodes) >= 1 {
pr.primaries[prefix] = nodes[0]
changed = true
log.Debug().
Caller().
Str("prefix", prefix.String()).
@ -127,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
Str("prefix", prefix.String()).
Msg("Cleaning up primary route that no longer has available nodes")
delete(pr.primaries, prefix)
changed = true
}
}
@ -162,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix)
// If no routes are being set, remove the node from the routes map.
if len(prefixes) == 0 {
wasPresent := false
if _, ok := pr.routes[node]; ok {
delete(pr.routes, node)
wasPresent = true
log.Debug().
Caller().
Uint64("node.id", node.Uint64()).
Msg("Removed node from primary routes (no prefixes)")
}
changed := pr.updatePrimaryLocked()
log.Debug().
Caller().
@ -236,7 +242,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
}
}
tsaddr.SortPrefixes(routes)
slices.SortFunc(routes, netip.Prefix.Compare)
return routes
}
@ -254,13 +260,15 @@ func (pr *PrimaryRoutes) stringLocked() string {
fmt.Fprintln(&sb, "Available routes:")
ids := types.NodeIDs(xmaps.Keys(pr.routes))
sort.Sort(ids)
slices.Sort(ids)
for _, id := range ids {
prefixes := pr.routes[id]
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
}
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
for route, nodeID := range pr.primaries {
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
}
@ -294,7 +302,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes {
// Populate available routes
for nodeID, routes := range pr.routes {
prefixes := routes.Slice()
tsaddr.SortPrefixes(prefixes)
slices.SortFunc(prefixes, netip.Prefix.Compare)
debug.AvailableRoutes[nodeID] = prefixes
}

View file

@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24"))
pr.SetRoutes(2, mp("192.168.2.0/24"))
pr.SetRoutes(1) // Deregister by setting no routes
return pr.SetRoutes(1, mp("192.168.3.0/24"))
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) {
{
name: "multiple-nodes-register-same-route",
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) {
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
return pr.SetRoutes(1) // true, 2 primary
return pr.SetRoutes(1) // true, 2 primary
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
2: {
@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) {
operations: func(pr *PrimaryRoutes) bool {
pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
},
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) {
operations: func(pr *PrimaryRoutes) bool {
var wg sync.WaitGroup
wg.Add(2)
var change1, change2 bool
go func() {
defer wg.Done()
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
}()
go func() {
defer wg.Done()
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
}()
wg.Wait()
return change1 || change2
@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pr := New()
change := tt.operations(pr)
if change != tt.expectedChange {
t.Errorf("change = %v, want %v", change, tt.expectedChange)
}
comps := append(util.Comparers, cmpopts.EquateEmpty())
if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
t.Errorf("routes mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
}
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
}

View file

@ -77,6 +77,7 @@ func (s *State) DebugOverview() string {
ephemeralCount := 0
now := time.Now()
for _, node := range allNodes.All() {
if node.Valid() {
userName := node.Owner().Name()
@ -103,17 +104,21 @@ func (s *State) DebugOverview() string {
// User statistics
sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users)))
for userName, nodeCount := range userNodeCounts {
sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount))
}
sb.WriteString("\n")
// Policy information
sb.WriteString("Policy:\n")
sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode))
if s.cfg.Policy.Mode == types.PolicyModeFile {
sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path))
}
sb.WriteString("\n")
// DERP information
@ -123,6 +128,7 @@ func (s *State) DebugOverview() string {
} else {
sb.WriteString("DERP: not configured\n")
}
sb.WriteString("\n")
// Route information
@ -130,6 +136,7 @@ func (s *State) DebugOverview() string {
if s.primaryRoutes.String() == "" {
routeCount = 0
}
sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount))
sb.WriteString("\n")
@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string {
for _, node := range region.Nodes {
sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n",
node.Name, node.HostName, node.DERPPort))
if node.STUNPort != 0 {
sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort))
}
}
sb.WriteString("\n")
}
@ -236,7 +245,7 @@ func (s *State) DebugPolicy() (string, error) {
return string(pol), nil
default:
return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode)
return "", fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, s.cfg.Policy.Mode)
}
}
@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
if s.primaryRoutes.String() == "" {
routeCount = 0
}
info.PrimaryRoutes = routeCount
return info

View file

@ -8,7 +8,6 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/ptr"
)
// TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode
@ -21,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
// Create NodeStore
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -44,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
// 6. If DELETE came after UPDATE, the returned node should be invalid
done := make(chan bool, 2)
var updatedNode types.NodeView
var updateOk bool
var (
updatedNode types.NodeView
updateOk bool
)
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
go func() {
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
})
done <- true
}()
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
go func() {
store.DeleteNode(node.ID)
done <- true
}()
@ -91,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -106,7 +113,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
// Start UpdateNode in goroutine - it will queue and wait for batch
go func() {
node, ok := store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
})
resultChan <- struct {
node types.NodeView
@ -148,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
node := createTestNode(3, 1, "test-user", "test-node-3")
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -156,7 +164,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
// Simulate UpdateNode being called
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
})
require.True(t, ok, "UpdateNode should succeed")
require.True(t, updatedNode.Valid(), "UpdateNode should return valid node")
@ -204,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -214,21 +223,26 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
// 2. DeleteNode (from handleLogout when client sends logout request)
var updatedNode types.NodeView
var updateOk bool
var (
updatedNode types.NodeView
updateOk bool
)
done := make(chan bool, 2)
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
go func() {
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
})
done <- true
}()
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
go func() {
store.DeleteNode(ephemeralNode.ID)
done <- true
}()
@ -267,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
// 5. UpdateNode and DeleteNode batch together
// 6. UpdateNode returns a valid node (from before delete in batch)
// 7. persistNodeToDB is called with the stale valid node
// 8. Node gets re-inserted into database instead of staying deleted
// 8. Node gets re-inserted into database instead of staying deleted.
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
ephemeralNode.AuthKey = &types.PreAuthKey{
@ -279,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -294,7 +309,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
go func() {
node, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
endpoint := netip.MustParseAddrPort("10.0.0.1:41641")
n.Endpoints = []netip.AddrPort{endpoint}
})
@ -349,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -363,7 +379,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
go func() {
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
})
updateDone <- struct {
node types.NodeView
@ -399,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
// 3. UpdateNode and DeleteNode batch together
// 4. UpdateNode returns a valid node (from before delete in batch)
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
// 6. persistNodeToDB must detect the node is deleted and refuse to persist.
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
ephemeralNode.AuthKey = &types.PreAuthKey{
@ -409,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
}
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -417,7 +434,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
// UpdateNode returns a node
updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
n.LastSeen = ptr.To(time.Now())
n.LastSeen = new(time.Now())
})
require.True(t, ok, "UpdateNode should succeed")
require.True(t, updatedNode.Valid(), "updated node should be valid")

View file

@ -29,6 +29,7 @@ func netInfoFromMapRequest(
Uint64("node.id", nodeID.Uint64()).
Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP).
Msg("using NetInfo from previous Hostinfo in MapRequest")
return currentHostinfo.NetInfo
}

View file

@ -1,15 +1,12 @@
package state
import (
"net/netip"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestNetInfoFromMapRequest(t *testing.T) {
@ -136,26 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node")
})
}
// Simple helper function for tests
func createTestNodeSimple(id types.NodeID) *types.Node {
user := types.User{
Name: "test-user",
}
machineKey := key.NewMachine()
nodeKey := key.NewNode()
node := &types.Node{
ID: id,
Hostname: "test-node",
UserID: ptr.To(uint(id)),
User: &user,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPv4: &netip.Addr{},
IPv6: &netip.Addr{},
}
return node
}

View file

@ -55,8 +55,8 @@ var (
})
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "nodestore_nodes_total",
Help: "Total number of nodes in the NodeStore",
Name: "nodestore_nodes",
Help: "Number of nodes in the NodeStore",
})
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc
for _, n := range allNodes {
nodes[n.ID] = *n
}
snap := snapshotFromNodes(nodes, peersFunc)
store := &NodeStore{
@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView {
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
resultNode := <-work.nodeResult
nodeStoreOperations.WithLabelValues("put").Inc()
return resultNode
@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
resultNode := <-work.nodeResult
nodeStoreOperations.WithLabelValues("update").Inc()
// Return the node and whether it exists (is valid)
@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) {
}
nodeStoreQueueDepth.Inc()
s.writeQueue <- work
<-work.result
nodeStoreQueueDepth.Dec()
@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() {
if len(batch) != 0 {
s.applyBatch(batch)
}
return
}
batch = append(batch, w)
if len(batch) >= s.batchSize {
s.applyBatch(batch)
@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) {
w.updateFn(&n)
nodes[w.nodeID] = n
}
if w.nodeResult != nil {
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
}
@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) {
nodeView := node.View()
for _, w := range workItems {
w.nodeResult <- nodeView
close(w.nodeResult)
}
} else {
// Node was deleted or doesn't exist
for _, w := range workItems {
w.nodeResult <- types.NodeView{} // Send invalid view
close(w.nodeResult)
}
}
@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
peersByNode: func() map[types.NodeID][]types.NodeView {
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
defer peersTimer.ObserveDuration()
return peersFunc(allNodes)
}(),
nodesByUser: make(map[types.UserID][]types.NodeView),
@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
}
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
}
@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string {
// User distribution (shows internal UserID tracking, not display owner)
sb.WriteString("Nodes by Internal User ID:\n")
for userID, nodes := range snapshot.nodesByUser {
if len(nodes) > 0 {
userName := "unknown"
taggedCount := 0
if len(nodes) > 0 && nodes[0].Valid() {
userName = nodes[0].User().Name()
// Count tagged nodes (which have UserID set but are owned by "tagged-devices")
@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string {
}
}
}
sb.WriteString("\n")
// Peer relationships summary
sb.WriteString("Peer Relationships:\n")
totalPeers := 0
for nodeID, peers := range snapshot.peersByNode {
peerCount := len(peers)
totalPeers += peerCount
if node, exists := snapshot.nodesByID[nodeID]; exists {
sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n",
nodeID, node.Hostname, peerCount))
}
}
if len(snapshot.peersByNode) > 0 {
avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode))
sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers))
}
sb.WriteString("\n")
// Node key index
@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() {
}
s.writeQueue <- w
<-result
}

View file

@ -2,6 +2,7 @@ package state
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime"
@ -13,7 +14,13 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
// Test sentinel errors for concurrent operations.
var (
errTestUpdateNodeFailed = errors.New("UpdateNode failed")
errTestGetNodeFailed = errors.New("GetNode failed")
errTestPutNodeFailed = errors.New("PutNode failed")
)
func TestSnapshotFromNodes(t *testing.T) {
@ -33,6 +40,7 @@ func TestSnapshotFromNodes(t *testing.T) {
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Empty(t, snapshot.nodesByID)
assert.Empty(t, snapshot.allNodes)
assert.Empty(t, snapshot.peersByNode)
@ -45,9 +53,11 @@ func TestSnapshotFromNodes(t *testing.T) {
nodes := map[types.NodeID]types.Node{
1: createTestNode(1, 1, "user1", "node1"),
}
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 1)
assert.Len(t, snapshot.allNodes, 1)
assert.Len(t, snapshot.peersByNode, 1)
@ -71,6 +81,7 @@ func TestSnapshotFromNodes(t *testing.T) {
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 2)
assert.Len(t, snapshot.allNodes, 2)
assert.Len(t, snapshot.peersByNode, 2)
@ -96,6 +107,7 @@ func TestSnapshotFromNodes(t *testing.T) {
return nodes, allowAllPeersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 3)
assert.Len(t, snapshot.allNodes, 3)
assert.Len(t, snapshot.peersByNode, 3)
@ -125,6 +137,7 @@ func TestSnapshotFromNodes(t *testing.T) {
return nodes, peersFunc
},
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
t.Helper()
assert.Len(t, snapshot.nodesByID, 4)
assert.Len(t, snapshot.allNodes, 4)
assert.Len(t, snapshot.peersByNode, 4)
@ -174,7 +187,7 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string)
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: ptr.To(userID),
UserID: new(userID),
User: &types.User{
Name: username,
DisplayName: username,
@ -193,11 +206,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
@ -208,6 +223,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
for _, node := range nodes {
var peers []types.NodeView
nodeIsOdd := node.ID()%2 == 1
for _, n := range nodes {
@ -222,6 +238,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
@ -237,6 +254,7 @@ func TestNodeStoreOperations(t *testing.T) {
{
name: "create empty store and add single node",
setupFunc: func(t *testing.T) *NodeStore {
t.Helper()
return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
},
steps: []testStep{
@ -455,10 +473,13 @@ func TestNodeStoreOperations(t *testing.T) {
// Add nodes in sequence
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
assert.True(t, n1.Valid())
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
assert.True(t, n2.Valid())
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
assert.True(t, n3.Valid())
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
assert.True(t, n4.Valid())
@ -526,16 +547,20 @@ func TestNodeStoreOperations(t *testing.T) {
done2 := make(chan struct{})
done3 := make(chan struct{})
var resultNode1, resultNode2 types.NodeView
var newNode3 types.NodeView
var ok1, ok2 bool
var (
resultNode1, resultNode2 types.NodeView
newNode3 types.NodeView
ok1, ok2 bool
)
// These should all be processed in the same batch
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "batch-updated-node1"
n.GivenName = "batch-given-1"
})
close(done1)
}()
@ -544,12 +569,14 @@ func TestNodeStoreOperations(t *testing.T) {
n.Hostname = "batch-updated-node2"
n.GivenName = "batch-given-2"
})
close(done2)
}()
go func() {
node3 := createTestNode(3, 1, "user1", "node3")
newNode3 = store.PutNode(node3)
close(done3)
}()
@ -602,20 +629,23 @@ func TestNodeStoreOperations(t *testing.T) {
// This test verifies that when multiple updates to the same node
// are batched together, each returned node reflects ALL changes
// in the batch, not just the individual update's changes.
done1 := make(chan struct{})
done2 := make(chan struct{})
done3 := make(chan struct{})
var resultNode1, resultNode2, resultNode3 types.NodeView
var ok1, ok2, ok3 bool
var (
resultNode1, resultNode2, resultNode3 types.NodeView
ok1, ok2, ok3 bool
)
// These updates all modify node 1 and should be batched together
// The final state should have all three modifications applied
go func() {
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "multi-update-hostname"
})
close(done1)
}()
@ -623,6 +653,7 @@ func TestNodeStoreOperations(t *testing.T) {
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "multi-update-givenname"
})
close(done2)
}()
@ -630,6 +661,7 @@ func TestNodeStoreOperations(t *testing.T) {
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.Tags = []string{"tag1", "tag2"}
})
close(done3)
}()
@ -723,14 +755,18 @@ func TestNodeStoreOperations(t *testing.T) {
done2 := make(chan struct{})
done3 := make(chan struct{})
var result1, result2, result3 types.NodeView
var ok1, ok2, ok3 bool
var (
result1, result2, result3 types.NodeView
ok1, ok2, ok3 bool
)
// Start concurrent updates
go func() {
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "concurrent-db-hostname"
})
close(done1)
}()
@ -738,6 +774,7 @@ func TestNodeStoreOperations(t *testing.T) {
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
n.GivenName = "concurrent-db-given"
})
close(done2)
}()
@ -745,6 +782,7 @@ func TestNodeStoreOperations(t *testing.T) {
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.Tags = []string{"concurrent-tag"}
})
close(done3)
}()
@ -828,6 +866,7 @@ func TestNodeStoreOperations(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
store := tt.setupFunc(t)
store.Start()
defer store.Stop()
@ -847,88 +886,107 @@ type testStep struct {
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
// Helper for concurrent test nodes
// Helper for concurrent test nodes.
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
machineKey := key.NewMachine()
nodeKey := key.NewNode()
return types.Node{
ID: id,
Hostname: hostname,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
UserID: ptr.To(uint(1)),
UserID: new(uint(1)),
User: &types.User{
Name: "concurrent-test-user",
},
}
}
// --- Concurrency: concurrent PutNode operations ---
// --- Concurrency: concurrent PutNode operations ---.
func TestNodeStoreConcurrentPutNode(t *testing.T) {
const concurrentOps = 20
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, concurrentOps)
for i := range concurrentOps {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
//nolint:gosec
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
}
// --- Batching: concurrent ops fit in one batch ---
// --- Batching: concurrent ops fit in one batch ---.
func TestNodeStoreBatchingEfficiency(t *testing.T) {
const batchSize = 10
const ops = 15 // more than batchSize
const ops = 15
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
var wg sync.WaitGroup
results := make(chan bool, ops)
for i := range ops {
wg.Add(1)
go func(nodeID int) {
defer wg.Done()
//nolint:gosec
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
resultNode := store.PutNode(node)
results <- resultNode.Valid()
}(i + 1)
}
wg.Wait()
close(results)
successCount := 0
for success := range results {
if success {
successCount++
}
}
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
}
// --- Race conditions: many goroutines on same node ---
// --- Race conditions: many goroutines on same node ---.
func TestNodeStoreRaceConditions(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -937,13 +995,18 @@ func TestNodeStoreRaceConditions(t *testing.T) {
resultNode := store.PutNode(node)
require.True(t, resultNode.Valid())
const numGoroutines = 30
const opsPerGoroutine = 10
const (
numGoroutines = 30
opsPerGoroutine = 10
)
var wg sync.WaitGroup
errors := make(chan error, numGoroutines*opsPerGoroutine)
for i := range numGoroutines {
wg.Add(1)
go func(gid int) {
defer wg.Done()
@ -954,40 +1017,46 @@ func TestNodeStoreRaceConditions(t *testing.T) {
n.Hostname = "race-updated"
})
if !resultNode.Valid() {
errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j)
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestUpdateNodeFailed, gid, j)
}
case 1:
retrieved, found := store.GetNode(nodeID)
if !found || !retrieved.Valid() {
errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j)
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestGetNodeFailed, gid, j)
}
case 2:
newNode := createConcurrentTestNode(nodeID, "race-put")
resultNode := store.PutNode(newNode)
if !resultNode.Valid() {
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestPutNodeFailed, gid, j)
}
}
}
}(i)
}
wg.Wait()
close(errors)
errorCount := 0
for err := range errors {
t.Error(err)
errorCount++
}
if errorCount > 0 {
t.Fatalf("Race condition test failed with %d errors", errorCount)
}
}
// --- Resource cleanup: goroutine leak detection ---
// --- Resource cleanup: goroutine leak detection ---.
func TestNodeStoreResourceCleanup(t *testing.T) {
// initialGoroutines := runtime.NumGoroutine()
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -1001,7 +1070,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
const ops = 100
for i := range ops {
nodeID := types.NodeID(i + 1)
nodeID := types.NodeID(i + 1) //nolint:gosec
node := createConcurrentTestNode(nodeID, "cleanup-node")
resultNode := store.PutNode(node)
assert.True(t, resultNode.Valid())
@ -1010,10 +1079,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
})
retrieved, found := store.GetNode(nodeID)
assert.True(t, found && retrieved.Valid())
if i%10 == 9 {
store.DeleteNode(nodeID)
}
}
runtime.GC()
// Wait for goroutines to settle and check for leaks
@ -1024,9 +1095,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
}
// --- Timeout/deadlock: operations complete within reasonable time ---
// --- Timeout/deadlock: operations complete within reasonable time ---.
func TestNodeStoreOperationTimeout(t *testing.T) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
@ -1034,36 +1106,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
defer cancel()
const ops = 30
var wg sync.WaitGroup
putResults := make([]error, ops)
updateResults := make([]error, ops)
// Launch all PutNode operations concurrently
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i)
nodeID := types.NodeID(i) //nolint:gosec
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
node := createConcurrentTestNode(id, "timeout-node")
resultNode := store.PutNode(node)
endPut := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
if !resultNode.Valid() {
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
putResults[idx-1] = fmt.Errorf("%w for node %d", errTestPutNodeFailed, id)
}
}(i, nodeID)
}
wg.Wait()
// Launch all UpdateNode operations concurrently
wg = sync.WaitGroup{}
for i := 1; i <= ops; i++ {
nodeID := types.NodeID(i)
nodeID := types.NodeID(i) //nolint:gosec
wg.Add(1)
go func(idx int, id types.NodeID) {
defer wg.Done()
startUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
@ -1071,31 +1154,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
})
endUpdate := time.Now()
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
if !ok || !resultNode.Valid() {
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
updateResults[idx-1] = fmt.Errorf("%w for node %d", errTestUpdateNodeFailed, id)
}
}(i, nodeID)
}
done := make(chan struct{})
go func() {
wg.Wait()
close(done)
}()
select {
case <-done:
errorCount := 0
for _, err := range putResults {
if err != nil {
t.Error(err)
errorCount++
}
}
for _, err := range updateResults {
if err != nil {
t.Error(err)
errorCount++
}
}
if errorCount == 0 {
t.Log("All concurrent operations completed successfully within timeout")
} else {
@ -1107,13 +1199,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
}
}
// --- Edge case: update non-existent node ---
// --- Edge case: update non-existent node ---.
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
for i := range 10 {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
nonExistentID := types.NodeID(999 + i)
nonExistentID := types.NodeID(999 + i) //nolint:gosec
updateCallCount := 0
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
updateCallCount++
@ -1127,20 +1221,22 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
}
}
// --- Allocation benchmark ---
// --- Allocation benchmark ---.
func BenchmarkNodeStoreAllocations(b *testing.B) {
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
for i := 0; b.Loop(); i++ {
nodeID := types.NodeID(i + 1)
nodeID := types.NodeID(i + 1) //nolint:gosec
node := createConcurrentTestNode(nodeID, "bench-node")
store.PutNode(node)
store.UpdateNode(nodeID, func(n *types.Node) {
n.Hostname = "bench-updated"
})
store.GetNode(nodeID)
if i%10 == 9 {
store.DeleteNode(nodeID)
}

View file

@ -25,10 +25,8 @@ import (
"github.com/rs/zerolog/log"
"golang.org/x/sync/errgroup"
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
"tailscale.com/types/views"
zcache "zgo.at/zcache/v2"
)
@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) {
// On startup, all nodes should be marked as offline until they reconnect
// This ensures we don't have stale online status from previous runs
for _, node := range nodes {
node.IsOnline = ptr.To(false)
node.IsOnline = new(false)
}
users, err := db.ListUsers()
if err != nil {
@ -189,7 +187,8 @@ func NewState(cfg *types.Config) (*State, error) {
func (s *State) Close() error {
s.nodeStore.Stop()
if err := s.db.Close(); err != nil {
err := s.db.Close()
if err != nil {
return fmt.Errorf("closing database: %w", err)
}
@ -225,6 +224,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
// propagate correctly when switching between policy types.
s.nodeStore.RebuildPeerMaps()
//nolint:prealloc
cs := []change.Change{change.PolicyChange()}
// Always call autoApproveNodes during policy reload, regardless of whether
@ -255,6 +255,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
// CreateUser creates a new user and updates the policy manager.
// Returns the created user, change set, and any error.
func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) {
//nolint:noinlineerr
if err := s.db.DB.Save(&user).Error; err != nil {
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
}
@ -289,6 +290,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
return nil, err
}
//nolint:noinlineerr
if err := updateFn(user); err != nil {
return nil, err
}
@ -468,10 +470,8 @@ func (s *State) Connect(id types.NodeID) []change.Change {
// CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification
// This ensures that when the NodeCameOnline change is distributed and processed by other nodes,
// the NodeStore already reflects the correct online status for full map generation.
// now := time.Now()
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
n.IsOnline = ptr.To(true)
// n.LastSeen = ptr.To(now)
n.IsOnline = new(true)
})
if !ok {
return nil
@ -495,16 +495,17 @@ func (s *State) Connect(id types.NodeID) []change.Change {
// Disconnect marks a node as disconnected and updates its primary routes in the state.
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
//nolint:staticcheck
now := time.Now()
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
n.LastSeen = ptr.To(now)
n.LastSeen = new(now)
// NodeStore is the source of truth for all node state including online status.
n.IsOnline = ptr.To(false)
n.IsOnline = new(false)
})
if !ok {
return nil, fmt.Errorf("node not found: %d", id)
return nil, fmt.Errorf("%w: %d", ErrNodeNotFound, id)
}
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected")
@ -745,13 +746,14 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) {
if err := util.ValidateHostname(newName); err != nil {
err := util.ValidateHostname(newName)
if err != nil {
return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err)
}
// Check name uniqueness against NodeStore
allNodes := s.nodeStore.ListNodes()
for i := 0; i < allNodes.Len(); i++ {
for i := range allNodes.Len() {
node := allNodes.At(i)
if node.ID() != nodeID && node.AsStruct().GivenName == newName {
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName)
@ -790,7 +792,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
// Preserve online status and NetInfo when refreshing from database
existingNode, exists := s.nodeStore.GetNode(node.ID)
if exists && existingNode.Valid() {
node.IsOnline = ptr.To(existingNode.IsOnline().Get())
node.IsOnline = new(existingNode.IsOnline().Get())
// TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics
// when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest).
@ -818,6 +820,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
var updates []change.Change
//nolint:unqueryvet
for _, node := range s.nodeStore.ListNodes().All() {
if !node.Valid() {
continue
@ -1117,7 +1120,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
DiscoKey: params.DiscoKey,
Hostinfo: params.Hostinfo,
Endpoints: params.Endpoints,
LastSeen: ptr.To(time.Now()),
LastSeen: new(time.Now()),
RegisterMethod: params.RegisterMethod,
Expiry: params.Expiry,
}
@ -1218,7 +1221,8 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
// New node - database first to get ID, then NodeStore
savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
if err := tx.Save(&nodeToRegister).Error; err != nil {
err := tx.Save(&nodeToRegister).Error
if err != nil {
return nil, fmt.Errorf("failed to save node: %w", err)
}
@ -1407,8 +1411,8 @@ func (s *State) HandleNodeFromAuthPath(
node.Endpoints = regEntry.Node.Endpoints
node.RegisterMethod = regEntry.Node.RegisterMethod
node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now())
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
// Tagged nodes keep their existing expiry (disabled).
// User-owned nodes update expiry from the provided value or registration entry.
@ -1669,8 +1673,8 @@ func (s *State) HandleNodeFromPreAuthKey(
// Only update AuthKey reference
node.AuthKey = pak
node.AuthKeyID = &pak.ID
node.IsOnline = ptr.To(false)
node.LastSeen = ptr.To(time.Now())
node.IsOnline = new(false)
node.LastSeen = new(time.Now())
// Tagged nodes keep their existing expiry (disabled).
// User-owned nodes update expiry from the client request.
@ -1697,7 +1701,7 @@ func (s *State) HandleNodeFromPreAuthKey(
}
}
return nil, nil
return nil, nil //nolint:nilnil
})
if err != nil {
return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err)
@ -2122,8 +2126,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool {
newRoutes = []netip.Prefix{}
}
tsaddr.SortPrefixes(oldRoutes)
tsaddr.SortPrefixes(newRoutes)
slices.SortFunc(oldRoutes, netip.Prefix.Compare)
slices.SortFunc(newRoutes, netip.Prefix.Compare)
return !slices.Equal(oldRoutes, newRoutes)
}

View file

@ -13,6 +13,9 @@ import (
"tailscale.com/types/logger"
)
// ErrNoCertDomains is returned when no cert domains are available for HTTPS.
var ErrNoCertDomains = errors.New("no cert domains available for HTTPS")
func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error {
opts := tailsql.Options{
Hostname: "tailsql-headscale",
@ -71,7 +74,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
certDomains := tsNode.CertDomains()
if len(certDomains) == 0 {
return errors.New("no cert domains available for HTTPS")
return ErrNoCertDomains
}
base := "https://" + certDomains[0]
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
@ -93,6 +96,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
mux := tsql.NewMux()
tsweb.Debugger(mux)
go http.Serve(lst, mux)
logf("TailSQL started")
<-ctx.Done()
logf("TailSQL shutting down...")

View file

@ -15,43 +15,43 @@ import (
// Material for MkDocs design system - exact values from official docs.
const (
// Text colors - from --md-default-fg-color CSS variables.
colorTextPrimary = "#000000de" //nolint:unused // rgba(0,0,0,0.87) - Body text
colorTextSecondary = "#0000008a" //nolint:unused // rgba(0,0,0,0.54) - Headings (--md-default-fg-color--light)
colorTextTertiary = "#00000052" //nolint:unused // rgba(0,0,0,0.32) - Lighter text
colorTextLightest = "#00000012" //nolint:unused // rgba(0,0,0,0.07) - Lightest text
colorTextPrimary = "#000000de" //nolint:unused
colorTextSecondary = "#0000008a" //nolint:unused
colorTextTertiary = "#00000052" //nolint:unused
colorTextLightest = "#00000012" //nolint:unused
// Code colors - from --md-code-* CSS variables.
colorCodeFg = "#36464e" //nolint:unused // Code text color (--md-code-fg-color)
colorCodeBg = "#f5f5f5" //nolint:unused // Code background (--md-code-bg-color)
colorCodeFg = "#36464e" //nolint:unused
colorCodeBg = "#f5f5f5" //nolint:unused
// Border colors.
colorBorderLight = "#e5e7eb" //nolint:unused // Light borders
colorBorderMedium = "#d1d5db" //nolint:unused // Medium borders
colorBorderLight = "#e5e7eb" //nolint:unused
colorBorderMedium = "#d1d5db" //nolint:unused
// Background colors.
colorBackgroundPage = "#ffffff" //nolint:unused // Page background
colorBackgroundCard = "#ffffff" //nolint:unused // Card/content background
colorBackgroundPage = "#ffffff" //nolint:unused
colorBackgroundCard = "#ffffff" //nolint:unused
// Accent colors - from --md-primary/accent-fg-color.
colorPrimaryAccent = "#4051b5" //nolint:unused // Primary accent (links)
colorAccent = "#526cfe" //nolint:unused // Secondary accent
colorPrimaryAccent = "#4051b5" //nolint:unused
colorAccent = "#526cfe" //nolint:unused
// Success colors.
colorSuccess = "#059669" //nolint:unused // Success states
colorSuccessLight = "#d1fae5" //nolint:unused // Success backgrounds
colorSuccess = "#059669" //nolint:unused
colorSuccessLight = "#d1fae5" //nolint:unused
)
// Spacing System
// Based on 4px/8px base unit for consistent rhythm.
// Uses rem units for scalability with user font size preferences.
const (
spaceXS = "0.25rem" //nolint:unused // 4px - Tight spacing
spaceS = "0.5rem" //nolint:unused // 8px - Small spacing
spaceM = "1rem" //nolint:unused // 16px - Medium spacing (base)
spaceL = "1.5rem" //nolint:unused // 24px - Large spacing
spaceXL = "2rem" //nolint:unused // 32px - Extra large spacing
space2XL = "3rem" //nolint:unused // 48px - 2x extra large spacing
space3XL = "4rem" //nolint:unused // 64px - 3x extra large spacing
spaceXS = "0.25rem" //nolint:unused
spaceS = "0.5rem" //nolint:unused
spaceM = "1rem" //nolint:unused
spaceL = "1.5rem" //nolint:unused
spaceXL = "2rem" //nolint:unused
space2XL = "3rem" //nolint:unused
space3XL = "4rem" //nolint:unused
)
// Typography System
@ -63,26 +63,26 @@ const (
fontFamilyCode = `"Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace` //nolint:unused
// Font sizes - from .md-typeset CSS rules.
fontSizeBase = "0.8rem" //nolint:unused // 12.8px - Base text (.md-typeset)
fontSizeH1 = "2em" //nolint:unused // 2x base - Main headings
fontSizeH2 = "1.5625em" //nolint:unused // 1.5625x base - Section headings
fontSizeH3 = "1.25em" //nolint:unused // 1.25x base - Subsection headings
fontSizeSmall = "0.8em" //nolint:unused // 0.8x base - Small text
fontSizeCode = "0.85em" //nolint:unused // 0.85x base - Inline code
fontSizeBase = "0.8rem" //nolint:unused
fontSizeH1 = "2em" //nolint:unused
fontSizeH2 = "1.5625em" //nolint:unused
fontSizeH3 = "1.25em" //nolint:unused
fontSizeSmall = "0.8em" //nolint:unused
fontSizeCode = "0.85em" //nolint:unused
// Line heights - from .md-typeset CSS rules.
lineHeightBase = "1.6" //nolint:unused // Body text (.md-typeset)
lineHeightH1 = "1.3" //nolint:unused // H1 headings
lineHeightH2 = "1.4" //nolint:unused // H2 headings
lineHeightH3 = "1.5" //nolint:unused // H3 headings
lineHeightCode = "1.4" //nolint:unused // Code blocks (pre)
lineHeightBase = "1.6" //nolint:unused
lineHeightH1 = "1.3" //nolint:unused
lineHeightH2 = "1.4" //nolint:unused
lineHeightH3 = "1.5" //nolint:unused
lineHeightCode = "1.4" //nolint:unused
)
// Responsive Container Component
// Creates a centered container with responsive padding and max-width.
// Mobile-first approach: starts at 100% width with padding, constrains on larger screens.
//
//nolint:unused // Reserved for future use in Phase 4.
//nolint:unused
func responsiveContainer(children ...elem.Node) *elem.Element {
return elem.Div(attrs.Props{
attrs.Style: styles.Props{
@ -100,7 +100,7 @@ func responsiveContainer(children ...elem.Node) *elem.Element {
// - title: Optional title for the card (empty string for no title)
// - children: Content elements to display in the card
//
//nolint:unused // Reserved for future use in Phase 4.
//nolint:unused
func card(title string, children ...elem.Node) *elem.Element {
cardContent := children
if title != "" {
@ -134,7 +134,7 @@ func card(title string, children ...elem.Node) *elem.Element {
// EXTRACTED FROM: .md-typeset pre CSS rules
// Exact styling from Material for MkDocs documentation.
//
//nolint:unused // Used across apple.go, windows.go, register_web.go templates.
//nolint:unused
func codeBlock(code string) *elem.Element {
return elem.Pre(attrs.Props{
attrs.Style: styles.Props{
@ -164,7 +164,7 @@ func codeBlock(code string) *elem.Element {
// Returns inline styles for the main content container that matches .md-typeset.
// EXTRACTED FROM: .md-typeset CSS rule from Material for MkDocs.
//
//nolint:unused // Used in general.go for mdTypesetBody.
//nolint:unused
func baseTypesetStyles() styles.Props {
return styles.Props{
styles.FontSize: fontSizeBase, // 0.8rem
@ -180,7 +180,7 @@ func baseTypesetStyles() styles.Props {
// Returns inline styles for H1 headings that match .md-typeset h1.
// EXTRACTED FROM: .md-typeset h1 CSS rule from Material for MkDocs.
//
//nolint:unused // Used across templates for main headings.
//nolint:unused
func h1Styles() styles.Props {
return styles.Props{
styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54)
@ -198,7 +198,7 @@ func h1Styles() styles.Props {
// Returns inline styles for H2 headings that match .md-typeset h2.
// EXTRACTED FROM: .md-typeset h2 CSS rule from Material for MkDocs.
//
//nolint:unused // Used across templates for section headings.
//nolint:unused
func h2Styles() styles.Props {
return styles.Props{
styles.FontSize: fontSizeH2, // 1.5625em
@ -216,7 +216,7 @@ func h2Styles() styles.Props {
// Returns inline styles for H3 headings that match .md-typeset h3.
// EXTRACTED FROM: .md-typeset h3 CSS rule from Material for MkDocs.
//
//nolint:unused // Used across templates for subsection headings.
//nolint:unused
func h3Styles() styles.Props {
return styles.Props{
styles.FontSize: fontSizeH3, // 1.25em
@ -234,7 +234,7 @@ func h3Styles() styles.Props {
// Returns inline styles for paragraphs that match .md-typeset p.
// EXTRACTED FROM: .md-typeset p CSS rule from Material for MkDocs.
//
//nolint:unused // Used for consistent paragraph spacing.
//nolint:unused
func paragraphStyles() styles.Props {
return styles.Props{
styles.Margin: "1em 0",
@ -250,7 +250,7 @@ func paragraphStyles() styles.Props {
// Returns inline styles for ordered lists that match .md-typeset ol.
// EXTRACTED FROM: .md-typeset ol CSS rule from Material for MkDocs.
//
//nolint:unused // Used for numbered instruction lists.
//nolint:unused
func orderedListStyles() styles.Props {
return styles.Props{
styles.MarginBottom: "1em",
@ -268,7 +268,7 @@ func orderedListStyles() styles.Props {
// Returns inline styles for unordered lists that match .md-typeset ul.
// EXTRACTED FROM: .md-typeset ul CSS rule from Material for MkDocs.
//
//nolint:unused // Used for bullet point lists.
//nolint:unused
func unorderedListStyles() styles.Props {
return styles.Props{
styles.MarginBottom: "1em",
@ -287,7 +287,7 @@ func unorderedListStyles() styles.Props {
// EXTRACTED FROM: .md-typeset a CSS rule from Material for MkDocs.
// Note: Hover states cannot be implemented with inline styles.
//
//nolint:unused // Used for text links.
//nolint:unused
func linkStyles() styles.Props {
return styles.Props{
styles.Color: colorPrimaryAccent, // #4051b5 - var(--md-primary-fg-color)
@ -301,7 +301,7 @@ func linkStyles() styles.Props {
// Returns inline styles for inline code that matches .md-typeset code.
// EXTRACTED FROM: .md-typeset code CSS rule from Material for MkDocs.
//
//nolint:unused // Used for inline code snippets.
//nolint:unused
func inlineCodeStyles() styles.Props {
return styles.Props{
styles.BackgroundColor: colorCodeBg, // #f5f5f5
@ -317,7 +317,7 @@ func inlineCodeStyles() styles.Props {
// Inline Code Component
// For inline code snippets within text.
//
//nolint:unused // Reserved for future inline code usage.
//nolint:unused
func inlineCode(code string) *elem.Element {
return elem.Code(attrs.Props{
attrs.Style: inlineCodeStyles().ToInline(),
@ -327,7 +327,7 @@ func inlineCode(code string) *elem.Element {
// orDivider creates a visual "or" divider between sections.
// Styled with lines on either side for better visual separation.
//
//nolint:unused // Used in apple.go template.
//nolint:unused
func orDivider() *elem.Element {
return elem.Div(attrs.Props{
attrs.Style: styles.Props{
@ -367,7 +367,7 @@ func orDivider() *elem.Element {
// warningBox creates a warning message box with icon and content.
//
//nolint:unused // Used in apple.go template.
//nolint:unused
func warningBox(title, message string) *elem.Element {
return elem.Div(attrs.Props{
attrs.Style: styles.Props{
@ -404,7 +404,7 @@ func warningBox(title, message string) *elem.Element {
// downloadButton creates a nice button-style link for downloads.
//
//nolint:unused // Used in apple.go template.
//nolint:unused
func downloadButton(href, text string) *elem.Element {
return elem.A(attrs.Props{
attrs.Href: href,
@ -428,7 +428,7 @@ func downloadButton(href, text string) *elem.Element {
// Creates a link with proper security attributes for external URLs.
// Automatically adds rel="noreferrer noopener" and target="_blank".
//
//nolint:unused // Used in apple.go, oidc_callback.go templates.
//nolint:unused
func externalLink(href, text string) *elem.Element {
return elem.A(attrs.Props{
attrs.Href: href,
@ -444,7 +444,7 @@ func externalLink(href, text string) *elem.Element {
// Instruction Step Component
// For numbered instruction lists with consistent formatting.
//
//nolint:unused // Reserved for future use in Phase 4.
//nolint:unused
func instructionStep(_ int, text string) *elem.Element {
return elem.Li(attrs.Props{
attrs.Style: styles.Props{
@ -457,7 +457,7 @@ func instructionStep(_ int, text string) *elem.Element {
// Status Message Component
// For displaying success/error/info messages with appropriate styling.
//
//nolint:unused // Reserved for future use in Phase 4.
//nolint:unused
func statusMessage(message string, isSuccess bool) *elem.Element {
bgColor := colorSuccessLight
textColor := colorSuccess

View file

@ -333,7 +333,7 @@ func NodeOnline(nodeID types.NodeID) Change {
PeerPatches: []*tailcfg.PeerChange{
{
NodeID: nodeID.NodeID(),
Online: ptrTo(true),
Online: new(true),
},
},
}
@ -346,7 +346,7 @@ func NodeOffline(nodeID types.NodeID) Change {
PeerPatches: []*tailcfg.PeerChange{
{
NodeID: nodeID.NodeID(),
Online: ptrTo(false),
Online: new(false),
},
},
}
@ -365,11 +365,6 @@ func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change {
}
}
// ptrTo returns a pointer to the given value.
func ptrTo[T any](v T) *T {
return &v
}
// High-level change constructors
// NodeAdded returns a Change for when a node is added or updated.

View file

@ -16,8 +16,8 @@ func TestChange_FieldSync(t *testing.T) {
typ := reflect.TypeFor[Change]()
boolCount := 0
for i := range typ.NumField() {
if typ.Field(i).Type.Kind() == reflect.Bool {
for field := range typ.Fields() {
if field.Type.Kind() == reflect.Bool {
boolCount++
}
}

View file

@ -20,7 +20,10 @@ const (
DatabaseSqlite = "sqlite3"
)
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
var (
ErrCannotParsePrefix = errors.New("cannot parse prefix")
ErrInvalidRegIDLength = errors.New("registration ID has invalid length")
)
type StateUpdateType int
@ -175,8 +178,9 @@ func MustRegistrationID() RegistrationID {
func RegistrationIDFromString(str string) (RegistrationID, error) {
if len(str) != RegistrationIDLength {
return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength)
return "", fmt.Errorf("%w: expected %d characters", ErrInvalidRegIDLength, RegistrationIDLength)
}
return RegistrationID(str), nil
}

View file

@ -30,13 +30,16 @@ const (
PKCEMethodS256 string = "S256"
defaultNodeStoreBatchSize = 100
defaultWALAutocheckpoint = 1000 // SQLite default
)
var (
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
errInvalidAllocationStrategy = errors.New("invalid prefixes.allocation strategy")
)
type IPAllocationStrategy string
@ -301,6 +304,7 @@ func validatePKCEMethod(method string) error {
if method != PKCEMethodPlain && method != PKCEMethodS256 {
return errInvalidPKCEMethod
}
return nil
}
@ -377,7 +381,7 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
viper.SetDefault("database.sqlite.write_ahead_log", true)
viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default
viper.SetDefault("database.sqlite.wal_autocheckpoint", defaultWALAutocheckpoint)
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
@ -402,7 +406,8 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
if err := viper.ReadInConfig(); err != nil {
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
var configFileNotFoundError viper.ConfigFileNotFoundError
if errors.As(err, &configFileNotFoundError) {
log.Warn().Msg("No config file found, using defaults")
return nil
}
@ -442,7 +447,8 @@ func validateServerConfig() error {
depr.fatal("oidc.map_legacy_users")
if viper.GetBool("oidc.enabled") {
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
err := validatePKCEMethod(viper.GetString("oidc.pkce.method"))
if err != nil {
return err
}
}
@ -910,6 +916,7 @@ func LoadCLIConfig() (*Config, error) {
// LoadServerConfig returns the full Headscale configuration to
// host a Headscale server. This is called as part of `headscale serve`.
func LoadServerConfig() (*Config, error) {
//nolint:noinlineerr
if err := validateServerConfig(); err != nil {
return nil, err
}
@ -928,7 +935,7 @@ func LoadServerConfig() (*Config, error) {
}
if prefix4 == nil && prefix6 == nil {
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
return nil, errNoPrefixConfigured
}
allocStr := viper.GetString("prefixes.allocation")
@ -940,7 +947,8 @@ func LoadServerConfig() (*Config, error) {
alloc = IPAllocationStrategyRandom
default:
return nil, fmt.Errorf(
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
"%w: %q, allowed options: %s, %s",
errInvalidAllocationStrategy,
allocStr,
IPAllocationStrategySequential,
IPAllocationStrategyRandom,
@ -979,7 +987,8 @@ func LoadServerConfig() (*Config, error) {
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
if dnsConfig.BaseDomain != "" {
if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
err := isSafeServerURL(serverURL, dnsConfig.BaseDomain)
if err != nil {
return nil, err
}
}
@ -1082,6 +1091,7 @@ func LoadServerConfig() (*Config, error) {
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
return workers
}
return DefaultBatcherWorkers()
}(),
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
@ -1117,6 +1127,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
}
s := len(serverDomainParts)
b := len(baseDomainParts)
for i := range baseDomainParts {
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {

View file

@ -349,11 +349,8 @@ func TestReadConfigFromEnv(t *testing.T) {
}
func TestTLSConfigValidation(t *testing.T) {
tmpDir, err := os.MkdirTemp("", "headscale")
if err != nil {
t.Fatal(err)
}
// defer os.RemoveAll(tmpDir)
tmpDir := t.TempDir()
configYaml := []byte(`---
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: ""
@ -363,7 +360,8 @@ noise:
// Populate a custom config file
configFilePath := filepath.Join(tmpDir, "config.yaml")
err = os.WriteFile(configFilePath, configYaml, 0o600)
err := os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
@ -398,10 +396,12 @@ server_url: http://127.0.0.1:8080
tls_letsencrypt_hostname: example.com
tls_letsencrypt_challenge_type: TLS-ALPN-01
`)
err = os.WriteFile(configFilePath, configYaml, 0o600)
if err != nil {
t.Fatalf("Couldn't write file %s", configFilePath)
}
err = LoadConfig(tmpDir, false)
require.NoError(t, err)
}
@ -463,6 +463,7 @@ func TestSafeServerURL(t *testing.T) {
return
}
assert.NoError(t, err)
})
}

View file

@ -42,10 +42,6 @@ type (
NodeIDs []NodeID
)
func (n NodeIDs) Len() int { return len(n) }
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
func (id NodeID) StableID() tailcfg.StableNodeID {
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
}
@ -160,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool {
// Strip invalid DNS characters for givenName comparison
normalised := strings.ToLower(node.Hostname)
normalised = invalidDNSRegex.ReplaceAllString(normalised, "")
return node.GivenName == normalised
}
@ -197,13 +194,7 @@ func (node *Node) IPs() []netip.Addr {
// HasIP reports if a node has a given IP address.
func (node *Node) HasIP(i netip.Addr) bool {
for _, ip := range node.IPs() {
if ip.Compare(i) == 0 {
return true
}
}
return false
return slices.Contains(node.IPs(), i)
}
// IsTagged reports if a device is tagged and therefore should not be treated
@ -243,6 +234,7 @@ func (node *Node) RequestTags() []string {
}
func (node *Node) Prefixes() []netip.Prefix {
//nolint:prealloc
var addrs []netip.Prefix
for _, nodeAddress := range node.IPs() {
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
@ -272,6 +264,7 @@ func (node *Node) IsExitNode() bool {
}
func (node *Node) IPsAsString() []string {
//nolint:prealloc
var ret []string
for _, ip := range node.IPs() {
@ -355,13 +348,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
}
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
for _, node := range nodes {
if node.NodeKey == nodeKey {
return true
}
}
return false
return slices.ContainsFunc(nodes, func(node *Node) bool {
return node.NodeKey == nodeKey
})
}
func (node *Node) Proto() *v1.Node {
@ -478,7 +467,7 @@ func (node *Node) IsSubnetRouter() bool {
return len(node.SubnetRoutes()) > 0
}
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes.
func (node *Node) AllApprovedRoutes() []netip.Prefix {
return append(node.SubnetRoutes(), node.ExitRoutes()...)
}
@ -586,13 +575,16 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
}
newHostname := strings.ToLower(hostInfo.Hostname)
if err := util.ValidateHostname(newHostname); err != nil {
err := util.ValidateHostname(newHostname)
if err != nil {
log.Warn().
Str("node.id", node.ID.String()).
Str("current_hostname", node.Hostname).
Str("rejected_hostname", hostInfo.Hostname).
Err(err).
Msg("Rejecting invalid hostname update from hostinfo")
return
}
@ -684,6 +676,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
func (nodes Nodes) DebugString() string {
var sb strings.Builder
sb.WriteString("Nodes:\n")
for _, node := range nodes {
sb.WriteString(node.DebugString())
sb.WriteString("\n")
@ -855,7 +848,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer
// GetFQDN returns the fully qualified domain name for the node.
func (nv NodeView) GetFQDN(baseDomain string) (string, error) {
if !nv.Valid() {
return "", errors.New("failed to create valid FQDN: node view is invalid")
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrInvalidNodeView)
}
return nv.ж.GetFQDN(baseDomain)
@ -934,11 +927,11 @@ func (nv NodeView) TailscaleUserID() tailcfg.UserID {
}
if nv.IsTagged() {
//nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64
//nolint:gosec
return tailcfg.UserID(int64(TaggedDevices.ID))
}
//nolint:gosec // G115: UserID values are within int64 range
//nolint:gosec
return tailcfg.UserID(int64(nv.UserID().Get()))
}
@ -1048,7 +1041,7 @@ func (nv NodeView) TailNode(
primaryRoutes := primaryRouteFunc(nv.ID())
allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes())
tsaddr.SortPrefixes(allowedIPs)
slices.SortFunc(allowedIPs, netip.Prefix.Compare)
capMap := tailcfg.NodeCapMap{
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
@ -1063,7 +1056,7 @@ func (nv NodeView) TailNode(
}
tNode := tailcfg.Node{
//nolint:gosec // G115: NodeID values are within int64 range
//nolint:gosec
ID: tailcfg.NodeID(nv.ID()),
StableID: nv.ID().StableID(),
Name: hostname,

View file

@ -6,7 +6,6 @@ import (
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged.
@ -69,7 +68,7 @@ func TestNodeIsTagged(t *testing.T) {
{
name: "node with user and no tags - not tagged",
node: Node{
UserID: ptr.To(uint(42)),
UserID: new(uint(42)),
Tags: []string{},
},
want: false,
@ -112,7 +111,7 @@ func TestNodeViewIsTagged(t *testing.T) {
{
name: "user-owned node",
node: Node{
UserID: ptr.To(uint(1)),
UserID: new(uint(1)),
},
want: false,
},
@ -223,7 +222,7 @@ func TestNodeTagsImmutableAfterRegistration(t *testing.T) {
// Test that a user-owned node is not tagged
userNode := Node{
ID: 2,
UserID: ptr.To(uint(42)),
UserID: new(uint(42)),
Tags: []string{},
RegisterMethod: util.RegisterMethodOIDC,
}
@ -243,7 +242,7 @@ func TestNodeOwnershipModel(t *testing.T) {
name: "tagged node has tags, UserID is informational",
node: Node{
ID: 1,
UserID: ptr.To(uint(5)), // "created by" user 5
UserID: new(uint(5)), // "created by" user 5
Tags: []string{"tag:server"},
},
wantIsTagged: true,
@ -253,7 +252,7 @@ func TestNodeOwnershipModel(t *testing.T) {
name: "user-owned node has no tags",
node: Node{
ID: 2,
UserID: ptr.To(uint(5)),
UserID: new(uint(5)),
Tags: []string{},
},
wantIsTagged: false,
@ -265,7 +264,7 @@ func TestNodeOwnershipModel(t *testing.T) {
name: "node with only authkey tags - not tagged (tags should be copied)",
node: Node{
ID: 3,
UserID: ptr.To(uint(5)), // "created by" user 5
UserID: new(uint(5)), // "created by" user 5
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},

View file

@ -407,7 +407,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "我的电脑",
Hostname: "我的电脑", //nolint:gosmopolitan
},
want: Node{
GivenName: "valid-hostname",
@ -491,7 +491,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "server-北京-01",
Hostname: "server-北京-01", //nolint:gosmopolitan
},
want: Node{
GivenName: "valid-hostname",
@ -505,7 +505,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "我的电脑",
Hostname: "我的电脑", //nolint:gosmopolitan
},
want: Node{
GivenName: "valid-hostname",
@ -533,7 +533,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
Hostname: "valid-hostname",
},
change: &tailcfg.Hostinfo{
Hostname: "测试💻机器",
Hostname: "测试💻机器", //nolint:gosmopolitan
},
want: Node{
GivenName: "valid-hostname",

View file

@ -114,7 +114,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey {
return &protoKey
}
// canUsePreAuthKey checks if a pre auth key can be used.
// Validate checks if a pre auth key can be used.
func (pak *PreAuthKey) Validate() error {
if pak == nil {
return PAKError("invalid authkey")
@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error {
if pak.Expiration != nil {
return *pak.Expiration
}
return time.Time{}
}()).
Time("now", time.Now()).

View file

@ -110,8 +110,7 @@ func TestCanUsePreAuthKey(t *testing.T) {
if err == nil {
t.Errorf("expected error but got none")
} else {
var httpErr PAKError
ok := errors.As(err, &httpErr)
httpErr, ok := errors.AsType[PAKError](err)
if !ok {
t.Errorf("expected HTTPError but got %T", err)
} else {

View file

@ -4,6 +4,7 @@ import (
"cmp"
"database/sql"
"encoding/json"
"errors"
"fmt"
"net/mail"
"net/url"
@ -18,6 +19,9 @@ import (
"tailscale.com/tailcfg"
)
// ErrCannotParseBool is returned when a value cannot be parsed as a boolean.
var ErrCannotParseBool = errors.New("could not parse value as boolean")
type UserID uint64
type Users []User
@ -40,9 +44,11 @@ var TaggedDevices = User{
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
for _, user := range u {
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
}
sb.WriteString(" ]")
return sb.String()
@ -89,11 +95,12 @@ func (u *User) StringID() string {
if u == nil {
return ""
}
return strconv.FormatUint(uint64(u.ID), 10)
}
// TypedID returns a pointer to the user's ID as a UserID type.
// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)).
// This is a convenience method to avoid ugly casting like new(types.UserID(user.ID)).
func (u *User) TypedID() *UserID {
uid := UserID(u.ID)
return &uid
@ -148,7 +155,7 @@ func (u UserView) ID() uint {
func (u *User) TailscaleLogin() tailcfg.Login {
return tailcfg.Login{
ID: tailcfg.LoginID(u.ID),
ID: tailcfg.LoginID(u.ID), //nolint:gosec
Provider: u.Provider,
LoginName: u.Username(),
DisplayName: u.Display(),
@ -194,8 +201,8 @@ func (u *User) Proto() *v1.User {
}
}
// JumpCloud returns a JSON where email_verified is returned as a
// string "true" or "false" instead of a boolean.
// FlexibleBoolean handles JSON where email_verified is returned as a
// string "true" or "false" instead of a boolean (e.g., JumpCloud).
// This maps bool to a specific type with a custom unmarshaler to
// ensure we can decode it from a string.
// https://github.com/juanfont/headscale/issues/2293
@ -203,6 +210,7 @@ type FlexibleBoolean bool
func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
var val any
err := json.Unmarshal(data, &val)
if err != nil {
return fmt.Errorf("could not unmarshal data: %w", err)
@ -216,10 +224,11 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
if err != nil {
return fmt.Errorf("could not parse %s as boolean: %w", v, err)
}
*bit = FlexibleBoolean(pv)
default:
return fmt.Errorf("could not parse %v as boolean", v)
return fmt.Errorf("%w: %v", ErrCannotParseBool, v)
}
return nil
@ -253,9 +262,11 @@ func (c *OIDCClaims) Identifier() string {
if c.Iss == "" && c.Sub == "" {
return ""
}
if c.Iss == "" {
return CleanIdentifier(c.Sub)
}
if c.Sub == "" {
return CleanIdentifier(c.Iss)
}
@ -266,8 +277,10 @@ func (c *OIDCClaims) Identifier() string {
var result string
// Try to parse as URL to handle URL joining correctly
//nolint:noinlineerr
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
// For URLs, use proper URL path joining
//nolint:noinlineerr
if joined, err := url.JoinPath(issuer, subject); err == nil {
result = joined
}
@ -340,6 +353,7 @@ func CleanIdentifier(identifier string) string {
cleanParts = append(cleanParts, trimmed)
}
}
if len(cleanParts) == 0 {
return ""
}
@ -382,6 +396,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) {
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
identifier = "/" + identifier
}
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
u.DisplayName = claims.Name
u.ProfilePicURL = claims.ProfilePictureURL

View file

@ -66,10 +66,13 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got OIDCClaims
if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil {
err := json.Unmarshal([]byte(tt.jsonstr), &got)
if err != nil {
t.Errorf("UnmarshallOIDCClaims() error = %v", err)
return
}
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff)
}
@ -190,6 +193,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) {
}
result := claims.Identifier()
assert.Equal(t, tt.expected, result)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
}
@ -282,6 +286,7 @@ func TestCleanIdentifier(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
result := CleanIdentifier(tt.identifier)
assert.Equal(t, tt.expected, result)
if diff := cmp.Diff(tt.expected, result); diff != "" {
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
}
@ -479,7 +484,9 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var got OIDCClaims
if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil {
err := json.Unmarshal([]byte(tt.jsonstr), &got)
if err != nil {
t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err)
return
}
@ -487,6 +494,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
var user User
user.FromClaim(&got, tt.emailVerifiedRequired)
if diff := cmp.Diff(user, tt.want); diff != "" {
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
}

View file

@ -38,9 +38,7 @@ func (v *VersionInfo) String() string {
return sb.String()
}
var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) {
return debug.ReadBuildInfo()
})
var buildInfo = sync.OnceValues(debug.ReadBuildInfo)
var GetVersionInfo = sync.OnceValue(func() *VersionInfo {
info := &VersionInfo{

View file

@ -20,12 +20,30 @@ const (
// value related to RFC 1123 and 952.
LabelHostnameLength = 63
// minNameLength is the minimum length for usernames and hostnames.
minNameLength = 2
)
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidHostName = errors.New("invalid hostname")
// Sentinel errors for username validation.
var (
ErrUsernameTooShort = errors.New("username must be at least 2 characters long")
ErrUsernameMustStartLetter = errors.New("username must start with a letter")
ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'")
ErrUsernameInvalidChar = errors.New("username contains invalid character")
)
// Sentinel errors for hostname validation.
var (
ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters")
ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen")
ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot")
)
// ValidateUsername checks if a username is valid.
// It must be at least 2 characters long, start with a letter, and contain
// only letters, numbers, hyphens, dots, and underscores.
@ -33,13 +51,13 @@ var ErrInvalidHostName = errors.New("invalid hostname")
// It cannot contain invalid characters.
func ValidateUsername(username string) error {
// Ensure the username meets the minimum length requirement
if len(username) < 2 {
return errors.New("username must be at least 2 characters long")
if len(username) < minNameLength {
return ErrUsernameTooShort
}
// Ensure the username starts with a letter
if !unicode.IsLetter(rune(username[0])) {
return errors.New("username must start with a letter")
return ErrUsernameMustStartLetter
}
atCount := 0
@ -55,10 +73,10 @@ func ValidateUsername(username string) error {
case char == '@':
atCount++
if atCount > 1 {
return errors.New("username cannot contain more than one '@'")
return ErrUsernameTooManyAt
}
default:
return fmt.Errorf("username contains invalid character: '%c'", char)
return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char)
}
}
@ -69,11 +87,8 @@ func ValidateUsername(username string) error {
// This function does NOT modify the input - it only validates.
// The hostname must already be lowercase and contain only valid characters.
func ValidateHostname(name string) error {
if len(name) < 2 {
return fmt.Errorf(
"hostname %q is too short, must be at least 2 characters",
name,
)
if len(name) < minNameLength {
return fmt.Errorf("%w: %q", ErrHostnameTooShort, name)
}
if len(name) > LabelHostnameLength {
return fmt.Errorf(
@ -90,17 +105,11 @@ func ValidateHostname(name string) error {
}
if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") {
return fmt.Errorf(
"hostname %q cannot start or end with a hyphen",
name,
)
return fmt.Errorf("%w: %q", ErrHostnameHyphenEnds, name)
}
if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
return fmt.Errorf(
"hostname %q cannot start or end with a dot",
name,
)
return fmt.Errorf("%w: %q", ErrHostnameDotEnds, name)
}
if invalidDNSRegex.MatchString(name) {

View file

@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) {
t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr && got != tt.want {
t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want)
}
@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) {
t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr)
return
}
if tt.wantErr && tt.errorContains != "" {
if err == nil || !strings.Contains(err.Error(), tt.errorContains) {
t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains)

View file

@ -14,11 +14,14 @@ func YesNo(msg string) bool {
fmt.Fprint(os.Stderr, msg+" [y/n] ")
var resp string
fmt.Scanln(&resp)
_, _ = fmt.Scanln(&resp)
resp = strings.ToLower(resp)
switch resp {
case "y", "yes", "sure":
return true
}
return false
}

View file

@ -86,7 +86,8 @@ func TestYesNo(t *testing.T) {
// Write test input
go func() {
defer w.Close()
w.WriteString(tt.input)
_, _ = w.WriteString(tt.input)
}()
// Call the function
@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) {
// Restore stdin and stderr
os.Stdin = oldStdin
os.Stderr = oldStderr
stderrW.Close()
// Check the result
@ -104,10 +106,12 @@ func TestYesNo(t *testing.T) {
// Check that the prompt was written to stderr
var stderrBuf bytes.Buffer
io.Copy(&stderrBuf, stderrR)
_, _ = io.Copy(&stderrBuf, stderrR)
stderrR.Close()
expectedPrompt := "Test question [y/n] "
actualPrompt := stderrBuf.String()
if actualPrompt != expectedPrompt {
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
@ -130,7 +134,8 @@ func TestYesNoPromptMessage(t *testing.T) {
// Write test input
go func() {
defer w.Close()
w.WriteString("n\n")
_, _ = w.WriteString("n\n")
}()
// Call the function with a custom message
@ -140,14 +145,17 @@ func TestYesNoPromptMessage(t *testing.T) {
// Restore stdin and stderr
os.Stdin = oldStdin
os.Stderr = oldStderr
stderrW.Close()
// Check that the custom message was included in the prompt
var stderrBuf bytes.Buffer
io.Copy(&stderrBuf, stderrR)
_, _ = io.Copy(&stderrBuf, stderrR)
stderrR.Close()
expectedPrompt := customMessage + " [y/n] "
actualPrompt := stderrBuf.String()
if actualPrompt != expectedPrompt {
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
@ -186,7 +194,8 @@ func TestYesNoCaseInsensitive(t *testing.T) {
// Write test input
go func() {
defer w.Close()
w.WriteString(tc.input)
_, _ = w.WriteString(tc.input)
}()
// Call the function
@ -195,10 +204,11 @@ func TestYesNoCaseInsensitive(t *testing.T) {
// Restore stdin and stderr
os.Stdin = oldStdin
os.Stderr = oldStderr
stderrW.Close()
// Drain stderr
io.Copy(io.Discard, stderrR)
_, _ = io.Copy(io.Discard, stderrR)
stderrR.Close()
if result != tc.expected {

View file

@ -9,6 +9,9 @@ import (
"tailscale.com/tailcfg"
)
// invalidStringRandomLength is the length of random bytes for invalid string generation.
const invalidStringRandomLength = 8
// GenerateRandomBytes returns securely generated random bytes.
// It will return an error if the system's secure random
// number generator fails to function correctly, in which
@ -33,6 +36,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) {
b, err := GenerateRandomBytes(n)
uenc := base64.RawURLEncoding.EncodeToString(b)
return uenc[:n], err
}
@ -67,7 +71,7 @@ func MustGenerateRandomStringDNSSafe(size int) string {
}
func InvalidString() string {
hash, _ := GenerateRandomStringDNSSafe(8)
hash, _ := GenerateRandomStringDNSSafe(invalidStringRandomLength)
return "invalid-" + hash
}
@ -99,6 +103,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string {
DstIPs: %v
}
`, rule.SrcIPs, rule.DstPorts))
if index < len(rules)-1 {
sb.WriteString(", ")
}

Some files were not shown because too many files have changed in this diff Show more