mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-23 02:24:10 +00:00
Merge 35364bfc9a into 4e1834adaf
This commit is contained in:
commit
487716f645
126 changed files with 3264 additions and 1938 deletions
|
|
@ -25,6 +25,7 @@ linters:
|
|||
- revive
|
||||
- tagliatelle
|
||||
- testpackage
|
||||
- thelper
|
||||
- varnamelen
|
||||
- wrapcheck
|
||||
- wsl
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ package cli
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
|
|
@ -19,6 +18,7 @@ const (
|
|||
errMockOidcClientIDNotDefined = Error("MOCKOIDC_CLIENT_ID not defined")
|
||||
errMockOidcClientSecretNotDefined = Error("MOCKOIDC_CLIENT_SECRET not defined")
|
||||
errMockOidcPortNotDefined = Error("MOCKOIDC_PORT not defined")
|
||||
errMockOidcUsersNotDefined = Error("MOCKOIDC_USERS not defined")
|
||||
refreshTTL = 60 * time.Minute
|
||||
)
|
||||
|
||||
|
|
@ -69,10 +69,11 @@ func mockOIDC() error {
|
|||
|
||||
userStr := os.Getenv("MOCKOIDC_USERS")
|
||||
if userStr == "" {
|
||||
return errors.New("MOCKOIDC_USERS not defined")
|
||||
return errMockOidcUsersNotDefined
|
||||
}
|
||||
|
||||
var users []mockoidc.MockUser
|
||||
|
||||
err := json.Unmarshal([]byte(userStr), &users)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshalling users: %w", err)
|
||||
|
|
@ -133,10 +134,11 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
|
|||
ErrorQueue: &mockoidc.ErrorQueue{},
|
||||
}
|
||||
|
||||
mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||
_ = mock.AddMiddleware(func(h http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info().Msgf("Request: %+v", r)
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if r.Response != nil {
|
||||
log.Info().Msgf("Response: %+v", r.Response)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,12 +72,12 @@ func init() {
|
|||
nodeCmd.AddCommand(deleteNodeCmd)
|
||||
|
||||
tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
tagCmd.MarkFlagRequired("identifier")
|
||||
_ = tagCmd.MarkFlagRequired("identifier")
|
||||
tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node")
|
||||
nodeCmd.AddCommand(tagCmd)
|
||||
|
||||
approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)")
|
||||
approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
_ = approveRoutesCmd.MarkFlagRequired("identifier")
|
||||
approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`)
|
||||
nodeCmd.AddCommand(approveRoutesCmd)
|
||||
|
||||
|
|
@ -233,10 +233,7 @@ var listNodeRoutesCmd = &cobra.Command{
|
|||
return
|
||||
}
|
||||
|
||||
tableData, err := nodeRoutesToPtables(nodes)
|
||||
if err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
|
||||
}
|
||||
tableData := nodeRoutesToPtables(nodes)
|
||||
|
||||
err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render()
|
||||
if err != nil {
|
||||
|
|
@ -601,9 +598,7 @@ func nodesToPtables(
|
|||
return tableData, nil
|
||||
}
|
||||
|
||||
func nodeRoutesToPtables(
|
||||
nodes []*v1.Node,
|
||||
) (pterm.TableData, error) {
|
||||
func nodeRoutesToPtables(nodes []*v1.Node) pterm.TableData {
|
||||
tableHeader := []string{
|
||||
"ID",
|
||||
"Hostname",
|
||||
|
|
@ -627,7 +622,7 @@ func nodeRoutesToPtables(
|
|||
)
|
||||
}
|
||||
|
||||
return tableData, nil
|
||||
return tableData
|
||||
}
|
||||
|
||||
var tagCmd = &cobra.Command{
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
//nolint:gosec
|
||||
bypassFlag = "bypass-grpc-and-access-database-directly"
|
||||
)
|
||||
|
||||
|
|
@ -29,13 +30,17 @@ func init() {
|
|||
if err := setPolicy.MarkFlagRequired("file"); err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
|
||||
policyCmd.AddCommand(setPolicy)
|
||||
|
||||
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
|
||||
|
||||
err := checkPolicy.MarkFlagRequired("file")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
policyCmd.AddCommand(checkPolicy)
|
||||
}
|
||||
|
||||
|
|
@ -173,6 +178,7 @@ var setPolicy = &cobra.Command{
|
|||
defer cancel()
|
||||
defer conn.Close()
|
||||
|
||||
//nolint:noinlineerr
|
||||
if _, err := client.SetPolicy(ctx, request); err != nil {
|
||||
ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ func initConfig() {
|
|||
Repository: "headscale",
|
||||
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
|
||||
}
|
||||
|
||||
res, err := latest.Check(githubTag, versionInfo.Version)
|
||||
if err == nil && res.Outdated {
|
||||
//nolint
|
||||
|
|
@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,8 +23,7 @@ var serveCmd = &cobra.Command{
|
|||
Run: func(cmd *cobra.Command, args []string) {
|
||||
app, err := newHeadscaleServerWithConfig()
|
||||
if err != nil {
|
||||
var squibbleErr squibble.ValidationError
|
||||
if errors.As(err, &squibbleErr) {
|
||||
if squibbleErr, ok := errors.AsType[squibble.ValidationError](err); ok {
|
||||
fmt.Printf("SQLite schema failed to validate:\n")
|
||||
fmt.Println(squibbleErr.Diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -14,6 +14,12 @@ import (
|
|||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
// Sentinel errors for CLI commands.
|
||||
var (
|
||||
ErrNameOrIDRequired = errors.New("--name or --identifier flag is required")
|
||||
ErrMultipleUsersFoundUseID = errors.New("unable to determine user, query returned multiple users, use ID")
|
||||
)
|
||||
|
||||
func usernameAndIDFlag(cmd *cobra.Command) {
|
||||
cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)")
|
||||
cmd.Flags().StringP("name", "n", "", "Username")
|
||||
|
|
@ -23,12 +29,12 @@ func usernameAndIDFlag(cmd *cobra.Command) {
|
|||
// If both are empty, it will exit the program with an error.
|
||||
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||
username, _ := cmd.Flags().GetString("name")
|
||||
|
||||
identifier, _ := cmd.Flags().GetInt64("identifier")
|
||||
if username == "" && identifier < 0 {
|
||||
err := errors.New("--name or --identifier flag is required")
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Cannot rename user: "+status.Convert(err).Message(),
|
||||
ErrNameOrIDRequired,
|
||||
"Cannot rename user: "+status.Convert(ErrNameOrIDRequired).Message(),
|
||||
"",
|
||||
)
|
||||
}
|
||||
|
|
@ -50,7 +56,7 @@ func init() {
|
|||
userCmd.AddCommand(renameUserCmd)
|
||||
usernameAndIDFlag(renameUserCmd)
|
||||
renameUserCmd.Flags().StringP("new-name", "r", "", "New username")
|
||||
renameNodeCmd.MarkFlagRequired("new-name")
|
||||
_ = renameUserCmd.MarkFlagRequired("new-name")
|
||||
}
|
||||
|
||||
var errMissingParameter = errors.New("missing parameters")
|
||||
|
|
@ -94,6 +100,7 @@ var createUserCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" {
|
||||
//nolint:noinlineerr
|
||||
if _, err := url.Parse(pictureURL); err != nil {
|
||||
ErrorOutput(
|
||||
err,
|
||||
|
|
@ -148,7 +155,7 @@ var destroyUserCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
err := ErrMultipleUsersFoundUseID
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
|
|
@ -276,7 +283,7 @@ var renameUserCmd = &cobra.Command{
|
|||
}
|
||||
|
||||
if len(users.GetUsers()) != 1 {
|
||||
err := errors.New("Unable to determine user to delete, query returned multiple users, use ID")
|
||||
err := ErrMultipleUsersFoundUseID
|
||||
ErrorOutput(
|
||||
err,
|
||||
"Error: "+status.Convert(err).Message(),
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
cerrdefs "github.com/containerd/errdefs"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/docker/docker/client"
|
||||
"github.com/docker/docker/errdefs"
|
||||
)
|
||||
|
||||
// cleanupBeforeTest performs cleanup operations before running tests.
|
||||
|
|
@ -25,6 +25,7 @@ func cleanupBeforeTest(ctx context.Context) error {
|
|||
return fmt.Errorf("failed to clean stale test containers: %w", err)
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
return fmt.Errorf("failed to prune networks: %w", err)
|
||||
}
|
||||
|
|
@ -55,6 +56,7 @@ func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runI
|
|||
|
||||
// killTestContainers terminates and removes all test containers.
|
||||
func killTestContainers(ctx context.Context) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -69,8 +71,10 @@ func killTestContainers(ctx context.Context) error {
|
|||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, cont := range containers {
|
||||
shouldRemove := false
|
||||
|
||||
for _, name := range cont.Names {
|
||||
if strings.Contains(name, "headscale-test-suite") ||
|
||||
strings.Contains(name, "hs-") ||
|
||||
|
|
@ -107,6 +111,7 @@ func killTestContainers(ctx context.Context) error {
|
|||
// This function filters containers by the hi.run-id label to only affect containers
|
||||
// belonging to the specified test run, leaving other concurrent test runs untouched.
|
||||
func killTestContainersByRunID(ctx context.Context, runID string) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -149,6 +154,7 @@ func killTestContainersByRunID(ctx context.Context, runID string) error {
|
|||
// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs
|
||||
// without interfering with currently running concurrent tests.
|
||||
func cleanupStaleTestContainers(ctx context.Context) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -223,6 +229,7 @@ func removeContainerWithRetry(ctx context.Context, cli *client.Client, container
|
|||
|
||||
// pruneDockerNetworks removes unused Docker networks.
|
||||
func pruneDockerNetworks(ctx context.Context) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -245,6 +252,7 @@ func pruneDockerNetworks(ctx context.Context) error {
|
|||
|
||||
// cleanOldImages removes test-related and old dangling Docker images.
|
||||
func cleanOldImages(ctx context.Context) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -259,8 +267,10 @@ func cleanOldImages(ctx context.Context) error {
|
|||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, img := range images {
|
||||
shouldRemove := false
|
||||
|
||||
for _, tag := range img.RepoTags {
|
||||
if strings.Contains(tag, "hs-") ||
|
||||
strings.Contains(tag, "headscale-integration") ||
|
||||
|
|
@ -295,6 +305,7 @@ func cleanOldImages(ctx context.Context) error {
|
|||
|
||||
// cleanCacheVolume removes the Docker volume used for Go module cache.
|
||||
func cleanCacheVolume(ctx context.Context) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -302,11 +313,12 @@ func cleanCacheVolume(ctx context.Context) error {
|
|||
defer cli.Close()
|
||||
|
||||
volumeName := "hs-integration-go-cache"
|
||||
|
||||
err = cli.VolumeRemove(ctx, volumeName, true)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
if cerrdefs.IsNotFound(err) {
|
||||
fmt.Printf("Go module cache volume not found: %s\n", volumeName)
|
||||
} else if errdefs.IsConflict(err) {
|
||||
} else if cerrdefs.IsConflict(err) {
|
||||
fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName)
|
||||
} else {
|
||||
fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err)
|
||||
|
|
|
|||
|
|
@ -14,6 +14,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
cerrdefs "github.com/containerd/errdefs"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/image"
|
||||
"github.com/docker/docker/api/types/mount"
|
||||
|
|
@ -26,10 +27,21 @@ var (
|
|||
ErrTestFailed = errors.New("test failed")
|
||||
ErrUnexpectedContainerWait = errors.New("unexpected end of container wait")
|
||||
ErrNoDockerContext = errors.New("no docker context found")
|
||||
ErrMemoryLimitExceeded = errors.New("container exceeded memory limits")
|
||||
)
|
||||
|
||||
// Docker container constants.
|
||||
const (
|
||||
containerFinalStateWait = 10 * time.Second
|
||||
containerStateCheckInterval = 500 * time.Millisecond
|
||||
dirPermissions = 0o755
|
||||
)
|
||||
|
||||
// runTestContainer executes integration tests in a Docker container.
|
||||
//
|
||||
//nolint:gocyclo
|
||||
func runTestContainer(ctx context.Context, config *RunConfig) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -52,6 +64,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
}
|
||||
|
||||
const dirPerm = 0o755
|
||||
//nolint:noinlineerr
|
||||
if err := os.MkdirAll(absLogsDir, dirPerm); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
|
@ -60,7 +73,9 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
if config.Verbose {
|
||||
log.Printf("Running pre-test cleanup...")
|
||||
}
|
||||
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
|
||||
|
||||
err := cleanupBeforeTest(ctx)
|
||||
if err != nil && config.Verbose {
|
||||
log.Printf("Warning: pre-test cleanup failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -71,6 +86,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
}
|
||||
|
||||
imageName := "golang:" + config.GoVersion
|
||||
//nolint:noinlineerr
|
||||
if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil {
|
||||
return fmt.Errorf("failed to ensure image availability: %w", err)
|
||||
}
|
||||
|
|
@ -84,6 +100,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
log.Printf("Created container: %s", resp.ID)
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil {
|
||||
return fmt.Errorf("failed to start container: %w", err)
|
||||
}
|
||||
|
|
@ -95,13 +112,17 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
|
||||
// Start stats collection for container resource monitoring (if enabled)
|
||||
var statsCollector *StatsCollector
|
||||
|
||||
if config.Stats {
|
||||
var err error
|
||||
|
||||
//nolint:contextcheck
|
||||
statsCollector, err = NewStatsCollector()
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||
}
|
||||
|
||||
statsCollector = nil
|
||||
}
|
||||
|
||||
|
|
@ -110,7 +131,8 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
|
||||
// Start stats collection immediately - no need for complex retry logic
|
||||
// The new implementation monitors Docker events and will catch containers as they start
|
||||
if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil {
|
||||
err := statsCollector.StartCollection(ctx, runID, config.Verbose)
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to start stats collection: %v", err)
|
||||
}
|
||||
|
|
@ -122,11 +144,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
exitCode, err := streamAndWait(ctx, cli, resp.ID)
|
||||
|
||||
// Ensure all containers have finished and logs are flushed before extracting artifacts
|
||||
if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose {
|
||||
waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose)
|
||||
if waitErr != nil && config.Verbose {
|
||||
log.Printf("Warning: failed to wait for container finalization: %v", waitErr)
|
||||
}
|
||||
|
||||
// Extract artifacts from test containers before cleanup
|
||||
//nolint:noinlineerr
|
||||
if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose {
|
||||
log.Printf("Warning: failed to extract artifacts from containers: %v", err)
|
||||
}
|
||||
|
|
@ -140,12 +164,13 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
if len(violations) > 0 {
|
||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||
log.Printf("=================================")
|
||||
|
||||
for _, violation := range violations {
|
||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||
}
|
||||
|
||||
return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations))
|
||||
return fmt.Errorf("%w: %d container(s)", ErrMemoryLimitExceeded, len(violations))
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -344,9 +369,10 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
testContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
|
||||
// Wait for all test containers to reach a final state
|
||||
maxWaitTime := 10 * time.Second
|
||||
checkInterval := 500 * time.Millisecond
|
||||
maxWaitTime := containerFinalStateWait
|
||||
checkInterval := containerStateCheckInterval
|
||||
timeout := time.After(maxWaitTime)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
|
@ -356,6 +382,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
if verbose {
|
||||
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
allFinalized := true
|
||||
|
|
@ -366,12 +393,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
if verbose {
|
||||
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if container is in a final state
|
||||
if !isContainerFinalized(inspect.State) {
|
||||
allFinalized = false
|
||||
|
||||
if verbose {
|
||||
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
||||
}
|
||||
|
|
@ -384,6 +413,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
if verbose {
|
||||
log.Printf("All test containers finalized, ready for artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -400,13 +430,16 @@ func isContainerFinalized(state *container.State) bool {
|
|||
func findProjectRoot(startPath string) string {
|
||||
current := startPath
|
||||
for {
|
||||
//nolint:noinlineerr
|
||||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
|
||||
return current
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
return startPath
|
||||
}
|
||||
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
|
|
@ -416,6 +449,7 @@ func boolToInt(b bool) int {
|
|||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
|
@ -435,6 +469,7 @@ func createDockerClient() (*client.Client, error) {
|
|||
}
|
||||
|
||||
var clientOpts []client.Opt
|
||||
|
||||
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
|
||||
|
||||
if contextInfo != nil {
|
||||
|
|
@ -444,6 +479,7 @@ func createDockerClient() (*client.Client, error) {
|
|||
if runConfig.Verbose {
|
||||
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
||||
}
|
||||
|
||||
clientOpts = append(clientOpts, client.WithHost(host))
|
||||
}
|
||||
}
|
||||
|
|
@ -459,13 +495,15 @@ func createDockerClient() (*client.Client, error) {
|
|||
|
||||
// getCurrentDockerContext retrieves the current Docker context information.
|
||||
func getCurrentDockerContext() (*DockerContext, error) {
|
||||
cmd := exec.Command("docker", "context", "inspect")
|
||||
cmd := exec.CommandContext(context.Background(), "docker", "context", "inspect")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get docker context: %w", err)
|
||||
}
|
||||
|
||||
var contexts []DockerContext
|
||||
//nolint:noinlineerr
|
||||
if err := json.Unmarshal(output, &contexts); err != nil {
|
||||
return nil, fmt.Errorf("failed to parse docker context: %w", err)
|
||||
}
|
||||
|
|
@ -486,11 +524,12 @@ func getDockerSocketPath() string {
|
|||
|
||||
// checkImageAvailableLocally checks if the specified Docker image is available locally.
|
||||
func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) {
|
||||
_, _, err := cli.ImageInspectWithRaw(ctx, imageName)
|
||||
_, err := cli.ImageInspect(ctx, imageName)
|
||||
if err != nil {
|
||||
if client.IsErrNotFound(err) {
|
||||
if cerrdefs.IsNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
||||
}
|
||||
|
||||
|
|
@ -509,6 +548,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
|||
if verbose {
|
||||
log.Printf("Image %s is available locally", imageName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -533,6 +573,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to read pull output: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Image %s pulled successfully", imageName)
|
||||
}
|
||||
|
||||
|
|
@ -547,9 +588,11 @@ func listControlFiles(logsDir string) {
|
|||
return
|
||||
}
|
||||
|
||||
var logFiles []string
|
||||
var dataFiles []string
|
||||
var dataDirs []string
|
||||
var (
|
||||
logFiles []string
|
||||
dataFiles []string
|
||||
dataDirs []string
|
||||
)
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
|
|
@ -578,6 +621,7 @@ func listControlFiles(logsDir string) {
|
|||
|
||||
if len(logFiles) > 0 {
|
||||
log.Printf("Headscale logs:")
|
||||
|
||||
for _, file := range logFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
|
|
@ -585,9 +629,11 @@ func listControlFiles(logsDir string) {
|
|||
|
||||
if len(dataFiles) > 0 || len(dataDirs) > 0 {
|
||||
log.Printf("Headscale data:")
|
||||
|
||||
for _, file := range dataFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
|
||||
for _, dir := range dataDirs {
|
||||
log.Printf(" %s/", dir)
|
||||
}
|
||||
|
|
@ -596,6 +642,7 @@ func listControlFiles(logsDir string) {
|
|||
|
||||
// extractArtifactsFromContainers collects container logs and files from the specific test run.
|
||||
func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create Docker client: %w", err)
|
||||
|
|
@ -612,9 +659,11 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
|||
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
|
||||
extractedCount := 0
|
||||
|
||||
for _, cont := range currentTestContainers {
|
||||
// Extract container logs and tar files
|
||||
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
|
||||
err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err)
|
||||
}
|
||||
|
|
@ -622,6 +671,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
|||
if verbose {
|
||||
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
|
||||
}
|
||||
|
||||
extractedCount++
|
||||
}
|
||||
}
|
||||
|
|
@ -645,11 +695,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
|||
|
||||
// Find the test container to get its run ID label
|
||||
var runID string
|
||||
|
||||
for _, cont := range containers {
|
||||
if cont.ID == testContainerID {
|
||||
if cont.Labels != nil {
|
||||
runID = cont.Labels["hi.run-id"]
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -690,18 +742,21 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
|||
// extractContainerArtifacts saves logs and tar files from a container.
|
||||
func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error {
|
||||
// Ensure the logs directory exists
|
||||
if err := os.MkdirAll(logsDir, 0o755); err != nil {
|
||||
err := os.MkdirAll(logsDir, dirPermissions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Extract container logs
|
||||
//nolint:noinlineerr
|
||||
if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
||||
return fmt.Errorf("failed to extract logs: %w", err)
|
||||
}
|
||||
|
||||
// Extract tar files for headscale containers only
|
||||
if strings.HasPrefix(containerName, "hs-") {
|
||||
if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil {
|
||||
err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose)
|
||||
if err != nil {
|
||||
if verbose {
|
||||
log.Printf("Warning: failed to extract files from %s: %v", containerName, err)
|
||||
}
|
||||
|
|
@ -741,11 +796,13 @@ func extractContainerLogs(ctx context.Context, cli *client.Client, containerID,
|
|||
}
|
||||
|
||||
// Write stdout logs
|
||||
//nolint:gosec,mnd,noinlineerr
|
||||
if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write stdout log: %w", err)
|
||||
}
|
||||
|
||||
// Write stderr logs
|
||||
//nolint:gosec,mnd,noinlineerr
|
||||
if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil {
|
||||
return fmt.Errorf("failed to write stderr log: %w", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,12 +38,15 @@ func runDoctorCheck(ctx context.Context) error {
|
|||
}
|
||||
|
||||
// Check 3: Go installation
|
||||
//nolint:contextcheck
|
||||
results = append(results, checkGoInstallation())
|
||||
|
||||
// Check 4: Git repository
|
||||
//nolint:contextcheck
|
||||
results = append(results, checkGitRepository())
|
||||
|
||||
// Check 5: Required files
|
||||
//nolint:contextcheck
|
||||
results = append(results, checkRequiredFiles())
|
||||
|
||||
// Display results
|
||||
|
|
@ -86,6 +89,7 @@ func checkDockerBinary() DoctorResult {
|
|||
|
||||
// checkDockerDaemon verifies Docker daemon is running and accessible.
|
||||
func checkDockerDaemon(ctx context.Context) DoctorResult {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -125,6 +129,7 @@ func checkDockerDaemon(ctx context.Context) DoctorResult {
|
|||
|
||||
// checkDockerContext verifies Docker context configuration.
|
||||
func checkDockerContext(_ context.Context) DoctorResult {
|
||||
//nolint:contextcheck
|
||||
contextInfo, err := getCurrentDockerContext()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -155,6 +160,7 @@ func checkDockerContext(_ context.Context) DoctorResult {
|
|||
|
||||
// checkDockerSocket verifies Docker socket accessibility.
|
||||
func checkDockerSocket(ctx context.Context) DoctorResult {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -192,6 +198,7 @@ func checkDockerSocket(ctx context.Context) DoctorResult {
|
|||
|
||||
// checkGolangImage verifies the golang Docker image is available locally or can be pulled.
|
||||
func checkGolangImage(ctx context.Context) DoctorResult {
|
||||
//nolint:contextcheck
|
||||
cli, err := createDockerClient()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -265,7 +272,8 @@ func checkGoInstallation() DoctorResult {
|
|||
}
|
||||
}
|
||||
|
||||
cmd := exec.Command("go", "version")
|
||||
cmd := exec.CommandContext(context.Background(), "go", "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -286,7 +294,8 @@ func checkGoInstallation() DoctorResult {
|
|||
|
||||
// checkGitRepository verifies we're in a git repository.
|
||||
func checkGitRepository() DoctorResult {
|
||||
cmd := exec.Command("git", "rev-parse", "--git-dir")
|
||||
cmd := exec.CommandContext(context.Background(), "git", "rev-parse", "--git-dir")
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -316,9 +325,12 @@ func checkRequiredFiles() DoctorResult {
|
|||
}
|
||||
|
||||
var missingFiles []string
|
||||
|
||||
for _, file := range requiredFiles {
|
||||
cmd := exec.Command("test", "-e", file)
|
||||
if err := cmd.Run(); err != nil {
|
||||
cmd := exec.CommandContext(context.Background(), "test", "-e", file)
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
missingFiles = append(missingFiles, file)
|
||||
}
|
||||
}
|
||||
|
|
@ -350,6 +362,7 @@ func displayDoctorResults(results []DoctorResult) {
|
|||
|
||||
for _, result := range results {
|
||||
var icon string
|
||||
|
||||
switch result.Status {
|
||||
case "PASS":
|
||||
icon = "✅"
|
||||
|
|
|
|||
|
|
@ -79,13 +79,18 @@ func main() {
|
|||
}
|
||||
|
||||
func cleanAll(ctx context.Context) error {
|
||||
if err := killTestContainers(ctx); err != nil {
|
||||
err := killTestContainers(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
|
||||
err = pruneDockerNetworks(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cleanOldImages(ctx); err != nil {
|
||||
|
||||
err = cleanOldImages(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,9 @@ func runIntegrationTest(env *command.Env) error {
|
|||
if runConfig.Verbose {
|
||||
log.Printf("Running pre-flight system checks...")
|
||||
}
|
||||
if err := runDoctorCheck(env.Context()); err != nil {
|
||||
|
||||
err := runDoctorCheck(env.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("pre-flight checks failed: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -66,8 +68,10 @@ func runIntegrationTest(env *command.Env) error {
|
|||
func detectGoVersion() string {
|
||||
goModPath := filepath.Join("..", "..", "go.mod")
|
||||
|
||||
//nolint:noinlineerr
|
||||
if _, err := os.Stat("go.mod"); err == nil {
|
||||
goModPath = "go.mod"
|
||||
//nolint:noinlineerr
|
||||
} else if _, err := os.Stat("../../go.mod"); err == nil {
|
||||
goModPath = "../../go.mod"
|
||||
}
|
||||
|
|
@ -94,8 +98,10 @@ func detectGoVersion() string {
|
|||
|
||||
// splitLines splits a string into lines without using strings.Split.
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
var current string
|
||||
var (
|
||||
lines []string
|
||||
current string
|
||||
)
|
||||
|
||||
for _, char := range s {
|
||||
if char == '\n' {
|
||||
|
|
|
|||
|
|
@ -1,23 +1,32 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"sort"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/docker/docker/api/types"
|
||||
"github.com/docker/docker/api/types/container"
|
||||
"github.com/docker/docker/api/types/events"
|
||||
"github.com/docker/docker/api/types/filters"
|
||||
"github.com/docker/docker/client"
|
||||
)
|
||||
|
||||
// ErrStatsCollectionAlreadyStarted is returned when stats collection is already running.
|
||||
var ErrStatsCollectionAlreadyStarted = errors.New("stats collection already started")
|
||||
|
||||
// Stats calculation constants.
|
||||
const (
|
||||
bytesPerKB = 1024
|
||||
percentageMultiplier = 100.0
|
||||
)
|
||||
|
||||
// ContainerStats represents statistics for a single container.
|
||||
type ContainerStats struct {
|
||||
ContainerID string
|
||||
|
|
@ -63,17 +72,19 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
|||
defer sc.mutex.Unlock()
|
||||
|
||||
if sc.collectionStarted {
|
||||
return errors.New("stats collection already started")
|
||||
return ErrStatsCollectionAlreadyStarted
|
||||
}
|
||||
|
||||
sc.collectionStarted = true
|
||||
|
||||
// Start monitoring existing containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||
|
||||
// Start Docker events monitoring for new containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||
|
||||
if verbose {
|
||||
|
|
@ -87,10 +98,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
|||
func (sc *StatsCollector) StopCollection() {
|
||||
// Check if already stopped without holding lock
|
||||
sc.mutex.RLock()
|
||||
|
||||
if !sc.collectionStarted {
|
||||
sc.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
// Signal stop to all goroutines
|
||||
|
|
@ -114,6 +127,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
|
|||
if verbose {
|
||||
log.Printf("Failed to list existing containers: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -147,13 +161,13 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
|||
case event := <-events:
|
||||
if event.Type == "container" && event.Action == "start" {
|
||||
// Get container details
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.ID)
|
||||
containerInfo, err := sc.client.ContainerInspect(ctx, event.Actor.ID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
// Convert to types.Container format for consistency
|
||||
cont := types.Container{
|
||||
// Convert to container.Summary format for consistency
|
||||
cont := container.Summary{
|
||||
ID: containerInfo.ID,
|
||||
Names: []string{containerInfo.Name},
|
||||
Labels: containerInfo.Config.Labels,
|
||||
|
|
@ -167,13 +181,14 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
|||
if verbose {
|
||||
log.Printf("Error in Docker events stream: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// shouldMonitorContainer determines if a container should be monitored.
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool {
|
||||
func (sc *StatsCollector) shouldMonitorContainer(cont container.Summary, runID string) bool {
|
||||
// Check if it has the correct run ID label
|
||||
if cont.Labels == nil || cont.Labels["hi.run-id"] != runID {
|
||||
return false
|
||||
|
|
@ -213,6 +228,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
|
|||
}
|
||||
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||
}
|
||||
|
||||
|
|
@ -226,12 +242,14 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
if verbose {
|
||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
defer statsResponse.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(statsResponse.Body)
|
||||
var prevStats *container.Stats
|
||||
|
||||
var prevStats *container.StatsResponse
|
||||
|
||||
for {
|
||||
select {
|
||||
|
|
@ -240,12 +258,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
case <-ctx.Done():
|
||||
return
|
||||
default:
|
||||
var stats container.Stats
|
||||
if err := decoder.Decode(&stats); err != nil {
|
||||
var stats container.StatsResponse
|
||||
|
||||
err := decoder.Decode(&stats)
|
||||
if err != nil {
|
||||
// EOF is expected when container stops or stream ends
|
||||
if err.Error() != "EOF" && verbose {
|
||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -256,13 +277,15 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
}
|
||||
|
||||
// Calculate memory usage in MB
|
||||
memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024)
|
||||
memoryMB := float64(stats.MemoryStats.Usage) / (bytesPerKB * bytesPerKB)
|
||||
|
||||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||
if prevStats != nil {
|
||||
// Get container stats reference without holding the main mutex
|
||||
var containerStats *ContainerStats
|
||||
var exists bool
|
||||
var (
|
||||
containerStats *ContainerStats
|
||||
exists bool
|
||||
)
|
||||
|
||||
sc.mutex.RLock()
|
||||
containerStats, exists = sc.containers[containerID]
|
||||
|
|
@ -286,7 +309,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
}
|
||||
|
||||
// calculateCPUPercent calculates CPU usage percentage from Docker stats.
|
||||
func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
||||
func calculateCPUPercent(prevStats, stats *container.StatsResponse) float64 {
|
||||
// CPU calculation based on Docker's implementation
|
||||
cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage)
|
||||
systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage)
|
||||
|
|
@ -299,7 +322,7 @@ func calculateCPUPercent(prevStats, stats *container.Stats) float64 {
|
|||
numCPUs = 1.0
|
||||
}
|
||||
|
||||
return (cpuDelta / systemDelta) * numCPUs * 100.0
|
||||
return (cpuDelta / systemDelta) * numCPUs * percentageMultiplier
|
||||
}
|
||||
|
||||
return 0.0
|
||||
|
|
@ -331,10 +354,12 @@ type StatsSummary struct {
|
|||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||
// Take snapshot of container references without holding main lock long
|
||||
sc.mutex.RLock()
|
||||
|
||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||
for _, containerStats := range sc.containers {
|
||||
containerRefs = append(containerRefs, containerStats)
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||
|
|
@ -371,8 +396,8 @@ func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
|||
}
|
||||
|
||||
// Sort by container name for consistent output
|
||||
sort.Slice(summaries, func(i, j int) bool {
|
||||
return summaries[i].ContainerName < summaries[j].ContainerName
|
||||
slices.SortFunc(summaries, func(a, b ContainerStatsSummary) int {
|
||||
return cmp.Compare(a.ContainerName, b.ContainerName)
|
||||
})
|
||||
|
||||
return summaries
|
||||
|
|
@ -384,23 +409,25 @@ func calculateStatsSummary(values []float64) StatsSummary {
|
|||
return StatsSummary{}
|
||||
}
|
||||
|
||||
min := values[0]
|
||||
max := values[0]
|
||||
minVal := values[0]
|
||||
maxVal := values[0]
|
||||
sum := 0.0
|
||||
|
||||
for _, value := range values {
|
||||
if value < min {
|
||||
min = value
|
||||
if value < minVal {
|
||||
minVal = value
|
||||
}
|
||||
if value > max {
|
||||
max = value
|
||||
|
||||
if value > maxVal {
|
||||
maxVal = value
|
||||
}
|
||||
|
||||
sum += value
|
||||
}
|
||||
|
||||
return StatsSummary{
|
||||
Min: min,
|
||||
Max: max,
|
||||
Min: minVal,
|
||||
Max: maxVal,
|
||||
Average: sum / float64(len(values)),
|
||||
}
|
||||
}
|
||||
|
|
@ -434,6 +461,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
|
|||
}
|
||||
|
||||
summaries := sc.GetSummary()
|
||||
|
||||
var violations []MemoryViolation
|
||||
|
||||
for _, summary := range summaries {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
|
|
@ -11,6 +12,8 @@ import (
|
|||
"github.com/juanfont/headscale/integration/integrationutil"
|
||||
)
|
||||
|
||||
var errDirectoryRequired = errors.New("directory is required")
|
||||
|
||||
type MapConfig struct {
|
||||
Directory string `flag:"directory,Directory to read map responses from"`
|
||||
}
|
||||
|
|
@ -40,7 +43,7 @@ func main() {
|
|||
// runIntegrationTest executes the integration test workflow.
|
||||
func runOnline(env *command.Env) error {
|
||||
if mapConfig.Directory == "" {
|
||||
return fmt.Errorf("directory is required")
|
||||
return errDirectoryRequired
|
||||
}
|
||||
|
||||
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
|
||||
|
|
@ -57,5 +60,6 @@ func runOnline(env *command.Env) error {
|
|||
|
||||
os.Stderr.Write(out)
|
||||
os.Stderr.Write([]byte("\n"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
8
flake.lock
generated
8
flake.lock
generated
|
|
@ -20,16 +20,16 @@
|
|||
},
|
||||
"nixpkgs": {
|
||||
"locked": {
|
||||
"lastModified": 1766840161,
|
||||
"narHash": "sha256-Ss/LHpJJsng8vz1Pe33RSGIWUOcqM1fjrehjUkdrWio=",
|
||||
"lastModified": 1769011238,
|
||||
"narHash": "sha256-WPiOcgZv7GQ/AVd9giOrlZjzXHwBNM4yQ+JzLrgI3Xk=",
|
||||
"owner": "NixOS",
|
||||
"repo": "nixpkgs",
|
||||
"rev": "3edc4a30ed3903fdf6f90c837f961fa6b49582d1",
|
||||
"rev": "a895ec2c048eba3bceab06d5dfee5026a6b1c875",
|
||||
"type": "github"
|
||||
},
|
||||
"original": {
|
||||
"owner": "NixOS",
|
||||
"ref": "nixpkgs-unstable",
|
||||
"ref": "master",
|
||||
"repo": "nixpkgs",
|
||||
"type": "github"
|
||||
}
|
||||
|
|
|
|||
19
flake.nix
19
flake.nix
|
|
@ -2,7 +2,8 @@
|
|||
description = "headscale - Open Source Tailscale Control server";
|
||||
|
||||
inputs = {
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/nixpkgs-unstable";
|
||||
# TODO: Move back to nixpkgs-unstable once Go 1.26 is available there
|
||||
nixpkgs.url = "github:NixOS/nixpkgs/master";
|
||||
flake-utils.url = "github:numtide/flake-utils";
|
||||
};
|
||||
|
||||
|
|
@ -23,11 +24,11 @@
|
|||
default = headscale;
|
||||
};
|
||||
|
||||
overlay = _: prev:
|
||||
overlays.default = _: prev:
|
||||
let
|
||||
pkgs = nixpkgs.legacyPackages.${prev.system};
|
||||
buildGo = pkgs.buildGo125Module;
|
||||
vendorHash = "sha256-escboufgbk+lEitw48eWEIltXbaCPdysb/g4YR+extg=";
|
||||
buildGo = pkgs.buildGo126Module;
|
||||
vendorHash = "sha256-hL9vHunaxodGt3g/CIVirXy4OjZKTI3XwbVPPRb34OY=";
|
||||
in
|
||||
{
|
||||
headscale = buildGo {
|
||||
|
|
@ -129,10 +130,10 @@
|
|||
(system:
|
||||
let
|
||||
pkgs = import nixpkgs {
|
||||
overlays = [ self.overlay ];
|
||||
overlays = [ self.overlays.default ];
|
||||
inherit system;
|
||||
};
|
||||
buildDeps = with pkgs; [ git go_1_25 gnumake ];
|
||||
buildDeps = with pkgs; [ git go_1_26 gnumake ];
|
||||
devDeps = with pkgs;
|
||||
buildDeps
|
||||
++ [
|
||||
|
|
@ -167,7 +168,7 @@
|
|||
clang-tools # clang-format
|
||||
protobuf-language-server
|
||||
]
|
||||
++ lib.optional pkgs.stdenv.isLinux [ traceroute ];
|
||||
++ lib.optional pkgs.stdenv.hostPlatform.isLinux [ traceroute ];
|
||||
|
||||
# Add entry to build a docker image with headscale
|
||||
# caveat: only works on Linux
|
||||
|
|
@ -184,7 +185,7 @@
|
|||
in
|
||||
rec {
|
||||
# `nix develop`
|
||||
devShell = pkgs.mkShell {
|
||||
devShells.default = pkgs.mkShell {
|
||||
buildInputs =
|
||||
devDeps
|
||||
++ [
|
||||
|
|
@ -219,8 +220,8 @@
|
|||
packages = with pkgs; {
|
||||
inherit headscale;
|
||||
inherit headscale-docker;
|
||||
default = headscale;
|
||||
};
|
||||
defaultPackage = pkgs.headscale;
|
||||
|
||||
# `nix run`
|
||||
apps.headscale = flake-utils.lib.mkApp {
|
||||
|
|
|
|||
2
go.mod
2
go.mod
|
|
@ -1,6 +1,6 @@
|
|||
module github.com/juanfont/headscale
|
||||
|
||||
go 1.25
|
||||
go 1.26rc2
|
||||
|
||||
require (
|
||||
github.com/arl/statsviz v0.7.2
|
||||
|
|
|
|||
|
|
@ -67,6 +67,9 @@ var (
|
|||
)
|
||||
)
|
||||
|
||||
// oidcProviderInitTimeout is the timeout for initializing the OIDC provider.
|
||||
const oidcProviderInitTimeout = 30 * time.Second
|
||||
|
||||
var (
|
||||
debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK")
|
||||
debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT")
|
||||
|
|
@ -142,6 +145,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
if !ok {
|
||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -157,10 +161,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
app.ephemeralGC = ephemeralGC
|
||||
|
||||
var authProvider AuthProvider
|
||||
|
||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), oidcProviderInitTimeout)
|
||||
defer cancel()
|
||||
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
|
|
@ -177,6 +183,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
authProvider = oidcProvider
|
||||
}
|
||||
}
|
||||
|
||||
app.authProvider = authProvider
|
||||
|
||||
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
||||
|
|
@ -251,9 +258,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
lastExpiryCheck := time.Unix(0, 0)
|
||||
|
||||
derpTickerChan := make(<-chan time.Time)
|
||||
|
||||
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
||||
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
||||
defer derpTicker.Stop()
|
||||
|
||||
derpTickerChan = derpTicker.C
|
||||
}
|
||||
|
||||
|
|
@ -271,8 +280,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var expiredNodeChanges []change.Change
|
||||
var changed bool
|
||||
var (
|
||||
expiredNodeChanges []change.Change
|
||||
changed bool
|
||||
)
|
||||
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
|
|
@ -287,11 +298,14 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
|
||||
case <-derpTickerChan:
|
||||
log.Info().Msg("Fetching DERPMap updates")
|
||||
|
||||
//nolint:contextcheck
|
||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
|
||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
region, _ := h.DERPServer.GenerateRegion()
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
|
|
@ -303,6 +317,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
||||
continue
|
||||
}
|
||||
|
||||
h.state.SetDERPMap(derpMap)
|
||||
|
||||
h.Change(change.DERPMap())
|
||||
|
|
@ -311,6 +326,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
h.Change(change.ExtraRecords())
|
||||
|
|
@ -390,6 +406,8 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||
|
||||
writeUnauthorized := func(statusCode int) {
|
||||
writer.WriteHeader(statusCode)
|
||||
|
||||
//nolint:noinlineerr
|
||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
|
||||
log.Error().Err(err).Msg("writing HTTP response failed")
|
||||
}
|
||||
|
|
@ -486,6 +504,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
var err error
|
||||
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
|
||||
if profilingEnabled {
|
||||
|
|
@ -512,6 +531,7 @@ func (h *Headscale) Serve() error {
|
|||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
|
|
@ -545,6 +565,7 @@ func (h *Headscale) Serve() error {
|
|||
// around between restarts, they will reconnect and the GC will
|
||||
// be cancelled.
|
||||
go h.ephemeralGC.Start()
|
||||
|
||||
ephmNodes := h.state.ListEphemeralNodes()
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
|
|
@ -555,7 +576,9 @@ func (h *Headscale) Serve() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
||||
|
||||
go h.extraRecordMan.Run()
|
||||
defer h.extraRecordMan.Close()
|
||||
}
|
||||
|
|
@ -564,6 +587,7 @@ func (h *Headscale) Serve() error {
|
|||
// records updates
|
||||
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
||||
defer scheduleCancel()
|
||||
|
||||
go h.scheduledTasks(scheduleCtx)
|
||||
|
||||
if zl.GlobalLevel() == zl.TraceLevel {
|
||||
|
|
@ -751,7 +775,6 @@ func (h *Headscale) Serve() error {
|
|||
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
||||
}
|
||||
|
||||
|
||||
var tailsqlContext context.Context
|
||||
if tailsqlEnabled {
|
||||
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||
|
|
@ -863,6 +886,8 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
// Close state connections
|
||||
info("closing state and database")
|
||||
|
||||
//nolint:contextcheck
|
||||
err = h.state.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to close state")
|
||||
|
|
|
|||
|
|
@ -16,12 +16,11 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
type AuthProvider interface {
|
||||
RegisterHandler(http.ResponseWriter, *http.Request)
|
||||
AuthURL(types.RegistrationID) string
|
||||
RegisterHandler(w http.ResponseWriter, r *http.Request)
|
||||
AuthURL(regID types.RegistrationID) string
|
||||
}
|
||||
|
||||
func (h *Headscale) handleRegister(
|
||||
|
|
@ -52,6 +51,7 @@ func (h *Headscale) handleRegister(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("handling logout: %w", err)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
|
@ -113,8 +113,7 @@ func (h *Headscale) handleRegister(
|
|||
resp, err := h.handleRegisterWithAuthKey(req, machineKey)
|
||||
if err != nil {
|
||||
// Preserve HTTPError types so they can be handled properly by the HTTP layer
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
if httpErr, ok := errors.AsType[HTTPError](err); ok {
|
||||
return nil, httpErr
|
||||
}
|
||||
|
||||
|
|
@ -133,7 +132,7 @@ func (h *Headscale) handleRegister(
|
|||
}
|
||||
|
||||
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
||||
// logout attempt from a node. If the node is not attempting to
|
||||
// logout attempt from a node. If the node is not attempting to.
|
||||
func (h *Headscale) handleLogout(
|
||||
node types.NodeView,
|
||||
req tailcfg.RegisterRequest,
|
||||
|
|
@ -160,6 +159,7 @@ func (h *Headscale) handleLogout(
|
|||
Interface("reg.req", req).
|
||||
Bool("unexpected", true).
|
||||
Msg("Node key expired, forcing re-authentication")
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
|
|
@ -279,6 +279,7 @@ func (h *Headscale) waitForFollowup(
|
|||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node.View()), nil
|
||||
}
|
||||
}
|
||||
|
|
@ -316,7 +317,7 @@ func (h *Headscale) reqToNewRegisterResponse(
|
|||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
LastSeen: new(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -344,8 +345,8 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
var perr types.PAKError
|
||||
if errors.As(err, &perr) {
|
||||
|
||||
if perr, ok := errors.AsType[types.PAKError](err); ok {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||
}
|
||||
|
||||
|
|
@ -355,7 +356,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||
// If node is not valid, it means an ephemeral node was deleted during logout
|
||||
if !node.Valid() {
|
||||
h.Change(changed)
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// This is a bit of a back and forth, but we have a bit of a chicken and egg
|
||||
|
|
@ -435,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with empty hostname, generated default")
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
|
|
@ -443,7 +445,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||
MachineKey: machineKey,
|
||||
NodeKey: req.NodeKey,
|
||||
Hostinfo: hostinfo,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
LastSeen: new(time.Now()),
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package hscontrol
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
|
@ -16,14 +17,20 @@ import (
|
|||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// Interactive step type constants
|
||||
// Test sentinel errors.
|
||||
var (
|
||||
errNodeNotFoundAfterSetup = errors.New("node not found after setup")
|
||||
errInvalidAuthURLFormat = errors.New("invalid AuthURL format")
|
||||
)
|
||||
|
||||
// Interactive step type constants.
|
||||
const (
|
||||
stepTypeInitialRequest = "initial_request"
|
||||
stepTypeAuthCompletion = "auth_completion"
|
||||
stepTypeFollowupRequest = "followup_request"
|
||||
)
|
||||
|
||||
// interactiveStep defines a step in the interactive authentication workflow
|
||||
// interactiveStep defines a step in the interactive authentication workflow.
|
||||
type interactiveStep struct {
|
||||
stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest
|
||||
expectAuthURL bool
|
||||
|
|
@ -31,6 +38,7 @@ type interactiveStep struct {
|
|||
callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func TestAuthenticationFlows(t *testing.T) {
|
||||
// Shared test keys for consistent behavior across test cases
|
||||
machineKey1 := key.NewMachine()
|
||||
|
|
@ -69,12 +77,15 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
{
|
||||
name: "preauth_key_valid_new_node",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
user := app.state.CreateUserForTest("preauth-user")
|
||||
|
||||
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -89,7 +100,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -111,6 +122,8 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
{
|
||||
name: "preauth_key_reusable_multiple_nodes",
|
||||
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
|
||||
t.Helper()
|
||||
|
||||
user := app.state.CreateUserForTest("reusable-user")
|
||||
|
||||
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
|
||||
|
|
@ -129,6 +142,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -154,7 +168,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey2.Public() },
|
||||
machineKey: machineKey2.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -163,6 +177,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify both nodes exist
|
||||
node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
|
||||
|
||||
assert.True(t, found1)
|
||||
assert.True(t, found2)
|
||||
assert.Equal(t, "reusable-node-1", node1.Hostname())
|
||||
|
|
@ -196,6 +211,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -221,12 +237,13 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey2.Public() },
|
||||
machineKey: machineKey2.Public,
|
||||
wantError: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
// First node should exist, second should not
|
||||
_, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
_, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
|
||||
|
||||
assert.True(t, found1)
|
||||
assert.False(t, found2)
|
||||
},
|
||||
|
|
@ -254,7 +271,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
|
||||
|
|
@ -272,6 +289,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -286,7 +304,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -323,7 +341,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -352,7 +370,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -391,6 +409,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -400,8 +419,10 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Wait for node to be available in NodeStore with debug info
|
||||
var attemptCount int
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
attemptCount++
|
||||
|
||||
_, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
if assert.True(c, found, "node should be available in NodeStore") {
|
||||
t.Logf("Node found in NodeStore after %d attempts", attemptCount)
|
||||
|
|
@ -417,7 +438,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
wantExpired: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
|
|
@ -451,6 +472,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -471,7 +493,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(-1 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key
|
||||
machineKey: machineKey2.Public, // Different machine key
|
||||
wantError: true,
|
||||
},
|
||||
// TEST: Existing node cannot extend expiry without re-auth
|
||||
|
|
@ -500,6 +522,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -520,7 +543,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
// TEST: Expired node must re-authenticate
|
||||
|
|
@ -549,25 +572,31 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Wait for node to be available in NodeStore
|
||||
var node types.NodeView
|
||||
var found bool
|
||||
var (
|
||||
node types.NodeView
|
||||
found bool
|
||||
)
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
node, found = app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(c, found, "node should be available in NodeStore")
|
||||
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
|
||||
|
||||
if !found {
|
||||
return "", fmt.Errorf("node not found after setup")
|
||||
return "", errNodeNotFoundAfterSetup
|
||||
}
|
||||
|
||||
// Expire the node
|
||||
expiredTime := time.Now().Add(-1 * time.Hour)
|
||||
_, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime)
|
||||
|
||||
return "", err
|
||||
},
|
||||
request: func(_ string) tailcfg.RegisterRequest {
|
||||
|
|
@ -577,7 +606,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour), // Future expiry
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantExpired: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.NodeKeyExpired)
|
||||
|
|
@ -610,6 +639,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -630,7 +660,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(-1 * time.Hour), // Logout
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantExpired: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.NodeKeyExpired)
|
||||
|
|
@ -673,6 +703,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// and handleRegister will receive the value when it starts waiting
|
||||
go func() {
|
||||
user := app.state.CreateUserForTest("followup-user")
|
||||
|
||||
node := app.state.CreateNodeForTest(user, "followup-success-node")
|
||||
registered <- node
|
||||
}()
|
||||
|
|
@ -685,7 +716,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
NodeKey: nodeKey1.Public(),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -723,7 +754,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
NodeKey: nodeKey1.Public(),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
// TEST: Invalid followup URL is rejected
|
||||
|
|
@ -742,7 +773,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
NodeKey: nodeKey1.Public(),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
// TEST: Non-existent registration ID is rejected
|
||||
|
|
@ -761,7 +792,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
NodeKey: nodeKey1.Public(),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
|
||||
|
|
@ -782,6 +813,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -796,7 +828,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -821,6 +853,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -833,7 +866,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -865,6 +898,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -879,7 +913,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
|
||||
|
|
@ -898,6 +932,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -912,7 +947,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -922,6 +957,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "tagged-pak-node", node.Hostname())
|
||||
|
||||
if node.AuthKey().Valid() {
|
||||
assert.NotEmpty(t, node.AuthKey().Tags())
|
||||
}
|
||||
|
|
@ -1031,6 +1067,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1047,6 +1084,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak2.Key, nil
|
||||
},
|
||||
request: func(newAuthKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1061,7 +1099,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(48 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1099,6 +1137,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1124,7 +1163,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(48 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuthURL: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.Contains(t, resp.AuthURL, "register/")
|
||||
|
|
@ -1161,6 +1200,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1177,6 +1217,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pakRotation.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1191,7 +1232,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1226,6 +1267,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1240,7 +1282,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Time{}, // Zero time
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1265,6 +1307,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1286,7 +1329,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1342,7 +1385,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: false, // Should not be authorized yet - needs to use new AuthURL
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
// Should get a new AuthURL, not an error
|
||||
|
|
@ -1367,7 +1410,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
NodeKey: nodeKey1.Public(),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
// TEST: Wrong followup path format is rejected
|
||||
|
|
@ -1386,7 +1429,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
NodeKey: nodeKey1.Public(),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantError: true,
|
||||
},
|
||||
|
||||
|
|
@ -1417,7 +1460,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -1429,6 +1472,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify custom hostinfo was preserved through interactive workflow
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be found after interactive registration")
|
||||
|
||||
if found {
|
||||
assert.Equal(t, "custom-interactive-node", node.Hostname())
|
||||
assert.Equal(t, "linux", node.Hostinfo().OS())
|
||||
|
|
@ -1455,6 +1499,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1469,7 +1514,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1508,7 +1553,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -1520,6 +1565,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify registration ID was properly generated and used
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be registered after interactive workflow")
|
||||
|
||||
if found {
|
||||
assert.Equal(t, "registration-id-test-node", node.Hostname())
|
||||
assert.Equal(t, "test-os", node.Hostinfo().OS())
|
||||
|
|
@ -1535,6 +1581,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1549,7 +1596,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1577,6 +1624,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1592,7 +1640,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(12 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1632,6 +1680,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1648,6 +1697,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak2.Key, nil
|
||||
},
|
||||
request: func(user2AuthKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1662,7 +1712,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
assert.True(t, resp.MachineAuthorized)
|
||||
|
|
@ -1712,6 +1762,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1735,7 +1786,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key
|
||||
machineKey: machineKey1.Public, // Same machine key
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -1789,7 +1840,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: false, // Should not be authorized yet - needs to use new AuthURL
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
// Should get a new AuthURL, not an error
|
||||
|
|
@ -1799,13 +1850,13 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Verify the response contains a valid registration URL
|
||||
authURL, err := url.Parse(resp.AuthURL)
|
||||
assert.NoError(t, err, "AuthURL should be a valid URL")
|
||||
require.NoError(t, err, "AuthURL should be a valid URL")
|
||||
assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/")
|
||||
|
||||
// Extract and validate the new registration ID exists in cache
|
||||
newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/")
|
||||
newRegID, err := types.RegistrationIDFromString(newRegIDStr)
|
||||
assert.NoError(t, err, "should be able to parse new registration ID")
|
||||
require.NoError(t, err, "should be able to parse new registration ID")
|
||||
|
||||
// Verify new registration entry exists in cache
|
||||
_, found := app.state.GetRegistrationCacheEntry(newRegID)
|
||||
|
|
@ -1838,6 +1889,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1858,7 +1910,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now(), // Exactly now (edge case between past and future)
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
wantAuth: true,
|
||||
wantExpired: true,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
|
|
@ -1890,7 +1942,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey2.Public() },
|
||||
machineKey: machineKey2.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -1932,6 +1984,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1955,7 +2008,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -2001,7 +2054,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -2055,7 +2108,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
// This test validates concurrent interactive registration attempts
|
||||
assert.Contains(t, resp.AuthURL, "/register/")
|
||||
|
|
@ -2097,6 +2150,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Collect results - at least one should succeed
|
||||
successCount := 0
|
||||
|
||||
for range numConcurrent {
|
||||
select {
|
||||
case err := <-results:
|
||||
|
|
@ -2162,7 +2216,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -2206,7 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
requiresInteractiveFlow: true,
|
||||
interactiveSteps: []interactiveStep{
|
||||
{stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true},
|
||||
|
|
@ -2217,6 +2271,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Should handle nil hostinfo gracefully
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be registered despite nil hostinfo")
|
||||
|
||||
if found {
|
||||
// Should have some default hostname or handle nil gracefully
|
||||
hostname := node.Hostname()
|
||||
|
|
@ -2243,7 +2298,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
// Get initial AuthURL and extract registration ID
|
||||
authURL := resp.AuthURL
|
||||
|
|
@ -2265,7 +2320,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
nil,
|
||||
"error-test-method",
|
||||
)
|
||||
assert.Error(t, err, "should fail with invalid user ID")
|
||||
require.Error(t, err, "should fail with invalid user ID")
|
||||
|
||||
// Cache entry should still exist after auth error (for retry scenarios)
|
||||
_, stillFound := app.state.GetRegistrationCacheEntry(registrationID)
|
||||
|
|
@ -2297,7 +2352,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
// Test multiple interactive registration attempts for the same node can coexist
|
||||
authURL1 := resp.AuthURL
|
||||
|
|
@ -2315,12 +2370,14 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
|
||||
require.NoError(t, err)
|
||||
|
||||
authURL2 := resp2.AuthURL
|
||||
assert.Contains(t, authURL2, "/register/")
|
||||
|
||||
// Both should have different registration IDs
|
||||
regID1, err1 := extractRegistrationIDFromAuthURL(authURL1)
|
||||
regID2, err2 := extractRegistrationIDFromAuthURL(authURL2)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
|
||||
|
|
@ -2328,6 +2385,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Both cache entries should exist simultaneously
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first registration cache entry should exist")
|
||||
assert.True(t, found2, "second registration cache entry should exist")
|
||||
|
||||
|
|
@ -2354,7 +2412,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
},
|
||||
machineKey: func() key.MachinePublic { return machineKey1.Public() },
|
||||
machineKey: machineKey1.Public,
|
||||
validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) {
|
||||
authURL1 := resp.AuthURL
|
||||
regID1, err := extractRegistrationIDFromAuthURL(authURL1)
|
||||
|
|
@ -2371,6 +2429,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
|
||||
require.NoError(t, err)
|
||||
|
||||
authURL2 := resp2.AuthURL
|
||||
regID2, err := extractRegistrationIDFromAuthURL(authURL2)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -2378,6 +2437,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify both exist
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first cache entry should exist")
|
||||
assert.True(t, found2, "second cache entry should exist")
|
||||
|
||||
|
|
@ -2403,6 +2463,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
errorChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
responseChan <- resp
|
||||
}()
|
||||
|
||||
|
|
@ -2430,6 +2491,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify the node was created with the second registration's data
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be registered")
|
||||
|
||||
if found {
|
||||
assert.Equal(t, "pending-node-2", node.Hostname())
|
||||
assert.Equal(t, "second-registration-user", node.User().Name())
|
||||
|
|
@ -2463,8 +2525,10 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Set up context with timeout for followup tests
|
||||
ctx := context.Background()
|
||||
|
||||
if req.Followup != "" {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
}
|
||||
|
|
@ -2516,7 +2580,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow
|
||||
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow.
|
||||
func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
name string
|
||||
setupFunc func(*testing.T, *Headscale) (string, error)
|
||||
|
|
@ -2535,6 +2599,8 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||
validateCompleteResponse bool
|
||||
}, app *Headscale, dynamicValue string,
|
||||
) {
|
||||
t.Helper()
|
||||
|
||||
// Build initial request
|
||||
req := tt.request(dynamicValue)
|
||||
machineKey := tt.machineKey()
|
||||
|
|
@ -2597,6 +2663,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||
errorChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
responseChan <- resp
|
||||
}()
|
||||
|
||||
|
|
@ -2650,25 +2717,30 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||
if responseToValidate == nil {
|
||||
responseToValidate = initialResp
|
||||
}
|
||||
|
||||
tt.validate(t, responseToValidate, app)
|
||||
}
|
||||
}
|
||||
|
||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL
|
||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
|
||||
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
|
||||
// AuthURL format: "http://localhost/register/abc123"
|
||||
const registerPrefix = "/register/"
|
||||
|
||||
idx := strings.LastIndex(authURL, registerPrefix)
|
||||
if idx == -1 {
|
||||
return "", fmt.Errorf("invalid AuthURL format: %s", authURL)
|
||||
return "", fmt.Errorf("%w: %s", errInvalidAuthURLFormat, authURL)
|
||||
}
|
||||
|
||||
idStr := authURL[idx+len(registerPrefix):]
|
||||
|
||||
return types.RegistrationIDFromString(idStr)
|
||||
}
|
||||
|
||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response
|
||||
func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) {
|
||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
|
||||
func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, _ tailcfg.RegisterRequest) {
|
||||
t.Helper()
|
||||
|
||||
// Basic response validation
|
||||
require.NotNil(t, resp, "response should not be nil")
|
||||
require.True(t, resp.MachineAuthorized, "machine should be authorized")
|
||||
|
|
@ -2681,7 +2753,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe
|
|||
// Additional validation can be added here as needed
|
||||
}
|
||||
|
||||
// Simple test to validate basic node creation and lookup
|
||||
// Simple test to validate basic node creation and lookup.
|
||||
func TestNodeStoreLookup(t *testing.T) {
|
||||
app := createTestApp(t)
|
||||
|
||||
|
|
@ -2713,8 +2785,10 @@ func TestNodeStoreLookup(t *testing.T) {
|
|||
|
||||
// Wait for node to be available in NodeStore
|
||||
var node types.NodeView
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var found bool
|
||||
|
||||
node, found = app.state.GetNodeByNodeKey(nodeKey.Public())
|
||||
assert.True(c, found, "Node should be found in NodeStore")
|
||||
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore")
|
||||
|
|
@ -2783,8 +2857,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
|
||||
// Get the node ID
|
||||
var registeredNode types.NodeView
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var found bool
|
||||
|
||||
registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public())
|
||||
assert.True(c, found, "Node should be found in NodeStore")
|
||||
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available")
|
||||
|
|
@ -2796,6 +2872,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
// Verify initial state: user1 has 2 nodes, user2 has 2 nodes
|
||||
user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID))
|
||||
user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID))
|
||||
|
||||
require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially")
|
||||
require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially")
|
||||
|
||||
|
|
@ -2876,6 +2953,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
|
||||
// Verify new nodes were created for user1 with the same machine keys
|
||||
t.Logf("Verifying new nodes created for user1 from user2's machine keys...")
|
||||
|
||||
for i := 2; i < 4; i++ {
|
||||
node := nodes[i]
|
||||
// Should be able to find a node with user1 and this machine key (the new one)
|
||||
|
|
@ -2899,7 +2977,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
// Expected behavior:
|
||||
// - User1's original node should STILL EXIST (expired)
|
||||
// - User2 should get a NEW node created (NOT transfer)
|
||||
// - Both nodes share the same machine key (same physical device)
|
||||
// - Both nodes share the same machine key (same physical device).
|
||||
func TestWebFlowReauthDifferentUser(t *testing.T) {
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey1 := key.NewNode()
|
||||
|
|
@ -3043,7 +3121,8 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
|
|||
// Count nodes per user
|
||||
user1Nodes := 0
|
||||
user2Nodes := 0
|
||||
for i := 0; i < allNodesSlice.Len(); i++ {
|
||||
|
||||
for i := range allNodesSlice.Len() {
|
||||
n := allNodesSlice.At(i)
|
||||
if n.UserID().Get() == user1.ID {
|
||||
user1Nodes++
|
||||
|
|
@ -3060,7 +3139,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Helper function to create test app
|
||||
// Helper function to create test app.
|
||||
func createTestApp(t *testing.T) *Headscale {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -3147,6 +3226,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Log("Step 1: Initial registration with pre-auth key")
|
||||
|
||||
initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public())
|
||||
require.NoError(t, err, "initial registration should succeed")
|
||||
require.NotNil(t, initialResp)
|
||||
|
|
@ -3172,6 +3252,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
|
|||
// - System reboots
|
||||
// The Tailscale client persists the pre-auth key in its state and sends it on every registration
|
||||
t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key")
|
||||
|
||||
restartReq := tailcfg.RegisterRequest{
|
||||
Auth: &tailcfg.RegisterResponseAuth{
|
||||
AuthKey: pakNew.Key, // Same key, now marked as Used=true
|
||||
|
|
@ -3188,10 +3269,12 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
|
|||
restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public())
|
||||
|
||||
// This is the assertion that currently FAILS in v0.27.0
|
||||
assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed")
|
||||
require.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed")
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Error received (this is the bug): %v", err)
|
||||
t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially")
|
||||
|
||||
return // Stop here to show the bug clearly
|
||||
}
|
||||
|
||||
|
|
@ -3289,7 +3372,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) {
|
|||
}
|
||||
|
||||
_, err = app.handleRegister(context.Background(), req, machineKey.Public())
|
||||
assert.Error(t, err, "expired pre-auth key should be rejected")
|
||||
require.Error(t, err, "expired pre-auth key should be rejected")
|
||||
assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration")
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ import (
|
|||
)
|
||||
|
||||
const (
|
||||
apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential
|
||||
apiKeyPrefix = "hskey-api-" //nolint:gosec
|
||||
apiKeyPrefixLength = 12
|
||||
apiKeyHashLength = 64
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,6 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/logger"
|
||||
"gorm.io/gorm/schema"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"zgo.at/zcache/v2"
|
||||
)
|
||||
|
||||
|
|
@ -76,6 +75,7 @@ func NewHeadscaleDatabase(
|
|||
ID: "202501221827",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// Remove any invalid routes associated with a node that does not exist.
|
||||
//nolint:staticcheck
|
||||
if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) {
|
||||
err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error
|
||||
if err != nil {
|
||||
|
|
@ -84,6 +84,7 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
|
||||
// Remove any invalid routes without a node_id.
|
||||
//nolint:staticcheck
|
||||
if tx.Migrator().HasTable(&types.Route{}) {
|
||||
err := tx.Exec("delete from routes where node_id is null").Error
|
||||
if err != nil {
|
||||
|
|
@ -91,6 +92,7 @@ func NewHeadscaleDatabase(
|
|||
}
|
||||
}
|
||||
|
||||
//nolint:staticcheck
|
||||
err := tx.AutoMigrate(&types.Route{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Route: %w", err)
|
||||
|
|
@ -109,6 +111,7 @@ func NewHeadscaleDatabase(
|
|||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.PreAuthKey: %w", err)
|
||||
}
|
||||
|
||||
err = tx.AutoMigrate(&types.Node{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("automigrating types.Node: %w", err)
|
||||
|
|
@ -155,7 +158,9 @@ AND auth_key_id NOT IN (
|
|||
|
||||
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||
|
||||
//nolint:staticcheck
|
||||
var routes []types.Route
|
||||
|
||||
err = tx.Find(&routes).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching routes: %w", err)
|
||||
|
|
@ -168,10 +173,13 @@ AND auth_key_id NOT IN (
|
|||
}
|
||||
|
||||
for nodeID, routes := range nodeRoutes {
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.SortFunc(routes, netip.Prefix.Compare)
|
||||
routes = slices.Compact(routes)
|
||||
|
||||
data, err := json.Marshal(routes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("marshaling routes for node %d: %w", nodeID, err)
|
||||
}
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error
|
||||
if err != nil {
|
||||
|
|
@ -180,6 +188,7 @@ AND auth_key_id NOT IN (
|
|||
}
|
||||
|
||||
// Drop the old table.
|
||||
//nolint:staticcheck
|
||||
_ = tx.Migrator().DropTable(&types.Route{})
|
||||
|
||||
return nil
|
||||
|
|
@ -256,10 +265,13 @@ AND auth_key_id NOT IN (
|
|||
|
||||
// Check if routes table exists and drop it (should have been migrated already)
|
||||
var routesExists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
||||
if err == nil && routesExists {
|
||||
log.Info().Msg("Dropping leftover routes table")
|
||||
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
|
||||
|
||||
err := tx.Exec("DROP TABLE routes").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("dropping routes table: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -281,6 +293,7 @@ AND auth_key_id NOT IN (
|
|||
for _, table := range tablesToRename {
|
||||
// Check if table exists before renaming
|
||||
var exists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
||||
|
|
@ -291,7 +304,8 @@ AND auth_key_id NOT IN (
|
|||
_ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
|
||||
// Rename current table to _old
|
||||
if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err)
|
||||
}
|
||||
}
|
||||
|
|
@ -365,7 +379,8 @@ AND auth_key_id NOT IN (
|
|||
}
|
||||
|
||||
for _, createSQL := range tableCreationSQL {
|
||||
if err := tx.Exec(createSQL).Error; err != nil {
|
||||
err := tx.Exec(createSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating new table: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -394,7 +409,8 @@ AND auth_key_id NOT IN (
|
|||
}
|
||||
|
||||
for _, copySQL := range dataCopySQL {
|
||||
if err := tx.Exec(copySQL).Error; err != nil {
|
||||
err := tx.Exec(copySQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("copying data: %w", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -417,14 +433,16 @@ AND auth_key_id NOT IN (
|
|||
}
|
||||
|
||||
for _, indexSQL := range indexes {
|
||||
if err := tx.Exec(indexSQL).Error; err != nil {
|
||||
err := tx.Exec(indexSQL).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating index: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Drop old tables only after everything succeeds
|
||||
for _, table := range tablesToRename {
|
||||
if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil {
|
||||
err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error
|
||||
if err != nil {
|
||||
log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded")
|
||||
}
|
||||
}
|
||||
|
|
@ -762,6 +780,7 @@ AND auth_key_id NOT IN (
|
|||
|
||||
// or else it blocks...
|
||||
sqlConn.SetMaxIdleConns(maxIdleConns)
|
||||
|
||||
sqlConn.SetMaxOpenConns(maxOpenConns)
|
||||
defer sqlConn.SetMaxIdleConns(1)
|
||||
defer sqlConn.SetMaxOpenConns(1)
|
||||
|
|
@ -779,6 +798,7 @@ AND auth_key_id NOT IN (
|
|||
},
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil {
|
||||
return nil, fmt.Errorf("validating schema: %w", err)
|
||||
}
|
||||
|
|
@ -913,6 +933,7 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||
|
||||
// Get the current foreign key status
|
||||
var fkOriginallyEnabled int
|
||||
//nolint:noinlineerr
|
||||
if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil {
|
||||
return fmt.Errorf("checking foreign key status: %w", err)
|
||||
}
|
||||
|
|
@ -942,27 +963,32 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||
|
||||
if needsFKDisabled {
|
||||
// Disable foreign keys for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
} else {
|
||||
// Ensure foreign keys are enabled for this migration
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
err := dbConn.Exec("PRAGMA foreign_keys = ON").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Run up to this specific migration (will only run the next pending migration)
|
||||
if err := migrations.MigrateTo(migrationID); err != nil {
|
||||
err := migrations.MigrateTo(migrationID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("running migration %s: %w", migrationID, err)
|
||||
}
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil {
|
||||
return fmt.Errorf("restoring foreign keys: %w", err)
|
||||
}
|
||||
|
||||
// Run the rest of the migrations
|
||||
//nolint:noinlineerr
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -1005,7 +1031,8 @@ func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormig
|
|||
}
|
||||
} else {
|
||||
// PostgreSQL can run all migrations in one block - no foreign key issues
|
||||
if err := migrations.Migrate(); err != nil {
|
||||
err := migrations.Migrate()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -1031,7 +1058,7 @@ func (hsdb *HSDatabase) Close() error {
|
|||
}
|
||||
|
||||
if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
|
||||
db.Exec("VACUUM")
|
||||
_, _ = db.ExecContext(context.Background(), "VACUUM")
|
||||
}
|
||||
|
||||
return db.Close()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"os/exec"
|
||||
|
|
@ -44,6 +45,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
|||
|
||||
// Verify api_keys data preservation
|
||||
var apiKeyCount int
|
||||
|
||||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||||
|
|
@ -176,7 +178,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
|||
return err
|
||||
}
|
||||
|
||||
_, err = db.Exec(string(schemaContent))
|
||||
_, err = db.ExecContext(context.Background(), string(schemaContent))
|
||||
|
||||
return err
|
||||
}
|
||||
|
|
@ -186,6 +188,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
|||
func requireConstraintFailed(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||
}
|
||||
|
|
@ -320,7 +323,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
|
|||
}
|
||||
|
||||
// Construct the pg_restore command
|
||||
cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
cmd := exec.CommandContext(context.Background(), pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath)
|
||||
|
||||
// Set the output streams
|
||||
cmd.Stdout = os.Stdout
|
||||
|
|
@ -401,6 +404,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
|||
// skip already-applied migrations and only run new ones.
|
||||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemas, err := os.ReadDir("testdata/sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Basic deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var deletionWg sync.WaitGroup
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
deletionWg sync.WaitGroup
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
deletionWg.Done()
|
||||
}
|
||||
|
|
@ -43,14 +47,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
go gc.Start()
|
||||
|
||||
// Schedule several nodes for deletion with short expiry
|
||||
const expiry = fifty
|
||||
const numNodes = 100
|
||||
const (
|
||||
expiry = fifty
|
||||
numNodes = 100
|
||||
)
|
||||
|
||||
// Set up wait group for expected deletions
|
||||
|
||||
deletionWg.Add(numNodes)
|
||||
|
||||
for i := 1; i <= numNodes; i++ {
|
||||
gc.Schedule(types.NodeID(i), expiry)
|
||||
gc.Schedule(types.NodeID(i), expiry) //nolint:gosec
|
||||
}
|
||||
|
||||
// Wait for all scheduled deletions to complete
|
||||
|
|
@ -63,7 +70,7 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
|
||||
// Schedule and immediately cancel to test that part of the code
|
||||
for i := numNodes + 1; i <= numNodes*2; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec
|
||||
gc.Schedule(nodeID, time.Hour)
|
||||
gc.Cancel(nodeID)
|
||||
}
|
||||
|
|
@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
|
|
@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
|||
|
||||
// Start GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
const shortExpiry = fifty
|
||||
const longExpiry = 1 * time.Hour
|
||||
const (
|
||||
shortExpiry = fifty
|
||||
longExpiry = 1 * time.Hour
|
||||
)
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
|
@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
|||
// and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
}
|
||||
|
||||
// Start the GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
const expiry = fifty
|
||||
|
||||
// Schedule node for deletion
|
||||
|
|
@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
|||
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
||||
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
|
|
@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
nodeDeleted := make(chan struct{})
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
close(nodeDeleted) // Signal that deletion happened
|
||||
}
|
||||
|
|
@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||
// Use a WaitGroup to ensure the GC has started
|
||||
var startWg sync.WaitGroup
|
||||
startWg.Add(1)
|
||||
|
||||
go func() {
|
||||
startWg.Done() // Signal that the goroutine has started
|
||||
gc.Start()
|
||||
}()
|
||||
|
||||
startWg.Wait() // Wait for the GC to start
|
||||
|
||||
// Close GC right away
|
||||
|
|
@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||
|
||||
// Check no node was deleted
|
||||
deleteMutex.Lock()
|
||||
|
||||
nodesDeleted := len(deletedIDs)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
||||
|
||||
|
|
@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
}
|
||||
|
||||
|
|
@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||
go gc.Start()
|
||||
|
||||
// Number of concurrent scheduling goroutines
|
||||
const numSchedulers = 10
|
||||
const nodesPerScheduler = 50
|
||||
const (
|
||||
numSchedulers = 10
|
||||
nodesPerScheduler = 50
|
||||
)
|
||||
|
||||
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
||||
|
||||
|
|
@ -353,8 +394,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||
case <-stopScheduling:
|
||||
return
|
||||
default:
|
||||
nodeID := types.NodeID(baseNodeID + j + 1)
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
nodeID := types.NodeID(baseNodeID + j + 1) //nolint:gosec
|
||||
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
|
||||
atomic.AddInt64(&scheduledCount, 1)
|
||||
|
||||
// Yield to other goroutines to introduce variability
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var mpp = func(pref string) *netip.Prefix {
|
||||
|
|
@ -484,12 +483,13 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer db.Close()
|
||||
|
||||
alloc, err := NewIPAllocator(
|
||||
db,
|
||||
ptr.To(tsaddr.CGNATRange()),
|
||||
ptr.To(tsaddr.TailscaleULARange()),
|
||||
new(tsaddr.CGNATRange()),
|
||||
new(tsaddr.TailscaleULARange()),
|
||||
types.IPAllocationStrategySequential,
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -497,17 +497,17 @@ func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
|||
}
|
||||
|
||||
// Validate that we do not give out 100.100.100.100
|
||||
nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange()))
|
||||
nextQuad100, err := alloc.next(na("100.100.100.99"), new(tsaddr.CGNATRange()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, na("100.100.100.101"), *nextQuad100)
|
||||
|
||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange()))
|
||||
nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), new(tsaddr.TailscaleULARange()))
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6)
|
||||
|
||||
// Validate that we do not give out fd7a:115c:a1e0::53
|
||||
nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange()))
|
||||
nextChrome, err := alloc.next(na("100.115.91.255"), new(tsaddr.CGNATRange()))
|
||||
t.Logf("chrome: %s", nextChrome.String())
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, na("100.115.94.0"), *nextChrome)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
package db
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"regexp"
|
||||
"slices"
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -20,7 +20,6 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
const (
|
||||
|
|
@ -37,6 +36,7 @@ var (
|
|||
"node not found in registration cache",
|
||||
)
|
||||
ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface")
|
||||
ErrNameNotUnique = errors.New("name is not unique")
|
||||
)
|
||||
|
||||
// ListPeers returns peers of node, regardless of any Policy or if the node is expired.
|
||||
|
|
@ -60,7 +60,7 @@ func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types
|
|||
return types.Nodes{}, err
|
||||
}
|
||||
|
||||
sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID })
|
||||
slices.SortFunc(nodes, func(a, b *types.Node) int { return cmp.Compare(a.ID, b.ID) })
|
||||
|
||||
return nodes, nil
|
||||
}
|
||||
|
|
@ -207,6 +207,7 @@ func SetTags(
|
|||
|
||||
slices.Sort(tags)
|
||||
tags = slices.Compact(tags)
|
||||
|
||||
b, err := json.Marshal(tags)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -220,7 +221,7 @@ func SetTags(
|
|||
return nil
|
||||
}
|
||||
|
||||
// SetTags takes a Node struct pointer and update the forced tags.
|
||||
// SetApprovedRoutes updates the approved routes for a node.
|
||||
func SetApprovedRoutes(
|
||||
tx *gorm.DB,
|
||||
nodeID types.NodeID,
|
||||
|
|
@ -228,7 +229,8 @@ func SetApprovedRoutes(
|
|||
) error {
|
||||
if len(routes) == 0 {
|
||||
// if no routes are provided, we remove all
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil {
|
||||
err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("removing approved routes: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -251,6 +253,7 @@ func SetApprovedRoutes(
|
|||
return err
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil {
|
||||
return fmt.Errorf("updating approved routes: %w", err)
|
||||
}
|
||||
|
|
@ -277,18 +280,21 @@ func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error {
|
|||
func RenameNode(tx *gorm.DB,
|
||||
nodeID types.NodeID, newName string,
|
||||
) error {
|
||||
if err := util.ValidateHostname(newName); err != nil {
|
||||
err := util.ValidateHostname(newName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Check if the new name is unique
|
||||
var count int64
|
||||
if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil {
|
||||
|
||||
err = tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check name uniqueness: %w", err)
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return errors.New("name is not unique")
|
||||
return ErrNameNotUnique
|
||||
}
|
||||
|
||||
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil {
|
||||
|
|
@ -379,6 +385,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||
if ipv4 == nil {
|
||||
ipv4 = oldNode.IPv4
|
||||
}
|
||||
|
||||
if ipv6 == nil {
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
|
|
@ -407,6 +414,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||
node.IPv6 = ipv6
|
||||
|
||||
var err error
|
||||
|
||||
node.Hostname, err = util.NormaliseHostname(node.Hostname)
|
||||
if err != nil {
|
||||
newHostname := util.InvalidString()
|
||||
|
|
@ -491,8 +499,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) {
|
|||
|
||||
func isUniqueName(tx *gorm.DB, name string) (bool, error) {
|
||||
nodes := types.Nodes{}
|
||||
if err := tx.
|
||||
Where("given_name = ?", name).Find(&nodes).Error; err != nil {
|
||||
|
||||
err := tx.
|
||||
Where("given_name = ?", name).Find(&nodes).Error
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
|
|
@ -646,7 +656,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
|||
panic("CreateNodeForTest requires a valid user")
|
||||
}
|
||||
|
||||
nodeName := "testnode"
|
||||
nodeName := "testnode" //nolint:goconst
|
||||
if len(hostname) > 0 && hostname[0] != "" {
|
||||
nodeName = hostname[0]
|
||||
}
|
||||
|
|
@ -668,7 +678,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
|
|||
Hostname: nodeName,
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: &pak.ID,
|
||||
}
|
||||
|
||||
err = hsdb.DB.Save(node).Error
|
||||
|
|
@ -694,9 +704,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
|||
}
|
||||
|
||||
var registeredNode *types.Node
|
||||
|
||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
|
||||
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import (
|
|||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestGetNode(t *testing.T) {
|
||||
|
|
@ -99,6 +98,7 @@ func TestExpireNode(t *testing.T) {
|
|||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint:staticcheck
|
||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -115,7 +115,7 @@ func TestExpireNode(t *testing.T) {
|
|||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
Expiry: &time.Time{},
|
||||
}
|
||||
db.DB.Save(node)
|
||||
|
|
@ -143,6 +143,7 @@ func TestSetTags(t *testing.T) {
|
|||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint:staticcheck
|
||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -159,7 +160,7 @@ func TestSetTags(t *testing.T) {
|
|||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
|
||||
trx := db.DB.Save(node)
|
||||
|
|
@ -443,7 +444,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.routes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
}
|
||||
|
||||
err = adb.DB.Save(&node).Error
|
||||
|
|
@ -460,17 +461,17 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
RoutableIPs: tt.routes,
|
||||
},
|
||||
Tags: []string{"tag:exit"},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
err = adb.DB.Save(&nodeTagged).Error
|
||||
require.NoError(t, err)
|
||||
|
||||
users, err := adb.ListUsers()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
nodes, err := adb.ListNodes()
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
pm, err := pmf(users, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
|
@ -498,6 +499,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -509,6 +511,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -597,7 +600,7 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
|
|||
|
||||
// Use shorter expiry for faster tests
|
||||
for i := range want {
|
||||
go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec // test code, no overflow risk
|
||||
go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec
|
||||
}
|
||||
|
||||
// Wait for all deletions to complete
|
||||
|
|
@ -636,9 +639,11 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint:staticcheck
|
||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint:staticcheck
|
||||
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -649,7 +654,7 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||
Hostname: "test",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
|
||||
nodeEph := types.Node{
|
||||
|
|
@ -659,7 +664,7 @@ func TestListEphemeralNodes(t *testing.T) {
|
|||
Hostname: "ephemeral",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pakEph.ID),
|
||||
AuthKeyID: new(pakEph.ID),
|
||||
}
|
||||
|
||||
err = db.DB.Save(&node).Error
|
||||
|
|
@ -719,6 +724,7 @@ func TestNodeNaming(t *testing.T) {
|
|||
// break your network, so they should be replaced when registering
|
||||
// a node.
|
||||
// https://github.com/juanfont/headscale/issues/2343
|
||||
//nolint:gosmopolitan
|
||||
nodeInvalidHostname := types.Node{
|
||||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
|
|
@ -746,12 +752,19 @@ func TestNodeNaming(t *testing.T) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil)
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -810,25 +823,26 @@ func TestNodeNaming(t *testing.T) {
|
|||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "test")
|
||||
})
|
||||
assert.ErrorContains(t, err, "name is not unique")
|
||||
require.ErrorContains(t, err, "name is not unique")
|
||||
|
||||
// Rename invalid chars
|
||||
//nolint:gosmopolitan
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[2].ID, "我的电脑")
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename too short
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[3].ID, "a")
|
||||
})
|
||||
assert.ErrorContains(t, err, "at least 2 characters")
|
||||
require.ErrorContains(t, err, "at least 2 characters")
|
||||
|
||||
// Rename with emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
return RenameNode(tx, nodes[0].ID, "hostname-with-💩")
|
||||
})
|
||||
assert.ErrorContains(t, err, "invalid characters")
|
||||
require.ErrorContains(t, err, "invalid characters")
|
||||
|
||||
// Rename with only emoji
|
||||
err = db.Write(func(tx *gorm.DB) error {
|
||||
|
|
@ -896,12 +910,12 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "chinese_chars_with_dash_rejected",
|
||||
newName: "server-北京-01",
|
||||
newName: "server-北京-01", //nolint:gosmopolitan
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
name: "chinese_only_rejected",
|
||||
newName: "我的电脑",
|
||||
newName: "我的电脑", //nolint:gosmopolitan
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
|
|
@ -911,7 +925,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
|
|||
},
|
||||
{
|
||||
name: "mixed_chinese_emoji_rejected",
|
||||
newName: "测试💻机器",
|
||||
newName: "测试💻机器", //nolint:gosmopolitan
|
||||
wantErr: "invalid characters",
|
||||
},
|
||||
{
|
||||
|
|
@ -1000,6 +1014,7 @@ func TestListPeers(t *testing.T) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
|
@ -1085,6 +1100,7 @@ func TestListNodes(t *testing.T) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -332,7 +332,7 @@ func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// MarkExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
// ExpirePreAuthKey marks a PreAuthKey as expired.
|
||||
func ExpirePreAuthKey(tx *gorm.DB, id uint64) error {
|
||||
now := time.Now()
|
||||
return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error
|
||||
|
|
|
|||
|
|
@ -11,7 +11,6 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestCreatePreAuthKey(t *testing.T) {
|
||||
|
|
@ -24,7 +23,7 @@ func TestCreatePreAuthKey(t *testing.T) {
|
|||
test: func(t *testing.T, db *HSDatabase) {
|
||||
t.Helper()
|
||||
|
||||
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil)
|
||||
_, err := db.CreatePreAuthKey(new(types.UserID(12345)), true, false, nil, nil)
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
|
|
@ -127,7 +126,7 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
|
|||
Hostname: "testest",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(key.ID),
|
||||
AuthKeyID: new(key.ID),
|
||||
}
|
||||
db.DB.Save(&node)
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,9 @@ var (
|
|||
const (
|
||||
// DefaultBusyTimeout is the default busy timeout in milliseconds.
|
||||
DefaultBusyTimeout = 10000
|
||||
// DefaultWALAutocheckpoint is the default WAL autocheckpoint value (number of pages).
|
||||
// SQLite default is 1000 pages.
|
||||
DefaultWALAutocheckpoint = 1000
|
||||
)
|
||||
|
||||
// JournalMode represents SQLite journal_mode pragma values.
|
||||
|
|
@ -310,7 +313,7 @@ func Default(path string) *Config {
|
|||
BusyTimeout: DefaultBusyTimeout,
|
||||
JournalMode: JournalModeWAL,
|
||||
AutoVacuum: AutoVacuumIncremental,
|
||||
WALAutocheckpoint: 1000,
|
||||
WALAutocheckpoint: DefaultWALAutocheckpoint,
|
||||
Synchronous: SynchronousNormal,
|
||||
ForeignKeys: true,
|
||||
TxLock: TxLockImmediate,
|
||||
|
|
@ -362,7 +365,8 @@ func (c *Config) Validate() error {
|
|||
// ToURL builds a properly encoded SQLite connection string using _pragma parameters
|
||||
// compatible with modernc.org/sqlite driver.
|
||||
func (c *Config) ToURL() (string, error) {
|
||||
if err := c.Validate(); err != nil {
|
||||
err := c.Validate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("invalid config: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -372,18 +376,23 @@ func (c *Config) ToURL() (string, error) {
|
|||
if c.BusyTimeout > 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
||||
}
|
||||
|
||||
if c.JournalMode != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
||||
}
|
||||
|
||||
if c.AutoVacuum != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
||||
}
|
||||
|
||||
if c.WALAutocheckpoint >= 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
||||
}
|
||||
|
||||
if c.Synchronous != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
||||
}
|
||||
|
||||
if c.ForeignKeys {
|
||||
pragmas = append(pragmas, "foreign_keys=ON")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
|
|||
t.Errorf("Config.ToURL() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
|
|
@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
|
|||
Path: "",
|
||||
BusyTimeout: -1,
|
||||
}
|
||||
|
||||
_, err := config.ToURL()
|
||||
if err == nil {
|
||||
t.Error("Config.ToURL() with invalid config should return error")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
package sqliteconfig
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
|
@ -101,7 +102,8 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
// Test connection
|
||||
if err := db.Ping(); err != nil {
|
||||
//nolint:noinlineerr
|
||||
if err := db.PingContext(context.Background()); err != nil {
|
||||
t.Fatalf("Failed to ping database: %v", err)
|
||||
}
|
||||
|
||||
|
|
@ -109,8 +111,10 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||
for pragma, expectedValue := range tt.expected {
|
||||
t.Run("pragma_"+pragma, func(t *testing.T) {
|
||||
var actualValue any
|
||||
|
||||
query := "PRAGMA " + pragma
|
||||
err := db.QueryRow(query).Scan(&actualValue)
|
||||
|
||||
err := db.QueryRowContext(context.Background(), query).Scan(&actualValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query %s: %v", query, err)
|
||||
}
|
||||
|
|
@ -178,23 +182,25 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||
);
|
||||
`
|
||||
|
||||
if _, err := db.Exec(schema); err != nil {
|
||||
//nolint:noinlineerr
|
||||
if _, err := db.ExecContext(context.Background(), schema); err != nil {
|
||||
t.Fatalf("Failed to create schema: %v", err)
|
||||
}
|
||||
|
||||
// Insert parent record
|
||||
if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
//nolint:noinlineerr
|
||||
if _, err := db.ExecContext(context.Background(), "INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil {
|
||||
t.Fatalf("Failed to insert parent: %v", err)
|
||||
}
|
||||
|
||||
// Test 1: Valid foreign key should work
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
_, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')")
|
||||
if err != nil {
|
||||
t.Fatalf("Valid foreign key insert failed: %v", err)
|
||||
}
|
||||
|
||||
// Test 2: Invalid foreign key should fail
|
||||
_, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
_, err = db.ExecContext(context.Background(), "INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation, but insert succeeded")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
|
|
@ -204,7 +210,7 @@ func TestForeignKeyConstraintEnforcement(t *testing.T) {
|
|||
}
|
||||
|
||||
// Test 3: Deleting referenced parent should fail
|
||||
_, err = db.Exec("DELETE FROM parent WHERE id = 1")
|
||||
_, err = db.ExecContext(context.Background(), "DELETE FROM parent WHERE id = 1")
|
||||
if err == nil {
|
||||
t.Error("Expected foreign key constraint violation when deleting referenced parent")
|
||||
} else if !contains(err.Error(), "FOREIGN KEY constraint failed") {
|
||||
|
|
@ -249,7 +255,8 @@ func TestJournalModeValidation(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
var actualMode string
|
||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
||||
|
||||
err = db.QueryRowContext(context.Background(), "PRAGMA journal_mode").Scan(&actualMode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,12 +3,20 @@ package db
|
|||
import (
|
||||
"context"
|
||||
"encoding"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"gorm.io/gorm/schema"
|
||||
)
|
||||
|
||||
// Sentinel errors for text serialisation.
|
||||
var (
|
||||
ErrTextUnmarshalFailed = errors.New("failed to unmarshal text value")
|
||||
ErrUnsupportedType = errors.New("unsupported type")
|
||||
ErrTextMarshalerOnly = errors.New("only encoding.TextMarshaler is supported")
|
||||
)
|
||||
|
||||
// Got from https://github.com/xdg-go/strum/blob/main/types.go
|
||||
var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]()
|
||||
|
||||
|
|
@ -17,7 +25,7 @@ func isTextUnmarshaler(rv reflect.Value) bool {
|
|||
}
|
||||
|
||||
func maybeInstantiatePtr(rv reflect.Value) {
|
||||
if rv.Kind() == reflect.Ptr && rv.IsNil() {
|
||||
if rv.Kind() == reflect.Pointer && rv.IsNil() {
|
||||
np := reflect.New(rv.Type().Elem())
|
||||
rv.Set(np)
|
||||
}
|
||||
|
|
@ -36,27 +44,30 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||
|
||||
// If the field is a pointer, we need to dereference it to get the actual type
|
||||
// so we do not end with a second pointer.
|
||||
if fieldValue.Elem().Kind() == reflect.Ptr {
|
||||
if fieldValue.Elem().Kind() == reflect.Pointer {
|
||||
fieldValue = fieldValue.Elem()
|
||||
}
|
||||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
case string:
|
||||
bytes = []byte(v)
|
||||
default:
|
||||
return fmt.Errorf("failed to unmarshal text value: %#v", dbValue)
|
||||
return fmt.Errorf("%w: %#v", ErrTextUnmarshalFailed, dbValue)
|
||||
}
|
||||
|
||||
if isTextUnmarshaler(fieldValue) {
|
||||
maybeInstantiatePtr(fieldValue)
|
||||
f := fieldValue.MethodByName("UnmarshalText")
|
||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||
|
||||
ret := f.Call(args)
|
||||
if !ret[0].IsNil() {
|
||||
//nolint:forcetypeassert
|
||||
return decodingError(field.Name, ret[0].Interface().(error))
|
||||
}
|
||||
|
||||
|
|
@ -65,7 +76,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||
// If it is not a pointer, we need to assign the value to the
|
||||
// field.
|
||||
dstField := field.ReflectValueOf(ctx, dst)
|
||||
if dstField.Kind() == reflect.Ptr {
|
||||
if dstField.Kind() == reflect.Pointer {
|
||||
dstField.Set(fieldValue)
|
||||
} else {
|
||||
dstField.Set(fieldValue.Elem())
|
||||
|
|
@ -73,7 +84,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||
|
||||
return nil
|
||||
} else {
|
||||
return fmt.Errorf("unsupported type: %T", fieldValue.Interface())
|
||||
return fmt.Errorf("%w: %T", ErrUnsupportedType, fieldValue.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -86,9 +97,10 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
|||
// If the value is nil, we return nil, however, go nil values are not
|
||||
// always comparable, particularly when reflection is involved:
|
||||
// https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8
|
||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil
|
||||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
b, err := v.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -96,6 +108,6 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
|||
|
||||
return string(b), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v)
|
||||
return nil, fmt.Errorf("%w, got %T", ErrTextMarshalerOnly, v)
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ var (
|
|||
ErrUserExists = errors.New("user already exists")
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
ErrUserStillHasNodes = errors.New("user not empty: node(s) found")
|
||||
ErrTooManyWhereArgs = errors.New("expect 0 or 1 where User structs")
|
||||
ErrMultipleUsers = errors.New("expected exactly one user")
|
||||
)
|
||||
|
||||
func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
||||
|
|
@ -26,7 +28,8 @@ func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) {
|
|||
// CreateUser creates a new User. Returns error if could not be created
|
||||
// or another user already exists.
|
||||
func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) {
|
||||
if err := util.ValidateHostname(user.Name); err != nil {
|
||||
err := util.ValidateHostname(user.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
|
|
@ -88,10 +91,13 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
|||
// not exist or if another User exists with the new name.
|
||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
var err error
|
||||
|
||||
oldUser, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err = util.ValidateHostname(newName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -151,7 +157,7 @@ func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) {
|
|||
// ListUsers gets all the existing users.
|
||||
func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
||||
if len(where) > 1 {
|
||||
return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where))
|
||||
return nil, fmt.Errorf("%w, got %d", ErrTooManyWhereArgs, len(where))
|
||||
}
|
||||
|
||||
var user *types.User
|
||||
|
|
@ -160,7 +166,9 @@ func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) {
|
|||
}
|
||||
|
||||
users := []types.User{}
|
||||
if err := tx.Where(user).Find(&users).Error; err != nil {
|
||||
|
||||
err := tx.Where(user).Find(&users).Error
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
|
@ -180,7 +188,7 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
|
|||
}
|
||||
|
||||
if len(users) != 1 {
|
||||
return nil, fmt.Errorf("expected exactly one user, found %d", len(users))
|
||||
return nil, fmt.Errorf("%w, found %d", ErrMultipleUsers, len(users))
|
||||
}
|
||||
|
||||
return &users[0], nil
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestCreateAndDestroyUser(t *testing.T) {
|
||||
|
|
@ -71,6 +70,7 @@ func TestDestroyUserErrors(t *testing.T) {
|
|||
user, err := db.CreateUser(types.User{Name: "test"})
|
||||
require.NoError(t, err)
|
||||
|
||||
//nolint:staticcheck
|
||||
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -79,7 +79,7 @@ func TestDestroyUserErrors(t *testing.T) {
|
|||
Hostname: "testnode",
|
||||
UserID: &user.ID,
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
AuthKeyID: ptr.To(pak.ID),
|
||||
AuthKeyID: new(pak.ID),
|
||||
}
|
||||
trx := db.DB.Save(&node)
|
||||
require.NoError(t, trx.Error)
|
||||
|
|
|
|||
|
|
@ -25,34 +25,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
overview := h.state.DebugOverviewJSON()
|
||||
|
||||
overviewJSON, err := json.MarshalIndent(overview, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(overviewJSON)
|
||||
_, _ = w.Write(overviewJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
overview := h.state.DebugOverview()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(overview))
|
||||
_, _ = w.Write([]byte(overview))
|
||||
}
|
||||
}))
|
||||
|
||||
// Configuration endpoint
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config := h.state.DebugConfig()
|
||||
|
||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(configJSON)
|
||||
_, _ = w.Write(configJSON)
|
||||
}))
|
||||
|
||||
// Policy endpoint
|
||||
|
|
@ -70,8 +75,9 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policy))
|
||||
_, _ = w.Write([]byte(policy))
|
||||
}))
|
||||
|
||||
// Filter rules endpoint
|
||||
|
|
@ -81,27 +87,31 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(filterJSON)
|
||||
_, _ = w.Write(filterJSON)
|
||||
}))
|
||||
|
||||
// SSH policies endpoint
|
||||
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sshPolicies := h.state.DebugSSHPolicies()
|
||||
|
||||
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(sshJSON)
|
||||
_, _ = w.Write(sshJSON)
|
||||
}))
|
||||
|
||||
// DERP map endpoint
|
||||
|
|
@ -112,20 +122,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
derpInfo := h.state.DebugDERPJSON()
|
||||
|
||||
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(derpJSON)
|
||||
_, _ = w.Write(derpJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
derpInfo := h.state.DebugDERPMap()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(derpInfo))
|
||||
_, _ = w.Write([]byte(derpInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
|
|
@ -137,34 +150,39 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
nodeStoreNodes := h.state.DebugNodeStoreJSON()
|
||||
|
||||
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(nodeStoreJSON)
|
||||
_, _ = w.Write(nodeStoreJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
nodeStoreInfo := h.state.DebugNodeStore()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(nodeStoreInfo))
|
||||
_, _ = w.Write([]byte(nodeStoreInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
// Registration cache endpoint
|
||||
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cacheInfo := h.state.DebugRegistrationCache()
|
||||
|
||||
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(cacheJSON)
|
||||
_, _ = w.Write(cacheJSON)
|
||||
}))
|
||||
|
||||
// Routes endpoint
|
||||
|
|
@ -175,20 +193,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
routes := h.state.DebugRoutes()
|
||||
|
||||
routesJSON, err := json.MarshalIndent(routes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(routesJSON)
|
||||
_, _ = w.Write(routesJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
routes := h.state.DebugRoutesString()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(routes))
|
||||
_, _ = w.Write([]byte(routes))
|
||||
}
|
||||
}))
|
||||
|
||||
|
|
@ -200,20 +221,23 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
policyManagerInfo := h.state.DebugPolicyManagerJSON()
|
||||
|
||||
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(policyManagerJSON)
|
||||
_, _ = w.Write(policyManagerJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
policyManagerInfo := h.state.DebugPolicyManager()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policyManagerInfo))
|
||||
_, _ = w.Write([]byte(policyManagerInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
|
|
@ -226,7 +250,8 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if res == nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
_, _ = w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -235,9 +260,10 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(resJSON)
|
||||
_, _ = w.Write(resJSON)
|
||||
}))
|
||||
|
||||
// Batcher endpoint
|
||||
|
|
@ -257,14 +283,14 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(batcherJSON)
|
||||
_, _ = w.Write(batcherJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
batcherInfo := h.debugBatcher()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(batcherInfo))
|
||||
_, _ = w.Write([]byte(batcherInfo))
|
||||
}
|
||||
}))
|
||||
|
||||
|
|
@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
|
|||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
|
@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
|
|||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
|||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
|
|||
for id := range dm.Regions {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
|
|
@ -160,16 +161,20 @@ func derpRandom() *rand.Rand {
|
|||
|
||||
derpRandomOnce.Do(func() {
|
||||
seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String())
|
||||
//nolint:gosec
|
||||
rnd := rand.New(rand.NewSource(0))
|
||||
//nolint:gosec
|
||||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
|
||||
derpRandomInst = rnd
|
||||
})
|
||||
|
||||
return derpRandomInst
|
||||
}
|
||||
|
||||
func resetDerpRandomForTesting() {
|
||||
derpRandomMu.Lock()
|
||||
defer derpRandomMu.Unlock()
|
||||
|
||||
derpRandomOnce = sync.Once{}
|
||||
derpRandomInst = nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
viper.Set("dns.base_domain", tt.baseDomain)
|
||||
|
||||
defer viper.Reset()
|
||||
|
||||
resetDerpRandomForTesting()
|
||||
|
||||
testMap := tt.derpMap.View().AsStruct()
|
||||
|
|
|
|||
|
|
@ -74,9 +74,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
var host string
|
||||
var port int
|
||||
var portStr string
|
||||
|
||||
var (
|
||||
host string
|
||||
port int
|
||||
portStr string
|
||||
)
|
||||
|
||||
// Extract hostname and port from URL
|
||||
host, portStr, err = net.SplitHostPort(serverURL.Host)
|
||||
|
|
@ -97,12 +100,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||
|
||||
// If debug flag is set, resolve hostname to IP address
|
||||
if debugUseDERPIP {
|
||||
ips, err := net.LookupIP(host)
|
||||
addrs, err := net.DefaultResolver.LookupIPAddr(context.Background(), host)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host)
|
||||
} else if len(ips) > 0 {
|
||||
} else if len(addrs) > 0 {
|
||||
// Use the first IP address
|
||||
ipStr := ips[0].String()
|
||||
ipStr := addrs[0].IP.String()
|
||||
log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr)
|
||||
host = ipStr
|
||||
}
|
||||
|
|
@ -205,6 +208,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
|||
return
|
||||
}
|
||||
defer websocketConn.Close(websocket.StatusInternalError, "closing")
|
||||
|
||||
if websocketConn.Subprotocol() != "derp" {
|
||||
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
|
||||
|
||||
|
|
@ -309,6 +313,8 @@ func DERPBootstrapDNSHandler(
|
|||
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
||||
defer cancel()
|
||||
var resolver net.Resolver
|
||||
|
||||
//nolint:unqueryvet
|
||||
for _, region := range derpMap.Regions().All() {
|
||||
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
|
||||
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
|
||||
|
|
@ -320,6 +326,7 @@ func DERPBootstrapDNSHandler(
|
|||
|
||||
continue
|
||||
}
|
||||
|
||||
dnsEntries[node.HostName()] = addrs
|
||||
}
|
||||
}
|
||||
|
|
@ -411,7 +418,9 @@ type DERPVerifyTransport struct {
|
|||
|
||||
func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
buf := new(bytes.Buffer)
|
||||
if err := t.handleVerifyRequest(req, buf); err != nil {
|
||||
|
||||
err := t.handleVerifyRequest(req, buf)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ")
|
||||
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
|
|
@ -15,6 +16,9 @@ import (
|
|||
"tailscale.com/util/set"
|
||||
)
|
||||
|
||||
// ErrPathIsDirectory is returned when a path is a directory instead of a file.
|
||||
var ErrPathIsDirectory = errors.New("path is a directory, only file is supported")
|
||||
|
||||
type ExtraRecordsMan struct {
|
||||
mu sync.RWMutex
|
||||
records set.Set[tailcfg.DNSRecord]
|
||||
|
|
@ -39,7 +43,7 @@ func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) {
|
|||
}
|
||||
|
||||
if fi.IsDir() {
|
||||
return nil, fmt.Errorf("path is a directory, only file is supported: %s", path)
|
||||
return nil, fmt.Errorf("%w: %s", ErrPathIsDirectory, path)
|
||||
}
|
||||
|
||||
records, hash, err := readExtraRecordsFromPath(path)
|
||||
|
|
@ -85,18 +89,22 @@ func (e *ExtraRecordsMan) Run() {
|
|||
log.Error().Caller().Msgf("file watcher event channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Op {
|
||||
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
|
||||
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
|
||||
|
||||
if event.Name != e.path {
|
||||
continue
|
||||
}
|
||||
|
||||
e.updateRecords()
|
||||
|
||||
// If a file is removed or renamed, fsnotify will loose track of it
|
||||
// and not watch it. We will therefore attempt to re-add it with a backoff.
|
||||
case fsnotify.Remove, fsnotify.Rename:
|
||||
_, err := backoff.Retry(context.Background(), func() (struct{}, error) {
|
||||
//nolint:noinlineerr
|
||||
if _, err := os.Stat(e.path); err != nil {
|
||||
return struct{}{}, err
|
||||
}
|
||||
|
|
@ -123,6 +131,7 @@ func (e *ExtraRecordsMan) Run() {
|
|||
log.Error().Caller().Msgf("file watcher error channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -165,6 +174,7 @@ func (e *ExtraRecordsMan) updateRecords() {
|
|||
e.hashes[e.path] = newHash
|
||||
|
||||
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
|
||||
|
||||
e.updateCh <- e.records.Slice()
|
||||
}
|
||||
|
||||
|
|
@ -183,6 +193,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
|
|||
}
|
||||
|
||||
var records []tailcfg.DNSRecord
|
||||
|
||||
err = json.Unmarshal(b, &records)
|
||||
if err != nil {
|
||||
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
package hscontrol
|
||||
|
||||
import (
|
||||
"cmp"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
|
@ -11,7 +12,6 @@ import (
|
|||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
|
|
@ -135,8 +135,8 @@ func (api headscaleV1APIServer) ListUsers(
|
|||
response[index] = user.Proto()
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
slices.SortFunc(response, func(a, b *v1.User) int {
|
||||
return cmp.Compare(a.Id, b.Id)
|
||||
})
|
||||
|
||||
return &v1.ListUsersResponse{Users: response}, nil
|
||||
|
|
@ -221,8 +221,8 @@ func (api headscaleV1APIServer) ListPreAuthKeys(
|
|||
response[index] = key.Proto()
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
slices.SortFunc(response, func(a, b *v1.PreAuthKey) int {
|
||||
return cmp.Compare(a.Id, b.Id)
|
||||
})
|
||||
|
||||
return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil
|
||||
|
|
@ -387,7 +387,7 @@ func (api headscaleV1APIServer) SetApprovedRoutes(
|
|||
newApproved = append(newApproved, prefix)
|
||||
}
|
||||
}
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
slices.SortFunc(newApproved, netip.Prefix.Compare)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
|
||||
node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved)
|
||||
|
|
@ -535,8 +535,8 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N
|
|||
response[index] = resp
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
slices.SortFunc(response, func(a, b *v1.Node) int {
|
||||
return cmp.Compare(a.Id, b.Id)
|
||||
})
|
||||
|
||||
return response
|
||||
|
|
@ -632,8 +632,8 @@ func (api headscaleV1APIServer) ListApiKeys(
|
|||
response[index] = key.Proto()
|
||||
}
|
||||
|
||||
sort.Slice(response, func(i, j int) bool {
|
||||
return response[i].Id < response[j].Id
|
||||
slices.SortFunc(response, func(a, b *v1.ApiKey) int {
|
||||
return cmp.Compare(a.Id, b.Id)
|
||||
})
|
||||
|
||||
return &v1.ListApiKeysResponse{ApiKeys: response}, nil
|
||||
|
|
|
|||
|
|
@ -36,8 +36,7 @@ const (
|
|||
|
||||
// httpError logs an error and sends an HTTP error response with the given.
|
||||
func httpError(w http.ResponseWriter, err error) {
|
||||
var herr HTTPError
|
||||
if errors.As(err, &herr) {
|
||||
if herr, ok := errors.AsType[HTTPError](err); ok {
|
||||
http.Error(w, herr.Msg, herr.Code)
|
||||
log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg)
|
||||
} else {
|
||||
|
|
@ -56,7 +55,7 @@ type HTTPError struct {
|
|||
func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) }
|
||||
func (e HTTPError) Unwrap() error { return e.Err }
|
||||
|
||||
// Error returns an HTTPError containing the given information.
|
||||
// NewHTTPError returns an HTTPError containing the given information.
|
||||
func NewHTTPError(code int, msg string, err error) HTTPError {
|
||||
return HTTPError{Code: code, Msg: msg, Err: err}
|
||||
}
|
||||
|
|
@ -92,6 +91,7 @@ func (h *Headscale) handleVerifyRequest(
|
|||
}
|
||||
|
||||
var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest
|
||||
//nolint:noinlineerr
|
||||
if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil {
|
||||
return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err))
|
||||
}
|
||||
|
|
@ -155,7 +155,11 @@ func (h *Headscale) KeyHandler(
|
|||
}
|
||||
|
||||
writer.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(writer).Encode(resp)
|
||||
|
||||
err := json.NewEncoder(writer).Encode(resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to encode key response")
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
|
@ -180,8 +184,12 @@ func (h *Headscale) HealthHandler(
|
|||
res.Status = "fail"
|
||||
}
|
||||
|
||||
json.NewEncoder(writer).Encode(res)
|
||||
//nolint:noinlineerr
|
||||
if err := json.NewEncoder(writer).Encode(res); err != nil {
|
||||
log.Error().Err(err).Msg("failed to encode health response")
|
||||
}
|
||||
}
|
||||
|
||||
err := h.state.PingDB(req.Context())
|
||||
if err != nil {
|
||||
respond(err)
|
||||
|
|
@ -218,6 +226,7 @@ func (h *Headscale) VersionHandler(
|
|||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
versionInfo := types.GetVersionInfo()
|
||||
|
||||
err := json.NewEncoder(writer).Encode(versionInfo)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
|
@ -267,7 +276,7 @@ func (a *AuthProviderWeb) RegisterHandler(
|
|||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||
_, _ = writer.Write([]byte(templates.RegisterWeb(registrationId).Render()))
|
||||
}
|
||||
|
||||
func FaviconHandler(writer http.ResponseWriter, req *http.Request) {
|
||||
|
|
|
|||
|
|
@ -15,6 +15,17 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// Sentinel errors for batcher operations.
|
||||
var (
|
||||
ErrInvalidNodeID = errors.New("invalid nodeID")
|
||||
ErrMapperNil = errors.New("mapper is nil")
|
||||
ErrNodeConnectionNil = errors.New("nodeConnection is nil")
|
||||
)
|
||||
|
||||
// workChannelMultiplier is the multiplier for work channel capacity based on worker count.
|
||||
// The size is arbitrary chosen, the sizing should be revisited.
|
||||
const workChannelMultiplier = 200
|
||||
|
||||
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
|
||||
Namespace: "headscale",
|
||||
Name: "mapresponse_generated_total",
|
||||
|
|
@ -42,8 +53,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
|
|||
workers: workers,
|
||||
tick: time.NewTicker(batchTime),
|
||||
|
||||
// The size of this channel is arbitrary chosen, the sizing should be revisited.
|
||||
workCh: make(chan work, workers*200),
|
||||
workCh: make(chan work, workers*workChannelMultiplier),
|
||||
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
|
||||
connected: xsync.NewMap[types.NodeID, *time.Time](),
|
||||
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
|
||||
|
|
@ -76,20 +86,20 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
|||
version := nc.version()
|
||||
|
||||
if r.IsEmpty() {
|
||||
return nil, nil //nolint:nilnil // Empty response means nothing to send
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
if nodeID == 0 {
|
||||
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
|
||||
return nil, fmt.Errorf("%w: %d", ErrInvalidNodeID, nodeID)
|
||||
}
|
||||
|
||||
if mapper == nil {
|
||||
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
|
||||
return nil, fmt.Errorf("%w for nodeID %d", ErrMapperNil, nodeID)
|
||||
}
|
||||
|
||||
// Handle self-only responses
|
||||
if r.IsSelfOnly() && r.TargetNode != nodeID {
|
||||
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// Check if this is a self-update (the changed node is the receiving node).
|
||||
|
|
@ -135,7 +145,7 @@ func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*t
|
|||
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
|
||||
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
|
||||
if nc == nil {
|
||||
return errors.New("nodeConnection is nil")
|
||||
return ErrNodeConnectionNil
|
||||
}
|
||||
|
||||
nodeID := nc.nodeID()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package mapper
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
|
@ -13,10 +14,35 @@ import (
|
|||
"github.com/puzpuzpuz/xsync/v4"
|
||||
"github.com/rs/zerolog/log"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var errConnectionClosed = errors.New("connection channel already closed")
|
||||
// Sentinel errors for lock-free batcher operations.
|
||||
var (
|
||||
errConnectionClosed = errors.New("connection channel already closed")
|
||||
ErrInitialMapTimeout = errors.New("failed to send initial map: timeout")
|
||||
ErrNodeNotFound = errors.New("node not found")
|
||||
ErrBatcherShutdown = errors.New("batcher shutting down")
|
||||
ErrConnectionTimeout = errors.New("connection timeout sending to channel (likely stale connection)")
|
||||
)
|
||||
|
||||
// Batcher configuration constants.
|
||||
const (
|
||||
// initialMapSendTimeout is the timeout for sending the initial map response to a new connection.
|
||||
initialMapSendTimeout = 5 * time.Second
|
||||
|
||||
// offlineNodeCleanupThreshold is how long a node must be offline before it's cleaned up.
|
||||
offlineNodeCleanupThreshold = 15 * time.Minute
|
||||
|
||||
// offlineNodeCleanupInterval is the interval between cleanup runs.
|
||||
offlineNodeCleanupInterval = 5 * time.Minute
|
||||
|
||||
// connectionSendTimeout is the timeout for detecting stale connections.
|
||||
// Kept short to quickly detect Docker containers that are forcefully terminated.
|
||||
connectionSendTimeout = 50 * time.Millisecond
|
||||
|
||||
// connectionIDBytes is the number of random bytes used for connection IDs.
|
||||
connectionIDBytes = 8
|
||||
)
|
||||
|
||||
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
|
||||
type LockFreeBatcher struct {
|
||||
|
|
@ -78,6 +104,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
|
|
@ -86,12 +113,13 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||
select {
|
||||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
||||
case <-time.After(initialMapSendTimeout):
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(ErrInitialMapTimeout).Msg("Initial map send timeout")
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", initialMapSendTimeout).
|
||||
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
|
||||
return fmt.Errorf("%w for node %d", ErrInitialMapTimeout, id)
|
||||
}
|
||||
|
||||
// Update connection status
|
||||
|
|
@ -130,13 +158,14 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
|||
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection removed but keeping online because other connections remain")
|
||||
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
// No active connections - keep the node entry alive for rapid reconnections
|
||||
// The node will get a fresh full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection")
|
||||
b.connected.Store(id, ptr.To(time.Now()))
|
||||
b.connected.Store(id, new(time.Now()))
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
@ -177,7 +206,7 @@ func (b *LockFreeBatcher) doWork() {
|
|||
}
|
||||
|
||||
// Create a cleanup ticker for removing truly disconnected nodes
|
||||
cleanupTicker := time.NewTicker(5 * time.Minute)
|
||||
cleanupTicker := time.NewTicker(offlineNodeCleanupInterval)
|
||||
defer cleanupTicker.Stop()
|
||||
|
||||
for {
|
||||
|
|
@ -212,10 +241,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
var err error
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
|
|
@ -229,7 +260,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||
nc.updateSentPeers(result.mapResponse)
|
||||
}
|
||||
} else {
|
||||
result.err = fmt.Errorf("node %d not found", w.nodeID)
|
||||
result.err = fmt.Errorf("%w: %d", ErrNodeNotFound, w.nodeID)
|
||||
|
||||
b.workErrors.Add(1)
|
||||
log.Error().Err(result.err).
|
||||
|
|
@ -383,14 +414,13 @@ func (b *LockFreeBatcher) processBatchedChanges() {
|
|||
// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks.
|
||||
// TODO(kradalby): reevaluate if we want to keep this.
|
||||
func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
||||
cleanupThreshold := 15 * time.Minute
|
||||
now := time.Now()
|
||||
|
||||
var nodesToCleanup []types.NodeID
|
||||
|
||||
// Find nodes that have been offline for too long
|
||||
b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool {
|
||||
if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold {
|
||||
if disconnectTime != nil && now.Sub(*disconnectTime) > offlineNodeCleanupThreshold {
|
||||
// Double-check the node doesn't have active connections
|
||||
if nodeConn, exists := b.nodes.Load(nodeID); exists {
|
||||
if !nodeConn.hasActiveConnections() {
|
||||
|
|
@ -398,13 +428,14 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
// Clean up the identified nodes
|
||||
for _, nodeID := range nodesToCleanup {
|
||||
log.Info().Uint64("node.id", nodeID.Uint64()).
|
||||
Dur("offline_duration", cleanupThreshold).
|
||||
Dur("offline_duration", offlineNodeCleanupThreshold).
|
||||
Msg("Cleaning up node that has been offline for too long")
|
||||
|
||||
b.nodes.Delete(nodeID)
|
||||
|
|
@ -450,6 +481,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -465,6 +497,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -484,7 +517,7 @@ func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Chang
|
|||
case result := <-resultCh:
|
||||
return result.mapResponse, result.err
|
||||
case <-b.done:
|
||||
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
|
||||
return nil, fmt.Errorf("%w: generating map response for node %d", ErrBatcherShutdown, id)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -517,9 +550,10 @@ type multiChannelNodeConn struct {
|
|||
|
||||
// generateConnectionID generates a unique connection identifier.
|
||||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
bytes := make([]byte, connectionIDBytes)
|
||||
_, _ = rand.Read(bytes)
|
||||
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
|
|
@ -546,11 +580,14 @@ func (mc *multiChannelNodeConn) close() {
|
|||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
|
|
@ -572,9 +609,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
|
|||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("Successfully removed connection")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -608,6 +647,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
// This is not an error - the node will receive a full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
|
|
@ -616,7 +656,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
|
||||
successCount := 0
|
||||
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
|
|
@ -625,8 +667,10 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: attempting to send to connection")
|
||||
|
||||
if err := conn.send(data); err != nil {
|
||||
err := conn.send(data)
|
||||
if err != nil {
|
||||
lastErr = err
|
||||
|
||||
failedConnections = append(failedConnections, i)
|
||||
log.Warn().Err(err).
|
||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
|
|
@ -634,6 +678,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: successfully sent to connection")
|
||||
|
|
@ -685,10 +730,10 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
|
|||
// Update last used timestamp on successful send
|
||||
entry.lastUsed.Store(time.Now().Unix())
|
||||
return nil
|
||||
case <-time.After(50 * time.Millisecond):
|
||||
case <-time.After(connectionSendTimeout):
|
||||
// Connection is likely stale - client isn't reading from channel
|
||||
// This catches the case where Docker containers are killed but channels remain open
|
||||
return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id)
|
||||
return fmt.Errorf("%w: connection %s", ErrConnectionTimeout, entry.id)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -798,6 +843,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -812,6 +858,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
|
@ -35,6 +36,7 @@ type batcherTestCase struct {
|
|||
// that would normally be sent by poll.go in production.
|
||||
type testBatcherWrapper struct {
|
||||
Batcher
|
||||
|
||||
state *state.State
|
||||
}
|
||||
|
||||
|
|
@ -80,12 +82,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
|
|||
}
|
||||
|
||||
// Finally remove from the real batcher
|
||||
removed := t.Batcher.RemoveNode(id, c)
|
||||
if !removed {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
return t.Batcher.RemoveNode(id, c)
|
||||
}
|
||||
|
||||
// wrapBatcherForTest wraps a batcher with test-specific behavior.
|
||||
|
|
@ -129,15 +126,13 @@ const (
|
|||
SMALL_BUFFER_SIZE = 3
|
||||
TINY_BUFFER_SIZE = 1 // For maximum contention
|
||||
LARGE_BUFFER_SIZE = 200
|
||||
|
||||
reservedResponseHeaderSize = 4
|
||||
)
|
||||
|
||||
// TestData contains all test entities created for a test scenario.
|
||||
type TestData struct {
|
||||
Database *db.HSDatabase
|
||||
Users []*types.User
|
||||
Nodes []node
|
||||
Nodes []*node
|
||||
State *state.State
|
||||
Config *types.Config
|
||||
Batcher Batcher
|
||||
|
|
@ -223,11 +218,11 @@ func setupBatcherWithTestData(
|
|||
// Create test users and nodes in the database
|
||||
users := database.CreateUsersForTest(userCount, "testuser")
|
||||
|
||||
allNodes := make([]node, 0, userCount*nodesPerUser)
|
||||
allNodes := make([]*node, 0, userCount*nodesPerUser)
|
||||
for _, user := range users {
|
||||
dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node")
|
||||
for i := range dbNodes {
|
||||
allNodes = append(allNodes, node{
|
||||
allNodes = append(allNodes, &node{
|
||||
n: dbNodes[i],
|
||||
ch: make(chan *tailcfg.MapResponse, bufferSize),
|
||||
})
|
||||
|
|
@ -241,8 +236,8 @@ func setupBatcherWithTestData(
|
|||
}
|
||||
|
||||
derpMap, err := derp.GetDERPMap(cfg.DERP)
|
||||
assert.NoError(t, err)
|
||||
assert.NotNil(t, derpMap)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, derpMap)
|
||||
|
||||
state.SetDERPMap(derpMap)
|
||||
|
||||
|
|
@ -318,23 +313,6 @@ func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) {
|
|||
stats.LastUpdate = time.Now()
|
||||
}
|
||||
|
||||
// getStats returns a copy of the statistics for a node.
|
||||
func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats {
|
||||
ut.mu.RLock()
|
||||
defer ut.mu.RUnlock()
|
||||
|
||||
if stats, exists := ut.stats[nodeID]; exists {
|
||||
// Return a copy to avoid race conditions
|
||||
return UpdateStats{
|
||||
TotalUpdates: stats.TotalUpdates,
|
||||
UpdateSizes: append([]int{}, stats.UpdateSizes...),
|
||||
LastUpdate: stats.LastUpdate,
|
||||
}
|
||||
}
|
||||
|
||||
return UpdateStats{}
|
||||
}
|
||||
|
||||
// getAllStats returns a copy of all statistics.
|
||||
func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats {
|
||||
ut.mu.RLock()
|
||||
|
|
@ -344,7 +322,7 @@ func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats {
|
|||
for nodeID, stats := range ut.stats {
|
||||
result[nodeID] = UpdateStats{
|
||||
TotalUpdates: stats.TotalUpdates,
|
||||
UpdateSizes: append([]int{}, stats.UpdateSizes...),
|
||||
UpdateSizes: slices.Clone(stats.UpdateSizes),
|
||||
LastUpdate: stats.LastUpdate,
|
||||
}
|
||||
}
|
||||
|
|
@ -386,16 +364,14 @@ type UpdateInfo struct {
|
|||
}
|
||||
|
||||
// parseUpdateAndAnalyze parses an update and returns detailed information.
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) {
|
||||
info := UpdateInfo{
|
||||
func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) UpdateInfo {
|
||||
return UpdateInfo{
|
||||
PeerCount: len(resp.Peers),
|
||||
PatchCount: len(resp.PeersChangedPatch),
|
||||
IsFull: len(resp.Peers) > 0,
|
||||
IsPatch: len(resp.PeersChangedPatch) > 0,
|
||||
IsDERP: resp.DERPMap != nil,
|
||||
}
|
||||
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// start begins consuming updates from the node's channel and tracking stats.
|
||||
|
|
@ -417,36 +393,36 @@ func (n *node) start() {
|
|||
atomic.AddInt64(&n.updateCount, 1)
|
||||
|
||||
// Parse update and track detailed stats
|
||||
if info, err := parseUpdateAndAnalyze(data); err == nil {
|
||||
// Track update types
|
||||
if info.IsFull {
|
||||
atomic.AddInt64(&n.fullCount, 1)
|
||||
n.lastPeerCount.Store(int64(info.PeerCount))
|
||||
// Update max peers seen using compare-and-swap for thread safety
|
||||
for {
|
||||
current := n.maxPeersCount.Load()
|
||||
if int64(info.PeerCount) <= current {
|
||||
break
|
||||
}
|
||||
info := parseUpdateAndAnalyze(data)
|
||||
|
||||
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
|
||||
break
|
||||
}
|
||||
// Track update types
|
||||
if info.IsFull {
|
||||
atomic.AddInt64(&n.fullCount, 1)
|
||||
n.lastPeerCount.Store(int64(info.PeerCount))
|
||||
// Update max peers seen using compare-and-swap for thread safety
|
||||
for {
|
||||
current := n.maxPeersCount.Load()
|
||||
if int64(info.PeerCount) <= current {
|
||||
break
|
||||
}
|
||||
|
||||
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if info.IsPatch {
|
||||
atomic.AddInt64(&n.patchCount, 1)
|
||||
// For patches, we track how many patch items using compare-and-swap
|
||||
for {
|
||||
current := n.maxPeersCount.Load()
|
||||
if int64(info.PatchCount) <= current {
|
||||
break
|
||||
}
|
||||
if info.IsPatch {
|
||||
atomic.AddInt64(&n.patchCount, 1)
|
||||
// For patches, we track how many patch items using compare-and-swap
|
||||
for {
|
||||
current := n.maxPeersCount.Load()
|
||||
if int64(info.PatchCount) <= current {
|
||||
break
|
||||
}
|
||||
|
||||
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
|
||||
break
|
||||
}
|
||||
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -540,7 +516,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
|||
defer cleanup()
|
||||
|
||||
batcher := testData.Batcher
|
||||
testNode := &testData.Nodes[0]
|
||||
testNode := testData.Nodes[0]
|
||||
|
||||
t.Logf("Testing enhanced tracking with node ID %d", testNode.n.ID)
|
||||
|
||||
|
|
@ -548,7 +524,7 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
|
|||
testNode.start()
|
||||
|
||||
// Connect the node to the batcher
|
||||
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Wait for connection to be established
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
|
|
@ -656,8 +632,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||
t.Logf("Joining %d nodes as fast as possible...", len(allNodes))
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
node := allNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Issue full update after each join to ensure connectivity
|
||||
batcher.AddWork(change.FullUpdate())
|
||||
|
|
@ -676,8 +652,9 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
connectedCount := 0
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
node := allNodes[i]
|
||||
|
||||
currentMaxPeers := int(node.maxPeersCount.Load())
|
||||
if currentMaxPeers >= expectedPeers {
|
||||
|
|
@ -693,11 +670,12 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
||||
|
||||
t.Logf("✅ All nodes achieved full connectivity!")
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
|
||||
// Disconnect all nodes
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
node := allNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
|
|
@ -718,7 +696,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||
nodeDetails := make([]string, 0, min(10, len(allNodes)))
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
node := allNodes[i]
|
||||
stats := node.cleanup()
|
||||
|
||||
totalUpdates += stats.TotalUpdates
|
||||
|
|
@ -824,7 +802,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||
tn2 := testData.Nodes[1]
|
||||
|
||||
// Test AddNode with real node ID
|
||||
batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
_ = batcher.AddNode(tn.n.ID, tn.ch, 100)
|
||||
|
||||
if !batcher.IsConnected(tn.n.ID) {
|
||||
t.Error("Node should be connected after AddNode")
|
||||
|
|
@ -845,7 +823,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||
drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond)
|
||||
|
||||
// Add the second node and verify update message
|
||||
batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
_ = batcher.AddNode(tn2.n.ID, tn2.ch, 100)
|
||||
assert.True(t, batcher.IsConnected(tn2.n.ID))
|
||||
|
||||
// First node should get an update that second node has connected.
|
||||
|
|
@ -911,7 +889,7 @@ func TestBatcherBasicOperations(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) {
|
||||
func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, _ string, timeout time.Duration) {
|
||||
count := 0
|
||||
|
||||
timer := time.NewTimer(timeout)
|
||||
|
|
@ -1050,7 +1028,7 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
|||
testNodes := testData.Nodes
|
||||
|
||||
ch := make(chan *tailcfg.MapResponse, 10)
|
||||
batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Track update content for validation
|
||||
var receivedUpdates []*tailcfg.MapResponse
|
||||
|
|
@ -1130,6 +1108,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
|
|||
// The test verifies that channels are closed synchronously and deterministically
|
||||
// even when real node updates are being processed, ensuring no race conditions
|
||||
// occur during channel replacement with actual workload.
|
||||
//
|
||||
|
||||
func XTestBatcherChannelClosingRace(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
|
|
@ -1154,7 +1134,7 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||
ch1 := make(chan *tailcfg.MapResponse, 1)
|
||||
|
||||
wg.Go(func() {
|
||||
batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Add real work during connection chaos
|
||||
|
|
@ -1167,7 +1147,8 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
|
|||
|
||||
wg.Go(func() {
|
||||
runtime.Gosched() // Yield to introduce timing variability
|
||||
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
|
||||
_ = batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
|
||||
})
|
||||
|
||||
// Remove second connection
|
||||
|
|
@ -1258,7 +1239,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||
ch := make(chan *tailcfg.MapResponse, 5)
|
||||
|
||||
// Add node and immediately queue real work
|
||||
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
batcher.AddWork(change.DERPMap())
|
||||
|
||||
// Consumer goroutine to validate data and detect channel issues
|
||||
|
|
@ -1308,6 +1289,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||
for range i % 3 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Yield to allow workers to process and close channels
|
||||
|
|
@ -1350,6 +1332,8 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||
// real node data. The test validates that stable clients continue to function
|
||||
// normally and receive proper updates despite the connection churn from other clients,
|
||||
// ensuring system stability under concurrent load.
|
||||
//
|
||||
//nolint:gocyclo
|
||||
func TestBatcherConcurrentClients(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping concurrent client test in short mode")
|
||||
|
|
@ -1380,7 +1364,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
for _, node := range stableNodes {
|
||||
ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE)
|
||||
stableChannels[node.n.ID] = ch
|
||||
batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Monitor updates for each stable client
|
||||
go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) {
|
||||
|
|
@ -1391,6 +1375,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, reason := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
|
|
@ -1448,10 +1433,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
// Consume updates to prevent blocking
|
||||
go func() {
|
||||
|
|
@ -1462,6 +1449,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
|
|
@ -1494,6 +1482,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
for range i % 5 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
|
|
@ -1623,6 +1612,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
// It validates that the system remains stable with no deadlocks, panics, or
|
||||
// missed updates under sustained high load. The test uses real node data to
|
||||
// generate authentic update scenarios and tracks comprehensive statistics.
|
||||
//
|
||||
//nolint:gocyclo,thelper
|
||||
func XTestBatcherScalability(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping scalability test in short mode")
|
||||
|
|
@ -1651,6 +1642,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||
description string
|
||||
}
|
||||
|
||||
//nolint:prealloc
|
||||
var testCases []testCase
|
||||
|
||||
// Generate all combinations of the test matrix
|
||||
|
|
@ -1761,8 +1753,9 @@ func XTestBatcherScalability(t *testing.T) {
|
|||
var connectedNodesMutex sync.RWMutex
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
node := testNodes[i]
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[node.n.ID] = true
|
||||
|
|
@ -1878,6 +1871,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||
channel,
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
|
|
@ -1991,7 +1985,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||
|
||||
// Now disconnect all nodes from batcher to stop new updates
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
node := testNodes[i]
|
||||
batcher.RemoveNode(node.n.ID, node.ch)
|
||||
}
|
||||
|
||||
|
|
@ -2010,7 +2004,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||
nodeStatsReport := make([]string, 0, len(testNodes))
|
||||
|
||||
for i := range testNodes {
|
||||
node := &testNodes[i]
|
||||
node := testNodes[i]
|
||||
stats := node.cleanup()
|
||||
totalUpdates += stats.TotalUpdates
|
||||
totalPatches += stats.PatchUpdates
|
||||
|
|
@ -2139,7 +2133,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
|
|||
|
||||
// Connect nodes one at a time and wait for each to be connected
|
||||
for i, node := range allNodes {
|
||||
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
_ = batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
|
||||
|
||||
// Wait for node to be connected
|
||||
|
|
@ -2286,6 +2280,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
|
||||
for i, node := range allNodes {
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
|
|
@ -2302,6 +2297,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
|
||||
for i, node := range allNodes {
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
|
|
@ -2309,9 +2305,11 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
|
|
@ -2342,11 +2340,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
|
|
@ -2381,6 +2381,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
|
|
@ -2399,6 +2400,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func TestBatcherMultiConnection(t *testing.T) {
|
||||
for _, batcherFunc := range allBatcherFunctions {
|
||||
t.Run(batcherFunc.name, func(t *testing.T) {
|
||||
|
|
@ -2413,6 +2415,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
|
|
@ -2432,7 +2435,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
|
|
@ -2443,7 +2448,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
|
|
@ -2454,6 +2461,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
|
|
@ -2461,6 +2469,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
|
|
@ -2469,6 +2478,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
|
|
@ -2651,9 +2661,9 @@ func TestNodeDeletedWhileChangesPending(t *testing.T) {
|
|||
|
||||
batcher := testData.Batcher
|
||||
st := testData.State
|
||||
node1 := &testData.Nodes[0]
|
||||
node2 := &testData.Nodes[1]
|
||||
node3 := &testData.Nodes[2]
|
||||
node1 := testData.Nodes[0]
|
||||
node2 := testData.Nodes[1]
|
||||
node3 := testData.Nodes[2]
|
||||
|
||||
t.Logf("Testing issue #2924: Node1=%d, Node2=%d, Node3=%d",
|
||||
node1.n.ID, node2.n.ID, node3.n.ID)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
package mapper
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"cmp"
|
||||
"net/netip"
|
||||
"sort"
|
||||
"slices"
|
||||
"time"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
|
|
@ -36,6 +36,7 @@ const (
|
|||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
|
|
@ -69,7 +70,7 @@ func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVers
|
|||
func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
|
||||
nv, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFound)
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -123,6 +124,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -130,7 +132,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||
func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFound)
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -149,7 +151,7 @@ func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder {
|
|||
func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFound)
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -162,7 +164,7 @@ func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder {
|
|||
func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFound)
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -175,7 +177,7 @@ func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView])
|
|||
func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
b.addError(errors.New("node not found"))
|
||||
b.addError(ErrNodeNotFound)
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -229,7 +231,7 @@ func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView])
|
|||
func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) {
|
||||
node, ok := b.mapper.state.GetNodeByID(b.nodeID)
|
||||
if !ok {
|
||||
return nil, errors.New("node not found")
|
||||
return nil, ErrNodeNotFound
|
||||
}
|
||||
|
||||
// Get unreduced matchers for peer relationship determination.
|
||||
|
|
@ -261,8 +263,8 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
|
|||
}
|
||||
|
||||
// Peers is always returned sorted by Node.ID.
|
||||
sort.SliceStable(tailPeers, func(x, y int) bool {
|
||||
return tailPeers[x].ID < tailPeers[y].ID
|
||||
slices.SortStableFunc(tailPeers, func(a, b *tailcfg.Node) int {
|
||||
return cmp.Compare(a.ID, b.ID)
|
||||
})
|
||||
|
||||
return tailPeers, nil
|
||||
|
|
@ -276,20 +278,23 @@ func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange)
|
|||
|
||||
// WithPeersRemoved adds removed peer IDs.
|
||||
func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder {
|
||||
//nolint:prealloc
|
||||
var tailscaleIDs []tailcfg.NodeID
|
||||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
// Build finalizes the response and returns marshaled bytes.
|
||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ func TestMapResponseBuilder_MultipleErrors(t *testing.T) {
|
|||
// Build should return a multierr
|
||||
data, err := result.Build()
|
||||
assert.Nil(t, data)
|
||||
assert.Error(t, err)
|
||||
require.Error(t, err)
|
||||
|
||||
// The error should contain information about multiple errors
|
||||
assert.Contains(t, err.Error(), "multiple errors")
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ func newMapper(
|
|||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
|
|
@ -80,6 +79,7 @@ func generateUserProfiles(
|
|||
userID := user.Model().ID
|
||||
userMap[userID] = &user
|
||||
ids = append(ids, userID)
|
||||
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.Owner()
|
||||
peerUserID := peerUser.Model().ID
|
||||
|
|
@ -90,6 +90,7 @@ func generateUserProfiles(
|
|||
slices.Sort(ids)
|
||||
ids = slices.Compact(ids)
|
||||
var profiles []tailcfg.UserProfile
|
||||
|
||||
for _, id := range ids {
|
||||
if userMap[id] != nil {
|
||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||
|
|
@ -138,29 +139,6 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) {
|
|||
}
|
||||
}
|
||||
|
||||
// fullMapResponse returns a MapResponse for the given node.
|
||||
func (m *mapper) fullMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
peers := m.state.ListPeers(nodeID)
|
||||
|
||||
return m.NewMapResponseBuilder(nodeID).
|
||||
WithDebugType(fullResponseDebug).
|
||||
WithCapabilityVersion(capVer).
|
||||
WithSelfNode().
|
||||
WithDERPMap().
|
||||
WithDomain().
|
||||
WithCollectServicesDisabled().
|
||||
WithDebugConfig().
|
||||
WithSSHPolicy().
|
||||
WithDNSConfig().
|
||||
WithUserProfiles(peers).
|
||||
WithPacketFilters().
|
||||
WithPeers(peers).
|
||||
Build()
|
||||
}
|
||||
|
||||
func (m *mapper) selfMapResponse(
|
||||
nodeID types.NodeID,
|
||||
capVer tailcfg.CapabilityVersion,
|
||||
|
|
@ -214,7 +192,7 @@ func (m *mapper) policyChangeResponse(
|
|||
// Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved
|
||||
removedIDs := make([]types.NodeID, len(removedPeers))
|
||||
for i, id := range removedPeers {
|
||||
removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent
|
||||
removedIDs[i] = types.NodeID(id) //nolint:gosec
|
||||
}
|
||||
|
||||
builder.WithPeersRemoved(removedIDs...)
|
||||
|
|
@ -237,7 +215,7 @@ func (m *mapper) buildFromChange(
|
|||
resp *change.Change,
|
||||
) (*tailcfg.MapResponse, error) {
|
||||
if resp.IsEmpty() {
|
||||
return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
// If this is a self-update (the changed node is the receiving node),
|
||||
|
|
@ -306,6 +284,7 @@ func writeDebugMapResponse(
|
|||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
@ -319,6 +298,7 @@ func writeDebugMapResponse(
|
|||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
@ -327,7 +307,7 @@ func writeDebugMapResponse(
|
|||
|
||||
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
|
||||
if debugDumpMapResponsePath == "" {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
return ReadMapResponsesFromDirectory(debugDumpMapResponsePath)
|
||||
|
|
@ -375,6 +355,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
|||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())
|
||||
|
|
|
|||
|
|
@ -3,18 +3,13 @@ package mapper
|
|||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/google/go-cmp/cmp/cmpopts"
|
||||
"github.com/juanfont/headscale/hscontrol/policy"
|
||||
"github.com/juanfont/headscale/hscontrol/policy/matcher"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/dnstype"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var iap = func(ipStr string) *netip.Addr {
|
||||
|
|
@ -51,7 +46,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||
mach := func(hostname, username string, userid uint) *types.Node {
|
||||
return &types.Node{
|
||||
Hostname: hostname,
|
||||
UserID: ptr.To(userid),
|
||||
UserID: new(userid),
|
||||
User: &types.User{
|
||||
Name: username,
|
||||
},
|
||||
|
|
@ -82,86 +77,6 @@ func TestDNSConfigMapResponse(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// mockState is a mock implementation that provides the required methods.
|
||||
type mockState struct {
|
||||
polMan policy.PolicyManager
|
||||
derpMap *tailcfg.DERPMap
|
||||
primary *routes.PrimaryRoutes
|
||||
nodes types.Nodes
|
||||
peers types.Nodes
|
||||
}
|
||||
|
||||
func (m *mockState) DERPMap() *tailcfg.DERPMap {
|
||||
return m.derpMap
|
||||
}
|
||||
|
||||
func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
||||
if m.polMan == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
return m.polMan.Filter()
|
||||
}
|
||||
|
||||
func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
||||
if m.polMan == nil {
|
||||
return nil, nil
|
||||
}
|
||||
return m.polMan.SSHPolicy(node)
|
||||
}
|
||||
|
||||
func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
||||
if m.polMan == nil {
|
||||
return false
|
||||
}
|
||||
return m.polMan.NodeCanHaveTag(node, tag)
|
||||
}
|
||||
|
||||
func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
|
||||
if m.primary == nil {
|
||||
return nil
|
||||
}
|
||||
return m.primary.PrimaryRoutes(nodeID)
|
||||
}
|
||||
|
||||
func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(peerIDs) > 0 {
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
// Return all peers except the node itself
|
||||
var filtered types.Nodes
|
||||
for _, peer := range m.peers {
|
||||
if peer.ID != nodeID {
|
||||
filtered = append(filtered, peer)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
||||
if len(nodeIDs) > 0 {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
for _, node := range m.nodes {
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
}
|
||||
}
|
||||
|
||||
return filtered, nil
|
||||
}
|
||||
|
||||
return m.nodes, nil
|
||||
}
|
||||
|
||||
func Test_fullMapResponse(t *testing.T) {
|
||||
t.Skip("Test needs to be refactored for new state-based architecture")
|
||||
// TODO: Refactor this test to work with the new state-based mapper
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestTailNode(t *testing.T) {
|
||||
|
|
@ -95,7 +94,7 @@ func TestTailNode(t *testing.T) {
|
|||
IPv4: iap("100.64.0.1"),
|
||||
Hostname: "mini",
|
||||
GivenName: "mini",
|
||||
UserID: ptr.To(uint(0)),
|
||||
UserID: new(uint(0)),
|
||||
User: &types.User{
|
||||
Name: "mini",
|
||||
},
|
||||
|
|
@ -136,10 +135,10 @@ func TestTailNode(t *testing.T) {
|
|||
),
|
||||
Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")},
|
||||
AllowedIPs: []netip.Prefix{
|
||||
tsaddr.AllIPv4(),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
netip.MustParsePrefix("100.64.0.1/32"),
|
||||
tsaddr.AllIPv6(),
|
||||
tsaddr.AllIPv4(), // 0.0.0.0/0
|
||||
netip.MustParsePrefix("100.64.0.1/32"), // lower IPv4
|
||||
netip.MustParsePrefix("192.168.0.0/24"), // higher IPv4
|
||||
tsaddr.AllIPv6(), // ::/0 (IPv6)
|
||||
},
|
||||
PrimaryRoutes: []netip.Prefix{
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
|
|
|
|||
|
|
@ -31,6 +31,9 @@ const (
|
|||
earlyPayloadMagic = "\xff\xff\xffTS"
|
||||
)
|
||||
|
||||
// ErrUnsupportedClientVersion is returned when a client version is not supported.
|
||||
var ErrUnsupportedClientVersion = errors.New("unsupported client version")
|
||||
|
||||
type noiseServer struct {
|
||||
headscale *Headscale
|
||||
|
||||
|
|
@ -117,7 +120,7 @@ func (h *Headscale) NoiseUpgradeHandler(
|
|||
}
|
||||
|
||||
func unsupportedClientError(version tailcfg.CapabilityVersion) error {
|
||||
return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version)
|
||||
return fmt.Errorf("%w: %s (%d)", ErrUnsupportedClientVersion, capver.TailscaleVersion(version), version)
|
||||
}
|
||||
|
||||
func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error {
|
||||
|
|
@ -241,13 +244,17 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||
return
|
||||
}
|
||||
|
||||
//nolint:contextcheck
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
|
||||
var resp *tailcfg.RegisterResponse
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return &tailcfg.RegisterRequest{}, regErr(err)
|
||||
}
|
||||
|
||||
var regReq tailcfg.RegisterRequest
|
||||
//nolint:noinlineerr
|
||||
if err := json.Unmarshal(body, ®Req); err != nil {
|
||||
return ®Req, regErr(err)
|
||||
}
|
||||
|
|
@ -256,11 +263,11 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||
|
||||
resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer())
|
||||
if err != nil {
|
||||
var httpErr HTTPError
|
||||
if errors.As(err, &httpErr) {
|
||||
if httpErr, ok := errors.AsType[HTTPError](err); ok {
|
||||
resp = &tailcfg.RegisterResponse{
|
||||
Error: httpErr.Msg,
|
||||
}
|
||||
|
||||
return ®Req, resp
|
||||
}
|
||||
|
||||
|
|
@ -278,7 +285,8 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||
writer.Header().Set("Content-Type", "application/json; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if err := json.NewEncoder(writer).Encode(registerResponse); err != nil {
|
||||
err := json.NewEncoder(writer).Encode(registerResponse)
|
||||
if err != nil {
|
||||
log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse")
|
||||
return
|
||||
}
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ const (
|
|||
defaultOAuthOptionsCount = 3
|
||||
registerCacheExpiration = time.Minute * 15
|
||||
registerCacheCleanup = time.Minute * 20
|
||||
csrfTokenLength = 64
|
||||
)
|
||||
|
||||
var (
|
||||
|
|
@ -68,6 +69,7 @@ func NewAuthProviderOIDC(
|
|||
) (*AuthProviderOIDC, error) {
|
||||
var err error
|
||||
// grab oidc config if it hasn't been already
|
||||
//nolint:contextcheck
|
||||
oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err)
|
||||
|
|
@ -163,6 +165,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
|
|
@ -190,6 +193,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
}
|
||||
|
||||
stateCookieName := getCookieName("state", state)
|
||||
|
||||
cookieState, err := req.Cookie(stateCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||
|
|
@ -212,17 +216,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce == "" {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
||||
return
|
||||
}
|
||||
|
||||
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
||||
|
||||
nonce, err := req.Cookie(nonceCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce != nonce.Value {
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
||||
return
|
||||
|
|
@ -231,6 +238,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
nodeExpiry := a.determineNodeExpiry(idToken.Expiry)
|
||||
|
||||
var claims types.OIDCClaims
|
||||
//nolint:noinlineerr
|
||||
if err := idToken.Claims(&claims); err != nil {
|
||||
httpError(writer, fmt.Errorf("decoding ID token claims: %w", err))
|
||||
return
|
||||
|
|
@ -239,6 +247,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
var userinfo *oidc.UserInfo
|
||||
|
||||
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||
|
|
@ -255,6 +264,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
|
||||
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
|
||||
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
|
||||
|
||||
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
|
||||
if userinfo2.Groups != nil {
|
||||
claims.Groups = userinfo2.Groups
|
||||
|
|
@ -279,6 +289,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
Msgf("could not create or update user")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, werr := writer.Write([]byte("Could not create or update user"))
|
||||
if werr != nil {
|
||||
log.Error().
|
||||
|
|
@ -299,6 +310,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
verb := "Reauthenticated"
|
||||
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
|
|
@ -307,7 +319,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
httpError(writer, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -324,6 +338,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
//nolint:noinlineerr
|
||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||
util.LogErr(err, "Failed to write HTTP response")
|
||||
}
|
||||
|
|
@ -370,6 +386,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||
if !ok {
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
}
|
||||
|
|
@ -516,6 +533,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||
newUser bool
|
||||
c change.Change
|
||||
)
|
||||
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
||||
|
|
@ -600,7 +618,7 @@ func getCookieName(baseName, value string) string {
|
|||
}
|
||||
|
||||
func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) {
|
||||
val, err := util.GenerateRandomStringURLSafe(64)
|
||||
val, err := util.GenerateRandomStringURLSafe(csrfTokenLength)
|
||||
if err != nil {
|
||||
return val, err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ func (h *Headscale) WindowsConfigMessage(
|
|||
) {
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||
_, _ = writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render()))
|
||||
}
|
||||
|
||||
// AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it.
|
||||
|
|
@ -29,7 +29,7 @@ func (h *Headscale) AppleConfigMessage(
|
|||
) {
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||
_, _ = writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render()))
|
||||
}
|
||||
|
||||
func (h *Headscale) ApplePlatformConfig(
|
||||
|
|
@ -98,7 +98,7 @@ func (h *Headscale) ApplePlatformConfig(
|
|||
writer.Header().
|
||||
Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
writer.Write(content.Bytes())
|
||||
_, _ = writer.Write(content.Bytes())
|
||||
}
|
||||
|
||||
type AppleMobileConfig struct {
|
||||
|
|
|
|||
|
|
@ -21,10 +21,13 @@ func (m Match) DebugString() string {
|
|||
|
||||
sb.WriteString("Match:\n")
|
||||
sb.WriteString(" Sources:\n")
|
||||
|
||||
for _, prefix := range m.srcs.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
||||
sb.WriteString(" Destinations:\n")
|
||||
|
||||
for _, prefix := range m.dests.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,18 +19,18 @@ type PolicyManager interface {
|
|||
MatchersForNode(node types.NodeView) ([]matcher.Match, error)
|
||||
// BuildPeerMap constructs peer relationship maps for the given nodes
|
||||
BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView
|
||||
SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy([]byte) (bool, error)
|
||||
SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error)
|
||||
SetPolicy(data []byte) (bool, error)
|
||||
SetUsers(users []types.User) (bool, error)
|
||||
SetNodes(nodes views.Slice[types.NodeView]) (bool, error)
|
||||
// NodeCanHaveTag reports whether the given node can have the given tag.
|
||||
NodeCanHaveTag(types.NodeView, string) bool
|
||||
NodeCanHaveTag(node types.NodeView, tag string) bool
|
||||
|
||||
// TagExists reports whether the given tag is defined in the policy.
|
||||
TagExists(tag string) bool
|
||||
|
||||
// NodeCanApproveRoute reports whether the given node can approve the given route.
|
||||
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool
|
||||
NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool
|
||||
|
||||
Version() int
|
||||
DebugString() string
|
||||
|
|
@ -38,8 +38,11 @@ type PolicyManager interface {
|
|||
|
||||
// NewPolicyManager returns a new policy manager.
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
var (
|
||||
polMan PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
|
|
@ -66,6 +70,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
|||
}
|
||||
|
||||
func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
//nolint:prealloc
|
||||
var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error)
|
||||
|
||||
polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/samber/lo"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
|
|
@ -111,7 +110,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
|||
}
|
||||
|
||||
// Sort and deduplicate
|
||||
tsaddr.SortPrefixes(newApproved)
|
||||
slices.SortFunc(newApproved, netip.Prefix.Compare)
|
||||
newApproved = slices.Compact(newApproved)
|
||||
newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool {
|
||||
return route.IsValid()
|
||||
|
|
@ -120,12 +119,13 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
|||
// Sort the current approved for comparison
|
||||
sortedCurrent := make([]netip.Prefix, len(currentApproved))
|
||||
copy(sortedCurrent, currentApproved)
|
||||
tsaddr.SortPrefixes(sortedCurrent)
|
||||
slices.SortFunc(sortedCurrent, netip.Prefix.Compare)
|
||||
|
||||
// Only update if the routes actually changed
|
||||
if !slices.Equal(sortedCurrent, newApproved) {
|
||||
// Log what changed
|
||||
var added, kept []netip.Prefix
|
||||
|
||||
for _, route := range newApproved {
|
||||
if !slices.Contains(sortedCurrent, route) {
|
||||
added = append(added, route)
|
||||
|
|
|
|||
|
|
@ -3,16 +3,16 @@ package policy
|
|||
import (
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
|
|
@ -32,10 +32,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "test-node",
|
||||
UserID: ptr.To(user1.ID),
|
||||
User: ptr.To(user1),
|
||||
UserID: new(user1.ID),
|
||||
User: new(user1),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
Tags: []string{"tag:test"},
|
||||
}
|
||||
|
||||
|
|
@ -44,10 +44,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "other-node",
|
||||
UserID: ptr.To(user2.ID),
|
||||
User: ptr.To(user2),
|
||||
UserID: new(user2.ID),
|
||||
User: new(user2),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.2")),
|
||||
}
|
||||
|
||||
// Create a policy that auto-approves specific routes
|
||||
|
|
@ -76,7 +76,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||
}`
|
||||
|
||||
pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()}))
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -194,7 +194,7 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
|
|||
assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description)
|
||||
|
||||
// Sort for comparison since ApproveRoutesWithPolicy sorts the results
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description)
|
||||
|
||||
// Verify that all previously approved routes are still present
|
||||
|
|
@ -304,20 +304,23 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
|||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager or use nil if specified
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
if tt.name != "nil_policy_manager" {
|
||||
pm, err = pmf(users, nodes.ViewSlice())
|
||||
assert.NoError(t, err)
|
||||
require.NoError(t, err)
|
||||
} else {
|
||||
pm = nil
|
||||
}
|
||||
|
|
@ -330,7 +333,7 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
|||
if tt.wantApproved == nil {
|
||||
assert.Nil(t, gotApproved, "expected nil approved routes")
|
||||
} else {
|
||||
tsaddr.SortPrefixes(tt.wantApproved)
|
||||
slices.SortFunc(tt.wantApproved, netip.Prefix.Compare)
|
||||
assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch")
|
||||
}
|
||||
})
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
||||
|
|
@ -91,9 +90,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
|||
},
|
||||
announcedRoutes: []netip.Prefix{}, // No routes announced anymore
|
||||
nodeUser: "test",
|
||||
// Sorted by netip.Prefix.Compare: by IP address then by prefix length
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("10.0.0.0/24"),
|
||||
netip.MustParsePrefix("172.16.0.0/16"),
|
||||
netip.MustParsePrefix("192.168.0.0/24"),
|
||||
},
|
||||
wantChanged: false,
|
||||
|
|
@ -123,9 +123,10 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
|||
},
|
||||
nodeUser: "test",
|
||||
nodeTags: []string{"tag:approved"},
|
||||
// Sorted by netip.Prefix.Compare: by IP address then by prefix length
|
||||
wantApproved: []netip.Prefix{
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||
netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved
|
||||
netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved
|
||||
},
|
||||
wantChanged: true,
|
||||
},
|
||||
|
|
@ -168,13 +169,13 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
|
|||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: tt.nodeHostname,
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
Tags: tt.nodeTags,
|
||||
}
|
||||
|
|
@ -294,13 +295,13 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
|||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tt.announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: tt.currentApproved,
|
||||
}
|
||||
nodes := types.Nodes{&node}
|
||||
|
|
@ -326,6 +327,7 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
||||
//nolint:staticcheck
|
||||
user := types.User{
|
||||
Model: gorm.Model{ID: 1},
|
||||
Name: "test",
|
||||
|
|
@ -343,13 +345,13 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
|
|||
MachineKey: key.NewMachine().Public(),
|
||||
NodeKey: key.NewNode().Public(),
|
||||
Hostname: "testnode",
|
||||
UserID: ptr.To(user.ID),
|
||||
User: ptr.To(user),
|
||||
UserID: new(user.ID),
|
||||
User: new(user),
|
||||
RegisterMethod: util.RegisterMethodAuthKey,
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: announcedRoutes,
|
||||
},
|
||||
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
|
||||
IPv4: new(netip.MustParseAddr("100.64.0.1")),
|
||||
ApprovedRoutes: currentApproved,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
var ap = func(ipStr string) *netip.Addr {
|
||||
|
|
@ -33,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
|
|||
rules []tailcfg.FilterRule
|
||||
node *types.Node
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
|
|
@ -783,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
|
|||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
|
|
@ -796,6 +798,7 @@ func TestReduceNodes(t *testing.T) {
|
|||
|
||||
func TestReduceNodesFromPolicy(t *testing.T) {
|
||||
n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node {
|
||||
//nolint:prealloc
|
||||
var routes []netip.Prefix
|
||||
for _, route := range routess {
|
||||
routes = append(routes, netip.MustParsePrefix(route))
|
||||
|
|
@ -1032,8 +1035,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(nil, tt.nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -1051,9 +1057,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
|
|
@ -1074,21 +1082,21 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||
nodeUser1 := types.Node{
|
||||
Hostname: "user1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
UserID: ptr.To(uint(1)),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeUser2 := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
taggedClient := types.Node{
|
||||
Hostname: "tagged-client",
|
||||
IPv4: ap("100.64.0.4"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
Tags: []string{"tag:client"},
|
||||
}
|
||||
|
||||
|
|
@ -1207,7 +1215,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||
]
|
||||
}`,
|
||||
expectErr: true,
|
||||
errorMessage: `invalid SSH action "invalid", must be one of: accept, check`,
|
||||
errorMessage: `invalid SSH action: "invalid", must be one of: accept, check`,
|
||||
},
|
||||
{
|
||||
name: "invalid-check-period",
|
||||
|
|
@ -1242,7 +1250,7 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||
]
|
||||
}`,
|
||||
expectErr: true,
|
||||
errorMessage: "autogroup \"autogroup:invalid\" is not supported",
|
||||
errorMessage: `autogroup not supported for SSH: "autogroup:invalid" for SSH user`,
|
||||
},
|
||||
{
|
||||
name: "autogroup-nonroot-should-use-wildcard-with-root-excluded",
|
||||
|
|
@ -1406,13 +1414,17 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMessage)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -1435,6 +1447,7 @@ func TestReduceRoutes(t *testing.T) {
|
|||
routes []netip.Prefix
|
||||
rules []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
|
|
@ -2056,6 +2069,7 @@ func TestReduceRoutes(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
|
||||
got := ReduceRoutes(
|
||||
tt.args.node.View(),
|
||||
tt.args.routes,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
|||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
|
|
|
|||
|
|
@ -16,7 +16,6 @@ import (
|
|||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/util/must"
|
||||
)
|
||||
|
||||
|
|
@ -144,13 +143,13 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{},
|
||||
|
|
@ -191,7 +190,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{
|
||||
netip.MustParsePrefix("10.33.0.0/16"),
|
||||
|
|
@ -202,7 +201,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
|
@ -283,19 +282,19 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
// "internal" exit node
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
|
@ -344,7 +343,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
|
@ -353,12 +352,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
|
@ -453,7 +452,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: tsaddr.ExitRoutes(),
|
||||
},
|
||||
|
|
@ -462,12 +461,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
|
@ -565,7 +564,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
|
||||
},
|
||||
|
|
@ -574,12 +573,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
|
@ -655,7 +654,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
|
||||
},
|
||||
|
|
@ -664,12 +663,12 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[2]),
|
||||
User: new(users[2]),
|
||||
},
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
|
@ -737,7 +736,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.100"),
|
||||
IPv6: ap("fd7a:115c:a1e0::100"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
|
||||
},
|
||||
|
|
@ -747,7 +746,7 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
},
|
||||
},
|
||||
want: []tailcfg.FilterRule{
|
||||
|
|
@ -804,13 +803,13 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
node: &types.Node{
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[3]),
|
||||
User: new(users[3]),
|
||||
},
|
||||
peers: types.Nodes{
|
||||
&types.Node{
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
Hostinfo: &tailcfg.Hostinfo{
|
||||
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
|
||||
},
|
||||
|
|
@ -824,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm policy.PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm policy.PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
got, _ := pm.Filter()
|
||||
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
|
||||
got = policyutil.ReduceFilterRules(tt.node.View(), got)
|
||||
|
|
|
|||
|
|
@ -10,7 +10,6 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestNodeCanApproveRoute(t *testing.T) {
|
||||
|
|
@ -25,24 +24,24 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||
ID: 1,
|
||||
Hostname: "user1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
UserID: ptr.To(uint(1)),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(uint(1)),
|
||||
User: new(users[0]),
|
||||
}
|
||||
|
||||
exitNode := types.Node{
|
||||
ID: 2,
|
||||
Hostname: "user2-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
taggedNode := types.Node{
|
||||
ID: 3,
|
||||
Hostname: "tagged-server",
|
||||
IPv4: ap("100.64.0.3"),
|
||||
UserID: ptr.To(uint(3)),
|
||||
User: ptr.To(users[2]),
|
||||
UserID: new(uint(3)),
|
||||
User: new(users[2]),
|
||||
Tags: []string{"tag:router"},
|
||||
}
|
||||
|
||||
|
|
@ -50,8 +49,8 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||
ID: 4,
|
||||
Hostname: "multi-tag-node",
|
||||
IPv4: ap("100.64.0.4"),
|
||||
UserID: ptr.To(uint(2)),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(uint(2)),
|
||||
User: new(users[1]),
|
||||
Tags: []string{"tag:router", "tag:server"},
|
||||
}
|
||||
|
||||
|
|
@ -830,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||
if tt.name == "empty policy" {
|
||||
// We expect this one to have a valid but empty policy
|
||||
require.NoError(t, err)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -844,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
|
||||
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.canApprove, result, "Unexpected route approval result")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
package v2
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"time"
|
||||
|
|
@ -14,8 +13,6 @@ import (
|
|||
"tailscale.com/types/views"
|
||||
)
|
||||
|
||||
var ErrInvalidAction = errors.New("invalid action")
|
||||
|
||||
// compileFilterRules takes a set of nodes and an ACLPolicy and generates a
|
||||
// set of Tailscale compatible FilterRules used to allow traffic on clients.
|
||||
func (pol *Policy) compileFilterRules(
|
||||
|
|
@ -42,9 +39,10 @@ func (pol *Policy) compileFilterRules(
|
|||
continue
|
||||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
protocols := acl.Protocol.parseProtocol()
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
|
|
@ -121,14 +119,18 @@ func (pol *Policy) compileFilterRulesForNode(
|
|||
// It returns a slice of filter rules because when an ACL has both autogroup:self
|
||||
// and other destinations, they need to be split into separate rules with different
|
||||
// source filtering logic.
|
||||
//
|
||||
//nolint:gocyclo
|
||||
func (pol *Policy) compileACLWithAutogroupSelf(
|
||||
acl ACL,
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) ([]*tailcfg.FilterRule, error) {
|
||||
var autogroupSelfDests []AliasWithPorts
|
||||
var otherDests []AliasWithPorts
|
||||
var (
|
||||
autogroupSelfDests []AliasWithPorts
|
||||
otherDests []AliasWithPorts
|
||||
)
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
|
|
@ -138,14 +140,15 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
}
|
||||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
protocols := acl.Protocol.parseProtocol()
|
||||
|
||||
var rules []*tailcfg.FilterRule
|
||||
|
||||
var resolvedSrcIPs []*netipx.IPSet
|
||||
|
||||
for _, src := range acl.Sources {
|
||||
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return nil, fmt.Errorf("autogroup:self cannot be used in sources")
|
||||
return nil, ErrAutogroupSelfInSource
|
||||
}
|
||||
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
|
|
@ -167,6 +170,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
if len(autogroupSelfDests) > 0 {
|
||||
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.User().ID() == node.User().ID() && !n.IsTagged() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
|
|
@ -176,6 +180,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
if len(sameUserNodes) > 0 {
|
||||
// Filter sources to only same-user untagged devices
|
||||
var srcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, ips := range resolvedSrcIPs {
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
|
|
@ -192,6 +197,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
|
||||
if srcSet != nil && len(srcSet.Prefixes()) > 0 {
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range autogroupSelfDests {
|
||||
for _, n := range sameUserNodes {
|
||||
for _, port := range dest.Ports {
|
||||
|
|
@ -280,13 +286,14 @@ func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction {
|
|||
}
|
||||
}
|
||||
|
||||
//nolint:gocyclo
|
||||
func (pol *Policy) compileSSHPolicy(
|
||||
users types.Users,
|
||||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) (*tailcfg.SSHPolicy, error) {
|
||||
if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 {
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil
|
||||
}
|
||||
|
||||
log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname())
|
||||
|
|
@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy(
|
|||
// Separate destinations into autogroup:self and others
|
||||
// This is needed because autogroup:self requires filtering sources to same-user only,
|
||||
// while other destinations should use all resolved sources
|
||||
var autogroupSelfDests []Alias
|
||||
var otherDests []Alias
|
||||
var (
|
||||
autogroupSelfDests []Alias
|
||||
otherDests []Alias
|
||||
)
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
|
|
@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
|
||||
switch rule.Action {
|
||||
case SSHActionAccept:
|
||||
action = sshAction(true, 0)
|
||||
|
|
@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy(
|
|||
// by default, we do not allow root unless explicitly stated
|
||||
userMap["root"] = ""
|
||||
}
|
||||
|
||||
if rule.Users.ContainsRoot() {
|
||||
userMap["root"] = "root"
|
||||
}
|
||||
|
||||
for _, u := range rule.Users.NormalUsers() {
|
||||
userMap[u.String()] = u.String()
|
||||
}
|
||||
|
|
@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Build destination set for autogroup:self (same-user untagged devices only)
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.User().ID() == node.User().ID() && !n.IsTagged() {
|
||||
n.AppendToIPSet(&dest)
|
||||
|
|
@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
// Filter sources to only same-user untagged devices
|
||||
// Pre-filter to same-user untagged devices for efficiency
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.User().ID() == node.User().ID() && !n.IsTagged() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
|
|
@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
}
|
||||
|
||||
var filteredSrcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
|
|
@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy(
|
|||
if len(otherDests) > 0 {
|
||||
// Build destination set for other destinations
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, dst := range otherDests {
|
||||
ips, err := dst.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
continue
|
||||
}
|
||||
|
||||
if ips != nil {
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
|
||||
|
|
@ -410,14 +409,14 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
|||
nodeUser1 := types.Node{
|
||||
Hostname: "user1-device",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeUser2 := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeUser1, &nodeUser2}
|
||||
|
|
@ -590,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
|||
if sshPolicy == nil {
|
||||
return // Expected empty result
|
||||
}
|
||||
|
||||
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -622,14 +623,14 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
|||
nodeUser1 := types.Node{
|
||||
Hostname: "user1-device",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
}
|
||||
nodeUser2 := types.Node{
|
||||
Hostname: "user2-device",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{&nodeUser1, &nodeUser2}
|
||||
|
|
@ -671,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
|||
}
|
||||
|
||||
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers.
|
||||
func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
// Create users matching the integration test
|
||||
users := types.Users{
|
||||
|
|
@ -683,15 +684,15 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
|||
node1 := &types.Node{
|
||||
Hostname: "user1-node",
|
||||
IPv4: createAddr("100.64.0.1"),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
User: new(users[0]),
|
||||
}
|
||||
|
||||
node2 := &types.Node{
|
||||
Hostname: "user2-node",
|
||||
IPv4: createAddr("100.64.0.2"),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
User: new(users[1]),
|
||||
}
|
||||
|
||||
nodes := types.Nodes{node1, node2}
|
||||
|
|
@ -736,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
|||
}
|
||||
|
||||
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
|
||||
// to JSON and that the sshUsers field is not empty
|
||||
// to JSON and that the sshUsers field is not empty.
|
||||
func TestSSHJSONSerialization(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
|
|
@ -776,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) {
|
|||
|
||||
// Parse back to verify structure
|
||||
var parsed tailcfg.SSHPolicy
|
||||
|
||||
err = json.Unmarshal(jsonData, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -806,19 +808,19 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
|
||||
nodes := types.Nodes{
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.3"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.4"),
|
||||
},
|
||||
// Tagged device for user1
|
||||
|
|
@ -860,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(rules) != 1 {
|
||||
t.Fatalf("expected 1 rule, got %d", len(rules))
|
||||
}
|
||||
|
|
@ -876,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
found := false
|
||||
|
||||
addr := netip.MustParseAddr(expectedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
|
|
@ -893,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
|
||||
for _, excludedIP := range excludedSourceIPs {
|
||||
addr := netip.MustParseAddr(excludedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
|
|
@ -938,11 +943,11 @@ func TestTagUserMutualExclusivity(t *testing.T) {
|
|||
nodes := types.Nodes{
|
||||
// User-owned nodes
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
},
|
||||
// Tagged nodes
|
||||
|
|
@ -960,8 +965,8 @@ func TestTagUserMutualExclusivity(t *testing.T) {
|
|||
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
|
||||
Tag("tag:server"): Owners{new(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{new(Username("user2@"))},
|
||||
},
|
||||
ACLs: []ACL{
|
||||
// Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes
|
||||
|
|
@ -1056,11 +1061,11 @@ func TestAutogroupTagged(t *testing.T) {
|
|||
nodes := types.Nodes{
|
||||
// User-owned nodes (not tagged)
|
||||
{
|
||||
User: ptr.To(users[0]),
|
||||
User: new(users[0]),
|
||||
IPv4: ap("100.64.0.1"),
|
||||
},
|
||||
{
|
||||
User: ptr.To(users[1]),
|
||||
User: new(users[1]),
|
||||
IPv4: ap("100.64.0.2"),
|
||||
},
|
||||
// Tagged nodes
|
||||
|
|
@ -1083,10 +1088,10 @@ func TestAutogroupTagged(t *testing.T) {
|
|||
|
||||
policy := &Policy{
|
||||
TagOwners: TagOwners{
|
||||
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
|
||||
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
|
||||
Tag("tag:server"): Owners{new(Username("user1@"))},
|
||||
Tag("tag:database"): Owners{new(Username("user2@"))},
|
||||
Tag("tag:web"): Owners{new(Username("user1@"))},
|
||||
Tag("tag:prod"): Owners{new(Username("user1@"))},
|
||||
},
|
||||
ACLs: []ACL{
|
||||
// Rule: autogroup:tagged can reach user-owned nodes
|
||||
|
|
@ -1206,10 +1211,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) {
|
|||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1273,11 +1278,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
|
|||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[2]), IPv4: ap("100.64.0.5")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1326,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
|
|||
assert.Empty(t, rules3, "user3 should have no rules")
|
||||
}
|
||||
|
||||
// Helper function to create IP addresses for testing
|
||||
// Helper function to create IP addresses for testing.
|
||||
func createAddr(ip string) *netip.Addr {
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
return &addr
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly
|
||||
// with autogroup:self in destinations
|
||||
// with autogroup:self in destinations.
|
||||
func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1342,13 +1347,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
|
||||
nodes := types.Nodes{
|
||||
// User1's nodes
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
|
||||
// User2's nodes
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
|
||||
// Tagged node for user1 (should be excluded)
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1381,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// Test for user2's first node
|
||||
|
|
@ -1399,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
for i, p := range rule2.Principals {
|
||||
principalIPs2[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2)
|
||||
|
||||
// Test for tagged node (should have no SSH rules)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy3 != nil {
|
||||
assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self")
|
||||
}
|
||||
|
|
@ -1412,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
|
||||
// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user
|
||||
// is in the source and autogroup:self in destination, only that user's devices
|
||||
// can SSH (and only if they match the target user)
|
||||
// can SSH (and only if they match the target user).
|
||||
func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1420,10 +1428,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
|||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1454,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user2's node: should have no rules (user1's devices can't match user2's self)
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations.
|
||||
func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1474,11 +1484,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
|||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1")},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3")},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4")},
|
||||
{User: new(users[2]), IPv4: ap("100.64.0.5")},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1512,29 +1522,31 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user3's node: should have no rules (not in group:admins)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices
|
||||
// are excluded from both sources and destinations when autogroup:self is used
|
||||
// are excluded from both sources and destinations when autogroup:self is used.
|
||||
func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1569,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs,
|
||||
"should only include untagged devices")
|
||||
|
||||
|
|
@ -1576,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
|||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self")
|
||||
}
|
||||
|
|
@ -1591,10 +1605,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
|||
}
|
||||
|
||||
nodes := types.Nodes{
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
|
||||
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
|
||||
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
|
||||
{User: new(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
|
||||
{User: new(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
|
||||
}
|
||||
|
||||
policy := &Policy{
|
||||
|
|
@ -1624,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
|||
// Verify autogroup:self rule has filtered sources (only same-user devices)
|
||||
selfRule := sshPolicy1.Rules[0]
|
||||
require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices")
|
||||
|
||||
selfPrincipals := make([]string, len(selfRule.Principals))
|
||||
for i, p := range selfRule.Principals {
|
||||
selfPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals,
|
||||
"autogroup:self rule should only include same-user untagged devices")
|
||||
|
||||
|
|
@ -1639,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
|||
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
|
||||
|
||||
routerRule := sshPolicyRouter.Rules[0]
|
||||
|
||||
routerPrincipals := make([]string, len(routerRule.Principals))
|
||||
for i, p := range routerRule.Principals {
|
||||
routerPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ import (
|
|||
"tailscale.com/util/deephash"
|
||||
)
|
||||
|
||||
// PolicyVersion is the version number of this policy implementation.
|
||||
const PolicyVersion = 2
|
||||
|
||||
// ErrInvalidTagOwner is returned when a tag owner is not an Alias type.
|
||||
var ErrInvalidTagOwner = errors.New("tag owner is not an Alias")
|
||||
|
||||
|
|
@ -111,6 +114,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Filter: filter,
|
||||
Policy: pm.pol,
|
||||
})
|
||||
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
if filterChanged {
|
||||
log.Debug().
|
||||
|
|
@ -120,7 +124,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Int("filter.rules.new", len(filter)).
|
||||
Msg("Policy filter hash changed")
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
|
||||
|
|
@ -135,6 +141,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
if tagOwnerChanged {
|
||||
log.Debug().
|
||||
|
|
@ -144,6 +151,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Int("tagOwners.new", len(tagMap)).
|
||||
Msg("Tag owner hash changed")
|
||||
}
|
||||
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
|
|
@ -153,6 +161,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
if autoApproveChanged {
|
||||
log.Debug().
|
||||
|
|
@ -162,10 +171,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Int("autoApprovers.new", len(autoMap)).
|
||||
Msg("Auto-approvers hash changed")
|
||||
}
|
||||
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
exitSetHash := deephash.Hash(&exitSet)
|
||||
|
||||
exitSetChanged := exitSetHash != pm.exitSetHash
|
||||
if exitSetChanged {
|
||||
log.Debug().
|
||||
|
|
@ -173,6 +184,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Str("exitSet.hash.new", exitSetHash.String()[:8]).
|
||||
Msg("Exit node set hash changed")
|
||||
}
|
||||
|
||||
pm.exitSet = exitSet
|
||||
pm.exitSetHash = exitSetHash
|
||||
|
||||
|
|
@ -199,6 +211,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
if !needsUpdate {
|
||||
log.Trace().
|
||||
Msg("Policy evaluation detected no changes - all hashes match")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
|
@ -224,6 +237,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
|
||||
pm.sshPolicyMap[node.ID()] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
|
|
@ -318,6 +332,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ
|
|||
if err != nil || len(filter) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter)
|
||||
}
|
||||
|
||||
|
|
@ -398,6 +413,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil
|
|||
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
|
||||
|
||||
pm.filterRulesMap[node.ID()] = reducedFilter
|
||||
|
||||
return reducedFilter, nil
|
||||
}
|
||||
|
||||
|
|
@ -442,7 +458,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
|
|||
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
|
||||
//
|
||||
// For global policies: returns the global matchers (same for all nodes)
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules.
|
||||
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
|
||||
if pm == nil {
|
||||
return nil, nil
|
||||
|
|
@ -474,6 +490,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
|||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
|
||||
// Clear SSH policy map when users change to force SSH policy recomputation
|
||||
|
|
@ -685,6 +702,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||
if pm.exitSet == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) {
|
||||
return true
|
||||
}
|
||||
|
|
@ -724,7 +742,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||
}
|
||||
|
||||
func (pm *PolicyManager) Version() int {
|
||||
return 2
|
||||
return PolicyVersion
|
||||
}
|
||||
|
||||
func (pm *PolicyManager) DebugString() string {
|
||||
|
|
@ -748,8 +766,10 @@ func (pm *PolicyManager) DebugString() string {
|
|||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
|
|
@ -758,14 +778,17 @@ func (pm *PolicyManager) DebugString() string {
|
|||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
|
|
@ -778,6 +801,7 @@ func (pm *PolicyManager) DebugString() string {
|
|||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Matchers:\n")
|
||||
sb.WriteString("an internal structure used to filter nodes and routes\n")
|
||||
|
||||
for _, match := range pm.matchers {
|
||||
sb.WriteString(match.DebugString())
|
||||
sb.WriteString("\n")
|
||||
|
|
@ -785,6 +809,7 @@ func (pm *PolicyManager) DebugString() string {
|
|||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range pm.nodes.All() {
|
||||
sb.WriteString(node.String())
|
||||
sb.WriteString("\n")
|
||||
|
|
@ -841,6 +866,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
|
||||
// Check if IPs changed (simple check - could be more sophisticated)
|
||||
oldIPs := oldNode.IPs()
|
||||
|
||||
newIPs := newNode.IPs()
|
||||
if len(oldIPs) != len(newIPs) {
|
||||
affectedUsers[newNode.User().ID()] = struct{}{}
|
||||
|
|
@ -862,6 +888,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
for nodeID := range pm.filterRulesMap {
|
||||
// Find the user for this cached node
|
||||
var nodeUserID uint
|
||||
|
||||
found := false
|
||||
|
||||
// Check in new nodes first
|
||||
|
|
@ -869,6 +896,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
if node.ID() == nodeID {
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -879,6 +907,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
if node.ID() == nodeID {
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -956,14 +985,7 @@ func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.
|
|||
// It will return a Owners list where all the Tag types have been resolved to their underlying Owners.
|
||||
func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) {
|
||||
if visiting[tag] {
|
||||
cycleStart := 0
|
||||
|
||||
for i, t := range chain {
|
||||
if t == tag {
|
||||
cycleStart = i
|
||||
break
|
||||
}
|
||||
}
|
||||
cycleStart := slices.Index(chain, tag)
|
||||
|
||||
cycleTags := make([]string, len(chain[cycleStart:]))
|
||||
for i, t := range chain[cycleStart:] {
|
||||
|
|
|
|||
|
|
@ -11,18 +11,16 @@ import (
|
|||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
|
||||
func node(name, ipv4, ipv6 string, user types.User, _ *tailcfg.Hostinfo) *types.Node {
|
||||
return &types.Node{
|
||||
ID: 0,
|
||||
Hostname: name,
|
||||
IPv4: ap(ipv4),
|
||||
IPv6: ap(ipv6),
|
||||
User: ptr.To(user),
|
||||
UserID: ptr.To(user.ID),
|
||||
Hostinfo: hostinfo,
|
||||
User: new(user),
|
||||
UserID: new(user.ID),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -57,6 +55,7 @@ func TestPolicyManager(t *testing.T) {
|
|||
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(
|
||||
tt.wantMatchers,
|
||||
matchers,
|
||||
|
|
@ -77,6 +76,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
|||
{Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"},
|
||||
}
|
||||
|
||||
//nolint:goconst
|
||||
policy := `{
|
||||
"acls": [
|
||||
{
|
||||
|
|
@ -95,7 +95,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
|||
}
|
||||
|
||||
for i, n := range initialNodes {
|
||||
n.ID = types.NodeID(i + 1)
|
||||
n.ID = types.NodeID(i + 1) //nolint:gosec
|
||||
}
|
||||
|
||||
pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice())
|
||||
|
|
@ -107,7 +107,7 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
require.Equal(t, len(initialNodes), len(pm.filterRulesMap))
|
||||
require.Len(t, pm.filterRulesMap, len(initialNodes))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
|
|
@ -177,15 +177,18 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
for i, n := range tt.newNodes {
|
||||
found := false
|
||||
|
||||
for _, origNode := range initialNodes {
|
||||
if n.Hostname == origNode.Hostname {
|
||||
n.ID = origNode.ID
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1)
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1) //nolint:gosec
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -370,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) {
|
|||
|
||||
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
|
||||
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters.
|
||||
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
|
||||
|
|
@ -410,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
|||
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
|
||||
// For node1, destinations should only be node1's IPs
|
||||
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
|
||||
|
||||
for _, rule := range filterNode1 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node1IPs, dst.IP,
|
||||
|
|
@ -419,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
|||
|
||||
// For node2, destinations should only be node2's IPs
|
||||
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
|
||||
|
||||
for _, rule := range filterNode2 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node2IPs, dst.IP,
|
||||
|
|
@ -457,8 +462,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
|
|||
Hostname: "test-1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
|
|
@ -468,8 +473,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
|
|||
Hostname: "test-2-router",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Tags: []string{"tag:node-router"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
|
@ -537,8 +542,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
|
|||
Hostname: "test-1-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
|
|
@ -547,8 +552,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
|
|||
Hostname: "test-2-device",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
|
|
@ -647,8 +652,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
|
|||
Hostname: "user1-node",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Tags: []string{"tag:web", "tag:internal"},
|
||||
}
|
||||
|
||||
|
|
@ -658,8 +663,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
|
|||
Hostname: "user2-node",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
}
|
||||
|
||||
initialNodes := types.Nodes{user1Node, user2Node}
|
||||
|
|
@ -686,8 +691,8 @@ func TestTagPropagationToPeerMap(t *testing.T) {
|
|||
Hostname: "user1-node",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Tags: []string{"tag:internal"}, // tag:web removed!
|
||||
}
|
||||
|
||||
|
|
@ -749,8 +754,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) {
|
|||
Hostname: "admin-device",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
|
|
@ -760,8 +765,8 @@ func TestAutogroupSelfWithAdminOverride(t *testing.T) {
|
|||
Hostname: "user1-server",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Tags: []string{"tag:server"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
|
@ -832,8 +837,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) {
|
|||
Hostname: "device-a",
|
||||
IPv4: ap("100.64.0.1"),
|
||||
IPv6: ap("fd7a:115c:a1e0::1"),
|
||||
User: ptr.To(users[0]),
|
||||
UserID: ptr.To(users[0].ID),
|
||||
User: new(users[0]),
|
||||
UserID: new(users[0].ID),
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
||||
|
|
@ -843,8 +848,8 @@ func TestAutogroupSelfSymmetricVisibility(t *testing.T) {
|
|||
Hostname: "device-b",
|
||||
IPv4: ap("100.64.0.2"),
|
||||
IPv6: ap("fd7a:115c:a1e0::2"),
|
||||
User: ptr.To(users[1]),
|
||||
UserID: ptr.To(users[1].ID),
|
||||
User: new(users[1]),
|
||||
UserID: new(users[1].ID),
|
||||
Tags: []string{"tag:web"},
|
||||
Hostinfo: &tailcfg.Hostinfo{},
|
||||
}
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
|
@ -9,6 +9,21 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// portRangeParts is the expected number of parts in a port range (start-end).
|
||||
const portRangeParts = 2
|
||||
|
||||
// Sentinel errors for port and destination parsing.
|
||||
var (
|
||||
ErrInputMissingColon = errors.New("input must contain a colon character separating destination and port")
|
||||
ErrInputStartsWithColon = errors.New("input cannot start with a colon character")
|
||||
ErrInputEndsWithColon = errors.New("input cannot end with a colon character")
|
||||
ErrInvalidPortRange = errors.New("invalid port range format")
|
||||
ErrPortRangeInverted = errors.New("invalid port range: first port is greater than last port")
|
||||
ErrPortMustBePositive = errors.New("first port must be >0, or use '*' for wildcard")
|
||||
ErrInvalidPortNumber = errors.New("invalid port number")
|
||||
ErrPortOutOfRange = errors.New("port number out of range")
|
||||
)
|
||||
|
||||
// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid.
|
||||
func splitDestinationAndPort(input string) (string, string, error) {
|
||||
// Find the last occurrence of the colon character
|
||||
|
|
@ -16,13 +31,15 @@ func splitDestinationAndPort(input string) (string, string, error) {
|
|||
|
||||
// Check if the colon character is present and not at the beginning or end of the string
|
||||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
return "", "", ErrInputMissingColon
|
||||
}
|
||||
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
return "", "", ErrInputStartsWithColon
|
||||
}
|
||||
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
return "", "", ErrInputEndsWithColon
|
||||
}
|
||||
|
||||
// Split the string into destination and port based on the last colon
|
||||
|
|
@ -45,11 +62,12 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
for part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
if len(rangeParts) != 2 {
|
||||
return nil, errors.New("invalid port range format")
|
||||
if len(rangeParts) != portRangeParts {
|
||||
return nil, ErrInvalidPortRange
|
||||
}
|
||||
|
||||
first, err := parsePort(rangeParts[0])
|
||||
|
|
@ -63,7 +81,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
}
|
||||
|
||||
if first > last {
|
||||
return nil, errors.New("invalid port range: first port is greater than last port")
|
||||
return nil, ErrPortRangeInverted
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last})
|
||||
|
|
@ -74,7 +92,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
}
|
||||
|
||||
if port < 1 {
|
||||
return nil, errors.New("first port must be >0, or use '*' for wildcard")
|
||||
return nil, ErrPortMustBePositive
|
||||
}
|
||||
|
||||
portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port})
|
||||
|
|
@ -88,11 +106,11 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
func parsePort(portStr string) (uint16, error) {
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return 0, errors.New("invalid port number")
|
||||
return 0, ErrInvalidPortNumber
|
||||
}
|
||||
|
||||
if port < 0 || port > 65535 {
|
||||
return 0, errors.New("port number out of range")
|
||||
return 0, ErrPortOutOfRange
|
||||
}
|
||||
|
||||
return uint16(port), nil
|
||||
|
|
|
|||
|
|
@ -24,14 +24,14 @@ func TestParseDestinationAndPort(t *testing.T) {
|
|||
{"tag:api-server:443", "tag:api-server", "443", nil},
|
||||
{"example-host-1:*", "example-host-1", "*", nil},
|
||||
{"hostname:80-90", "hostname", "80-90", nil},
|
||||
{"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")},
|
||||
{":invalid", "", "", errors.New("input cannot start with a colon character")},
|
||||
{"invalid:", "", "", errors.New("input cannot end with a colon character")},
|
||||
{"invalidinput", "", "", ErrInputMissingColon},
|
||||
{":invalid", "", "", ErrInputStartsWithColon},
|
||||
{"invalid:", "", "", ErrInputEndsWithColon},
|
||||
}
|
||||
|
||||
for _, testCase := range testCases {
|
||||
dst, port, err := splitDestinationAndPort(testCase.input)
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) {
|
||||
if dst != testCase.expectedDst || port != testCase.expectedPort || !errors.Is(err, testCase.expectedErr) {
|
||||
t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)",
|
||||
testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr)
|
||||
}
|
||||
|
|
@ -42,25 +42,23 @@ func TestParsePort(t *testing.T) {
|
|||
tests := []struct {
|
||||
input string
|
||||
expected uint16
|
||||
err string
|
||||
err error
|
||||
}{
|
||||
{"80", 80, ""},
|
||||
{"0", 0, ""},
|
||||
{"65535", 65535, ""},
|
||||
{"-1", 0, "port number out of range"},
|
||||
{"65536", 0, "port number out of range"},
|
||||
{"abc", 0, "invalid port number"},
|
||||
{"", 0, "invalid port number"},
|
||||
{"80", 80, nil},
|
||||
{"0", 0, nil},
|
||||
{"65535", 65535, nil},
|
||||
{"-1", 0, ErrPortOutOfRange},
|
||||
{"65536", 0, ErrPortOutOfRange},
|
||||
{"abc", 0, ErrInvalidPortNumber},
|
||||
{"", 0, ErrInvalidPortNumber},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePort(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
|
|
@ -71,30 +69,28 @@ func TestParsePortRange(t *testing.T) {
|
|||
tests := []struct {
|
||||
input string
|
||||
expected []tailcfg.PortRange
|
||||
err string
|
||||
err error
|
||||
}{
|
||||
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""},
|
||||
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""},
|
||||
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""},
|
||||
{"80-", nil, "invalid port range format"},
|
||||
{"-90", nil, "invalid port range format"},
|
||||
{"80-90,", nil, "invalid port number"},
|
||||
{"80,90-", nil, "invalid port range format"},
|
||||
{"80-90,abc", nil, "invalid port number"},
|
||||
{"80-90,65536", nil, "port number out of range"},
|
||||
{"80-90,90-80", nil, "invalid port range: first port is greater than last port"},
|
||||
{"80", []tailcfg.PortRange{{First: 80, Last: 80}}, nil},
|
||||
{"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, nil},
|
||||
{"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, nil},
|
||||
{"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, nil},
|
||||
{"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, nil},
|
||||
{"80-", nil, ErrInvalidPortRange},
|
||||
{"-90", nil, ErrInvalidPortRange},
|
||||
{"80-90,", nil, ErrInvalidPortNumber},
|
||||
{"80,90-", nil, ErrInvalidPortRange},
|
||||
{"80-90,abc", nil, ErrInvalidPortNumber},
|
||||
{"80-90,65536", nil, ErrPortOutOfRange},
|
||||
{"80-90,90-80", nil, ErrPortRangeInverted},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
result, err := parsePortRange(test.input)
|
||||
if err != nil && err.Error() != test.err {
|
||||
if !errors.Is(err, test.err) {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() {
|
|||
// This is not my favourite solution, but it kind of works in our eventually consistent world.
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
disconnected := true
|
||||
// Wait up to 10 seconds for the node to reconnect.
|
||||
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
|
||||
|
|
@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() {
|
|||
disconnected = false
|
||||
break
|
||||
}
|
||||
|
||||
<-ticker.C
|
||||
}
|
||||
|
||||
|
|
@ -212,11 +214,14 @@ func (m *mapSession) serveLongPoll() {
|
|||
// adding this before connecting it to the state ensure that
|
||||
// it does not miss any updates that might be sent in the split
|
||||
// time between the node connecting and the batcher being ready.
|
||||
//nolint:noinlineerr
|
||||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
|
||||
|
||||
m.h.Change(mapReqChange)
|
||||
|
|
@ -245,7 +250,8 @@ func (m *mapSession) serveLongPoll() {
|
|||
return
|
||||
}
|
||||
|
||||
if err := m.writeMap(update); err != nil {
|
||||
err := m.writeMap(update)
|
||||
if err != nil {
|
||||
m.errf(err, "cannot write update to client")
|
||||
return
|
||||
}
|
||||
|
|
@ -254,7 +260,8 @@ func (m *mapSession) serveLongPoll() {
|
|||
m.resetKeepAlive()
|
||||
|
||||
case <-m.keepAliveTicker.C:
|
||||
if err := m.writeMap(&keepAlive); err != nil {
|
||||
err := m.writeMap(&keepAlive)
|
||||
if err != nil {
|
||||
m.errf(err, "cannot write keep alive")
|
||||
return
|
||||
}
|
||||
|
|
@ -282,8 +289,9 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
|
|||
jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression)
|
||||
}
|
||||
|
||||
//nolint:prealloc
|
||||
data := make([]byte, reservedResponseHeaderSize)
|
||||
//nolint:gosec // G115: JSON response size will not exceed uint32 max
|
||||
//nolint:gosec
|
||||
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
|
||||
data = append(data, jsonBody...)
|
||||
|
||||
|
|
@ -328,13 +336,13 @@ func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event {
|
|||
Str("node.name", m.node.Hostname)
|
||||
}
|
||||
|
||||
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
|
||||
//nolint:zerologlint
|
||||
func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) }
|
||||
|
||||
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
|
||||
//nolint:zerologlint
|
||||
func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) }
|
||||
|
||||
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
|
||||
//nolint:zerologlint
|
||||
func (m *mapSession) errf(err error, msg string, a ...any) {
|
||||
m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import (
|
|||
"fmt"
|
||||
"net/netip"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
|
@ -57,7 +56,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
|||
// this is important so the same node is chosen two times in a row
|
||||
// as the primary route.
|
||||
ids := types.NodeIDs(xmaps.Keys(pr.routes))
|
||||
sort.Sort(ids)
|
||||
slices.Sort(ids)
|
||||
|
||||
// Create a map of prefixes to nodes that serve them so we
|
||||
// can determine the primary route for each prefix.
|
||||
|
|
@ -108,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
|||
Msg("Current primary no longer available")
|
||||
}
|
||||
}
|
||||
|
||||
if len(nodes) >= 1 {
|
||||
pr.primaries[prefix] = nodes[0]
|
||||
changed = true
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("prefix", prefix.String()).
|
||||
|
|
@ -127,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
|||
Str("prefix", prefix.String()).
|
||||
Msg("Cleaning up primary route that no longer has available nodes")
|
||||
delete(pr.primaries, prefix)
|
||||
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
|
@ -162,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix)
|
|||
// If no routes are being set, remove the node from the routes map.
|
||||
if len(prefixes) == 0 {
|
||||
wasPresent := false
|
||||
|
||||
if _, ok := pr.routes[node]; ok {
|
||||
delete(pr.routes, node)
|
||||
|
||||
wasPresent = true
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.Uint64()).
|
||||
Msg("Removed node from primary routes (no prefixes)")
|
||||
}
|
||||
|
||||
changed := pr.updatePrimaryLocked()
|
||||
log.Debug().
|
||||
Caller().
|
||||
|
|
@ -236,7 +242,7 @@ func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix {
|
|||
}
|
||||
}
|
||||
|
||||
tsaddr.SortPrefixes(routes)
|
||||
slices.SortFunc(routes, netip.Prefix.Compare)
|
||||
|
||||
return routes
|
||||
}
|
||||
|
|
@ -254,13 +260,15 @@ func (pr *PrimaryRoutes) stringLocked() string {
|
|||
fmt.Fprintln(&sb, "Available routes:")
|
||||
|
||||
ids := types.NodeIDs(xmaps.Keys(pr.routes))
|
||||
sort.Sort(ids)
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
prefixes := pr.routes[id]
|
||||
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
|
||||
}
|
||||
|
||||
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
|
||||
|
||||
for route, nodeID := range pr.primaries {
|
||||
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
|
||||
}
|
||||
|
|
@ -294,7 +302,7 @@ func (pr *PrimaryRoutes) DebugJSON() DebugRoutes {
|
|||
// Populate available routes
|
||||
for nodeID, routes := range pr.routes {
|
||||
prefixes := routes.Slice()
|
||||
tsaddr.SortPrefixes(prefixes)
|
||||
slices.SortFunc(prefixes, netip.Prefix.Compare)
|
||||
debug.AvailableRoutes[nodeID] = prefixes
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
pr.SetRoutes(1) // Deregister by setting no routes
|
||||
|
||||
return pr.SetRoutes(1, mp("192.168.3.0/24"))
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
|
|
@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
{
|
||||
name: "multiple-nodes-register-same-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
|
||||
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
|
|
@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
2: {
|
||||
|
|
@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
|
||||
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
|
||||
|
||||
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
|
|
@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
operations: func(pr *PrimaryRoutes) bool {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var change1, change2 bool
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return change1 || change2
|
||||
|
|
@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pr := New()
|
||||
|
||||
change := tt.operations(pr)
|
||||
if change != tt.expectedChange {
|
||||
t.Errorf("change = %v, want %v", change, tt.expectedChange)
|
||||
}
|
||||
|
||||
comps := append(util.Comparers, cmpopts.EquateEmpty())
|
||||
if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
|
||||
t.Errorf("routes mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
|
||||
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
|
||||
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ func (s *State) DebugOverview() string {
|
|||
ephemeralCount := 0
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for _, node := range allNodes.All() {
|
||||
if node.Valid() {
|
||||
userName := node.Owner().Name()
|
||||
|
|
@ -103,17 +104,21 @@ func (s *State) DebugOverview() string {
|
|||
|
||||
// User statistics
|
||||
sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users)))
|
||||
|
||||
for userName, nodeCount := range userNodeCounts {
|
||||
sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Policy information
|
||||
sb.WriteString("Policy:\n")
|
||||
sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode))
|
||||
|
||||
if s.cfg.Policy.Mode == types.PolicyModeFile {
|
||||
sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// DERP information
|
||||
|
|
@ -123,6 +128,7 @@ func (s *State) DebugOverview() string {
|
|||
} else {
|
||||
sb.WriteString("DERP: not configured\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Route information
|
||||
|
|
@ -130,6 +136,7 @@ func (s *State) DebugOverview() string {
|
|||
if s.primaryRoutes.String() == "" {
|
||||
routeCount = 0
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount))
|
||||
sb.WriteString("\n")
|
||||
|
||||
|
|
@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string {
|
|||
for _, node := range region.Nodes {
|
||||
sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n",
|
||||
node.Name, node.HostName, node.DERPPort))
|
||||
|
||||
if node.STUNPort != 0 {
|
||||
sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
|
|
@ -236,7 +245,7 @@ func (s *State) DebugPolicy() (string, error) {
|
|||
|
||||
return string(pol), nil
|
||||
default:
|
||||
return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode)
|
||||
return "", fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, s.cfg.Policy.Mode)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
|
|||
if s.primaryRoutes.String() == "" {
|
||||
routeCount = 0
|
||||
}
|
||||
|
||||
info.PrimaryRoutes = routeCount
|
||||
|
||||
return info
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode
|
||||
|
|
@ -21,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
|||
|
||||
// Create NodeStore
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -44,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
|||
// 6. If DELETE came after UPDATE, the returned node should be invalid
|
||||
|
||||
done := make(chan bool, 2)
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
|
||||
var (
|
||||
updatedNode types.NodeView
|
||||
updateOk bool
|
||||
)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
|
||||
go func() {
|
||||
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
store.DeleteNode(node.ID)
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
|
@ -91,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
|
|||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -106,7 +113,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
|
|||
// Start UpdateNode in goroutine - it will queue and wait for batch
|
||||
go func() {
|
||||
node, ok := store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
resultChan <- struct {
|
||||
node types.NodeView
|
||||
|
|
@ -148,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
|
|||
node := createTestNode(3, 1, "test-user", "test-node-3")
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -156,7 +164,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
|
|||
|
||||
// Simulate UpdateNode being called
|
||||
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
require.True(t, ok, "UpdateNode should succeed")
|
||||
require.True(t, updatedNode.Valid(), "UpdateNode should return valid node")
|
||||
|
|
@ -204,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -214,21 +223,26 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
|
||||
// 2. DeleteNode (from handleLogout when client sends logout request)
|
||||
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
var (
|
||||
updatedNode types.NodeView
|
||||
updateOk bool
|
||||
)
|
||||
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
go func() {
|
||||
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
store.DeleteNode(ephemeralNode.ID)
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
|
@ -267,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
// 5. UpdateNode and DeleteNode batch together
|
||||
// 6. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 7. persistNodeToDB is called with the stale valid node
|
||||
// 8. Node gets re-inserted into database instead of staying deleted
|
||||
// 8. Node gets re-inserted into database instead of staying deleted.
|
||||
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
||||
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
|
|
@ -279,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
|||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -294,7 +309,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
node, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
endpoint := netip.MustParseAddrPort("10.0.0.1:41641")
|
||||
n.Endpoints = []netip.AddrPort{endpoint}
|
||||
})
|
||||
|
|
@ -349,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
|||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -363,7 +379,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
|||
|
||||
go func() {
|
||||
updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
updateDone <- struct {
|
||||
node types.NodeView
|
||||
|
|
@ -399,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
|||
// 3. UpdateNode and DeleteNode batch together
|
||||
// 4. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist.
|
||||
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
||||
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
|
|
@ -409,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
|||
}
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -417,7 +434,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
|||
|
||||
// UpdateNode returns a node
|
||||
updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(time.Now())
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
require.True(t, ok, "UpdateNode should succeed")
|
||||
require.True(t, updatedNode.Valid(), "updated node should be valid")
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ func netInfoFromMapRequest(
|
|||
Uint64("node.id", nodeID.Uint64()).
|
||||
Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP).
|
||||
Msg("using NetInfo from previous Hostinfo in MapRequest")
|
||||
|
||||
return currentHostinfo.NetInfo
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
package state
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
func TestNetInfoFromMapRequest(t *testing.T) {
|
||||
|
|
@ -136,26 +133,3 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
|
|||
assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node")
|
||||
})
|
||||
}
|
||||
|
||||
// Simple helper function for tests
|
||||
func createTestNodeSimple(id types.NodeID) *types.Node {
|
||||
user := types.User{
|
||||
Name: "test-user",
|
||||
}
|
||||
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
node := &types.Node{
|
||||
ID: id,
|
||||
Hostname: "test-node",
|
||||
UserID: ptr.To(uint(id)),
|
||||
User: &user,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
IPv4: &netip.Addr{},
|
||||
IPv6: &netip.Addr{},
|
||||
}
|
||||
|
||||
return node
|
||||
}
|
||||
|
|
|
|||
|
|
@ -55,8 +55,8 @@ var (
|
|||
})
|
||||
nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
Name: "nodestore_nodes_total",
|
||||
Help: "Total number of nodes in the NodeStore",
|
||||
Name: "nodestore_nodes",
|
||||
Help: "Number of nodes in the NodeStore",
|
||||
})
|
||||
nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{
|
||||
Namespace: prometheusNamespace,
|
||||
|
|
@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc
|
|||
for _, n := range allNodes {
|
||||
nodes[n.ID] = *n
|
||||
}
|
||||
|
||||
snap := snapshotFromNodes(nodes, peersFunc)
|
||||
|
||||
store := &NodeStore{
|
||||
|
|
@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView {
|
|||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
|
||||
return resultNode
|
||||
|
|
@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
|||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
|
||||
// Return the node and whether it exists (is valid)
|
||||
|
|
@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) {
|
|||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
|
|
@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() {
|
|||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= s.batchSize {
|
||||
s.applyBatch(batch)
|
||||
|
|
@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) {
|
|||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
|
|
@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) {
|
|||
nodeView := node.View()
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- nodeView
|
||||
|
||||
close(w.nodeResult)
|
||||
}
|
||||
} else {
|
||||
// Node was deleted or doesn't exist
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- types.NodeView{} // Send invalid view
|
||||
|
||||
close(w.nodeResult)
|
||||
}
|
||||
}
|
||||
|
|
@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
|||
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||
defer peersTimer.ObserveDuration()
|
||||
|
||||
return peersFunc(allNodes)
|
||||
}(),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
|
|
@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
|||
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
|
||||
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
|
||||
}
|
||||
|
||||
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
|
||||
}
|
||||
|
||||
|
|
@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string {
|
|||
|
||||
// User distribution (shows internal UserID tracking, not display owner)
|
||||
sb.WriteString("Nodes by Internal User ID:\n")
|
||||
|
||||
for userID, nodes := range snapshot.nodesByUser {
|
||||
if len(nodes) > 0 {
|
||||
userName := "unknown"
|
||||
taggedCount := 0
|
||||
|
||||
if len(nodes) > 0 && nodes[0].Valid() {
|
||||
userName = nodes[0].User().Name()
|
||||
// Count tagged nodes (which have UserID set but are owned by "tagged-devices")
|
||||
|
|
@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Peer relationships summary
|
||||
sb.WriteString("Peer Relationships:\n")
|
||||
|
||||
totalPeers := 0
|
||||
|
||||
for nodeID, peers := range snapshot.peersByNode {
|
||||
peerCount := len(peers)
|
||||
|
||||
totalPeers += peerCount
|
||||
if node, exists := snapshot.nodesByID[nodeID]; exists {
|
||||
sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n",
|
||||
nodeID, node.Hostname, peerCount))
|
||||
}
|
||||
}
|
||||
|
||||
if len(snapshot.peersByNode) > 0 {
|
||||
avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode))
|
||||
sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Node key index
|
||||
|
|
@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() {
|
|||
}
|
||||
|
||||
s.writeQueue <- w
|
||||
|
||||
<-result
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package state
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
|
@ -13,7 +14,13 @@ import (
|
|||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// Test sentinel errors for concurrent operations.
|
||||
var (
|
||||
errTestUpdateNodeFailed = errors.New("UpdateNode failed")
|
||||
errTestGetNodeFailed = errors.New("GetNode failed")
|
||||
errTestPutNodeFailed = errors.New("PutNode failed")
|
||||
)
|
||||
|
||||
func TestSnapshotFromNodes(t *testing.T) {
|
||||
|
|
@ -33,6 +40,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
t.Helper()
|
||||
assert.Empty(t, snapshot.nodesByID)
|
||||
assert.Empty(t, snapshot.allNodes)
|
||||
assert.Empty(t, snapshot.peersByNode)
|
||||
|
|
@ -45,9 +53,11 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
t.Helper()
|
||||
assert.Len(t, snapshot.nodesByID, 1)
|
||||
assert.Len(t, snapshot.allNodes, 1)
|
||||
assert.Len(t, snapshot.peersByNode, 1)
|
||||
|
|
@ -71,6 +81,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
t.Helper()
|
||||
assert.Len(t, snapshot.nodesByID, 2)
|
||||
assert.Len(t, snapshot.allNodes, 2)
|
||||
assert.Len(t, snapshot.peersByNode, 2)
|
||||
|
|
@ -96,6 +107,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
t.Helper()
|
||||
assert.Len(t, snapshot.nodesByID, 3)
|
||||
assert.Len(t, snapshot.allNodes, 3)
|
||||
assert.Len(t, snapshot.peersByNode, 3)
|
||||
|
|
@ -125,6 +137,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||
return nodes, peersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
t.Helper()
|
||||
assert.Len(t, snapshot.nodesByID, 4)
|
||||
assert.Len(t, snapshot.allNodes, 4)
|
||||
assert.Len(t, snapshot.peersByNode, 4)
|
||||
|
|
@ -174,7 +187,7 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string)
|
|||
DiscoKey: discoKey.Public(),
|
||||
Hostname: hostname,
|
||||
GivenName: hostname,
|
||||
UserID: ptr.To(userID),
|
||||
UserID: new(userID),
|
||||
User: &types.User{
|
||||
Name: username,
|
||||
DisplayName: username,
|
||||
|
|
@ -193,11 +206,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() != node.ID() {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
|
|
@ -208,6 +223,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
|
||||
nodeIsOdd := node.ID()%2 == 1
|
||||
|
||||
for _, n := range nodes {
|
||||
|
|
@ -222,6 +238,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
|
|
@ -237,6 +254,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
{
|
||||
name: "create empty store and add single node",
|
||||
setupFunc: func(t *testing.T) *NodeStore {
|
||||
t.Helper()
|
||||
return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
},
|
||||
steps: []testStep{
|
||||
|
|
@ -455,10 +473,13 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
// Add nodes in sequence
|
||||
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
assert.True(t, n1.Valid())
|
||||
|
||||
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
assert.True(t, n2.Valid())
|
||||
|
||||
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
assert.True(t, n3.Valid())
|
||||
|
||||
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
assert.True(t, n4.Valid())
|
||||
|
||||
|
|
@ -526,16 +547,20 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2 types.NodeView
|
||||
var newNode3 types.NodeView
|
||||
var ok1, ok2 bool
|
||||
var (
|
||||
resultNode1, resultNode2 types.NodeView
|
||||
newNode3 types.NodeView
|
||||
ok1, ok2 bool
|
||||
)
|
||||
|
||||
// These should all be processed in the same batch
|
||||
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node1"
|
||||
n.GivenName = "batch-given-1"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
|
|
@ -544,12 +569,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
n.Hostname = "batch-updated-node2"
|
||||
n.GivenName = "batch-given-2"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
node3 := createTestNode(3, 1, "user1", "node3")
|
||||
newNode3 = store.PutNode(node3)
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
|
|
@ -602,20 +629,23 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
// This test verifies that when multiple updates to the same node
|
||||
// are batched together, each returned node reflects ALL changes
|
||||
// in the batch, not just the individual update's changes.
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
var (
|
||||
resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
ok1, ok2, ok3 bool
|
||||
)
|
||||
|
||||
// These updates all modify node 1 and should be batched together
|
||||
// The final state should have all three modifications applied
|
||||
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "multi-update-hostname"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
|
|
@ -623,6 +653,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "multi-update-givenname"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
|
|
@ -630,6 +661,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Tags = []string{"tag1", "tag2"}
|
||||
})
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
|
|
@ -723,14 +755,18 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var result1, result2, result3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
var (
|
||||
result1, result2, result3 types.NodeView
|
||||
ok1, ok2, ok3 bool
|
||||
)
|
||||
|
||||
// Start concurrent updates
|
||||
|
||||
go func() {
|
||||
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "concurrent-db-hostname"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
|
|
@ -738,6 +774,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "concurrent-db-given"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
|
|
@ -745,6 +782,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Tags = []string{"concurrent-tag"}
|
||||
})
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
|
|
@ -828,6 +866,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := tt.setupFunc(t)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -847,88 +886,107 @@ type testStep struct {
|
|||
|
||||
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
|
||||
|
||||
// Helper for concurrent test nodes
|
||||
// Helper for concurrent test nodes.
|
||||
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
return types.Node{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
MachineKey: machineKey.Public(),
|
||||
NodeKey: nodeKey.Public(),
|
||||
UserID: ptr.To(uint(1)),
|
||||
UserID: new(uint(1)),
|
||||
User: &types.User{
|
||||
Name: "concurrent-test-user",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// --- Concurrency: concurrent PutNode operations ---
|
||||
// --- Concurrency: concurrent PutNode operations ---.
|
||||
func TestNodeStoreConcurrentPutNode(t *testing.T) {
|
||||
const concurrentOps = 20
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
results := make(chan bool, concurrentOps)
|
||||
for i := range concurrentOps {
|
||||
wg.Add(1)
|
||||
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
|
||||
//nolint:gosec
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
|
||||
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Batching: concurrent ops fit in one batch ---
|
||||
// --- Batching: concurrent ops fit in one batch ---.
|
||||
func TestNodeStoreBatchingEfficiency(t *testing.T) {
|
||||
const batchSize = 10
|
||||
const ops = 15 // more than batchSize
|
||||
const ops = 15
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
results := make(chan bool, ops)
|
||||
for i := range ops {
|
||||
wg.Add(1)
|
||||
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
|
||||
//nolint:gosec
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
|
||||
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Race conditions: many goroutines on same node ---
|
||||
// --- Race conditions: many goroutines on same node ---.
|
||||
func TestNodeStoreRaceConditions(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -937,13 +995,18 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
|||
resultNode := store.PutNode(node)
|
||||
require.True(t, resultNode.Valid())
|
||||
|
||||
const numGoroutines = 30
|
||||
const opsPerGoroutine = 10
|
||||
const (
|
||||
numGoroutines = 30
|
||||
opsPerGoroutine = 10
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
errors := make(chan error, numGoroutines*opsPerGoroutine)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
|
||||
|
|
@ -954,40 +1017,46 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
|||
n.Hostname = "race-updated"
|
||||
})
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j)
|
||||
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestUpdateNodeFailed, gid, j)
|
||||
}
|
||||
case 1:
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
if !found || !retrieved.Valid() {
|
||||
errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j)
|
||||
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestGetNodeFailed, gid, j)
|
||||
}
|
||||
case 2:
|
||||
newNode := createConcurrentTestNode(nodeID, "race-put")
|
||||
|
||||
resultNode := store.PutNode(newNode)
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
|
||||
errors <- fmt.Errorf("%w in goroutine %d, op %d", errTestPutNodeFailed, gid, j)
|
||||
}
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
errorCount := 0
|
||||
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
|
||||
if errorCount > 0 {
|
||||
t.Fatalf("Race condition test failed with %d errors", errorCount)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Resource cleanup: goroutine leak detection ---
|
||||
// --- Resource cleanup: goroutine leak detection ---.
|
||||
func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
// initialGoroutines := runtime.NumGoroutine()
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -1001,7 +1070,7 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
|||
|
||||
const ops = 100
|
||||
for i := range ops {
|
||||
nodeID := types.NodeID(i + 1)
|
||||
nodeID := types.NodeID(i + 1) //nolint:gosec
|
||||
node := createConcurrentTestNode(nodeID, "cleanup-node")
|
||||
resultNode := store.PutNode(node)
|
||||
assert.True(t, resultNode.Valid())
|
||||
|
|
@ -1010,10 +1079,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
|||
})
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
assert.True(t, found && retrieved.Valid())
|
||||
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
|
||||
// Wait for goroutines to settle and check for leaks
|
||||
|
|
@ -1024,9 +1095,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
|||
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
|
||||
}
|
||||
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---.
|
||||
func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -1034,36 +1106,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
const ops = 30
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
putResults := make([]error, ops)
|
||||
updateResults := make([]error, ops)
|
||||
|
||||
// Launch all PutNode operations concurrently
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
|
||||
startPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
|
||||
node := createConcurrentTestNode(id, "timeout-node")
|
||||
resultNode := store.PutNode(node)
|
||||
endPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
|
||||
|
||||
if !resultNode.Valid() {
|
||||
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
|
||||
putResults[idx-1] = fmt.Errorf("%w for node %d", errTestPutNodeFailed, id)
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Launch all UpdateNode operations concurrently
|
||||
wg = sync.WaitGroup{}
|
||||
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
nodeID := types.NodeID(i) //nolint:gosec
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
|
||||
startUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
|
||||
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
|
||||
|
|
@ -1071,31 +1154,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||
})
|
||||
endUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
|
||||
|
||||
if !ok || !resultNode.Valid() {
|
||||
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
|
||||
updateResults[idx-1] = fmt.Errorf("%w for node %d", errTestUpdateNodeFailed, id)
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
errorCount := 0
|
||||
|
||||
for _, err := range putResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, err := range updateResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
if errorCount == 0 {
|
||||
t.Log("All concurrent operations completed successfully within timeout")
|
||||
} else {
|
||||
|
|
@ -1107,13 +1199,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Edge case: update non-existent node ---
|
||||
// --- Edge case: update non-existent node ---.
|
||||
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
||||
for i := range 10 {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
store.Start()
|
||||
nonExistentID := types.NodeID(999 + i)
|
||||
|
||||
nonExistentID := types.NodeID(999 + i) //nolint:gosec
|
||||
updateCallCount := 0
|
||||
|
||||
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
|
||||
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
|
||||
updateCallCount++
|
||||
|
|
@ -1127,20 +1221,22 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Allocation benchmark ---
|
||||
// --- Allocation benchmark ---.
|
||||
func BenchmarkNodeStoreAllocations(b *testing.B) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
for i := 0; b.Loop(); i++ {
|
||||
nodeID := types.NodeID(i + 1)
|
||||
nodeID := types.NodeID(i + 1) //nolint:gosec
|
||||
node := createConcurrentTestNode(nodeID, "bench-node")
|
||||
store.PutNode(node)
|
||||
store.UpdateNode(nodeID, func(n *types.Node) {
|
||||
n.Hostname = "bench-updated"
|
||||
})
|
||||
store.GetNode(nodeID)
|
||||
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,10 +25,8 @@ import (
|
|||
"github.com/rs/zerolog/log"
|
||||
"golang.org/x/sync/errgroup"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/net/tsaddr"
|
||||
"tailscale.com/tailcfg"
|
||||
"tailscale.com/types/key"
|
||||
"tailscale.com/types/ptr"
|
||||
"tailscale.com/types/views"
|
||||
zcache "zgo.at/zcache/v2"
|
||||
)
|
||||
|
|
@ -133,7 +131,7 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||
// On startup, all nodes should be marked as offline until they reconnect
|
||||
// This ensures we don't have stale online status from previous runs
|
||||
for _, node := range nodes {
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.IsOnline = new(false)
|
||||
}
|
||||
users, err := db.ListUsers()
|
||||
if err != nil {
|
||||
|
|
@ -189,7 +187,8 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||
func (s *State) Close() error {
|
||||
s.nodeStore.Stop()
|
||||
|
||||
if err := s.db.Close(); err != nil {
|
||||
err := s.db.Close()
|
||||
if err != nil {
|
||||
return fmt.Errorf("closing database: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -225,6 +224,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
|
|||
// propagate correctly when switching between policy types.
|
||||
s.nodeStore.RebuildPeerMaps()
|
||||
|
||||
//nolint:prealloc
|
||||
cs := []change.Change{change.PolicyChange()}
|
||||
|
||||
// Always call autoApproveNodes during policy reload, regardless of whether
|
||||
|
|
@ -255,6 +255,7 @@ func (s *State) ReloadPolicy() ([]change.Change, error) {
|
|||
// CreateUser creates a new user and updates the policy manager.
|
||||
// Returns the created user, change set, and any error.
|
||||
func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) {
|
||||
//nolint:noinlineerr
|
||||
if err := s.db.DB.Save(&user).Error; err != nil {
|
||||
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
|
||||
}
|
||||
|
|
@ -289,6 +290,7 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
|
|||
return nil, err
|
||||
}
|
||||
|
||||
//nolint:noinlineerr
|
||||
if err := updateFn(user); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -468,10 +470,8 @@ func (s *State) Connect(id types.NodeID) []change.Change {
|
|||
// CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification
|
||||
// This ensures that when the NodeCameOnline change is distributed and processed by other nodes,
|
||||
// the NodeStore already reflects the correct online status for full map generation.
|
||||
// now := time.Now()
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
n.IsOnline = ptr.To(true)
|
||||
// n.LastSeen = ptr.To(now)
|
||||
n.IsOnline = new(true)
|
||||
})
|
||||
if !ok {
|
||||
return nil
|
||||
|
|
@ -495,16 +495,17 @@ func (s *State) Connect(id types.NodeID) []change.Change {
|
|||
|
||||
// Disconnect marks a node as disconnected and updates its primary routes in the state.
|
||||
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
|
||||
//nolint:staticcheck
|
||||
now := time.Now()
|
||||
|
||||
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
|
||||
n.LastSeen = ptr.To(now)
|
||||
n.LastSeen = new(now)
|
||||
// NodeStore is the source of truth for all node state including online status.
|
||||
n.IsOnline = ptr.To(false)
|
||||
n.IsOnline = new(false)
|
||||
})
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("node not found: %d", id)
|
||||
return nil, fmt.Errorf("%w: %d", ErrNodeNotFound, id)
|
||||
}
|
||||
|
||||
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected")
|
||||
|
|
@ -745,13 +746,14 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
|
|||
|
||||
// RenameNode changes the display name of a node.
|
||||
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) {
|
||||
if err := util.ValidateHostname(newName); err != nil {
|
||||
err := util.ValidateHostname(newName)
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err)
|
||||
}
|
||||
|
||||
// Check name uniqueness against NodeStore
|
||||
allNodes := s.nodeStore.ListNodes()
|
||||
for i := 0; i < allNodes.Len(); i++ {
|
||||
for i := range allNodes.Len() {
|
||||
node := allNodes.At(i)
|
||||
if node.ID() != nodeID && node.AsStruct().GivenName == newName {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName)
|
||||
|
|
@ -790,7 +792,7 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
|
|||
// Preserve online status and NetInfo when refreshing from database
|
||||
existingNode, exists := s.nodeStore.GetNode(node.ID)
|
||||
if exists && existingNode.Valid() {
|
||||
node.IsOnline = ptr.To(existingNode.IsOnline().Get())
|
||||
node.IsOnline = new(existingNode.IsOnline().Get())
|
||||
|
||||
// TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics
|
||||
// when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest).
|
||||
|
|
@ -818,6 +820,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
|
|||
|
||||
var updates []change.Change
|
||||
|
||||
//nolint:unqueryvet
|
||||
for _, node := range s.nodeStore.ListNodes().All() {
|
||||
if !node.Valid() {
|
||||
continue
|
||||
|
|
@ -1117,7 +1120,7 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
|
|||
DiscoKey: params.DiscoKey,
|
||||
Hostinfo: params.Hostinfo,
|
||||
Endpoints: params.Endpoints,
|
||||
LastSeen: ptr.To(time.Now()),
|
||||
LastSeen: new(time.Now()),
|
||||
RegisterMethod: params.RegisterMethod,
|
||||
Expiry: params.Expiry,
|
||||
}
|
||||
|
|
@ -1218,7 +1221,8 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
|
|||
|
||||
// New node - database first to get ID, then NodeStore
|
||||
savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
|
||||
if err := tx.Save(&nodeToRegister).Error; err != nil {
|
||||
err := tx.Save(&nodeToRegister).Error
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to save node: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -1407,8 +1411,8 @@ func (s *State) HandleNodeFromAuthPath(
|
|||
|
||||
node.Endpoints = regEntry.Node.Endpoints
|
||||
node.RegisterMethod = regEntry.Node.RegisterMethod
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
node.IsOnline = new(false)
|
||||
node.LastSeen = new(time.Now())
|
||||
|
||||
// Tagged nodes keep their existing expiry (disabled).
|
||||
// User-owned nodes update expiry from the provided value or registration entry.
|
||||
|
|
@ -1669,8 +1673,8 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||
// Only update AuthKey reference
|
||||
node.AuthKey = pak
|
||||
node.AuthKeyID = &pak.ID
|
||||
node.IsOnline = ptr.To(false)
|
||||
node.LastSeen = ptr.To(time.Now())
|
||||
node.IsOnline = new(false)
|
||||
node.LastSeen = new(time.Now())
|
||||
|
||||
// Tagged nodes keep their existing expiry (disabled).
|
||||
// User-owned nodes update expiry from the client request.
|
||||
|
|
@ -1697,7 +1701,7 @@ func (s *State) HandleNodeFromPreAuthKey(
|
|||
}
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
return nil, nil //nolint:nilnil
|
||||
})
|
||||
if err != nil {
|
||||
return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err)
|
||||
|
|
@ -2122,8 +2126,8 @@ func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool {
|
|||
newRoutes = []netip.Prefix{}
|
||||
}
|
||||
|
||||
tsaddr.SortPrefixes(oldRoutes)
|
||||
tsaddr.SortPrefixes(newRoutes)
|
||||
slices.SortFunc(oldRoutes, netip.Prefix.Compare)
|
||||
slices.SortFunc(newRoutes, netip.Prefix.Compare)
|
||||
|
||||
return !slices.Equal(oldRoutes, newRoutes)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,9 @@ import (
|
|||
"tailscale.com/types/logger"
|
||||
)
|
||||
|
||||
// ErrNoCertDomains is returned when no cert domains are available for HTTPS.
|
||||
var ErrNoCertDomains = errors.New("no cert domains available for HTTPS")
|
||||
|
||||
func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error {
|
||||
opts := tailsql.Options{
|
||||
Hostname: "tailsql-headscale",
|
||||
|
|
@ -71,7 +74,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
|||
// When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443.
|
||||
certDomains := tsNode.CertDomains()
|
||||
if len(certDomains) == 0 {
|
||||
return errors.New("no cert domains available for HTTPS")
|
||||
return ErrNoCertDomains
|
||||
}
|
||||
base := "https://" + certDomains[0]
|
||||
go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
|
@ -93,6 +96,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
|||
mux := tsql.NewMux()
|
||||
tsweb.Debugger(mux)
|
||||
go http.Serve(lst, mux)
|
||||
|
||||
logf("TailSQL started")
|
||||
<-ctx.Done()
|
||||
logf("TailSQL shutting down...")
|
||||
|
|
|
|||
|
|
@ -15,43 +15,43 @@ import (
|
|||
// Material for MkDocs design system - exact values from official docs.
|
||||
const (
|
||||
// Text colors - from --md-default-fg-color CSS variables.
|
||||
colorTextPrimary = "#000000de" //nolint:unused // rgba(0,0,0,0.87) - Body text
|
||||
colorTextSecondary = "#0000008a" //nolint:unused // rgba(0,0,0,0.54) - Headings (--md-default-fg-color--light)
|
||||
colorTextTertiary = "#00000052" //nolint:unused // rgba(0,0,0,0.32) - Lighter text
|
||||
colorTextLightest = "#00000012" //nolint:unused // rgba(0,0,0,0.07) - Lightest text
|
||||
colorTextPrimary = "#000000de" //nolint:unused
|
||||
colorTextSecondary = "#0000008a" //nolint:unused
|
||||
colorTextTertiary = "#00000052" //nolint:unused
|
||||
colorTextLightest = "#00000012" //nolint:unused
|
||||
|
||||
// Code colors - from --md-code-* CSS variables.
|
||||
colorCodeFg = "#36464e" //nolint:unused // Code text color (--md-code-fg-color)
|
||||
colorCodeBg = "#f5f5f5" //nolint:unused // Code background (--md-code-bg-color)
|
||||
colorCodeFg = "#36464e" //nolint:unused
|
||||
colorCodeBg = "#f5f5f5" //nolint:unused
|
||||
|
||||
// Border colors.
|
||||
colorBorderLight = "#e5e7eb" //nolint:unused // Light borders
|
||||
colorBorderMedium = "#d1d5db" //nolint:unused // Medium borders
|
||||
colorBorderLight = "#e5e7eb" //nolint:unused
|
||||
colorBorderMedium = "#d1d5db" //nolint:unused
|
||||
|
||||
// Background colors.
|
||||
colorBackgroundPage = "#ffffff" //nolint:unused // Page background
|
||||
colorBackgroundCard = "#ffffff" //nolint:unused // Card/content background
|
||||
colorBackgroundPage = "#ffffff" //nolint:unused
|
||||
colorBackgroundCard = "#ffffff" //nolint:unused
|
||||
|
||||
// Accent colors - from --md-primary/accent-fg-color.
|
||||
colorPrimaryAccent = "#4051b5" //nolint:unused // Primary accent (links)
|
||||
colorAccent = "#526cfe" //nolint:unused // Secondary accent
|
||||
colorPrimaryAccent = "#4051b5" //nolint:unused
|
||||
colorAccent = "#526cfe" //nolint:unused
|
||||
|
||||
// Success colors.
|
||||
colorSuccess = "#059669" //nolint:unused // Success states
|
||||
colorSuccessLight = "#d1fae5" //nolint:unused // Success backgrounds
|
||||
colorSuccess = "#059669" //nolint:unused
|
||||
colorSuccessLight = "#d1fae5" //nolint:unused
|
||||
)
|
||||
|
||||
// Spacing System
|
||||
// Based on 4px/8px base unit for consistent rhythm.
|
||||
// Uses rem units for scalability with user font size preferences.
|
||||
const (
|
||||
spaceXS = "0.25rem" //nolint:unused // 4px - Tight spacing
|
||||
spaceS = "0.5rem" //nolint:unused // 8px - Small spacing
|
||||
spaceM = "1rem" //nolint:unused // 16px - Medium spacing (base)
|
||||
spaceL = "1.5rem" //nolint:unused // 24px - Large spacing
|
||||
spaceXL = "2rem" //nolint:unused // 32px - Extra large spacing
|
||||
space2XL = "3rem" //nolint:unused // 48px - 2x extra large spacing
|
||||
space3XL = "4rem" //nolint:unused // 64px - 3x extra large spacing
|
||||
spaceXS = "0.25rem" //nolint:unused
|
||||
spaceS = "0.5rem" //nolint:unused
|
||||
spaceM = "1rem" //nolint:unused
|
||||
spaceL = "1.5rem" //nolint:unused
|
||||
spaceXL = "2rem" //nolint:unused
|
||||
space2XL = "3rem" //nolint:unused
|
||||
space3XL = "4rem" //nolint:unused
|
||||
)
|
||||
|
||||
// Typography System
|
||||
|
|
@ -63,26 +63,26 @@ const (
|
|||
fontFamilyCode = `"Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace` //nolint:unused
|
||||
|
||||
// Font sizes - from .md-typeset CSS rules.
|
||||
fontSizeBase = "0.8rem" //nolint:unused // 12.8px - Base text (.md-typeset)
|
||||
fontSizeH1 = "2em" //nolint:unused // 2x base - Main headings
|
||||
fontSizeH2 = "1.5625em" //nolint:unused // 1.5625x base - Section headings
|
||||
fontSizeH3 = "1.25em" //nolint:unused // 1.25x base - Subsection headings
|
||||
fontSizeSmall = "0.8em" //nolint:unused // 0.8x base - Small text
|
||||
fontSizeCode = "0.85em" //nolint:unused // 0.85x base - Inline code
|
||||
fontSizeBase = "0.8rem" //nolint:unused
|
||||
fontSizeH1 = "2em" //nolint:unused
|
||||
fontSizeH2 = "1.5625em" //nolint:unused
|
||||
fontSizeH3 = "1.25em" //nolint:unused
|
||||
fontSizeSmall = "0.8em" //nolint:unused
|
||||
fontSizeCode = "0.85em" //nolint:unused
|
||||
|
||||
// Line heights - from .md-typeset CSS rules.
|
||||
lineHeightBase = "1.6" //nolint:unused // Body text (.md-typeset)
|
||||
lineHeightH1 = "1.3" //nolint:unused // H1 headings
|
||||
lineHeightH2 = "1.4" //nolint:unused // H2 headings
|
||||
lineHeightH3 = "1.5" //nolint:unused // H3 headings
|
||||
lineHeightCode = "1.4" //nolint:unused // Code blocks (pre)
|
||||
lineHeightBase = "1.6" //nolint:unused
|
||||
lineHeightH1 = "1.3" //nolint:unused
|
||||
lineHeightH2 = "1.4" //nolint:unused
|
||||
lineHeightH3 = "1.5" //nolint:unused
|
||||
lineHeightCode = "1.4" //nolint:unused
|
||||
)
|
||||
|
||||
// Responsive Container Component
|
||||
// Creates a centered container with responsive padding and max-width.
|
||||
// Mobile-first approach: starts at 100% width with padding, constrains on larger screens.
|
||||
//
|
||||
//nolint:unused // Reserved for future use in Phase 4.
|
||||
//nolint:unused
|
||||
func responsiveContainer(children ...elem.Node) *elem.Element {
|
||||
return elem.Div(attrs.Props{
|
||||
attrs.Style: styles.Props{
|
||||
|
|
@ -100,7 +100,7 @@ func responsiveContainer(children ...elem.Node) *elem.Element {
|
|||
// - title: Optional title for the card (empty string for no title)
|
||||
// - children: Content elements to display in the card
|
||||
//
|
||||
//nolint:unused // Reserved for future use in Phase 4.
|
||||
//nolint:unused
|
||||
func card(title string, children ...elem.Node) *elem.Element {
|
||||
cardContent := children
|
||||
if title != "" {
|
||||
|
|
@ -134,7 +134,7 @@ func card(title string, children ...elem.Node) *elem.Element {
|
|||
// EXTRACTED FROM: .md-typeset pre CSS rules
|
||||
// Exact styling from Material for MkDocs documentation.
|
||||
//
|
||||
//nolint:unused // Used across apple.go, windows.go, register_web.go templates.
|
||||
//nolint:unused
|
||||
func codeBlock(code string) *elem.Element {
|
||||
return elem.Pre(attrs.Props{
|
||||
attrs.Style: styles.Props{
|
||||
|
|
@ -164,7 +164,7 @@ func codeBlock(code string) *elem.Element {
|
|||
// Returns inline styles for the main content container that matches .md-typeset.
|
||||
// EXTRACTED FROM: .md-typeset CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used in general.go for mdTypesetBody.
|
||||
//nolint:unused
|
||||
func baseTypesetStyles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.FontSize: fontSizeBase, // 0.8rem
|
||||
|
|
@ -180,7 +180,7 @@ func baseTypesetStyles() styles.Props {
|
|||
// Returns inline styles for H1 headings that match .md-typeset h1.
|
||||
// EXTRACTED FROM: .md-typeset h1 CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used across templates for main headings.
|
||||
//nolint:unused
|
||||
func h1Styles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54)
|
||||
|
|
@ -198,7 +198,7 @@ func h1Styles() styles.Props {
|
|||
// Returns inline styles for H2 headings that match .md-typeset h2.
|
||||
// EXTRACTED FROM: .md-typeset h2 CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used across templates for section headings.
|
||||
//nolint:unused
|
||||
func h2Styles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.FontSize: fontSizeH2, // 1.5625em
|
||||
|
|
@ -216,7 +216,7 @@ func h2Styles() styles.Props {
|
|||
// Returns inline styles for H3 headings that match .md-typeset h3.
|
||||
// EXTRACTED FROM: .md-typeset h3 CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used across templates for subsection headings.
|
||||
//nolint:unused
|
||||
func h3Styles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.FontSize: fontSizeH3, // 1.25em
|
||||
|
|
@ -234,7 +234,7 @@ func h3Styles() styles.Props {
|
|||
// Returns inline styles for paragraphs that match .md-typeset p.
|
||||
// EXTRACTED FROM: .md-typeset p CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used for consistent paragraph spacing.
|
||||
//nolint:unused
|
||||
func paragraphStyles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.Margin: "1em 0",
|
||||
|
|
@ -250,7 +250,7 @@ func paragraphStyles() styles.Props {
|
|||
// Returns inline styles for ordered lists that match .md-typeset ol.
|
||||
// EXTRACTED FROM: .md-typeset ol CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used for numbered instruction lists.
|
||||
//nolint:unused
|
||||
func orderedListStyles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.MarginBottom: "1em",
|
||||
|
|
@ -268,7 +268,7 @@ func orderedListStyles() styles.Props {
|
|||
// Returns inline styles for unordered lists that match .md-typeset ul.
|
||||
// EXTRACTED FROM: .md-typeset ul CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used for bullet point lists.
|
||||
//nolint:unused
|
||||
func unorderedListStyles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.MarginBottom: "1em",
|
||||
|
|
@ -287,7 +287,7 @@ func unorderedListStyles() styles.Props {
|
|||
// EXTRACTED FROM: .md-typeset a CSS rule from Material for MkDocs.
|
||||
// Note: Hover states cannot be implemented with inline styles.
|
||||
//
|
||||
//nolint:unused // Used for text links.
|
||||
//nolint:unused
|
||||
func linkStyles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.Color: colorPrimaryAccent, // #4051b5 - var(--md-primary-fg-color)
|
||||
|
|
@ -301,7 +301,7 @@ func linkStyles() styles.Props {
|
|||
// Returns inline styles for inline code that matches .md-typeset code.
|
||||
// EXTRACTED FROM: .md-typeset code CSS rule from Material for MkDocs.
|
||||
//
|
||||
//nolint:unused // Used for inline code snippets.
|
||||
//nolint:unused
|
||||
func inlineCodeStyles() styles.Props {
|
||||
return styles.Props{
|
||||
styles.BackgroundColor: colorCodeBg, // #f5f5f5
|
||||
|
|
@ -317,7 +317,7 @@ func inlineCodeStyles() styles.Props {
|
|||
// Inline Code Component
|
||||
// For inline code snippets within text.
|
||||
//
|
||||
//nolint:unused // Reserved for future inline code usage.
|
||||
//nolint:unused
|
||||
func inlineCode(code string) *elem.Element {
|
||||
return elem.Code(attrs.Props{
|
||||
attrs.Style: inlineCodeStyles().ToInline(),
|
||||
|
|
@ -327,7 +327,7 @@ func inlineCode(code string) *elem.Element {
|
|||
// orDivider creates a visual "or" divider between sections.
|
||||
// Styled with lines on either side for better visual separation.
|
||||
//
|
||||
//nolint:unused // Used in apple.go template.
|
||||
//nolint:unused
|
||||
func orDivider() *elem.Element {
|
||||
return elem.Div(attrs.Props{
|
||||
attrs.Style: styles.Props{
|
||||
|
|
@ -367,7 +367,7 @@ func orDivider() *elem.Element {
|
|||
|
||||
// warningBox creates a warning message box with icon and content.
|
||||
//
|
||||
//nolint:unused // Used in apple.go template.
|
||||
//nolint:unused
|
||||
func warningBox(title, message string) *elem.Element {
|
||||
return elem.Div(attrs.Props{
|
||||
attrs.Style: styles.Props{
|
||||
|
|
@ -404,7 +404,7 @@ func warningBox(title, message string) *elem.Element {
|
|||
|
||||
// downloadButton creates a nice button-style link for downloads.
|
||||
//
|
||||
//nolint:unused // Used in apple.go template.
|
||||
//nolint:unused
|
||||
func downloadButton(href, text string) *elem.Element {
|
||||
return elem.A(attrs.Props{
|
||||
attrs.Href: href,
|
||||
|
|
@ -428,7 +428,7 @@ func downloadButton(href, text string) *elem.Element {
|
|||
// Creates a link with proper security attributes for external URLs.
|
||||
// Automatically adds rel="noreferrer noopener" and target="_blank".
|
||||
//
|
||||
//nolint:unused // Used in apple.go, oidc_callback.go templates.
|
||||
//nolint:unused
|
||||
func externalLink(href, text string) *elem.Element {
|
||||
return elem.A(attrs.Props{
|
||||
attrs.Href: href,
|
||||
|
|
@ -444,7 +444,7 @@ func externalLink(href, text string) *elem.Element {
|
|||
// Instruction Step Component
|
||||
// For numbered instruction lists with consistent formatting.
|
||||
//
|
||||
//nolint:unused // Reserved for future use in Phase 4.
|
||||
//nolint:unused
|
||||
func instructionStep(_ int, text string) *elem.Element {
|
||||
return elem.Li(attrs.Props{
|
||||
attrs.Style: styles.Props{
|
||||
|
|
@ -457,7 +457,7 @@ func instructionStep(_ int, text string) *elem.Element {
|
|||
// Status Message Component
|
||||
// For displaying success/error/info messages with appropriate styling.
|
||||
//
|
||||
//nolint:unused // Reserved for future use in Phase 4.
|
||||
//nolint:unused
|
||||
func statusMessage(message string, isSuccess bool) *elem.Element {
|
||||
bgColor := colorSuccessLight
|
||||
textColor := colorSuccess
|
||||
|
|
|
|||
|
|
@ -333,7 +333,7 @@ func NodeOnline(nodeID types.NodeID) Change {
|
|||
PeerPatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: nodeID.NodeID(),
|
||||
Online: ptrTo(true),
|
||||
Online: new(true),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -346,7 +346,7 @@ func NodeOffline(nodeID types.NodeID) Change {
|
|||
PeerPatches: []*tailcfg.PeerChange{
|
||||
{
|
||||
NodeID: nodeID.NodeID(),
|
||||
Online: ptrTo(false),
|
||||
Online: new(false),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
|
@ -365,11 +365,6 @@ func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change {
|
|||
}
|
||||
}
|
||||
|
||||
// ptrTo returns a pointer to the given value.
|
||||
func ptrTo[T any](v T) *T {
|
||||
return &v
|
||||
}
|
||||
|
||||
// High-level change constructors
|
||||
|
||||
// NodeAdded returns a Change for when a node is added or updated.
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ func TestChange_FieldSync(t *testing.T) {
|
|||
typ := reflect.TypeFor[Change]()
|
||||
boolCount := 0
|
||||
|
||||
for i := range typ.NumField() {
|
||||
if typ.Field(i).Type.Kind() == reflect.Bool {
|
||||
for field := range typ.Fields() {
|
||||
if field.Type.Kind() == reflect.Bool {
|
||||
boolCount++
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -20,7 +20,10 @@ const (
|
|||
DatabaseSqlite = "sqlite3"
|
||||
)
|
||||
|
||||
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
var (
|
||||
ErrCannotParsePrefix = errors.New("cannot parse prefix")
|
||||
ErrInvalidRegIDLength = errors.New("registration ID has invalid length")
|
||||
)
|
||||
|
||||
type StateUpdateType int
|
||||
|
||||
|
|
@ -175,8 +178,9 @@ func MustRegistrationID() RegistrationID {
|
|||
|
||||
func RegistrationIDFromString(str string) (RegistrationID, error) {
|
||||
if len(str) != RegistrationIDLength {
|
||||
return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength)
|
||||
return "", fmt.Errorf("%w: expected %d characters", ErrInvalidRegIDLength, RegistrationIDLength)
|
||||
}
|
||||
|
||||
return RegistrationID(str), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -30,13 +30,16 @@ const (
|
|||
PKCEMethodS256 string = "S256"
|
||||
|
||||
defaultNodeStoreBatchSize = 100
|
||||
defaultWALAutocheckpoint = 1000 // SQLite default
|
||||
)
|
||||
|
||||
var (
|
||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||
errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive")
|
||||
errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable")
|
||||
errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'")
|
||||
errNoPrefixConfigured = errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
errInvalidAllocationStrategy = errors.New("invalid prefixes.allocation strategy")
|
||||
)
|
||||
|
||||
type IPAllocationStrategy string
|
||||
|
|
@ -301,6 +304,7 @@ func validatePKCEMethod(method string) error {
|
|||
if method != PKCEMethodPlain && method != PKCEMethodS256 {
|
||||
return errInvalidPKCEMethod
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -377,7 +381,7 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600)
|
||||
|
||||
viper.SetDefault("database.sqlite.write_ahead_log", true)
|
||||
viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default
|
||||
viper.SetDefault("database.sqlite.wal_autocheckpoint", defaultWALAutocheckpoint)
|
||||
|
||||
viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"})
|
||||
viper.SetDefault("oidc.only_start_if_oidc_is_available", true)
|
||||
|
|
@ -402,7 +406,8 @@ func LoadConfig(path string, isFile bool) error {
|
|||
viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential))
|
||||
|
||||
if err := viper.ReadInConfig(); err != nil {
|
||||
if _, ok := err.(viper.ConfigFileNotFoundError); ok {
|
||||
var configFileNotFoundError viper.ConfigFileNotFoundError
|
||||
if errors.As(err, &configFileNotFoundError) {
|
||||
log.Warn().Msg("No config file found, using defaults")
|
||||
return nil
|
||||
}
|
||||
|
|
@ -442,7 +447,8 @@ func validateServerConfig() error {
|
|||
depr.fatal("oidc.map_legacy_users")
|
||||
|
||||
if viper.GetBool("oidc.enabled") {
|
||||
if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil {
|
||||
err := validatePKCEMethod(viper.GetString("oidc.pkce.method"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
|
@ -910,6 +916,7 @@ func LoadCLIConfig() (*Config, error) {
|
|||
// LoadServerConfig returns the full Headscale configuration to
|
||||
// host a Headscale server. This is called as part of `headscale serve`.
|
||||
func LoadServerConfig() (*Config, error) {
|
||||
//nolint:noinlineerr
|
||||
if err := validateServerConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
@ -928,7 +935,7 @@ func LoadServerConfig() (*Config, error) {
|
|||
}
|
||||
|
||||
if prefix4 == nil && prefix6 == nil {
|
||||
return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required")
|
||||
return nil, errNoPrefixConfigured
|
||||
}
|
||||
|
||||
allocStr := viper.GetString("prefixes.allocation")
|
||||
|
|
@ -940,7 +947,8 @@ func LoadServerConfig() (*Config, error) {
|
|||
alloc = IPAllocationStrategyRandom
|
||||
default:
|
||||
return nil, fmt.Errorf(
|
||||
"config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s",
|
||||
"%w: %q, allowed options: %s, %s",
|
||||
errInvalidAllocationStrategy,
|
||||
allocStr,
|
||||
IPAllocationStrategySequential,
|
||||
IPAllocationStrategyRandom,
|
||||
|
|
@ -979,7 +987,8 @@ func LoadServerConfig() (*Config, error) {
|
|||
// - Control plane runs on login.tailscale.com/controlplane.tailscale.com
|
||||
// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net)
|
||||
if dnsConfig.BaseDomain != "" {
|
||||
if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil {
|
||||
err := isSafeServerURL(serverURL, dnsConfig.BaseDomain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
|
@ -1082,6 +1091,7 @@ func LoadServerConfig() (*Config, error) {
|
|||
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||
return workers
|
||||
}
|
||||
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
||||
|
|
@ -1117,6 +1127,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
|
|||
}
|
||||
|
||||
s := len(serverDomainParts)
|
||||
|
||||
b := len(baseDomainParts)
|
||||
for i := range baseDomainParts {
|
||||
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||
|
|
|
|||
|
|
@ -349,11 +349,8 @@ func TestReadConfigFromEnv(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTLSConfigValidation(t *testing.T) {
|
||||
tmpDir, err := os.MkdirTemp("", "headscale")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
// defer os.RemoveAll(tmpDir)
|
||||
tmpDir := t.TempDir()
|
||||
|
||||
configYaml := []byte(`---
|
||||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: ""
|
||||
|
|
@ -363,7 +360,8 @@ noise:
|
|||
|
||||
// Populate a custom config file
|
||||
configFilePath := filepath.Join(tmpDir, "config.yaml")
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
|
||||
err := os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
}
|
||||
|
|
@ -398,10 +396,12 @@ server_url: http://127.0.0.1:8080
|
|||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||
`)
|
||||
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
}
|
||||
|
||||
err = LoadConfig(tmpDir, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
@ -463,6 +463,7 @@ func TestSafeServerURL(t *testing.T) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,10 +42,6 @@ type (
|
|||
NodeIDs []NodeID
|
||||
)
|
||||
|
||||
func (n NodeIDs) Len() int { return len(n) }
|
||||
func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] }
|
||||
func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] }
|
||||
|
||||
func (id NodeID) StableID() tailcfg.StableNodeID {
|
||||
return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10))
|
||||
}
|
||||
|
|
@ -160,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool {
|
|||
// Strip invalid DNS characters for givenName comparison
|
||||
normalised := strings.ToLower(node.Hostname)
|
||||
normalised = invalidDNSRegex.ReplaceAllString(normalised, "")
|
||||
|
||||
return node.GivenName == normalised
|
||||
}
|
||||
|
||||
|
|
@ -197,13 +194,7 @@ func (node *Node) IPs() []netip.Addr {
|
|||
|
||||
// HasIP reports if a node has a given IP address.
|
||||
func (node *Node) HasIP(i netip.Addr) bool {
|
||||
for _, ip := range node.IPs() {
|
||||
if ip.Compare(i) == 0 {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return slices.Contains(node.IPs(), i)
|
||||
}
|
||||
|
||||
// IsTagged reports if a device is tagged and therefore should not be treated
|
||||
|
|
@ -243,6 +234,7 @@ func (node *Node) RequestTags() []string {
|
|||
}
|
||||
|
||||
func (node *Node) Prefixes() []netip.Prefix {
|
||||
//nolint:prealloc
|
||||
var addrs []netip.Prefix
|
||||
for _, nodeAddress := range node.IPs() {
|
||||
ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen())
|
||||
|
|
@ -272,6 +264,7 @@ func (node *Node) IsExitNode() bool {
|
|||
}
|
||||
|
||||
func (node *Node) IPsAsString() []string {
|
||||
//nolint:prealloc
|
||||
var ret []string
|
||||
|
||||
for _, ip := range node.IPs() {
|
||||
|
|
@ -355,13 +348,9 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
|
|||
}
|
||||
|
||||
func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool {
|
||||
for _, node := range nodes {
|
||||
if node.NodeKey == nodeKey {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
return slices.ContainsFunc(nodes, func(node *Node) bool {
|
||||
return node.NodeKey == nodeKey
|
||||
})
|
||||
}
|
||||
|
||||
func (node *Node) Proto() *v1.Node {
|
||||
|
|
@ -478,7 +467,7 @@ func (node *Node) IsSubnetRouter() bool {
|
|||
return len(node.SubnetRoutes()) > 0
|
||||
}
|
||||
|
||||
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes
|
||||
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes.
|
||||
func (node *Node) AllApprovedRoutes() []netip.Prefix {
|
||||
return append(node.SubnetRoutes(), node.ExitRoutes()...)
|
||||
}
|
||||
|
|
@ -586,13 +575,16 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
|||
}
|
||||
|
||||
newHostname := strings.ToLower(hostInfo.Hostname)
|
||||
if err := util.ValidateHostname(newHostname); err != nil {
|
||||
|
||||
err := util.ValidateHostname(newHostname)
|
||||
if err != nil {
|
||||
log.Warn().
|
||||
Str("node.id", node.ID.String()).
|
||||
Str("current_hostname", node.Hostname).
|
||||
Str("rejected_hostname", hostInfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Rejecting invalid hostname update from hostinfo")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -684,6 +676,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
|
|||
func (nodes Nodes) DebugString() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range nodes {
|
||||
sb.WriteString(node.DebugString())
|
||||
sb.WriteString("\n")
|
||||
|
|
@ -855,7 +848,7 @@ func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.Peer
|
|||
// GetFQDN returns the fully qualified domain name for the node.
|
||||
func (nv NodeView) GetFQDN(baseDomain string) (string, error) {
|
||||
if !nv.Valid() {
|
||||
return "", errors.New("failed to create valid FQDN: node view is invalid")
|
||||
return "", fmt.Errorf("failed to create valid FQDN: %w", ErrInvalidNodeView)
|
||||
}
|
||||
|
||||
return nv.ж.GetFQDN(baseDomain)
|
||||
|
|
@ -934,11 +927,11 @@ func (nv NodeView) TailscaleUserID() tailcfg.UserID {
|
|||
}
|
||||
|
||||
if nv.IsTagged() {
|
||||
//nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64
|
||||
//nolint:gosec
|
||||
return tailcfg.UserID(int64(TaggedDevices.ID))
|
||||
}
|
||||
|
||||
//nolint:gosec // G115: UserID values are within int64 range
|
||||
//nolint:gosec
|
||||
return tailcfg.UserID(int64(nv.UserID().Get()))
|
||||
}
|
||||
|
||||
|
|
@ -1048,7 +1041,7 @@ func (nv NodeView) TailNode(
|
|||
|
||||
primaryRoutes := primaryRouteFunc(nv.ID())
|
||||
allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes())
|
||||
tsaddr.SortPrefixes(allowedIPs)
|
||||
slices.SortFunc(allowedIPs, netip.Prefix.Compare)
|
||||
|
||||
capMap := tailcfg.NodeCapMap{
|
||||
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
|
||||
|
|
@ -1063,7 +1056,7 @@ func (nv NodeView) TailNode(
|
|||
}
|
||||
|
||||
tNode := tailcfg.Node{
|
||||
//nolint:gosec // G115: NodeID values are within int64 range
|
||||
//nolint:gosec
|
||||
ID: tailcfg.NodeID(nv.ID()),
|
||||
StableID: nv.ID().StableID(),
|
||||
Name: hostname,
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ import (
|
|||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"gorm.io/gorm"
|
||||
"tailscale.com/types/ptr"
|
||||
)
|
||||
|
||||
// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged.
|
||||
|
|
@ -69,7 +68,7 @@ func TestNodeIsTagged(t *testing.T) {
|
|||
{
|
||||
name: "node with user and no tags - not tagged",
|
||||
node: Node{
|
||||
UserID: ptr.To(uint(42)),
|
||||
UserID: new(uint(42)),
|
||||
Tags: []string{},
|
||||
},
|
||||
want: false,
|
||||
|
|
@ -112,7 +111,7 @@ func TestNodeViewIsTagged(t *testing.T) {
|
|||
{
|
||||
name: "user-owned node",
|
||||
node: Node{
|
||||
UserID: ptr.To(uint(1)),
|
||||
UserID: new(uint(1)),
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
|
|
@ -223,7 +222,7 @@ func TestNodeTagsImmutableAfterRegistration(t *testing.T) {
|
|||
// Test that a user-owned node is not tagged
|
||||
userNode := Node{
|
||||
ID: 2,
|
||||
UserID: ptr.To(uint(42)),
|
||||
UserID: new(uint(42)),
|
||||
Tags: []string{},
|
||||
RegisterMethod: util.RegisterMethodOIDC,
|
||||
}
|
||||
|
|
@ -243,7 +242,7 @@ func TestNodeOwnershipModel(t *testing.T) {
|
|||
name: "tagged node has tags, UserID is informational",
|
||||
node: Node{
|
||||
ID: 1,
|
||||
UserID: ptr.To(uint(5)), // "created by" user 5
|
||||
UserID: new(uint(5)), // "created by" user 5
|
||||
Tags: []string{"tag:server"},
|
||||
},
|
||||
wantIsTagged: true,
|
||||
|
|
@ -253,7 +252,7 @@ func TestNodeOwnershipModel(t *testing.T) {
|
|||
name: "user-owned node has no tags",
|
||||
node: Node{
|
||||
ID: 2,
|
||||
UserID: ptr.To(uint(5)),
|
||||
UserID: new(uint(5)),
|
||||
Tags: []string{},
|
||||
},
|
||||
wantIsTagged: false,
|
||||
|
|
@ -265,7 +264,7 @@ func TestNodeOwnershipModel(t *testing.T) {
|
|||
name: "node with only authkey tags - not tagged (tags should be copied)",
|
||||
node: Node{
|
||||
ID: 3,
|
||||
UserID: ptr.To(uint(5)), // "created by" user 5
|
||||
UserID: new(uint(5)), // "created by" user 5
|
||||
AuthKey: &PreAuthKey{
|
||||
Tags: []string{"tag:database"},
|
||||
},
|
||||
|
|
|
|||
|
|
@ -407,7 +407,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
|||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
|
|
@ -491,7 +491,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
|||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "server-北京-01",
|
||||
Hostname: "server-北京-01", //nolint:gosmopolitan
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
|
|
@ -505,7 +505,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
|||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "我的电脑",
|
||||
Hostname: "我的电脑", //nolint:gosmopolitan
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
|
|
@ -533,7 +533,7 @@ func TestApplyHostnameFromHostInfo(t *testing.T) {
|
|||
Hostname: "valid-hostname",
|
||||
},
|
||||
change: &tailcfg.Hostinfo{
|
||||
Hostname: "测试💻机器",
|
||||
Hostname: "测试💻机器", //nolint:gosmopolitan
|
||||
},
|
||||
want: Node{
|
||||
GivenName: "valid-hostname",
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey {
|
|||
return &protoKey
|
||||
}
|
||||
|
||||
// canUsePreAuthKey checks if a pre auth key can be used.
|
||||
// Validate checks if a pre auth key can be used.
|
||||
func (pak *PreAuthKey) Validate() error {
|
||||
if pak == nil {
|
||||
return PAKError("invalid authkey")
|
||||
|
|
@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error {
|
|||
if pak.Expiration != nil {
|
||||
return *pak.Expiration
|
||||
}
|
||||
|
||||
return time.Time{}
|
||||
}()).
|
||||
Time("now", time.Now()).
|
||||
|
|
|
|||
|
|
@ -110,8 +110,7 @@ func TestCanUsePreAuthKey(t *testing.T) {
|
|||
if err == nil {
|
||||
t.Errorf("expected error but got none")
|
||||
} else {
|
||||
var httpErr PAKError
|
||||
ok := errors.As(err, &httpErr)
|
||||
httpErr, ok := errors.AsType[PAKError](err)
|
||||
if !ok {
|
||||
t.Errorf("expected HTTPError but got %T", err)
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import (
|
|||
"cmp"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/mail"
|
||||
"net/url"
|
||||
|
|
@ -18,6 +19,9 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// ErrCannotParseBool is returned when a value cannot be parsed as a boolean.
|
||||
var ErrCannotParseBool = errors.New("could not parse value as boolean")
|
||||
|
||||
type UserID uint64
|
||||
|
||||
type Users []User
|
||||
|
|
@ -40,9 +44,11 @@ var TaggedDevices = User{
|
|||
func (u Users) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[ ")
|
||||
|
||||
for _, user := range u {
|
||||
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
|
||||
}
|
||||
|
||||
sb.WriteString(" ]")
|
||||
|
||||
return sb.String()
|
||||
|
|
@ -89,11 +95,12 @@ func (u *User) StringID() string {
|
|||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strconv.FormatUint(uint64(u.ID), 10)
|
||||
}
|
||||
|
||||
// TypedID returns a pointer to the user's ID as a UserID type.
|
||||
// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)).
|
||||
// This is a convenience method to avoid ugly casting like new(types.UserID(user.ID)).
|
||||
func (u *User) TypedID() *UserID {
|
||||
uid := UserID(u.ID)
|
||||
return &uid
|
||||
|
|
@ -148,7 +155,7 @@ func (u UserView) ID() uint {
|
|||
|
||||
func (u *User) TailscaleLogin() tailcfg.Login {
|
||||
return tailcfg.Login{
|
||||
ID: tailcfg.LoginID(u.ID),
|
||||
ID: tailcfg.LoginID(u.ID), //nolint:gosec
|
||||
Provider: u.Provider,
|
||||
LoginName: u.Username(),
|
||||
DisplayName: u.Display(),
|
||||
|
|
@ -194,8 +201,8 @@ func (u *User) Proto() *v1.User {
|
|||
}
|
||||
}
|
||||
|
||||
// JumpCloud returns a JSON where email_verified is returned as a
|
||||
// string "true" or "false" instead of a boolean.
|
||||
// FlexibleBoolean handles JSON where email_verified is returned as a
|
||||
// string "true" or "false" instead of a boolean (e.g., JumpCloud).
|
||||
// This maps bool to a specific type with a custom unmarshaler to
|
||||
// ensure we can decode it from a string.
|
||||
// https://github.com/juanfont/headscale/issues/2293
|
||||
|
|
@ -203,6 +210,7 @@ type FlexibleBoolean bool
|
|||
|
||||
func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
|
||||
var val any
|
||||
|
||||
err := json.Unmarshal(data, &val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not unmarshal data: %w", err)
|
||||
|
|
@ -216,10 +224,11 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("could not parse %s as boolean: %w", v, err)
|
||||
}
|
||||
|
||||
*bit = FlexibleBoolean(pv)
|
||||
|
||||
default:
|
||||
return fmt.Errorf("could not parse %v as boolean", v)
|
||||
return fmt.Errorf("%w: %v", ErrCannotParseBool, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
@ -253,9 +262,11 @@ func (c *OIDCClaims) Identifier() string {
|
|||
if c.Iss == "" && c.Sub == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if c.Iss == "" {
|
||||
return CleanIdentifier(c.Sub)
|
||||
}
|
||||
|
||||
if c.Sub == "" {
|
||||
return CleanIdentifier(c.Iss)
|
||||
}
|
||||
|
|
@ -266,8 +277,10 @@ func (c *OIDCClaims) Identifier() string {
|
|||
|
||||
var result string
|
||||
// Try to parse as URL to handle URL joining correctly
|
||||
//nolint:noinlineerr
|
||||
if u, err := url.Parse(issuer); err == nil && u.Scheme != "" {
|
||||
// For URLs, use proper URL path joining
|
||||
//nolint:noinlineerr
|
||||
if joined, err := url.JoinPath(issuer, subject); err == nil {
|
||||
result = joined
|
||||
}
|
||||
|
|
@ -340,6 +353,7 @@ func CleanIdentifier(identifier string) string {
|
|||
cleanParts = append(cleanParts, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
|
@ -382,6 +396,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) {
|
|||
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
|
||||
identifier = "/" + identifier
|
||||
}
|
||||
|
||||
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
|
||||
u.DisplayName = claims.Name
|
||||
u.ProfilePicURL = claims.ProfilePictureURL
|
||||
|
|
|
|||
|
|
@ -66,10 +66,13 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got OIDCClaims
|
||||
if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil {
|
||||
|
||||
err := json.Unmarshal([]byte(tt.jsonstr), &got)
|
||||
if err != nil {
|
||||
t.Errorf("UnmarshallOIDCClaims() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -190,6 +193,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) {
|
|||
}
|
||||
result := claims.Identifier()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -282,6 +286,7 @@ func TestCleanIdentifier(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CleanIdentifier(tt.identifier)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -479,7 +484,9 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var got OIDCClaims
|
||||
if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil {
|
||||
|
||||
err := json.Unmarshal([]byte(tt.jsonstr), &got)
|
||||
if err != nil {
|
||||
t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
|
@ -487,6 +494,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
|||
var user User
|
||||
|
||||
user.FromClaim(&got, tt.emailVerifiedRequired)
|
||||
|
||||
if diff := cmp.Diff(user, tt.want); diff != "" {
|
||||
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,9 +38,7 @@ func (v *VersionInfo) String() string {
|
|||
return sb.String()
|
||||
}
|
||||
|
||||
var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) {
|
||||
return debug.ReadBuildInfo()
|
||||
})
|
||||
var buildInfo = sync.OnceValues(debug.ReadBuildInfo)
|
||||
|
||||
var GetVersionInfo = sync.OnceValue(func() *VersionInfo {
|
||||
info := &VersionInfo{
|
||||
|
|
|
|||
|
|
@ -20,12 +20,30 @@ const (
|
|||
|
||||
// value related to RFC 1123 and 952.
|
||||
LabelHostnameLength = 63
|
||||
|
||||
// minNameLength is the minimum length for usernames and hostnames.
|
||||
minNameLength = 2
|
||||
)
|
||||
|
||||
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
|
||||
|
||||
var ErrInvalidHostName = errors.New("invalid hostname")
|
||||
|
||||
// Sentinel errors for username validation.
|
||||
var (
|
||||
ErrUsernameTooShort = errors.New("username must be at least 2 characters long")
|
||||
ErrUsernameMustStartLetter = errors.New("username must start with a letter")
|
||||
ErrUsernameTooManyAt = errors.New("username cannot contain more than one '@'")
|
||||
ErrUsernameInvalidChar = errors.New("username contains invalid character")
|
||||
)
|
||||
|
||||
// Sentinel errors for hostname validation.
|
||||
var (
|
||||
ErrHostnameTooShort = errors.New("hostname too short, must be at least 2 characters")
|
||||
ErrHostnameHyphenEnds = errors.New("hostname cannot start or end with a hyphen")
|
||||
ErrHostnameDotEnds = errors.New("hostname cannot start or end with a dot")
|
||||
)
|
||||
|
||||
// ValidateUsername checks if a username is valid.
|
||||
// It must be at least 2 characters long, start with a letter, and contain
|
||||
// only letters, numbers, hyphens, dots, and underscores.
|
||||
|
|
@ -33,13 +51,13 @@ var ErrInvalidHostName = errors.New("invalid hostname")
|
|||
// It cannot contain invalid characters.
|
||||
func ValidateUsername(username string) error {
|
||||
// Ensure the username meets the minimum length requirement
|
||||
if len(username) < 2 {
|
||||
return errors.New("username must be at least 2 characters long")
|
||||
if len(username) < minNameLength {
|
||||
return ErrUsernameTooShort
|
||||
}
|
||||
|
||||
// Ensure the username starts with a letter
|
||||
if !unicode.IsLetter(rune(username[0])) {
|
||||
return errors.New("username must start with a letter")
|
||||
return ErrUsernameMustStartLetter
|
||||
}
|
||||
|
||||
atCount := 0
|
||||
|
|
@ -55,10 +73,10 @@ func ValidateUsername(username string) error {
|
|||
case char == '@':
|
||||
atCount++
|
||||
if atCount > 1 {
|
||||
return errors.New("username cannot contain more than one '@'")
|
||||
return ErrUsernameTooManyAt
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("username contains invalid character: '%c'", char)
|
||||
return fmt.Errorf("%w: '%c'", ErrUsernameInvalidChar, char)
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -69,11 +87,8 @@ func ValidateUsername(username string) error {
|
|||
// This function does NOT modify the input - it only validates.
|
||||
// The hostname must already be lowercase and contain only valid characters.
|
||||
func ValidateHostname(name string) error {
|
||||
if len(name) < 2 {
|
||||
return fmt.Errorf(
|
||||
"hostname %q is too short, must be at least 2 characters",
|
||||
name,
|
||||
)
|
||||
if len(name) < minNameLength {
|
||||
return fmt.Errorf("%w: %q", ErrHostnameTooShort, name)
|
||||
}
|
||||
if len(name) > LabelHostnameLength {
|
||||
return fmt.Errorf(
|
||||
|
|
@ -90,17 +105,11 @@ func ValidateHostname(name string) error {
|
|||
}
|
||||
|
||||
if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") {
|
||||
return fmt.Errorf(
|
||||
"hostname %q cannot start or end with a hyphen",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameHyphenEnds, name)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") {
|
||||
return fmt.Errorf(
|
||||
"hostname %q cannot start or end with a dot",
|
||||
name,
|
||||
)
|
||||
return fmt.Errorf("%w: %q", ErrHostnameDotEnds, name)
|
||||
}
|
||||
|
||||
if invalidDNSRegex.MatchString(name) {
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) {
|
|||
t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && got != tt.want {
|
||||
t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
|
@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) {
|
|||
t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && tt.errorContains != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.errorContains) {
|
||||
t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains)
|
||||
|
|
|
|||
|
|
@ -14,11 +14,14 @@ func YesNo(msg string) bool {
|
|||
fmt.Fprint(os.Stderr, msg+" [y/n] ")
|
||||
|
||||
var resp string
|
||||
fmt.Scanln(&resp)
|
||||
|
||||
_, _ = fmt.Scanln(&resp)
|
||||
|
||||
resp = strings.ToLower(resp)
|
||||
switch resp {
|
||||
case "y", "yes", "sure":
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,7 +86,8 @@ func TestYesNo(t *testing.T) {
|
|||
// Write test input
|
||||
go func() {
|
||||
defer w.Close()
|
||||
w.WriteString(tt.input)
|
||||
|
||||
_, _ = w.WriteString(tt.input)
|
||||
}()
|
||||
|
||||
// Call the function
|
||||
|
|
@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) {
|
|||
// Restore stdin and stderr
|
||||
os.Stdin = oldStdin
|
||||
os.Stderr = oldStderr
|
||||
|
||||
stderrW.Close()
|
||||
|
||||
// Check the result
|
||||
|
|
@ -104,10 +106,12 @@ func TestYesNo(t *testing.T) {
|
|||
|
||||
// Check that the prompt was written to stderr
|
||||
var stderrBuf bytes.Buffer
|
||||
io.Copy(&stderrBuf, stderrR)
|
||||
|
||||
_, _ = io.Copy(&stderrBuf, stderrR)
|
||||
stderrR.Close()
|
||||
|
||||
expectedPrompt := "Test question [y/n] "
|
||||
|
||||
actualPrompt := stderrBuf.String()
|
||||
if actualPrompt != expectedPrompt {
|
||||
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
|
||||
|
|
@ -130,7 +134,8 @@ func TestYesNoPromptMessage(t *testing.T) {
|
|||
// Write test input
|
||||
go func() {
|
||||
defer w.Close()
|
||||
w.WriteString("n\n")
|
||||
|
||||
_, _ = w.WriteString("n\n")
|
||||
}()
|
||||
|
||||
// Call the function with a custom message
|
||||
|
|
@ -140,14 +145,17 @@ func TestYesNoPromptMessage(t *testing.T) {
|
|||
// Restore stdin and stderr
|
||||
os.Stdin = oldStdin
|
||||
os.Stderr = oldStderr
|
||||
|
||||
stderrW.Close()
|
||||
|
||||
// Check that the custom message was included in the prompt
|
||||
var stderrBuf bytes.Buffer
|
||||
io.Copy(&stderrBuf, stderrR)
|
||||
|
||||
_, _ = io.Copy(&stderrBuf, stderrR)
|
||||
stderrR.Close()
|
||||
|
||||
expectedPrompt := customMessage + " [y/n] "
|
||||
|
||||
actualPrompt := stderrBuf.String()
|
||||
if actualPrompt != expectedPrompt {
|
||||
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
|
||||
|
|
@ -186,7 +194,8 @@ func TestYesNoCaseInsensitive(t *testing.T) {
|
|||
// Write test input
|
||||
go func() {
|
||||
defer w.Close()
|
||||
w.WriteString(tc.input)
|
||||
|
||||
_, _ = w.WriteString(tc.input)
|
||||
}()
|
||||
|
||||
// Call the function
|
||||
|
|
@ -195,10 +204,11 @@ func TestYesNoCaseInsensitive(t *testing.T) {
|
|||
// Restore stdin and stderr
|
||||
os.Stdin = oldStdin
|
||||
os.Stderr = oldStderr
|
||||
|
||||
stderrW.Close()
|
||||
|
||||
// Drain stderr
|
||||
io.Copy(io.Discard, stderrR)
|
||||
_, _ = io.Copy(io.Discard, stderrR)
|
||||
stderrR.Close()
|
||||
|
||||
if result != tc.expected {
|
||||
|
|
|
|||
|
|
@ -9,6 +9,9 @@ import (
|
|||
"tailscale.com/tailcfg"
|
||||
)
|
||||
|
||||
// invalidStringRandomLength is the length of random bytes for invalid string generation.
|
||||
const invalidStringRandomLength = 8
|
||||
|
||||
// GenerateRandomBytes returns securely generated random bytes.
|
||||
// It will return an error if the system's secure random
|
||||
// number generator fails to function correctly, in which
|
||||
|
|
@ -33,6 +36,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) {
|
|||
b, err := GenerateRandomBytes(n)
|
||||
|
||||
uenc := base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
return uenc[:n], err
|
||||
}
|
||||
|
||||
|
|
@ -67,7 +71,7 @@ func MustGenerateRandomStringDNSSafe(size int) string {
|
|||
}
|
||||
|
||||
func InvalidString() string {
|
||||
hash, _ := GenerateRandomStringDNSSafe(8)
|
||||
hash, _ := GenerateRandomStringDNSSafe(invalidStringRandomLength)
|
||||
return "invalid-" + hash
|
||||
}
|
||||
|
||||
|
|
@ -99,6 +103,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string {
|
|||
DstIPs: %v
|
||||
}
|
||||
`, rule.SrcIPs, rule.DstPorts))
|
||||
|
||||
if index < len(rules)-1 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue