mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-23 02:24:10 +00:00
all: apply golangci-lint auto-fixes
Apply auto-fixes from golangci-lint for the following linters: - wsl_v5: whitespace formatting and blank line adjustments - godot: add periods to comment sentences - nlreturn: add newlines before return statements - perfsprint: optimize fmt.Sprintf to more efficient alternatives Also add missing imports (errors, encoding/hex) where auto-fix added new code patterns that require them.
This commit is contained in:
parent
3675b65504
commit
ad7669a2d4
93 changed files with 1262 additions and 155 deletions
|
|
@ -73,6 +73,7 @@ func mockOIDC() error {
|
|||
}
|
||||
|
||||
var users []mockoidc.MockUser
|
||||
|
||||
err := json.Unmarshal([]byte(userStr), &users)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unmarshalling users: %w", err)
|
||||
|
|
@ -137,6 +138,7 @@ func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser
|
|||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
log.Info().Msgf("Request: %+v", r)
|
||||
h.ServeHTTP(w, r)
|
||||
|
||||
if r.Response != nil {
|
||||
log.Info().Msgf("Response: %+v", r.Response)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,13 +29,16 @@ func init() {
|
|||
if err := setPolicy.MarkFlagRequired("file"); err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running")
|
||||
policyCmd.AddCommand(setPolicy)
|
||||
|
||||
checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format")
|
||||
|
||||
if err := checkPolicy.MarkFlagRequired("file"); err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
|
||||
policyCmd.AddCommand(checkPolicy)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -80,6 +80,7 @@ func initConfig() {
|
|||
Repository: "headscale",
|
||||
TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }),
|
||||
}
|
||||
|
||||
res, err := latest.Check(githubTag, versionInfo.Version)
|
||||
if err == nil && res.Outdated {
|
||||
//nolint
|
||||
|
|
@ -101,6 +102,7 @@ func isPreReleaseVersion(version string) bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ func usernameAndIDFlag(cmd *cobra.Command) {
|
|||
// If both are empty, it will exit the program with an error.
|
||||
func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) {
|
||||
username, _ := cmd.Flags().GetString("name")
|
||||
|
||||
identifier, _ := cmd.Flags().GetInt64("identifier")
|
||||
if username == "" && identifier < 0 {
|
||||
err := errors.New("--name or --identifier flag is required")
|
||||
|
|
|
|||
|
|
@ -69,8 +69,10 @@ func killTestContainers(ctx context.Context) error {
|
|||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, cont := range containers {
|
||||
shouldRemove := false
|
||||
|
||||
for _, name := range cont.Names {
|
||||
if strings.Contains(name, "headscale-test-suite") ||
|
||||
strings.Contains(name, "hs-") ||
|
||||
|
|
@ -259,8 +261,10 @@ func cleanOldImages(ctx context.Context) error {
|
|||
}
|
||||
|
||||
removed := 0
|
||||
|
||||
for _, img := range images {
|
||||
shouldRemove := false
|
||||
|
||||
for _, tag := range img.RepoTags {
|
||||
if strings.Contains(tag, "hs-") ||
|
||||
strings.Contains(tag, "headscale-integration") ||
|
||||
|
|
@ -302,6 +306,7 @@ func cleanCacheVolume(ctx context.Context) error {
|
|||
defer cli.Close()
|
||||
|
||||
volumeName := "hs-integration-go-cache"
|
||||
|
||||
err = cli.VolumeRemove(ctx, volumeName, true)
|
||||
if err != nil {
|
||||
if errdefs.IsNotFound(err) {
|
||||
|
|
|
|||
|
|
@ -60,6 +60,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
if config.Verbose {
|
||||
log.Printf("Running pre-test cleanup...")
|
||||
}
|
||||
|
||||
if err := cleanupBeforeTest(ctx); err != nil && config.Verbose {
|
||||
log.Printf("Warning: pre-test cleanup failed: %v", err)
|
||||
}
|
||||
|
|
@ -95,13 +96,16 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
|
||||
// Start stats collection for container resource monitoring (if enabled)
|
||||
var statsCollector *StatsCollector
|
||||
|
||||
if config.Stats {
|
||||
var err error
|
||||
|
||||
statsCollector, err = NewStatsCollector()
|
||||
if err != nil {
|
||||
if config.Verbose {
|
||||
log.Printf("Warning: failed to create stats collector: %v", err)
|
||||
}
|
||||
|
||||
statsCollector = nil
|
||||
}
|
||||
|
||||
|
|
@ -140,6 +144,7 @@ func runTestContainer(ctx context.Context, config *RunConfig) error {
|
|||
if len(violations) > 0 {
|
||||
log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:")
|
||||
log.Printf("=================================")
|
||||
|
||||
for _, violation := range violations {
|
||||
log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB",
|
||||
violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB)
|
||||
|
|
@ -347,6 +352,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
maxWaitTime := 10 * time.Second
|
||||
checkInterval := 500 * time.Millisecond
|
||||
timeout := time.After(maxWaitTime)
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
|
|
@ -356,6 +362,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
if verbose {
|
||||
log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
allFinalized := true
|
||||
|
|
@ -366,12 +373,14 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
if verbose {
|
||||
log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err)
|
||||
}
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if container is in a final state
|
||||
if !isContainerFinalized(inspect.State) {
|
||||
allFinalized = false
|
||||
|
||||
if verbose {
|
||||
log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status)
|
||||
}
|
||||
|
|
@ -384,6 +393,7 @@ func waitForContainerFinalization(ctx context.Context, cli *client.Client, testC
|
|||
if verbose {
|
||||
log.Printf("All test containers finalized, ready for artifact extraction")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
@ -403,10 +413,12 @@ func findProjectRoot(startPath string) string {
|
|||
if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil {
|
||||
return current
|
||||
}
|
||||
|
||||
parent := filepath.Dir(current)
|
||||
if parent == current {
|
||||
return startPath
|
||||
}
|
||||
|
||||
current = parent
|
||||
}
|
||||
}
|
||||
|
|
@ -416,6 +428,7 @@ func boolToInt(b bool) int {
|
|||
if b {
|
||||
return 1
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
|
|
@ -435,6 +448,7 @@ func createDockerClient() (*client.Client, error) {
|
|||
}
|
||||
|
||||
var clientOpts []client.Opt
|
||||
|
||||
clientOpts = append(clientOpts, client.WithAPIVersionNegotiation())
|
||||
|
||||
if contextInfo != nil {
|
||||
|
|
@ -444,6 +458,7 @@ func createDockerClient() (*client.Client, error) {
|
|||
if runConfig.Verbose {
|
||||
log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host)
|
||||
}
|
||||
|
||||
clientOpts = append(clientOpts, client.WithHost(host))
|
||||
}
|
||||
}
|
||||
|
|
@ -460,6 +475,7 @@ func createDockerClient() (*client.Client, error) {
|
|||
// getCurrentDockerContext retrieves the current Docker context information.
|
||||
func getCurrentDockerContext() (*DockerContext, error) {
|
||||
cmd := exec.Command("docker", "context", "inspect")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get docker context: %w", err)
|
||||
|
|
@ -491,6 +507,7 @@ func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageNa
|
|||
if client.IsErrNotFound(err) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err)
|
||||
}
|
||||
|
||||
|
|
@ -509,6 +526,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
|||
if verbose {
|
||||
log.Printf("Image %s is available locally", imageName)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -533,6 +551,7 @@ func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName str
|
|||
if err != nil {
|
||||
return fmt.Errorf("failed to read pull output: %w", err)
|
||||
}
|
||||
|
||||
log.Printf("Image %s pulled successfully", imageName)
|
||||
}
|
||||
|
||||
|
|
@ -547,9 +566,11 @@ func listControlFiles(logsDir string) {
|
|||
return
|
||||
}
|
||||
|
||||
var logFiles []string
|
||||
var dataFiles []string
|
||||
var dataDirs []string
|
||||
var (
|
||||
logFiles []string
|
||||
dataFiles []string
|
||||
dataDirs []string
|
||||
)
|
||||
|
||||
for _, entry := range entries {
|
||||
name := entry.Name()
|
||||
|
|
@ -578,6 +599,7 @@ func listControlFiles(logsDir string) {
|
|||
|
||||
if len(logFiles) > 0 {
|
||||
log.Printf("Headscale logs:")
|
||||
|
||||
for _, file := range logFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
|
|
@ -585,9 +607,11 @@ func listControlFiles(logsDir string) {
|
|||
|
||||
if len(dataFiles) > 0 || len(dataDirs) > 0 {
|
||||
log.Printf("Headscale data:")
|
||||
|
||||
for _, file := range dataFiles {
|
||||
log.Printf(" %s", file)
|
||||
}
|
||||
|
||||
for _, dir := range dataDirs {
|
||||
log.Printf(" %s/", dir)
|
||||
}
|
||||
|
|
@ -612,6 +636,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
|||
currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose)
|
||||
|
||||
extractedCount := 0
|
||||
|
||||
for _, cont := range currentTestContainers {
|
||||
// Extract container logs and tar files
|
||||
if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil {
|
||||
|
|
@ -622,6 +647,7 @@ func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDi
|
|||
if verbose {
|
||||
log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12])
|
||||
}
|
||||
|
||||
extractedCount++
|
||||
}
|
||||
}
|
||||
|
|
@ -645,11 +671,13 @@ func getCurrentTestContainers(containers []container.Summary, testContainerID st
|
|||
|
||||
// Find the test container to get its run ID label
|
||||
var runID string
|
||||
|
||||
for _, cont := range containers {
|
||||
if cont.ID == testContainerID {
|
||||
if cont.Labels != nil {
|
||||
runID = cont.Labels["hi.run-id"]
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -266,6 +266,7 @@ func checkGoInstallation() DoctorResult {
|
|||
}
|
||||
|
||||
cmd := exec.Command("go", "version")
|
||||
|
||||
output, err := cmd.Output()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -287,6 +288,7 @@ func checkGoInstallation() DoctorResult {
|
|||
// checkGitRepository verifies we're in a git repository.
|
||||
func checkGitRepository() DoctorResult {
|
||||
cmd := exec.Command("git", "rev-parse", "--git-dir")
|
||||
|
||||
err := cmd.Run()
|
||||
if err != nil {
|
||||
return DoctorResult{
|
||||
|
|
@ -316,6 +318,7 @@ func checkRequiredFiles() DoctorResult {
|
|||
}
|
||||
|
||||
var missingFiles []string
|
||||
|
||||
for _, file := range requiredFiles {
|
||||
cmd := exec.Command("test", "-e", file)
|
||||
if err := cmd.Run(); err != nil {
|
||||
|
|
@ -350,6 +353,7 @@ func displayDoctorResults(results []DoctorResult) {
|
|||
|
||||
for _, result := range results {
|
||||
var icon string
|
||||
|
||||
switch result.Status {
|
||||
case "PASS":
|
||||
icon = "✅"
|
||||
|
|
|
|||
|
|
@ -82,9 +82,11 @@ func cleanAll(ctx context.Context) error {
|
|||
if err := killTestContainers(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pruneDockerNetworks(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := cleanOldImages(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ func runIntegrationTest(env *command.Env) error {
|
|||
if runConfig.Verbose {
|
||||
log.Printf("Running pre-flight system checks...")
|
||||
}
|
||||
|
||||
if err := runDoctorCheck(env.Context()); err != nil {
|
||||
return fmt.Errorf("pre-flight checks failed: %w", err)
|
||||
}
|
||||
|
|
@ -94,8 +95,10 @@ func detectGoVersion() string {
|
|||
|
||||
// splitLines splits a string into lines without using strings.Split.
|
||||
func splitLines(s string) []string {
|
||||
var lines []string
|
||||
var current string
|
||||
var (
|
||||
lines []string
|
||||
current string
|
||||
)
|
||||
|
||||
for _, char := range s {
|
||||
if char == '\n' {
|
||||
|
|
|
|||
|
|
@ -71,10 +71,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
|||
|
||||
// Start monitoring existing containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorExistingContainers(ctx, runID, verbose)
|
||||
|
||||
// Start Docker events monitoring for new containers
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.monitorDockerEvents(ctx, runID, verbose)
|
||||
|
||||
if verbose {
|
||||
|
|
@ -88,10 +90,12 @@ func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, ver
|
|||
func (sc *StatsCollector) StopCollection() {
|
||||
// Check if already stopped without holding lock
|
||||
sc.mutex.RLock()
|
||||
|
||||
if !sc.collectionStarted {
|
||||
sc.mutex.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
// Signal stop to all goroutines
|
||||
|
|
@ -115,6 +119,7 @@ func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID s
|
|||
if verbose {
|
||||
log.Printf("Failed to list existing containers: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -168,6 +173,7 @@ func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string,
|
|||
if verbose {
|
||||
log.Printf("Error in Docker events stream: %v", err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
@ -214,6 +220,7 @@ func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerI
|
|||
}
|
||||
|
||||
sc.wg.Add(1)
|
||||
|
||||
go sc.collectStatsForContainer(ctx, containerID, verbose)
|
||||
}
|
||||
|
||||
|
|
@ -227,11 +234,13 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
if verbose {
|
||||
log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
defer statsResponse.Body.Close()
|
||||
|
||||
decoder := json.NewDecoder(statsResponse.Body)
|
||||
|
||||
var prevStats *container.Stats
|
||||
|
||||
for {
|
||||
|
|
@ -247,6 +256,7 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
if err.Error() != "EOF" && verbose {
|
||||
log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -262,8 +272,10 @@ func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containe
|
|||
// Store the sample (skip first sample since CPU calculation needs previous stats)
|
||||
if prevStats != nil {
|
||||
// Get container stats reference without holding the main mutex
|
||||
var containerStats *ContainerStats
|
||||
var exists bool
|
||||
var (
|
||||
containerStats *ContainerStats
|
||||
exists bool
|
||||
)
|
||||
|
||||
sc.mutex.RLock()
|
||||
containerStats, exists = sc.containers[containerID]
|
||||
|
|
@ -332,10 +344,12 @@ type StatsSummary struct {
|
|||
func (sc *StatsCollector) GetSummary() []ContainerStatsSummary {
|
||||
// Take snapshot of container references without holding main lock long
|
||||
sc.mutex.RLock()
|
||||
|
||||
containerRefs := make([]*ContainerStats, 0, len(sc.containers))
|
||||
for _, containerStats := range sc.containers {
|
||||
containerRefs = append(containerRefs, containerStats)
|
||||
}
|
||||
|
||||
sc.mutex.RUnlock()
|
||||
|
||||
summaries := make([]ContainerStatsSummary, 0, len(containerRefs))
|
||||
|
|
@ -393,9 +407,11 @@ func calculateStatsSummary(values []float64) StatsSummary {
|
|||
if value < min {
|
||||
min = value
|
||||
}
|
||||
|
||||
if value > max {
|
||||
max = value
|
||||
}
|
||||
|
||||
sum += value
|
||||
}
|
||||
|
||||
|
|
@ -435,6 +451,7 @@ func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []Memo
|
|||
}
|
||||
|
||||
summaries := sc.GetSummary()
|
||||
|
||||
var violations []MemoryViolation
|
||||
|
||||
for _, summary := range summaries {
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package main
|
|||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
|
|
@ -40,7 +41,7 @@ func main() {
|
|||
// runIntegrationTest executes the integration test workflow.
|
||||
func runOnline(env *command.Env) error {
|
||||
if mapConfig.Directory == "" {
|
||||
return fmt.Errorf("directory is required")
|
||||
return errors.New("directory is required")
|
||||
}
|
||||
|
||||
resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory)
|
||||
|
|
@ -57,5 +58,6 @@ func runOnline(env *command.Env) error {
|
|||
|
||||
os.Stderr.Write(out)
|
||||
os.Stderr.Write([]byte("\n"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -142,6 +142,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
if !ok {
|
||||
log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed")
|
||||
log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -157,10 +158,12 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
app.ephemeralGC = ephemeralGC
|
||||
|
||||
var authProvider AuthProvider
|
||||
|
||||
authProvider = NewAuthProviderWeb(cfg.ServerURL)
|
||||
if cfg.OIDC.Issuer != "" {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
oidcProvider, err := NewAuthProviderOIDC(
|
||||
ctx,
|
||||
&app,
|
||||
|
|
@ -177,6 +180,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
|
|||
authProvider = oidcProvider
|
||||
}
|
||||
}
|
||||
|
||||
app.authProvider = authProvider
|
||||
|
||||
if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS
|
||||
|
|
@ -251,9 +255,11 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
lastExpiryCheck := time.Unix(0, 0)
|
||||
|
||||
derpTickerChan := make(<-chan time.Time)
|
||||
|
||||
if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 {
|
||||
derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency)
|
||||
defer derpTicker.Stop()
|
||||
|
||||
derpTickerChan = derpTicker.C
|
||||
}
|
||||
|
||||
|
|
@ -271,8 +277,10 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
return
|
||||
|
||||
case <-expireTicker.C:
|
||||
var expiredNodeChanges []change.Change
|
||||
var changed bool
|
||||
var (
|
||||
expiredNodeChanges []change.Change
|
||||
changed bool
|
||||
)
|
||||
|
||||
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
|
||||
|
||||
|
|
@ -287,11 +295,13 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
|
||||
case <-derpTickerChan:
|
||||
log.Info().Msg("Fetching DERPMap updates")
|
||||
|
||||
derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) {
|
||||
derpMap, err := derp.GetDERPMap(h.cfg.DERP)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion {
|
||||
region, _ := h.DERPServer.GenerateRegion()
|
||||
derpMap.Regions[region.RegionID] = ®ion
|
||||
|
|
@ -303,6 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
log.Error().Err(err).Msg("failed to build new DERPMap, retrying later")
|
||||
continue
|
||||
}
|
||||
|
||||
h.state.SetDERPMap(derpMap)
|
||||
|
||||
h.Change(change.DERPMap())
|
||||
|
|
@ -311,6 +322,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
|
|||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = records
|
||||
|
||||
h.Change(change.ExtraRecords())
|
||||
|
|
@ -390,6 +402,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler
|
|||
|
||||
writeUnauthorized := func(statusCode int) {
|
||||
writer.WriteHeader(statusCode)
|
||||
|
||||
if _, err := writer.Write([]byte("Unauthorized")); err != nil {
|
||||
log.Error().Err(err).Msg("writing HTTP response failed")
|
||||
}
|
||||
|
|
@ -486,6 +499,7 @@ func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router {
|
|||
// Serve launches the HTTP and gRPC server service Headscale and the API.
|
||||
func (h *Headscale) Serve() error {
|
||||
var err error
|
||||
|
||||
capver.CanOldCodeBeCleanedUp()
|
||||
|
||||
if profilingEnabled {
|
||||
|
|
@ -512,6 +526,7 @@ func (h *Headscale) Serve() error {
|
|||
Msg("Clients with a lower minimum version will be rejected")
|
||||
|
||||
h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state)
|
||||
|
||||
h.mapBatcher.Start()
|
||||
defer h.mapBatcher.Close()
|
||||
|
||||
|
|
@ -545,6 +560,7 @@ func (h *Headscale) Serve() error {
|
|||
// around between restarts, they will reconnect and the GC will
|
||||
// be cancelled.
|
||||
go h.ephemeralGC.Start()
|
||||
|
||||
ephmNodes := h.state.ListEphemeralNodes()
|
||||
for _, node := range ephmNodes.All() {
|
||||
h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout)
|
||||
|
|
@ -555,7 +571,9 @@ func (h *Headscale) Serve() error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("setting up extrarecord manager: %w", err)
|
||||
}
|
||||
|
||||
h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records()
|
||||
|
||||
go h.extraRecordMan.Run()
|
||||
defer h.extraRecordMan.Close()
|
||||
}
|
||||
|
|
@ -564,6 +582,7 @@ func (h *Headscale) Serve() error {
|
|||
// records updates
|
||||
scheduleCtx, scheduleCancel := context.WithCancel(context.Background())
|
||||
defer scheduleCancel()
|
||||
|
||||
go h.scheduledTasks(scheduleCtx)
|
||||
|
||||
if zl.GlobalLevel() == zl.TraceLevel {
|
||||
|
|
@ -751,7 +770,6 @@ func (h *Headscale) Serve() error {
|
|||
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
|
||||
}
|
||||
|
||||
|
||||
var tailsqlContext context.Context
|
||||
if tailsqlEnabled {
|
||||
if h.cfg.Database.Type != types.DatabaseSqlite {
|
||||
|
|
@ -863,6 +881,7 @@ func (h *Headscale) Serve() error {
|
|||
|
||||
// Close state connections
|
||||
info("closing state and database")
|
||||
|
||||
err = h.state.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("failed to close state")
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ func (h *Headscale) handleRegister(
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("handling logout: %w", err)
|
||||
}
|
||||
|
||||
if resp != nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
|
@ -131,7 +132,7 @@ func (h *Headscale) handleRegister(
|
|||
}
|
||||
|
||||
// handleLogout checks if the [tailcfg.RegisterRequest] is a
|
||||
// logout attempt from a node. If the node is not attempting to
|
||||
// logout attempt from a node. If the node is not attempting to.
|
||||
func (h *Headscale) handleLogout(
|
||||
node types.NodeView,
|
||||
req tailcfg.RegisterRequest,
|
||||
|
|
@ -158,6 +159,7 @@ func (h *Headscale) handleLogout(
|
|||
Interface("reg.req", req).
|
||||
Bool("unexpected", true).
|
||||
Msg("Node key expired, forcing re-authentication")
|
||||
|
||||
return &tailcfg.RegisterResponse{
|
||||
NodeKeyExpired: true,
|
||||
MachineAuthorized: false,
|
||||
|
|
@ -277,6 +279,7 @@ func (h *Headscale) waitForFollowup(
|
|||
// registration is expired in the cache, instruct the client to try a new registration
|
||||
return h.reqToNewRegisterResponse(req, machineKey)
|
||||
}
|
||||
|
||||
return nodeToRegisterResponse(node.View()), nil
|
||||
}
|
||||
}
|
||||
|
|
@ -342,6 +345,7 @@ func (h *Headscale) handleRegisterWithAuthKey(
|
|||
if errors.Is(err, gorm.ErrRecordNotFound) {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil)
|
||||
}
|
||||
|
||||
if perr, ok := errors.AsType[types.PAKError](err); ok {
|
||||
return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil)
|
||||
}
|
||||
|
|
@ -432,6 +436,7 @@ func (h *Headscale) handleRegisterInteractive(
|
|||
Str("generated.hostname", hostname).
|
||||
Msg("Received registration request with empty hostname, generated default")
|
||||
}
|
||||
|
||||
hostinfo.Hostname = hostname
|
||||
|
||||
nodeToRegister := types.NewRegisterNode(
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package hscontrol
|
|||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
|
@ -16,14 +17,14 @@ import (
|
|||
"tailscale.com/types/key"
|
||||
)
|
||||
|
||||
// Interactive step type constants
|
||||
// Interactive step type constants.
|
||||
const (
|
||||
stepTypeInitialRequest = "initial_request"
|
||||
stepTypeAuthCompletion = "auth_completion"
|
||||
stepTypeFollowupRequest = "followup_request"
|
||||
)
|
||||
|
||||
// interactiveStep defines a step in the interactive authentication workflow
|
||||
// interactiveStep defines a step in the interactive authentication workflow.
|
||||
type interactiveStep struct {
|
||||
stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest
|
||||
expectAuthURL bool
|
||||
|
|
@ -75,6 +76,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -129,6 +131,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -163,6 +166,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify both nodes exist
|
||||
node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
|
||||
|
||||
assert.True(t, found1)
|
||||
assert.True(t, found2)
|
||||
assert.Equal(t, "reusable-node-1", node1.Hostname())
|
||||
|
|
@ -196,6 +200,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -227,6 +232,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// First node should exist, second should not
|
||||
_, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
_, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public())
|
||||
|
||||
assert.True(t, found1)
|
||||
assert.False(t, found2)
|
||||
},
|
||||
|
|
@ -272,6 +278,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -391,6 +398,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -400,8 +408,10 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Wait for node to be available in NodeStore with debug info
|
||||
var attemptCount int
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
attemptCount++
|
||||
|
||||
_, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
if assert.True(c, found, "node should be available in NodeStore") {
|
||||
t.Logf("Node found in NodeStore after %d attempts", attemptCount)
|
||||
|
|
@ -451,6 +461,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -500,6 +511,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -549,25 +561,31 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
// Wait for node to be available in NodeStore
|
||||
var node types.NodeView
|
||||
var found bool
|
||||
var (
|
||||
node types.NodeView
|
||||
found bool
|
||||
)
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
node, found = app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(c, found, "node should be available in NodeStore")
|
||||
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
|
||||
|
||||
if !found {
|
||||
return "", fmt.Errorf("node not found after setup")
|
||||
return "", errors.New("node not found after setup")
|
||||
}
|
||||
|
||||
// Expire the node
|
||||
expiredTime := time.Now().Add(-1 * time.Hour)
|
||||
_, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime)
|
||||
|
||||
return "", err
|
||||
},
|
||||
request: func(_ string) tailcfg.RegisterRequest {
|
||||
|
|
@ -610,6 +628,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -673,6 +692,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// and handleRegister will receive the value when it starts waiting
|
||||
go func() {
|
||||
user := app.state.CreateUserForTest("followup-user")
|
||||
|
||||
node := app.state.CreateNodeForTest(user, "followup-success-node")
|
||||
registered <- node
|
||||
}()
|
||||
|
|
@ -782,6 +802,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -821,6 +842,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -865,6 +887,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -898,6 +921,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -922,6 +946,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found)
|
||||
assert.Equal(t, "tagged-pak-node", node.Hostname())
|
||||
|
||||
if node.AuthKey().Valid() {
|
||||
assert.NotEmpty(t, node.AuthKey().Tags())
|
||||
}
|
||||
|
|
@ -1031,6 +1056,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1047,6 +1073,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak2.Key, nil
|
||||
},
|
||||
request: func(newAuthKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1099,6 +1126,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1161,6 +1189,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1177,6 +1206,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pakRotation.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1226,6 +1256,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1265,6 +1296,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1429,6 +1461,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify custom hostinfo was preserved through interactive workflow
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be found after interactive registration")
|
||||
|
||||
if found {
|
||||
assert.Equal(t, "custom-interactive-node", node.Hostname())
|
||||
assert.Equal(t, "linux", node.Hostinfo().OS())
|
||||
|
|
@ -1455,6 +1488,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1520,6 +1554,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify registration ID was properly generated and used
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be registered after interactive workflow")
|
||||
|
||||
if found {
|
||||
assert.Equal(t, "registration-id-test-node", node.Hostname())
|
||||
assert.Equal(t, "test-os", node.Hostinfo().OS())
|
||||
|
|
@ -1535,6 +1570,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1577,6 +1613,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak.Key, nil
|
||||
},
|
||||
request: func(authKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1632,6 +1669,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1648,6 +1686,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return pak2.Key, nil
|
||||
},
|
||||
request: func(user2AuthKey string) tailcfg.RegisterRequest {
|
||||
|
|
@ -1712,6 +1751,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1838,6 +1878,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -1932,6 +1973,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
},
|
||||
Expiry: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
_, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public())
|
||||
if err != nil {
|
||||
return "", err
|
||||
|
|
@ -2097,6 +2139,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Collect results - at least one should succeed
|
||||
successCount := 0
|
||||
|
||||
for range numConcurrent {
|
||||
select {
|
||||
case err := <-results:
|
||||
|
|
@ -2217,6 +2260,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Should handle nil hostinfo gracefully
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be registered despite nil hostinfo")
|
||||
|
||||
if found {
|
||||
// Should have some default hostname or handle nil gracefully
|
||||
hostname := node.Hostname()
|
||||
|
|
@ -2315,12 +2359,14 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
|
||||
require.NoError(t, err)
|
||||
|
||||
authURL2 := resp2.AuthURL
|
||||
assert.Contains(t, authURL2, "/register/")
|
||||
|
||||
// Both should have different registration IDs
|
||||
regID1, err1 := extractRegistrationIDFromAuthURL(authURL1)
|
||||
regID2, err2 := extractRegistrationIDFromAuthURL(authURL2)
|
||||
|
||||
require.NoError(t, err1)
|
||||
require.NoError(t, err2)
|
||||
assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs")
|
||||
|
|
@ -2328,6 +2374,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Both cache entries should exist simultaneously
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first registration cache entry should exist")
|
||||
assert.True(t, found2, "second registration cache entry should exist")
|
||||
|
||||
|
|
@ -2371,6 +2418,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public())
|
||||
require.NoError(t, err)
|
||||
|
||||
authURL2 := resp2.AuthURL
|
||||
regID2, err := extractRegistrationIDFromAuthURL(authURL2)
|
||||
require.NoError(t, err)
|
||||
|
|
@ -2378,6 +2426,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify both exist
|
||||
_, found1 := app.state.GetRegistrationCacheEntry(regID1)
|
||||
_, found2 := app.state.GetRegistrationCacheEntry(regID2)
|
||||
|
||||
assert.True(t, found1, "first cache entry should exist")
|
||||
assert.True(t, found2, "second cache entry should exist")
|
||||
|
||||
|
|
@ -2403,6 +2452,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
errorChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
responseChan <- resp
|
||||
}()
|
||||
|
||||
|
|
@ -2430,6 +2480,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
// Verify the node was created with the second registration's data
|
||||
node, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
|
||||
assert.True(t, found, "node should be registered")
|
||||
|
||||
if found {
|
||||
assert.Equal(t, "pending-node-2", node.Hostname())
|
||||
assert.Equal(t, "second-registration-user", node.User().Name())
|
||||
|
|
@ -2463,8 +2514,10 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
|
||||
// Set up context with timeout for followup tests
|
||||
ctx := context.Background()
|
||||
|
||||
if req.Followup != "" {
|
||||
var cancel context.CancelFunc
|
||||
|
||||
ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
}
|
||||
|
|
@ -2516,7 +2569,7 @@ func TestAuthenticationFlows(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow
|
||||
// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow.
|
||||
func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
||||
name string
|
||||
setupFunc func(*testing.T, *Headscale) (string, error)
|
||||
|
|
@ -2597,6 +2650,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||
errorChan <- err
|
||||
return
|
||||
}
|
||||
|
||||
responseChan <- resp
|
||||
}()
|
||||
|
||||
|
|
@ -2650,24 +2704,27 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
|
|||
if responseToValidate == nil {
|
||||
responseToValidate = initialResp
|
||||
}
|
||||
|
||||
tt.validate(t, responseToValidate, app)
|
||||
}
|
||||
}
|
||||
|
||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL
|
||||
// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL.
|
||||
func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) {
|
||||
// AuthURL format: "http://localhost/register/abc123"
|
||||
const registerPrefix = "/register/"
|
||||
|
||||
idx := strings.LastIndex(authURL, registerPrefix)
|
||||
if idx == -1 {
|
||||
return "", fmt.Errorf("invalid AuthURL format: %s", authURL)
|
||||
}
|
||||
|
||||
idStr := authURL[idx+len(registerPrefix):]
|
||||
|
||||
return types.RegistrationIDFromString(idStr)
|
||||
}
|
||||
|
||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response
|
||||
// validateCompleteRegistrationResponse performs comprehensive validation of a registration response.
|
||||
func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) {
|
||||
// Basic response validation
|
||||
require.NotNil(t, resp, "response should not be nil")
|
||||
|
|
@ -2681,7 +2738,7 @@ func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterRe
|
|||
// Additional validation can be added here as needed
|
||||
}
|
||||
|
||||
// Simple test to validate basic node creation and lookup
|
||||
// Simple test to validate basic node creation and lookup.
|
||||
func TestNodeStoreLookup(t *testing.T) {
|
||||
app := createTestApp(t)
|
||||
|
||||
|
|
@ -2713,8 +2770,10 @@ func TestNodeStoreLookup(t *testing.T) {
|
|||
|
||||
// Wait for node to be available in NodeStore
|
||||
var node types.NodeView
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var found bool
|
||||
|
||||
node, found = app.state.GetNodeByNodeKey(nodeKey.Public())
|
||||
assert.True(c, found, "Node should be found in NodeStore")
|
||||
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore")
|
||||
|
|
@ -2783,8 +2842,10 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
|
||||
// Get the node ID
|
||||
var registeredNode types.NodeView
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var found bool
|
||||
|
||||
registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public())
|
||||
assert.True(c, found, "Node should be found in NodeStore")
|
||||
}, 1*time.Second, 100*time.Millisecond, "waiting for node to be available")
|
||||
|
|
@ -2796,6 +2857,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
// Verify initial state: user1 has 2 nodes, user2 has 2 nodes
|
||||
user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID))
|
||||
user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID))
|
||||
|
||||
require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially")
|
||||
require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially")
|
||||
|
||||
|
|
@ -2876,6 +2938,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
|
||||
// Verify new nodes were created for user1 with the same machine keys
|
||||
t.Logf("Verifying new nodes created for user1 from user2's machine keys...")
|
||||
|
||||
for i := 2; i < 4; i++ {
|
||||
node := nodes[i]
|
||||
// Should be able to find a node with user1 and this machine key (the new one)
|
||||
|
|
@ -2899,7 +2962,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
|
|||
// Expected behavior:
|
||||
// - User1's original node should STILL EXIST (expired)
|
||||
// - User2 should get a NEW node created (NOT transfer)
|
||||
// - Both nodes share the same machine key (same physical device)
|
||||
// - Both nodes share the same machine key (same physical device).
|
||||
func TestWebFlowReauthDifferentUser(t *testing.T) {
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey1 := key.NewNode()
|
||||
|
|
@ -3043,6 +3106,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
|
|||
// Count nodes per user
|
||||
user1Nodes := 0
|
||||
user2Nodes := 0
|
||||
|
||||
for i := 0; i < allNodesSlice.Len(); i++ {
|
||||
n := allNodesSlice.At(i)
|
||||
if n.UserID().Get() == user1.ID {
|
||||
|
|
@ -3060,7 +3124,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Helper function to create test app
|
||||
// Helper function to create test app.
|
||||
func createTestApp(t *testing.T) *Headscale {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -3147,6 +3211,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
|
|||
}
|
||||
|
||||
t.Log("Step 1: Initial registration with pre-auth key")
|
||||
|
||||
initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public())
|
||||
require.NoError(t, err, "initial registration should succeed")
|
||||
require.NotNil(t, initialResp)
|
||||
|
|
@ -3172,6 +3237,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
|
|||
// - System reboots
|
||||
// The Tailscale client persists the pre-auth key in its state and sends it on every registration
|
||||
t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key")
|
||||
|
||||
restartReq := tailcfg.RegisterRequest{
|
||||
Auth: &tailcfg.RegisterResponseAuth{
|
||||
AuthKey: pakNew.Key, // Same key, now marked as Used=true
|
||||
|
|
@ -3189,9 +3255,11 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
|
|||
|
||||
// This is the assertion that currently FAILS in v0.27.0
|
||||
assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed")
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Error received (this is the bug): %v", err)
|
||||
t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially")
|
||||
|
||||
return // Stop here to show the bug clearly
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -155,6 +155,7 @@ AND auth_key_id NOT IN (
|
|||
nodeRoutes := map[uint64][]netip.Prefix{}
|
||||
|
||||
var routes []types.Route
|
||||
|
||||
err = tx.Find(&routes).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetching routes: %w", err)
|
||||
|
|
@ -255,9 +256,11 @@ AND auth_key_id NOT IN (
|
|||
|
||||
// Check if routes table exists and drop it (should have been migrated already)
|
||||
var routesExists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists)
|
||||
if err == nil && routesExists {
|
||||
log.Info().Msg("Dropping leftover routes table")
|
||||
|
||||
if err := tx.Exec("DROP TABLE routes").Error; err != nil {
|
||||
return fmt.Errorf("dropping routes table: %w", err)
|
||||
}
|
||||
|
|
@ -280,6 +283,7 @@ AND auth_key_id NOT IN (
|
|||
for _, table := range tablesToRename {
|
||||
// Check if table exists before renaming
|
||||
var exists bool
|
||||
|
||||
err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists)
|
||||
if err != nil {
|
||||
return fmt.Errorf("checking if table %s exists: %w", table, err)
|
||||
|
|
@ -761,6 +765,7 @@ AND auth_key_id NOT IN (
|
|||
|
||||
// or else it blocks...
|
||||
sqlConn.SetMaxIdleConns(maxIdleConns)
|
||||
|
||||
sqlConn.SetMaxOpenConns(maxOpenConns)
|
||||
defer sqlConn.SetMaxIdleConns(1)
|
||||
defer sqlConn.SetMaxOpenConns(1)
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ func TestSQLiteMigrationAndDataValidation(t *testing.T) {
|
|||
|
||||
// Verify api_keys data preservation
|
||||
var apiKeyCount int
|
||||
|
||||
err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema")
|
||||
|
|
@ -186,6 +187,7 @@ func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error {
|
|||
func requireConstraintFailed(t *testing.T, err error) {
|
||||
t.Helper()
|
||||
require.Error(t, err)
|
||||
|
||||
if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") {
|
||||
require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error())
|
||||
}
|
||||
|
|
@ -401,6 +403,7 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
|
|||
// skip already-applied migrations and only run new ones.
|
||||
func TestSQLiteAllTestdataMigrations(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
schemas, err := os.ReadDir("testdata/sqlite")
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -27,13 +27,17 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Basic deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var deletionWg sync.WaitGroup
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
deletionWg sync.WaitGroup
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
deletionWg.Done()
|
||||
}
|
||||
|
|
@ -43,10 +47,13 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
go gc.Start()
|
||||
|
||||
// Schedule several nodes for deletion with short expiry
|
||||
const expiry = fifty
|
||||
const numNodes = 100
|
||||
const (
|
||||
expiry = fifty
|
||||
numNodes = 100
|
||||
)
|
||||
|
||||
// Set up wait group for expected deletions
|
||||
|
||||
deletionWg.Add(numNodes)
|
||||
|
||||
for i := 1; i <= numNodes; i++ {
|
||||
|
|
@ -87,14 +94,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
|
|||
// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
|
|
@ -102,11 +113,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
|||
|
||||
// Start GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
const shortExpiry = fifty
|
||||
const longExpiry = 1 * time.Hour
|
||||
const (
|
||||
shortExpiry = fifty
|
||||
longExpiry = 1 * time.Hour
|
||||
)
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
|
|
@ -136,23 +150,31 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
|
|||
// and verifies that the node is deleted only once.
|
||||
func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
}
|
||||
|
||||
// Start the GC
|
||||
gc := NewEphemeralGarbageCollector(deleteFunc)
|
||||
|
||||
go gc.Start()
|
||||
defer gc.Close()
|
||||
|
||||
nodeID := types.NodeID(1)
|
||||
|
||||
const expiry = fifty
|
||||
|
||||
// Schedule node for deletion
|
||||
|
|
@ -196,14 +218,18 @@ func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) {
|
|||
// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted.
|
||||
func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deletionNotifier := make(chan types.NodeID, 1)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
|
||||
deletionNotifier <- nodeID
|
||||
|
|
@ -246,13 +272,18 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
nodeDeleted := make(chan struct{})
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
close(nodeDeleted) // Signal that deletion happened
|
||||
}
|
||||
|
|
@ -263,10 +294,12 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||
// Use a WaitGroup to ensure the GC has started
|
||||
var startWg sync.WaitGroup
|
||||
startWg.Add(1)
|
||||
|
||||
go func() {
|
||||
startWg.Done() // Signal that the goroutine has started
|
||||
gc.Start()
|
||||
}()
|
||||
|
||||
startWg.Wait() // Wait for the GC to start
|
||||
|
||||
// Close GC right away
|
||||
|
|
@ -288,7 +321,9 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
|
|||
|
||||
// Check no node was deleted
|
||||
deleteMutex.Lock()
|
||||
|
||||
nodesDeleted := len(deletedIDs)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
|
||||
|
||||
|
|
@ -311,12 +346,16 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||
t.Logf("Initial number of goroutines: %d", initialGoroutines)
|
||||
|
||||
// Deletion tracking mechanism
|
||||
var deletedIDs []types.NodeID
|
||||
var deleteMutex sync.Mutex
|
||||
var (
|
||||
deletedIDs []types.NodeID
|
||||
deleteMutex sync.Mutex
|
||||
)
|
||||
|
||||
deleteFunc := func(nodeID types.NodeID) {
|
||||
deleteMutex.Lock()
|
||||
|
||||
deletedIDs = append(deletedIDs, nodeID)
|
||||
|
||||
deleteMutex.Unlock()
|
||||
}
|
||||
|
||||
|
|
@ -325,8 +364,10 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
|
|||
go gc.Start()
|
||||
|
||||
// Number of concurrent scheduling goroutines
|
||||
const numSchedulers = 10
|
||||
const nodesPerScheduler = 50
|
||||
const (
|
||||
numSchedulers = 10
|
||||
nodesPerScheduler = 50
|
||||
)
|
||||
|
||||
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
|
||||
|
||||
|
|
|
|||
|
|
@ -483,6 +483,7 @@ func TestBackfillIPAddresses(t *testing.T) {
|
|||
func TestIPAllocatorNextNoReservedIPs(t *testing.T) {
|
||||
db, err := newSQLiteTestDB()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer db.Close()
|
||||
|
||||
alloc, err := NewIPAllocator(
|
||||
|
|
|
|||
|
|
@ -206,6 +206,7 @@ func SetTags(
|
|||
|
||||
slices.Sort(tags)
|
||||
tags = slices.Compact(tags)
|
||||
|
||||
b, err := json.Marshal(tags)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -378,6 +379,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||
if ipv4 == nil {
|
||||
ipv4 = oldNode.IPv4
|
||||
}
|
||||
|
||||
if ipv6 == nil {
|
||||
ipv6 = oldNode.IPv6
|
||||
}
|
||||
|
|
@ -406,6 +408,7 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
|
|||
node.IPv6 = ipv6
|
||||
|
||||
var err error
|
||||
|
||||
node.Hostname, err = util.NormaliseHostname(node.Hostname)
|
||||
if err != nil {
|
||||
newHostname := util.InvalidString()
|
||||
|
|
@ -693,9 +696,12 @@ func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname .
|
|||
}
|
||||
|
||||
var registeredNode *types.Node
|
||||
|
||||
err = hsdb.DB.Transaction(func(tx *gorm.DB) error {
|
||||
var err error
|
||||
|
||||
registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6)
|
||||
|
||||
return err
|
||||
})
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -497,6 +497,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
if len(expectedRoutes1) == 0 {
|
||||
expectedRoutes1 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -508,6 +509,7 @@ func TestAutoApproveRoutes(t *testing.T) {
|
|||
if len(expectedRoutes2) == 0 {
|
||||
expectedRoutes2 = nil
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" {
|
||||
t.Errorf("unexpected enabled routes (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -745,12 +747,15 @@ func TestNodeNaming(t *testing.T) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, nodeInvalidHostname, new(mpp("100.64.0.66/32").Addr()), nil)
|
||||
_, err = RegisterNodeForTest(tx, nodeShortHostname, new(mpp("100.64.0.67/32").Addr()), nil)
|
||||
|
||||
return err
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
|
@ -999,6 +1004,7 @@ func TestListPeers(t *testing.T) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
|
@ -1084,6 +1090,7 @@ func TestListNodes(t *testing.T) {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = RegisterNodeForTest(tx, node2, nil, nil)
|
||||
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -372,18 +372,23 @@ func (c *Config) ToURL() (string, error) {
|
|||
if c.BusyTimeout > 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout))
|
||||
}
|
||||
|
||||
if c.JournalMode != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode))
|
||||
}
|
||||
|
||||
if c.AutoVacuum != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum))
|
||||
}
|
||||
|
||||
if c.WALAutocheckpoint >= 0 {
|
||||
pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint))
|
||||
}
|
||||
|
||||
if c.Synchronous != "" {
|
||||
pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous))
|
||||
}
|
||||
|
||||
if c.ForeignKeys {
|
||||
pragmas = append(pragmas, "foreign_keys=ON")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -294,6 +294,7 @@ func TestConfigToURL(t *testing.T) {
|
|||
t.Errorf("Config.ToURL() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got != tt.want {
|
||||
t.Errorf("Config.ToURL() = %q, want %q", got, tt.want)
|
||||
}
|
||||
|
|
@ -306,6 +307,7 @@ func TestConfigToURLInvalid(t *testing.T) {
|
|||
Path: "",
|
||||
BusyTimeout: -1,
|
||||
}
|
||||
|
||||
_, err := config.ToURL()
|
||||
if err == nil {
|
||||
t.Error("Config.ToURL() with invalid config should return error")
|
||||
|
|
|
|||
|
|
@ -109,7 +109,9 @@ func TestSQLiteDriverPragmaIntegration(t *testing.T) {
|
|||
for pragma, expectedValue := range tt.expected {
|
||||
t.Run("pragma_"+pragma, func(t *testing.T) {
|
||||
var actualValue any
|
||||
|
||||
query := "PRAGMA " + pragma
|
||||
|
||||
err := db.QueryRow(query).Scan(&actualValue)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query %s: %v", query, err)
|
||||
|
|
@ -249,6 +251,7 @@ func TestJournalModeValidation(t *testing.T) {
|
|||
defer db.Close()
|
||||
|
||||
var actualMode string
|
||||
|
||||
err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to query journal_mode: %v", err)
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||
|
||||
if dbValue != nil {
|
||||
var bytes []byte
|
||||
|
||||
switch v := dbValue.(type) {
|
||||
case []byte:
|
||||
bytes = v
|
||||
|
|
@ -55,6 +56,7 @@ func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect
|
|||
maybeInstantiatePtr(fieldValue)
|
||||
f := fieldValue.MethodByName("UnmarshalText")
|
||||
args := []reflect.Value{reflect.ValueOf(bytes)}
|
||||
|
||||
ret := f.Call(args)
|
||||
if !ret[0].IsNil() {
|
||||
return decodingError(field.Name, ret[0].Interface().(error))
|
||||
|
|
@ -89,6 +91,7 @@ func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflec
|
|||
if v == nil || (reflect.ValueOf(v).Kind() == reflect.Pointer && reflect.ValueOf(v).IsNil()) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
b, err := v.MarshalText()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
|
|||
|
|
@ -88,10 +88,12 @@ var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user")
|
|||
// not exist or if another User exists with the new name.
|
||||
func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error {
|
||||
var err error
|
||||
|
||||
oldUser, err := GetUserByID(tx, uid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err = util.ValidateHostname(newName); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,17 +25,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
overview := h.state.DebugOverviewJSON()
|
||||
|
||||
overviewJSON, err := json.MarshalIndent(overview, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(overviewJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
overview := h.state.DebugOverview()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(overview))
|
||||
|
|
@ -45,11 +48,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
// Configuration endpoint
|
||||
debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
config := h.state.DebugConfig()
|
||||
|
||||
configJSON, err := json.MarshalIndent(config, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(configJSON)
|
||||
|
|
@ -70,6 +75,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
} else {
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policy))
|
||||
}))
|
||||
|
|
@ -81,11 +87,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
filterJSON, err := json.MarshalIndent(filter, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(filterJSON)
|
||||
|
|
@ -94,11 +102,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
// SSH policies endpoint
|
||||
debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
sshPolicies := h.state.DebugSSHPolicies()
|
||||
|
||||
sshJSON, err := json.MarshalIndent(sshPolicies, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(sshJSON)
|
||||
|
|
@ -112,17 +122,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
derpInfo := h.state.DebugDERPJSON()
|
||||
|
||||
derpJSON, err := json.MarshalIndent(derpInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(derpJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
derpInfo := h.state.DebugDERPMap()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(derpInfo))
|
||||
|
|
@ -137,17 +150,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
nodeStoreNodes := h.state.DebugNodeStoreJSON()
|
||||
|
||||
nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(nodeStoreJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
nodeStoreInfo := h.state.DebugNodeStore()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(nodeStoreInfo))
|
||||
|
|
@ -157,11 +173,13 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
// Registration cache endpoint
|
||||
debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
cacheInfo := h.state.DebugRegistrationCache()
|
||||
|
||||
cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(cacheJSON)
|
||||
|
|
@ -175,17 +193,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
routes := h.state.DebugRoutes()
|
||||
|
||||
routesJSON, err := json.MarshalIndent(routes, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(routesJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
routes := h.state.DebugRoutesString()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(routes))
|
||||
|
|
@ -200,17 +221,20 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
|
||||
if wantsJSON {
|
||||
policyManagerInfo := h.state.DebugPolicyManagerJSON()
|
||||
|
||||
policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ")
|
||||
if err != nil {
|
||||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(policyManagerJSON)
|
||||
} else {
|
||||
// Default to text/plain for backward compatibility
|
||||
policyManagerInfo := h.state.DebugPolicyManager()
|
||||
|
||||
w.Header().Set("Content-Type", "text/plain")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte(policyManagerInfo))
|
||||
|
|
@ -227,6 +251,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
if res == nil {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set"))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -235,6 +260,7 @@ func (h *Headscale) debugHTTPServer() *http.Server {
|
|||
httpError(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(resJSON)
|
||||
|
|
@ -313,6 +339,7 @@ func (h *Headscale) debugBatcher() string {
|
|||
activeConnections: info.ActiveConnections,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if info.Connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
|
@ -327,9 +354,11 @@ func (h *Headscale) debugBatcher() string {
|
|||
activeConnections: 0,
|
||||
})
|
||||
totalNodes++
|
||||
|
||||
if connected {
|
||||
connectedCount++
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
@ -400,6 +429,7 @@ func (h *Headscale) debugBatcherJSON() DebugBatcherInfo {
|
|||
ActiveConnections: 0,
|
||||
}
|
||||
info.TotalNodes++
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -134,6 +134,7 @@ func shuffleDERPMap(dm *tailcfg.DERPMap) {
|
|||
for id := range dm.Regions {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
|
|
@ -164,12 +165,14 @@ func derpRandom() *rand.Rand {
|
|||
rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table)))
|
||||
derpRandomInst = rnd
|
||||
})
|
||||
|
||||
return derpRandomInst
|
||||
}
|
||||
|
||||
func resetDerpRandomForTesting() {
|
||||
derpRandomMu.Lock()
|
||||
defer derpRandomMu.Unlock()
|
||||
|
||||
derpRandomOnce = sync.Once{}
|
||||
derpRandomInst = nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -242,7 +242,9 @@ func TestShuffleDERPMapDeterministic(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
viper.Set("dns.base_domain", tt.baseDomain)
|
||||
|
||||
defer viper.Reset()
|
||||
|
||||
resetDerpRandomForTesting()
|
||||
|
||||
testMap := tt.derpMap.View().AsStruct()
|
||||
|
|
|
|||
|
|
@ -74,9 +74,11 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) {
|
|||
if err != nil {
|
||||
return tailcfg.DERPRegion{}, err
|
||||
}
|
||||
var host string
|
||||
var port int
|
||||
var portStr string
|
||||
var (
|
||||
host string
|
||||
port int
|
||||
portStr string
|
||||
)
|
||||
|
||||
// Extract hostname and port from URL
|
||||
host, portStr, err = net.SplitHostPort(serverURL.Host)
|
||||
|
|
@ -205,6 +207,7 @@ func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Reques
|
|||
return
|
||||
}
|
||||
defer websocketConn.Close(websocket.StatusInternalError, "closing")
|
||||
|
||||
if websocketConn.Subprotocol() != "derp" {
|
||||
websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol")
|
||||
|
||||
|
|
@ -309,6 +312,7 @@ func DERPBootstrapDNSHandler(
|
|||
resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute)
|
||||
defer cancel()
|
||||
var resolver net.Resolver
|
||||
|
||||
for _, region := range derpMap.Regions().All() {
|
||||
for _, node := range region.Nodes().All() { // we don't care if we override some nodes
|
||||
addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName())
|
||||
|
|
@ -320,6 +324,7 @@ func DERPBootstrapDNSHandler(
|
|||
|
||||
continue
|
||||
}
|
||||
|
||||
dnsEntries[node.HostName()] = addrs
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -85,12 +85,15 @@ func (e *ExtraRecordsMan) Run() {
|
|||
log.Error().Caller().Msgf("file watcher event channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
switch event.Op {
|
||||
case fsnotify.Create, fsnotify.Write, fsnotify.Chmod:
|
||||
log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event")
|
||||
|
||||
if event.Name != e.path {
|
||||
continue
|
||||
}
|
||||
|
||||
e.updateRecords()
|
||||
|
||||
// If a file is removed or renamed, fsnotify will loose track of it
|
||||
|
|
@ -123,6 +126,7 @@ func (e *ExtraRecordsMan) Run() {
|
|||
log.Error().Caller().Msgf("file watcher error channel closing")
|
||||
return
|
||||
}
|
||||
|
||||
log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err)
|
||||
}
|
||||
}
|
||||
|
|
@ -165,6 +169,7 @@ func (e *ExtraRecordsMan) updateRecords() {
|
|||
e.hashes[e.path] = newHash
|
||||
|
||||
log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len())
|
||||
|
||||
e.updateCh <- e.records.Slice()
|
||||
}
|
||||
|
||||
|
|
@ -183,6 +188,7 @@ func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error
|
|||
}
|
||||
|
||||
var records []tailcfg.DNSRecord
|
||||
|
||||
err = json.Unmarshal(b, &records)
|
||||
if err != nil {
|
||||
return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err)
|
||||
|
|
|
|||
|
|
@ -181,6 +181,7 @@ func (h *Headscale) HealthHandler(
|
|||
|
||||
json.NewEncoder(writer).Encode(res)
|
||||
}
|
||||
|
||||
err := h.state.PingDB(req.Context())
|
||||
if err != nil {
|
||||
respond(err)
|
||||
|
|
@ -217,6 +218,7 @@ func (h *Headscale) VersionHandler(
|
|||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
versionInfo := types.GetVersionInfo()
|
||||
|
||||
err := json.NewEncoder(writer).Encode(versionInfo)
|
||||
if err != nil {
|
||||
log.Error().
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ package mapper
|
|||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
|
@ -77,6 +78,7 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||
if err != nil {
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("failed to generate initial map for node %d: %w", id, err)
|
||||
}
|
||||
|
||||
|
|
@ -86,10 +88,11 @@ func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse
|
|||
case c <- initialMap:
|
||||
// Success
|
||||
case <-time.After(5 * time.Second):
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout")
|
||||
log.Error().Uint64("node.id", id.Uint64()).Err(errors.New("timeout")).Msg("Initial map send timeout")
|
||||
log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second).
|
||||
Msg("Initial map send timed out because channel was blocked or receiver not ready")
|
||||
nodeConn.removeConnectionByChannel(c)
|
||||
|
||||
return fmt.Errorf("failed to send initial map to node %d: timeout", id)
|
||||
}
|
||||
|
||||
|
|
@ -129,6 +132,7 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
|
|||
log.Debug().Caller().Uint64("node.id", id.Uint64()).
|
||||
Int("active.connections", nodeConn.getActiveConnectionCount()).
|
||||
Msg("Node connection removed but keeping online because other connections remain")
|
||||
|
||||
return true // Node still has active connections
|
||||
}
|
||||
|
||||
|
|
@ -211,10 +215,12 @@ func (b *LockFreeBatcher) worker(workerID int) {
|
|||
// This is used for synchronous map generation.
|
||||
if w.resultCh != nil {
|
||||
var result workResult
|
||||
|
||||
if nc, exists := b.nodes.Load(w.nodeID); exists {
|
||||
var err error
|
||||
|
||||
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
|
||||
|
||||
result.err = err
|
||||
if result.err != nil {
|
||||
b.workErrors.Add(1)
|
||||
|
|
@ -397,6 +403,7 @@ func (b *LockFreeBatcher) cleanupOfflineNodes() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -449,6 +456,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|||
if nodeConn.hasActiveConnections() {
|
||||
ret.Store(id, true)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -464,6 +472,7 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
|
|||
ret.Store(id, false)
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -518,7 +527,8 @@ type multiChannelNodeConn struct {
|
|||
func generateConnectionID() string {
|
||||
bytes := make([]byte, 8)
|
||||
rand.Read(bytes)
|
||||
return fmt.Sprintf("%x", bytes)
|
||||
|
||||
return hex.EncodeToString(bytes)
|
||||
}
|
||||
|
||||
// newMultiChannelNodeConn creates a new multi-channel node connection.
|
||||
|
|
@ -545,11 +555,14 @@ func (mc *multiChannelNodeConn) close() {
|
|||
// addConnection adds a new connection.
|
||||
func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) {
|
||||
mutexWaitStart := time.Now()
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id).
|
||||
Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT")
|
||||
|
||||
mc.mutex.Lock()
|
||||
|
||||
mutexWaitDur := time.Since(mutexWaitStart)
|
||||
|
||||
defer mc.mutex.Unlock()
|
||||
|
||||
mc.connections = append(mc.connections, entry)
|
||||
|
|
@ -571,9 +584,11 @@ func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapR
|
|||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)).
|
||||
Int("remaining_connections", len(mc.connections)).
|
||||
Msg("Successfully removed connection")
|
||||
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
|
|
@ -607,6 +622,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
// This is not an error - the node will receive a full map when it reconnects
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).
|
||||
Msg("send: skipping send to node with no active connections (likely rapid reconnection)")
|
||||
|
||||
return nil // Return success instead of error
|
||||
}
|
||||
|
||||
|
|
@ -615,7 +631,9 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
Msg("send: broadcasting to all connections")
|
||||
|
||||
var lastErr error
|
||||
|
||||
successCount := 0
|
||||
|
||||
var failedConnections []int // Track failed connections for removal
|
||||
|
||||
// Send to all connections
|
||||
|
|
@ -626,6 +644,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
|
||||
if err := conn.send(data); err != nil {
|
||||
lastErr = err
|
||||
|
||||
failedConnections = append(failedConnections, i)
|
||||
log.Warn().Err(err).
|
||||
Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
|
|
@ -633,6 +652,7 @@ func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error {
|
|||
Msg("send: connection send failed")
|
||||
} else {
|
||||
successCount++
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)).
|
||||
Str("conn.id", conn.id).Int("connection_index", i).
|
||||
Msg("send: successfully sent to connection")
|
||||
|
|
@ -797,6 +817,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|||
Connected: connected,
|
||||
ActiveConnections: activeConnCount,
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
@ -811,6 +832,7 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
|
|||
ActiveConnections: 0,
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
|
||||
|
|
|
|||
|
|
@ -677,6 +677,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
connectedCount := 0
|
||||
|
||||
for i := range allNodes {
|
||||
node := &allNodes[i]
|
||||
|
||||
|
|
@ -694,6 +695,7 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
|
|||
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
|
||||
|
||||
t.Logf("✅ All nodes achieved full connectivity!")
|
||||
|
||||
totalTime := time.Since(startTime)
|
||||
|
||||
// Disconnect all nodes
|
||||
|
|
@ -1309,6 +1311,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
|
|||
for range i % 3 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
batcher.RemoveNode(testNode.n.ID, ch)
|
||||
|
||||
// Yield to allow workers to process and close channels
|
||||
|
|
@ -1392,6 +1395,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, reason := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
|
|
@ -1449,7 +1453,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE)
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
churningChannels[nodeID] = ch
|
||||
|
||||
churningChannelsMutex.Unlock()
|
||||
|
||||
batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100))
|
||||
|
|
@ -1463,6 +1469,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
// Channel was closed, exit gracefully
|
||||
return
|
||||
}
|
||||
|
||||
if valid, _ := validateUpdateContent(data); valid {
|
||||
tracker.recordUpdate(
|
||||
nodeID,
|
||||
|
|
@ -1495,6 +1502,7 @@ func TestBatcherConcurrentClients(t *testing.T) {
|
|||
for range i % 5 {
|
||||
runtime.Gosched() // Introduce timing variability
|
||||
}
|
||||
|
||||
churningChannelsMutex.Lock()
|
||||
|
||||
ch, exists := churningChannels[nodeID]
|
||||
|
|
@ -1879,6 +1887,7 @@ func XTestBatcherScalability(t *testing.T) {
|
|||
channel,
|
||||
tailcfg.CapabilityVersion(100),
|
||||
)
|
||||
|
||||
connectedNodesMutex.Lock()
|
||||
|
||||
connectedNodes[nodeID] = true
|
||||
|
|
@ -2287,6 +2296,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
|
||||
// Phase 1: Connect all nodes initially
|
||||
t.Logf("Phase 1: Connecting all nodes...")
|
||||
|
||||
for i, node := range allNodes {
|
||||
err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
|
|
@ -2303,6 +2313,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
|
||||
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
|
||||
t.Logf("Phase 2: Rapid disconnect all nodes...")
|
||||
|
||||
for i, node := range allNodes {
|
||||
removed := batcher.RemoveNode(node.n.ID, node.ch)
|
||||
t.Logf("Node %d RemoveNode result: %t", i, removed)
|
||||
|
|
@ -2310,9 +2321,11 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
|
||||
// Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up)
|
||||
t.Logf("Phase 3: Rapid reconnect with new channels...")
|
||||
|
||||
newChannels := make([]chan *tailcfg.MapResponse, len(allNodes))
|
||||
for i, node := range allNodes {
|
||||
newChannels[i] = make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Errorf("Failed to reconnect node %d: %v", i, err)
|
||||
|
|
@ -2343,11 +2356,13 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
disconnectedCount++
|
||||
|
||||
t.Logf("Node %d missing from debug info entirely", i)
|
||||
}
|
||||
|
||||
|
|
@ -2382,6 +2397,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
|
|||
case update := <-newChannels[i]:
|
||||
if update != nil {
|
||||
receivedCount++
|
||||
|
||||
t.Logf("Node %d received update successfully", i)
|
||||
}
|
||||
case <-timeout:
|
||||
|
|
@ -2414,6 +2430,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 1: Connect first node with initial connection
|
||||
t.Logf("Phase 1: Connecting node 1 with first connection...")
|
||||
|
||||
err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add node1: %v", err)
|
||||
|
|
@ -2433,7 +2450,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 2: Add second connection for node1 (multi-connection scenario)
|
||||
t.Logf("Phase 2: Adding second connection for node 1...")
|
||||
|
||||
secondChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add second connection for node1: %v", err)
|
||||
|
|
@ -2444,7 +2463,9 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 3: Add third connection for node1
|
||||
t.Logf("Phase 3: Adding third connection for node 1...")
|
||||
|
||||
thirdChannel := make(chan *tailcfg.MapResponse, 10)
|
||||
|
||||
err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100))
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to add third connection for node1: %v", err)
|
||||
|
|
@ -2455,6 +2476,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
// Phase 4: Verify debug status shows correct connection count
|
||||
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
|
||||
|
||||
if debugBatcher, ok := batcher.(interface {
|
||||
Debug() map[types.NodeID]any
|
||||
}); ok {
|
||||
|
|
@ -2462,6 +2484,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
|
||||
if info, exists := debugInfo[node1.n.ID]; exists {
|
||||
t.Logf("Node1 debug info: %+v", info)
|
||||
|
||||
if infoMap, ok := info.(map[string]any); ok {
|
||||
if activeConnections, ok := infoMap["active_connections"].(int); ok {
|
||||
if activeConnections != 3 {
|
||||
|
|
@ -2470,6 +2493,7 @@ func TestBatcherMultiConnection(t *testing.T) {
|
|||
t.Logf("SUCCESS: Node1 correctly shows 3 active connections")
|
||||
}
|
||||
}
|
||||
|
||||
if connected, ok := infoMap["connected"].(bool); ok && !connected {
|
||||
t.Errorf("Node1 should show as connected with 3 active connections")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ const (
|
|||
// NewMapResponseBuilder creates a new builder with basic fields set.
|
||||
func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder {
|
||||
now := time.Now()
|
||||
|
||||
return &MapResponseBuilder{
|
||||
resp: &tailcfg.MapResponse{
|
||||
KeepAlive: false,
|
||||
|
|
@ -124,6 +125,7 @@ func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder {
|
|||
b.resp.Debug = &tailcfg.Debug{
|
||||
DisableLogTail: !b.mapper.cfg.LogTail.Enabled,
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
|
|
@ -281,16 +283,18 @@ func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapRe
|
|||
for _, id := range removedIDs {
|
||||
tailscaleIDs = append(tailscaleIDs, id.NodeID())
|
||||
}
|
||||
|
||||
b.resp.PeersRemoved = tailscaleIDs
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
// Build finalizes the response and returns marshaled bytes
|
||||
// Build finalizes the response and returns marshaled bytes.
|
||||
func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) {
|
||||
if len(b.errs) > 0 {
|
||||
return nil, multierr.New(b.errs...)
|
||||
}
|
||||
|
||||
if debugDumpMapResponsePath != "" {
|
||||
writeDebugMapResponse(b.resp, b.debugType, b.nodeID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -60,7 +60,6 @@ func newMapper(
|
|||
state *state.State,
|
||||
) *mapper {
|
||||
// uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength)
|
||||
|
||||
return &mapper{
|
||||
state: state,
|
||||
cfg: cfg,
|
||||
|
|
@ -80,6 +79,7 @@ func generateUserProfiles(
|
|||
userID := user.Model().ID
|
||||
userMap[userID] = &user
|
||||
ids = append(ids, userID)
|
||||
|
||||
for _, peer := range peers.All() {
|
||||
peerUser := peer.Owner()
|
||||
peerUserID := peerUser.Model().ID
|
||||
|
|
@ -90,6 +90,7 @@ func generateUserProfiles(
|
|||
slices.Sort(ids)
|
||||
ids = slices.Compact(ids)
|
||||
var profiles []tailcfg.UserProfile
|
||||
|
||||
for _, id := range ids {
|
||||
if userMap[id] != nil {
|
||||
profiles = append(profiles, userMap[id].TailscaleUserProfile())
|
||||
|
|
@ -306,6 +307,7 @@ func writeDebugMapResponse(
|
|||
|
||||
perms := fs.FileMode(debugMapResponsePerm)
|
||||
mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID))
|
||||
|
||||
err = os.MkdirAll(mPath, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
@ -319,6 +321,7 @@ func writeDebugMapResponse(
|
|||
)
|
||||
|
||||
log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath)
|
||||
|
||||
err = os.WriteFile(mapResponsePath, body, perms)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
|
@ -375,6 +378,7 @@ func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapRe
|
|||
}
|
||||
|
||||
var resp tailcfg.MapResponse
|
||||
|
||||
err = json.Unmarshal(body, &resp)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name())
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) {
|
|||
if m.polMan == nil {
|
||||
return tailcfg.FilterAllowAll, nil
|
||||
}
|
||||
|
||||
return m.polMan.Filter()
|
||||
}
|
||||
|
||||
|
|
@ -105,6 +106,7 @@ func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) {
|
|||
if m.polMan == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return m.polMan.SSHPolicy(node)
|
||||
}
|
||||
|
||||
|
|
@ -112,6 +114,7 @@ func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool {
|
|||
if m.polMan == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return m.polMan.NodeCanHaveTag(node, tag)
|
||||
}
|
||||
|
||||
|
|
@ -119,6 +122,7 @@ func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix {
|
|||
if m.primary == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return m.primary.PrimaryRoutes(nodeID)
|
||||
}
|
||||
|
||||
|
|
@ -126,6 +130,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
|||
if len(peerIDs) > 0 {
|
||||
// Filter peers by the provided IDs
|
||||
var filtered types.Nodes
|
||||
|
||||
for _, peer := range m.peers {
|
||||
if slices.Contains(peerIDs, peer.ID) {
|
||||
filtered = append(filtered, peer)
|
||||
|
|
@ -136,6 +141,7 @@ func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (typ
|
|||
}
|
||||
// Return all peers except the node itself
|
||||
var filtered types.Nodes
|
||||
|
||||
for _, peer := range m.peers {
|
||||
if peer.ID != nodeID {
|
||||
filtered = append(filtered, peer)
|
||||
|
|
@ -149,6 +155,7 @@ func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) {
|
|||
if len(nodeIDs) > 0 {
|
||||
// Filter nodes by the provided IDs
|
||||
var filtered types.Nodes
|
||||
|
||||
for _, node := range m.nodes {
|
||||
if slices.Contains(nodeIDs, node.ID) {
|
||||
filtered = append(filtered, node)
|
||||
|
|
|
|||
|
|
@ -243,10 +243,12 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||
|
||||
registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) {
|
||||
var resp *tailcfg.RegisterResponse
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
if err != nil {
|
||||
return &tailcfg.RegisterRequest{}, regErr(err)
|
||||
}
|
||||
|
||||
var regReq tailcfg.RegisterRequest
|
||||
if err := json.Unmarshal(body, ®Req); err != nil {
|
||||
return ®Req, regErr(err)
|
||||
|
|
@ -260,6 +262,7 @@ func (ns *noiseServer) NoiseRegistrationHandler(
|
|||
resp = &tailcfg.RegisterResponse{
|
||||
Error: httpErr.Msg,
|
||||
}
|
||||
|
||||
return ®Req, resp
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -163,6 +163,7 @@ func (a *AuthProviderOIDC) RegisterHandler(
|
|||
for k, v := range a.cfg.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
|
||||
extras = append(extras, oidc.Nonce(nonce))
|
||||
|
||||
// Cache the registration info
|
||||
|
|
@ -190,6 +191,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
}
|
||||
|
||||
stateCookieName := getCookieName("state", state)
|
||||
|
||||
cookieState, err := req.Cookie(stateCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err))
|
||||
|
|
@ -212,17 +214,20 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
httpError(writer, err)
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce == "" {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err))
|
||||
return
|
||||
}
|
||||
|
||||
nonceCookieName := getCookieName("nonce", idToken.Nonce)
|
||||
|
||||
nonce, err := req.Cookie(nonceCookieName)
|
||||
if err != nil {
|
||||
httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err))
|
||||
return
|
||||
}
|
||||
|
||||
if idToken.Nonce != nonce.Value {
|
||||
httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil))
|
||||
return
|
||||
|
|
@ -239,6 +244,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
// Fetch user information (email, groups, name, etc) from the userinfo endpoint
|
||||
// https://openid.net/specs/openid-connect-core-1_0.html#UserInfo
|
||||
var userinfo *oidc.UserInfo
|
||||
|
||||
userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
util.LogErr(err, "could not get userinfo; only using claims from id token")
|
||||
|
|
@ -255,6 +261,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified)
|
||||
claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username)
|
||||
claims.Name = cmp.Or(userinfo2.Name, claims.Name)
|
||||
|
||||
claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL)
|
||||
if userinfo2.Groups != nil {
|
||||
claims.Groups = userinfo2.Groups
|
||||
|
|
@ -279,6 +286,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
Msgf("could not create or update user")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusInternalServerError)
|
||||
|
||||
_, werr := writer.Write([]byte("Could not create or update user"))
|
||||
if werr != nil {
|
||||
log.Error().
|
||||
|
|
@ -299,6 +307,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
// Register the node if it does not exist.
|
||||
if registrationId != nil {
|
||||
verb := "Reauthenticated"
|
||||
|
||||
newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry)
|
||||
if err != nil {
|
||||
if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) {
|
||||
|
|
@ -307,7 +316,9 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
httpError(writer, err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -324,6 +335,7 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
|
|||
|
||||
writer.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
|
||||
if _, err := writer.Write(content.Bytes()); err != nil {
|
||||
util.LogErr(err, "Failed to write HTTP response")
|
||||
}
|
||||
|
|
@ -370,6 +382,7 @@ func (a *AuthProviderOIDC) getOauth2Token(
|
|||
if !ok {
|
||||
return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo)
|
||||
}
|
||||
|
||||
if regInfo.Verifier != nil {
|
||||
exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)}
|
||||
}
|
||||
|
|
@ -516,6 +529,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
|
|||
newUser bool
|
||||
c change.Change
|
||||
)
|
||||
|
||||
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
|
||||
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
|
||||
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
|
||||
|
|
|
|||
|
|
@ -21,10 +21,13 @@ func (m Match) DebugString() string {
|
|||
|
||||
sb.WriteString("Match:\n")
|
||||
sb.WriteString(" Sources:\n")
|
||||
|
||||
for _, prefix := range m.srcs.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
||||
sb.WriteString(" Destinations:\n")
|
||||
|
||||
for _, prefix := range m.dests.Prefixes() {
|
||||
sb.WriteString(" " + prefix.String() + "\n")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -38,8 +38,11 @@ type PolicyManager interface {
|
|||
|
||||
// NewPolicyManager returns a new policy manager.
|
||||
func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) {
|
||||
var polMan PolicyManager
|
||||
var err error
|
||||
var (
|
||||
polMan PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
polMan, err = policyv2.NewPolicyManager(pol, users, nodes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -59,6 +62,7 @@ func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[typ
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
polMans = append(polMans, pm)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -125,6 +125,7 @@ func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApprove
|
|||
if !slices.Equal(sortedCurrent, newApproved) {
|
||||
// Log what changed
|
||||
var added, kept []netip.Prefix
|
||||
|
||||
for _, route := range newApproved {
|
||||
if !slices.Contains(sortedCurrent, route) {
|
||||
added = append(added, route)
|
||||
|
|
|
|||
|
|
@ -312,8 +312,11 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
|
|||
nodes := types.Nodes{&node}
|
||||
|
||||
// Create policy manager or use nil if specified
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
if tt.name != "nil_policy_manager" {
|
||||
pm, err = pmf(users, nodes.ViewSlice())
|
||||
assert.NoError(t, err)
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ func TestReduceNodes(t *testing.T) {
|
|||
rules []tailcfg.FilterRule
|
||||
node *types.Node
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
|
|
@ -782,9 +783,11 @@ func TestReduceNodes(t *testing.T) {
|
|||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
|
|
@ -1031,8 +1034,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(nil, tt.nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -1050,9 +1056,11 @@ func TestReduceNodesFromPolicy(t *testing.T) {
|
|||
for _, v := range gotViews.All() {
|
||||
got = append(got, v.AsStruct())
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
|
||||
t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff)
|
||||
t.Log("Matchers: ")
|
||||
|
||||
for _, m := range matchers {
|
||||
t.Log("\t+", m.DebugString())
|
||||
}
|
||||
|
|
@ -1405,13 +1413,17 @@ func TestSSHPolicyRules(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice())
|
||||
|
||||
if tt.expectErr {
|
||||
require.Error(t, err)
|
||||
require.Contains(t, err.Error(), tt.errorMessage)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -1434,6 +1446,7 @@ func TestReduceRoutes(t *testing.T) {
|
|||
routes []netip.Prefix
|
||||
rules []tailcfg.FilterRule
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
|
|
@ -2055,6 +2068,7 @@ func TestReduceRoutes(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
matchers := matcher.MatchesFromFilterRules(tt.args.rules)
|
||||
|
||||
got := ReduceRoutes(
|
||||
tt.args.node.View(),
|
||||
tt.args.routes,
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcf
|
|||
for _, rule := range rules {
|
||||
// record if the rule is actually relevant for the given node.
|
||||
var dests []tailcfg.NetPortRange
|
||||
|
||||
DEST_LOOP:
|
||||
for _, dest := range rule.DstPorts {
|
||||
expanded, err := util.ParseIPSet(dest.IP, nil)
|
||||
|
|
|
|||
|
|
@ -823,10 +823,14 @@ func TestReduceFilterRules(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) {
|
||||
t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) {
|
||||
var pm policy.PolicyManager
|
||||
var err error
|
||||
var (
|
||||
pm policy.PolicyManager
|
||||
err error
|
||||
)
|
||||
|
||||
pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
got, _ := pm.Filter()
|
||||
t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " ")))
|
||||
got = policyutil.ReduceFilterRules(tt.node.View(), got)
|
||||
|
|
|
|||
|
|
@ -829,6 +829,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||
if tt.name == "empty policy" {
|
||||
// We expect this one to have a valid but empty policy
|
||||
require.NoError(t, err)
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -843,6 +844,7 @@ func TestNodeCanApproveRoute(t *testing.T) {
|
|||
if diff := cmp.Diff(tt.canApprove, result); diff != "" {
|
||||
t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.canApprove, result, "Unexpected route approval result")
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ func (pol *Policy) compileFilterRules(
|
|||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
ips, err := dest.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
|
|
@ -127,8 +128,10 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
node types.NodeView,
|
||||
nodes views.Slice[types.NodeView],
|
||||
) ([]*tailcfg.FilterRule, error) {
|
||||
var autogroupSelfDests []AliasWithPorts
|
||||
var otherDests []AliasWithPorts
|
||||
var (
|
||||
autogroupSelfDests []AliasWithPorts
|
||||
otherDests []AliasWithPorts
|
||||
)
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
|
|
@ -139,13 +142,14 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
}
|
||||
|
||||
protocols, _ := acl.Protocol.parseProtocol()
|
||||
|
||||
var rules []*tailcfg.FilterRule
|
||||
|
||||
var resolvedSrcIPs []*netipx.IPSet
|
||||
|
||||
for _, src := range acl.Sources {
|
||||
if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return nil, fmt.Errorf("autogroup:self cannot be used in sources")
|
||||
return nil, errors.New("autogroup:self cannot be used in sources")
|
||||
}
|
||||
|
||||
ips, err := src.Resolve(pol, users, nodes)
|
||||
|
|
@ -167,6 +171,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
if len(autogroupSelfDests) > 0 {
|
||||
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.User().ID() == node.User().ID() && !n.IsTagged() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
|
|
@ -176,6 +181,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
if len(sameUserNodes) > 0 {
|
||||
// Filter sources to only same-user untagged devices
|
||||
var srcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, ips := range resolvedSrcIPs {
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
|
|
@ -192,6 +198,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
|
|||
|
||||
if srcSet != nil && len(srcSet.Prefixes()) > 0 {
|
||||
var destPorts []tailcfg.NetPortRange
|
||||
|
||||
for _, dest := range autogroupSelfDests {
|
||||
for _, n := range sameUserNodes {
|
||||
for _, port := range dest.Ports {
|
||||
|
|
@ -297,8 +304,10 @@ func (pol *Policy) compileSSHPolicy(
|
|||
// Separate destinations into autogroup:self and others
|
||||
// This is needed because autogroup:self requires filtering sources to same-user only,
|
||||
// while other destinations should use all resolved sources
|
||||
var autogroupSelfDests []Alias
|
||||
var otherDests []Alias
|
||||
var (
|
||||
autogroupSelfDests []Alias
|
||||
otherDests []Alias
|
||||
)
|
||||
|
||||
for _, dst := range rule.Destinations {
|
||||
if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
|
|
@ -321,6 +330,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
}
|
||||
|
||||
var action tailcfg.SSHAction
|
||||
|
||||
switch rule.Action {
|
||||
case SSHActionAccept:
|
||||
action = sshAction(true, 0)
|
||||
|
|
@ -336,9 +346,11 @@ func (pol *Policy) compileSSHPolicy(
|
|||
// by default, we do not allow root unless explicitly stated
|
||||
userMap["root"] = ""
|
||||
}
|
||||
|
||||
if rule.Users.ContainsRoot() {
|
||||
userMap["root"] = "root"
|
||||
}
|
||||
|
||||
for _, u := range rule.Users.NormalUsers() {
|
||||
userMap[u.String()] = u.String()
|
||||
}
|
||||
|
|
@ -348,6 +360,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
if len(autogroupSelfDests) > 0 && !node.IsTagged() {
|
||||
// Build destination set for autogroup:self (same-user untagged devices only)
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.User().ID() == node.User().ID() && !n.IsTagged() {
|
||||
n.AppendToIPSet(&dest)
|
||||
|
|
@ -364,6 +377,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
// Filter sources to only same-user untagged devices
|
||||
// Pre-filter to same-user untagged devices for efficiency
|
||||
sameUserNodes := make([]types.NodeView, 0)
|
||||
|
||||
for _, n := range nodes.All() {
|
||||
if n.User().ID() == node.User().ID() && !n.IsTagged() {
|
||||
sameUserNodes = append(sameUserNodes, n)
|
||||
|
|
@ -371,6 +385,7 @@ func (pol *Policy) compileSSHPolicy(
|
|||
}
|
||||
|
||||
var filteredSrcIPs netipx.IPSetBuilder
|
||||
|
||||
for _, n := range sameUserNodes {
|
||||
// Check if any of this node's IPs are in the source set
|
||||
if slices.ContainsFunc(n.IPs(), srcIPs.Contains) {
|
||||
|
|
@ -406,12 +421,14 @@ func (pol *Policy) compileSSHPolicy(
|
|||
if len(otherDests) > 0 {
|
||||
// Build destination set for other destinations
|
||||
var dest netipx.IPSetBuilder
|
||||
|
||||
for _, dst := range otherDests {
|
||||
ips, err := dst.Resolve(pol, users, nodes)
|
||||
if err != nil {
|
||||
log.Trace().Caller().Err(err).Msgf("resolving destination ips")
|
||||
continue
|
||||
}
|
||||
|
||||
if ips != nil {
|
||||
dest.AddSet(ips)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -589,7 +589,9 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
|
|||
if sshPolicy == nil {
|
||||
return // Expected empty result
|
||||
}
|
||||
|
||||
assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -670,7 +672,7 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
|
|||
}
|
||||
|
||||
// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers
|
||||
// TestSSHOneUserToAll that was failing with empty sshUsers.
|
||||
func TestSSHIntegrationReproduction(t *testing.T) {
|
||||
// Create users matching the integration test
|
||||
users := types.Users{
|
||||
|
|
@ -735,7 +737,7 @@ func TestSSHIntegrationReproduction(t *testing.T) {
|
|||
}
|
||||
|
||||
// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized
|
||||
// to JSON and that the sshUsers field is not empty
|
||||
// to JSON and that the sshUsers field is not empty.
|
||||
func TestSSHJSONSerialization(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Name: "user1", Model: gorm.Model{ID: 1}},
|
||||
|
|
@ -775,6 +777,7 @@ func TestSSHJSONSerialization(t *testing.T) {
|
|||
|
||||
// Parse back to verify structure
|
||||
var parsed tailcfg.SSHPolicy
|
||||
|
||||
err = json.Unmarshal(jsonData, &parsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -859,6 +862,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if len(rules) != 1 {
|
||||
t.Fatalf("expected 1 rule, got %d", len(rules))
|
||||
}
|
||||
|
|
@ -875,6 +879,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
found := false
|
||||
|
||||
addr := netip.MustParseAddr(expectedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
|
|
@ -892,6 +897,7 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
|
|||
excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"}
|
||||
for _, excludedIP := range excludedSourceIPs {
|
||||
addr := netip.MustParseAddr(excludedIP)
|
||||
|
||||
for _, prefix := range rule.SrcIPs {
|
||||
pref := netip.MustParsePrefix(prefix)
|
||||
if pref.Contains(addr) {
|
||||
|
|
@ -1325,14 +1331,14 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
|
|||
assert.Empty(t, rules3, "user3 should have no rules")
|
||||
}
|
||||
|
||||
// Helper function to create IP addresses for testing
|
||||
// Helper function to create IP addresses for testing.
|
||||
func createAddr(ip string) *netip.Addr {
|
||||
addr, _ := netip.ParseAddr(ip)
|
||||
return &addr
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly
|
||||
// with autogroup:self in destinations
|
||||
// with autogroup:self in destinations.
|
||||
func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1380,6 +1386,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// Test for user2's first node
|
||||
|
|
@ -1398,12 +1405,14 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
for i, p := range rule2.Principals {
|
||||
principalIPs2[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2)
|
||||
|
||||
// Test for tagged node (should have no SSH rules)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy3 != nil {
|
||||
assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self")
|
||||
}
|
||||
|
|
@ -1411,7 +1420,7 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
|
|||
|
||||
// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user
|
||||
// is in the source and autogroup:self in destination, only that user's devices
|
||||
// can SSH (and only if they match the target user)
|
||||
// can SSH (and only if they match the target user).
|
||||
func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1453,18 +1462,20 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user2's node: should have no rules (user1's devices can't match user2's self)
|
||||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations
|
||||
// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations.
|
||||
func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1511,19 +1522,21 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs)
|
||||
|
||||
// For user3's node: should have no rules (not in group:admins)
|
||||
node5 := nodes[4].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices
|
||||
// are excluded from both sources and destinations when autogroup:self is used
|
||||
// are excluded from both sources and destinations when autogroup:self is used.
|
||||
func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
||||
users := types.Users{
|
||||
{Model: gorm.Model{ID: 1}, Name: "user1"},
|
||||
|
|
@ -1568,6 +1581,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
|||
for i, p := range rule.Principals {
|
||||
principalIPs[i] = p.NodeIP
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs,
|
||||
"should only include untagged devices")
|
||||
|
||||
|
|
@ -1575,6 +1589,7 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
|
|||
node3 := nodes[2].View()
|
||||
sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice())
|
||||
require.NoError(t, err)
|
||||
|
||||
if sshPolicy2 != nil {
|
||||
assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self")
|
||||
}
|
||||
|
|
@ -1623,10 +1638,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
|||
// Verify autogroup:self rule has filtered sources (only same-user devices)
|
||||
selfRule := sshPolicy1.Rules[0]
|
||||
require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices")
|
||||
|
||||
selfPrincipals := make([]string, len(selfRule.Principals))
|
||||
for i, p := range selfRule.Principals {
|
||||
selfPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals,
|
||||
"autogroup:self rule should only include same-user untagged devices")
|
||||
|
||||
|
|
@ -1638,10 +1655,12 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
|
|||
require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)")
|
||||
|
||||
routerRule := sshPolicyRouter.Rules[0]
|
||||
|
||||
routerPrincipals := make([]string, len(routerRule.Principals))
|
||||
for i, p := range routerRule.Principals {
|
||||
routerPrincipals[i] = p.NodeIP
|
||||
}
|
||||
|
||||
require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)")
|
||||
require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)")
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Filter: filter,
|
||||
Policy: pm.pol,
|
||||
})
|
||||
|
||||
filterChanged := filterHash != pm.filterHash
|
||||
if filterChanged {
|
||||
log.Debug().
|
||||
|
|
@ -120,7 +121,9 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Int("filter.rules.new", len(filter)).
|
||||
Msg("Policy filter hash changed")
|
||||
}
|
||||
|
||||
pm.filter = filter
|
||||
|
||||
pm.filterHash = filterHash
|
||||
if filterChanged {
|
||||
pm.matchers = matcher.MatchesFromFilterRules(pm.filter)
|
||||
|
|
@ -135,6 +138,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
}
|
||||
|
||||
tagOwnerMapHash := deephash.Hash(&tagMap)
|
||||
|
||||
tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash
|
||||
if tagOwnerChanged {
|
||||
log.Debug().
|
||||
|
|
@ -144,6 +148,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Int("tagOwners.new", len(tagMap)).
|
||||
Msg("Tag owner hash changed")
|
||||
}
|
||||
|
||||
pm.tagOwnerMap = tagMap
|
||||
pm.tagOwnerMapHash = tagOwnerMapHash
|
||||
|
||||
|
|
@ -153,6 +158,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
}
|
||||
|
||||
autoApproveMapHash := deephash.Hash(&autoMap)
|
||||
|
||||
autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash
|
||||
if autoApproveChanged {
|
||||
log.Debug().
|
||||
|
|
@ -162,10 +168,12 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Int("autoApprovers.new", len(autoMap)).
|
||||
Msg("Auto-approvers hash changed")
|
||||
}
|
||||
|
||||
pm.autoApproveMap = autoMap
|
||||
pm.autoApproveMapHash = autoApproveMapHash
|
||||
|
||||
exitSetHash := deephash.Hash(&exitSet)
|
||||
|
||||
exitSetChanged := exitSetHash != pm.exitSetHash
|
||||
if exitSetChanged {
|
||||
log.Debug().
|
||||
|
|
@ -173,6 +181,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
Str("exitSet.hash.new", exitSetHash.String()[:8]).
|
||||
Msg("Exit node set hash changed")
|
||||
}
|
||||
|
||||
pm.exitSet = exitSet
|
||||
pm.exitSetHash = exitSetHash
|
||||
|
||||
|
|
@ -199,6 +208,7 @@ func (pm *PolicyManager) updateLocked() (bool, error) {
|
|||
if !needsUpdate {
|
||||
log.Trace().
|
||||
Msg("Policy evaluation detected no changes - all hashes match")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
|
|
@ -224,6 +234,7 @@ func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, err
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("compiling SSH policy: %w", err)
|
||||
}
|
||||
|
||||
pm.sshPolicyMap[node.ID()] = sshPol
|
||||
|
||||
return sshPol, nil
|
||||
|
|
@ -318,6 +329,7 @@ func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[typ
|
|||
if err != nil || len(filter) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter)
|
||||
}
|
||||
|
||||
|
|
@ -398,6 +410,7 @@ func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.Fil
|
|||
reducedFilter := policyutil.ReduceFilterRules(node, pm.filter)
|
||||
|
||||
pm.filterRulesMap[node.ID()] = reducedFilter
|
||||
|
||||
return reducedFilter, nil
|
||||
}
|
||||
|
||||
|
|
@ -442,7 +455,7 @@ func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRul
|
|||
// This is different from FilterForNode which returns REDUCED rules for packet filtering.
|
||||
//
|
||||
// For global policies: returns the global matchers (same for all nodes)
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules
|
||||
// For autogroup:self: returns node-specific matchers from unreduced compiled rules.
|
||||
func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) {
|
||||
if pm == nil {
|
||||
return nil, nil
|
||||
|
|
@ -474,6 +487,7 @@ func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) {
|
|||
|
||||
pm.mu.Lock()
|
||||
defer pm.mu.Unlock()
|
||||
|
||||
pm.users = users
|
||||
|
||||
// Clear SSH policy map when users change to force SSH policy recomputation
|
||||
|
|
@ -685,6 +699,7 @@ func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Pr
|
|||
if pm.exitSet == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) {
|
||||
return true
|
||||
}
|
||||
|
|
@ -748,8 +763,10 @@ func (pm *PolicyManager) DebugString() string {
|
|||
}
|
||||
|
||||
fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap))
|
||||
|
||||
for prefix, approveAddrs := range pm.autoApproveMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range approveAddrs.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
|
|
@ -758,14 +775,17 @@ func (pm *PolicyManager) DebugString() string {
|
|||
sb.WriteString("\n\n")
|
||||
|
||||
fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap))
|
||||
|
||||
for prefix, tagOwners := range pm.tagOwnerMap {
|
||||
fmt.Fprintf(&sb, "\t%s:\n", prefix)
|
||||
|
||||
for _, iprange := range tagOwners.Ranges() {
|
||||
fmt.Fprintf(&sb, "\t\t%s\n", iprange)
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n\n")
|
||||
|
||||
if pm.filter != nil {
|
||||
filter, err := json.MarshalIndent(pm.filter, "", " ")
|
||||
if err == nil {
|
||||
|
|
@ -778,6 +798,7 @@ func (pm *PolicyManager) DebugString() string {
|
|||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Matchers:\n")
|
||||
sb.WriteString("an internal structure used to filter nodes and routes\n")
|
||||
|
||||
for _, match := range pm.matchers {
|
||||
sb.WriteString(match.DebugString())
|
||||
sb.WriteString("\n")
|
||||
|
|
@ -785,6 +806,7 @@ func (pm *PolicyManager) DebugString() string {
|
|||
|
||||
sb.WriteString("\n\n")
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range pm.nodes.All() {
|
||||
sb.WriteString(node.String())
|
||||
sb.WriteString("\n")
|
||||
|
|
@ -841,6 +863,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
|
||||
// Check if IPs changed (simple check - could be more sophisticated)
|
||||
oldIPs := oldNode.IPs()
|
||||
|
||||
newIPs := newNode.IPs()
|
||||
if len(oldIPs) != len(newIPs) {
|
||||
affectedUsers[newNode.User().ID()] = struct{}{}
|
||||
|
|
@ -862,6 +885,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
for nodeID := range pm.filterRulesMap {
|
||||
// Find the user for this cached node
|
||||
var nodeUserID uint
|
||||
|
||||
found := false
|
||||
|
||||
// Check in new nodes first
|
||||
|
|
@ -869,6 +893,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
if node.ID() == nodeID {
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -879,6 +904,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
|
|||
if node.ID() == nodeID {
|
||||
nodeUserID = node.User().ID()
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ func TestPolicyManager(t *testing.T) {
|
|||
if diff := cmp.Diff(tt.wantFilter, filter); diff != "" {
|
||||
t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(
|
||||
tt.wantMatchers,
|
||||
matchers,
|
||||
|
|
@ -176,13 +177,16 @@ func TestInvalidateAutogroupSelfCache(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
for i, n := range tt.newNodes {
|
||||
found := false
|
||||
|
||||
for _, origNode := range initialNodes {
|
||||
if n.Hostname == origNode.Hostname {
|
||||
n.ID = origNode.ID
|
||||
found = true
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !found {
|
||||
n.ID = types.NodeID(len(initialNodes) + i + 1)
|
||||
}
|
||||
|
|
@ -369,7 +373,7 @@ func TestInvalidateGlobalPolicyCache(t *testing.T) {
|
|||
|
||||
// TestAutogroupSelfReducedVsUnreducedRules verifies that:
|
||||
// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters
|
||||
// 2. FilterForNode returns reduced compiled rules for packet filters.
|
||||
func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
||||
user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}
|
||||
user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}
|
||||
|
|
@ -409,6 +413,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
|||
// FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations
|
||||
// For node1, destinations should only be node1's IPs
|
||||
node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"}
|
||||
|
||||
for _, rule := range filterNode1 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node1IPs, dst.IP,
|
||||
|
|
@ -418,6 +423,7 @@ func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) {
|
|||
|
||||
// For node2, destinations should only be node2's IPs
|
||||
node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"}
|
||||
|
||||
for _, rule := range filterNode2 {
|
||||
for _, dst := range rule.DstPorts {
|
||||
require.Contains(t, node2IPs, dst.IP,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ import (
|
|||
"tailscale.com/util/slicesx"
|
||||
)
|
||||
|
||||
// Global JSON options for consistent parsing across all struct unmarshaling
|
||||
// Global JSON options for consistent parsing across all struct unmarshaling.
|
||||
var policyJSONOpts = []json.Options{
|
||||
json.DefaultOptionsV2(),
|
||||
json.MatchCaseInsensitiveNames(true),
|
||||
|
|
@ -58,6 +58,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
var alias string
|
||||
|
||||
switch v := a.Alias.(type) {
|
||||
case *Username:
|
||||
alias = string(*v)
|
||||
|
|
@ -89,6 +90,7 @@ func (a AliasWithPorts) MarshalJSON() ([]byte, error) {
|
|||
|
||||
// Otherwise, format as "alias:ports"
|
||||
var ports []string
|
||||
|
||||
for _, port := range a.Ports {
|
||||
if port.First == port.Last {
|
||||
ports = append(ports, strconv.FormatUint(uint64(port.First), 10))
|
||||
|
|
@ -123,6 +125,7 @@ func (u Username) Validate() error {
|
|||
if isUser(string(u)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Username has to contain @, got: %q", u)
|
||||
}
|
||||
|
||||
|
|
@ -194,8 +197,10 @@ func (u Username) resolveUser(users types.Users) (types.User, error) {
|
|||
}
|
||||
|
||||
func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
var (
|
||||
ips netipx.IPSetBuilder
|
||||
errs []error
|
||||
)
|
||||
|
||||
user, err := u.resolveUser(users)
|
||||
if err != nil {
|
||||
|
|
@ -228,6 +233,7 @@ func (g Group) Validate() error {
|
|||
if isGroup(string(g)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf(`Group has to start with "group:", got: %q`, g)
|
||||
}
|
||||
|
||||
|
|
@ -268,8 +274,10 @@ func (g Group) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
var (
|
||||
ips netipx.IPSetBuilder
|
||||
errs []error
|
||||
)
|
||||
|
||||
for _, user := range p.Groups[g] {
|
||||
uips, err := user.Resolve(nil, users, nodes)
|
||||
|
|
@ -290,6 +298,7 @@ func (t Tag) Validate() error {
|
|||
if isTag(string(t)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf(`tag has to start with "tag:", got: %q`, t)
|
||||
}
|
||||
|
||||
|
|
@ -339,6 +348,7 @@ func (h Host) Validate() error {
|
|||
if isHost(string(h)) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Hostname %q is invalid", h)
|
||||
}
|
||||
|
||||
|
|
@ -352,13 +362,16 @@ func (h *Host) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
var (
|
||||
ips netipx.IPSetBuilder
|
||||
errs []error
|
||||
)
|
||||
|
||||
pref, ok := p.Hosts[h]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("unable to resolve host: %q", h)
|
||||
}
|
||||
|
||||
err := pref.Validate()
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
|
|
@ -376,6 +389,7 @@ func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView
|
|||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
}
|
||||
|
||||
for _, node := range nodes.All() {
|
||||
if node.InIPSet(ipsTemp) {
|
||||
node.AppendToIPSet(&ips)
|
||||
|
|
@ -391,6 +405,7 @@ func (p Prefix) Validate() error {
|
|||
if netip.Prefix(p).IsValid() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("Prefix %q is invalid", p)
|
||||
}
|
||||
|
||||
|
|
@ -404,6 +419,7 @@ func (p *Prefix) parseString(addr string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
addrPref, err := addr.Prefix(addr.BitLen())
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -418,6 +434,7 @@ func (p *Prefix) parseString(addr string) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*p = Prefix(pref)
|
||||
|
||||
return nil
|
||||
|
|
@ -428,6 +445,7 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := p.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -441,8 +459,10 @@ func (p *Prefix) UnmarshalJSON(b []byte) error {
|
|||
//
|
||||
// See [Policy], [types.Users], and [types.Nodes] for more details.
|
||||
func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
var (
|
||||
ips netipx.IPSetBuilder
|
||||
errs []error
|
||||
)
|
||||
|
||||
ips.AddPrefix(netip.Prefix(p))
|
||||
// If the IP is a single host, look for a node to ensure we add all the IPs of
|
||||
|
|
@ -587,8 +607,10 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
|||
|
||||
switch vs := v.(type) {
|
||||
case string:
|
||||
var portsPart string
|
||||
var err error
|
||||
var (
|
||||
portsPart string
|
||||
err error
|
||||
)
|
||||
|
||||
if strings.Contains(vs, ":") {
|
||||
vs, portsPart, err = splitDestinationAndPort(vs)
|
||||
|
|
@ -600,6 +622,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ve.Ports = ports
|
||||
} else {
|
||||
return errors.New(`hostport must contain a colon (":")`)
|
||||
|
|
@ -609,6 +632,7 @@ func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ve.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -646,6 +670,7 @@ func isHost(str string) bool {
|
|||
|
||||
func parseAlias(vs string) (Alias, error) {
|
||||
var pref Prefix
|
||||
|
||||
err := pref.parseString(vs)
|
||||
if err == nil {
|
||||
return &pref, nil
|
||||
|
|
@ -690,6 +715,7 @@ func (ve *AliasEnc) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ve.Alias = ptr
|
||||
|
||||
return nil
|
||||
|
|
@ -699,6 +725,7 @@ type Aliases []Alias
|
|||
|
||||
func (a *Aliases) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
|
||||
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -744,8 +771,10 @@ func (a Aliases) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
var (
|
||||
ips netipx.IPSetBuilder
|
||||
errs []error
|
||||
)
|
||||
|
||||
for _, alias := range a {
|
||||
aips, err := alias.Resolve(p, users, nodes)
|
||||
|
|
@ -770,6 +799,7 @@ func unmarshalPointer[T any](
|
|||
parseFunc func(string) (T, error),
|
||||
) (T, error) {
|
||||
var s string
|
||||
|
||||
err := json.Unmarshal(b, &s)
|
||||
if err != nil {
|
||||
var t T
|
||||
|
|
@ -789,6 +819,7 @@ type AutoApprovers []AutoApprover
|
|||
|
||||
func (aa *AutoApprovers) UnmarshalJSON(b []byte) error {
|
||||
var autoApprovers []AutoApproverEnc
|
||||
|
||||
err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -854,6 +885,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ve.AutoApprover = ptr
|
||||
|
||||
return nil
|
||||
|
|
@ -876,6 +908,7 @@ func (ve *OwnerEnc) UnmarshalJSON(b []byte) error {
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ve.Owner = ptr
|
||||
|
||||
return nil
|
||||
|
|
@ -885,6 +918,7 @@ type Owners []Owner
|
|||
|
||||
func (o *Owners) UnmarshalJSON(b []byte) error {
|
||||
var owners []OwnerEnc
|
||||
|
||||
err := json.Unmarshal(b, &owners, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -979,11 +1013,13 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||
|
||||
// Then validate each field can be converted to []string
|
||||
rawGroups := make(map[string][]string)
|
||||
|
||||
for key, value := range rawMap {
|
||||
switch v := value.(type) {
|
||||
case []any:
|
||||
// Convert []interface{} to []string
|
||||
var stringSlice []string
|
||||
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
stringSlice = append(stringSlice, str)
|
||||
|
|
@ -991,6 +1027,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||
return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item)
|
||||
}
|
||||
}
|
||||
|
||||
rawGroups[key] = stringSlice
|
||||
case string:
|
||||
return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v)
|
||||
|
|
@ -1000,6 +1037,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
*g = make(Groups)
|
||||
|
||||
for key, value := range rawGroups {
|
||||
group := Group(key)
|
||||
// Group name already validated above
|
||||
|
|
@ -1014,6 +1052,7 @@ func (g *Groups) UnmarshalJSON(b []byte) error {
|
|||
|
||||
return err
|
||||
}
|
||||
|
||||
usernames = append(usernames, username)
|
||||
}
|
||||
|
||||
|
|
@ -1033,6 +1072,7 @@ func (h *Hosts) UnmarshalJSON(b []byte) error {
|
|||
}
|
||||
|
||||
*h = make(Hosts)
|
||||
|
||||
for key, value := range rawHosts {
|
||||
host := Host(key)
|
||||
if err := host.Validate(); err != nil {
|
||||
|
|
@ -1076,6 +1116,7 @@ func (to TagOwners) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
rawTagOwners := make(map[string][]string)
|
||||
|
||||
for tag, owners := range to {
|
||||
tagStr := string(tag)
|
||||
ownerStrs := make([]string, len(owners))
|
||||
|
|
@ -1152,6 +1193,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||
if p == nil {
|
||||
return nil, nil, nil
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
routes := make(map[netip.Prefix]*netipx.IPSetBuilder)
|
||||
|
|
@ -1160,6 +1202,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||
if _, ok := routes[prefix]; !ok {
|
||||
routes[prefix] = new(netipx.IPSetBuilder)
|
||||
}
|
||||
|
||||
for _, autoApprover := range autoApprovers {
|
||||
aa, ok := autoApprover.(Alias)
|
||||
if !ok {
|
||||
|
|
@ -1173,6 +1216,7 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||
}
|
||||
|
||||
var exitNodeSetBuilder netipx.IPSetBuilder
|
||||
|
||||
if len(p.AutoApprovers.ExitNode) > 0 {
|
||||
for _, autoApprover := range p.AutoApprovers.ExitNode {
|
||||
aa, ok := autoApprover.(Alias)
|
||||
|
|
@ -1187,11 +1231,13 @@ func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.
|
|||
}
|
||||
|
||||
ret := make(map[netip.Prefix]*netipx.IPSet)
|
||||
|
||||
for prefix, builder := range routes {
|
||||
ipSet, err := builder.IPSet()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
ret[prefix] = ipSet
|
||||
}
|
||||
|
||||
|
|
@ -1235,6 +1281,7 @@ func (a *Action) UnmarshalJSON(b []byte) error {
|
|||
default:
|
||||
return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -1259,6 +1306,7 @@ func (a *SSHAction) UnmarshalJSON(b []byte) error {
|
|||
default:
|
||||
return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -1399,7 +1447,7 @@ func (p Protocol) validate() error {
|
|||
return nil
|
||||
case ProtocolWildcard:
|
||||
// Wildcard "*" is not allowed - Tailscale rejects it
|
||||
return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
|
||||
return errors.New("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)")
|
||||
default:
|
||||
// Try to parse as a numeric protocol number
|
||||
str := string(p)
|
||||
|
|
@ -1427,7 +1475,7 @@ func (p Protocol) MarshalJSON() ([]byte, error) {
|
|||
return json.Marshal(string(p))
|
||||
}
|
||||
|
||||
// Protocol constants matching the IANA numbers
|
||||
// Protocol constants matching the IANA numbers.
|
||||
const (
|
||||
protocolICMP = 1 // Internet Control Message
|
||||
protocolIGMP = 2 // Internet Group Management
|
||||
|
|
@ -1464,6 +1512,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error {
|
|||
|
||||
// Remove any fields that start with '#'
|
||||
filtered := make(map[string]any)
|
||||
|
||||
for key, value := range raw {
|
||||
if !strings.HasPrefix(key, "#") {
|
||||
filtered[key] = value
|
||||
|
|
@ -1478,6 +1527,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error {
|
|||
|
||||
// Create a type alias to avoid infinite recursion
|
||||
type aclAlias ACL
|
||||
|
||||
var temp aclAlias
|
||||
|
||||
// Unmarshal into the temporary struct using the v2 JSON options
|
||||
|
|
@ -1487,6 +1537,7 @@ func (a *ACL) UnmarshalJSON(b []byte) error {
|
|||
|
||||
// Copy the result back to the original struct
|
||||
*a = ACL(temp)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -1733,6 +1784,7 @@ func (p *Policy) validate() error {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, dst := range ssh.Destinations {
|
||||
switch dst := dst.(type) {
|
||||
case *AutoGroup:
|
||||
|
|
@ -1846,6 +1898,7 @@ func (g Groups) MarshalJSON() ([]byte, error) {
|
|||
for i, username := range usernames {
|
||||
users[i] = string(username)
|
||||
}
|
||||
|
||||
raw[string(group)] = users
|
||||
}
|
||||
|
||||
|
|
@ -1854,6 +1907,7 @@ func (g Groups) MarshalJSON() ([]byte, error) {
|
|||
|
||||
func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
|
||||
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -1877,6 +1931,7 @@ func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error {
|
|||
|
||||
func (a *SSHDstAliases) UnmarshalJSON(b []byte) error {
|
||||
var aliases []AliasEnc
|
||||
|
||||
err := json.Unmarshal(b, &aliases, policyJSONOpts...)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
@ -1960,8 +2015,10 @@ func (a SSHSrcAliases) MarshalJSON() ([]byte, error) {
|
|||
}
|
||||
|
||||
func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
|
||||
var ips netipx.IPSetBuilder
|
||||
var errs []error
|
||||
var (
|
||||
ips netipx.IPSetBuilder
|
||||
errs []error
|
||||
)
|
||||
|
||||
for _, alias := range a {
|
||||
aips, err := alias.Resolve(p, users, nodes)
|
||||
|
|
@ -2015,18 +2072,22 @@ func unmarshalPolicy(b []byte) (*Policy, error) {
|
|||
}
|
||||
|
||||
var policy Policy
|
||||
|
||||
ast, err := hujson.Parse(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parsing HuJSON: %w", err)
|
||||
}
|
||||
|
||||
ast.Standardize()
|
||||
|
||||
if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil {
|
||||
if serr, ok := errors.AsType[*json.SemanticError](err); ok && serr.Err == json.ErrUnknownName {
|
||||
ptr := serr.JSONPointer
|
||||
name := ptr.LastToken()
|
||||
|
||||
return nil, fmt.Errorf("unknown field %q", name)
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("parsing policy from bytes: %w", err)
|
||||
}
|
||||
|
||||
|
|
@ -2073,6 +2134,7 @@ func (p *Policy) usesAutogroupSelf() bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, dest := range acl.Destinations {
|
||||
if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return true
|
||||
|
|
@ -2087,6 +2149,7 @@ func (p *Policy) usesAutogroupSelf() bool {
|
|||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, dest := range ssh.Destinations {
|
||||
if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) {
|
||||
return true
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ func TestMarshalJSON(t *testing.T) {
|
|||
|
||||
// Unmarshal back to verify round trip
|
||||
var roundTripped Policy
|
||||
|
||||
err = json.Unmarshal(marshalled, &roundTripped)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -2020,6 +2021,7 @@ func TestResolvePolicy(t *testing.T) {
|
|||
}
|
||||
|
||||
var prefs []netip.Prefix
|
||||
|
||||
if ips != nil {
|
||||
if p := ips.Prefixes(); len(p) > 0 {
|
||||
prefs = p
|
||||
|
|
@ -2191,9 +2193,11 @@ func TestResolveAutoApprovers(t *testing.T) {
|
|||
t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmps...); diff != "" {
|
||||
t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if tt.wantAllIPRoutes != nil {
|
||||
if gotAllIPRoutes == nil {
|
||||
t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil")
|
||||
|
|
@ -2340,6 +2344,7 @@ func mustIPSet(prefixes ...string) *netipx.IPSet {
|
|||
for _, p := range prefixes {
|
||||
builder.AddPrefix(mp(p))
|
||||
}
|
||||
|
||||
ipSet, _ := builder.IPSet()
|
||||
|
||||
return ipSet
|
||||
|
|
@ -2349,6 +2354,7 @@ func ipSetComparer(x, y *netipx.IPSet) bool {
|
|||
if x == nil || y == nil {
|
||||
return x == y
|
||||
}
|
||||
|
||||
return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...)
|
||||
}
|
||||
|
||||
|
|
@ -2577,6 +2583,7 @@ func TestResolveTagOwners(t *testing.T) {
|
|||
t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.want, got, cmps...); diff != "" {
|
||||
t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -2852,6 +2859,7 @@ func TestNodeCanHaveTag(t *testing.T) {
|
|||
require.ErrorContains(t, err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
got := pm.NodeCanHaveTag(tt.node.View(), tt.tag)
|
||||
|
|
@ -3112,6 +3120,7 @@ func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var acl ACL
|
||||
|
||||
err := json.Unmarshal([]byte(tt.input), &acl)
|
||||
|
||||
if tt.wantErr {
|
||||
|
|
@ -3163,6 +3172,7 @@ func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) {
|
|||
|
||||
// Unmarshal back
|
||||
var unmarshaled ACL
|
||||
|
||||
err = json.Unmarshal(jsonBytes, &unmarshaled)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -3241,12 +3251,13 @@ func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) {
|
|||
assert.Contains(t, err.Error(), `invalid action "deny"`)
|
||||
}
|
||||
|
||||
// Helper function to parse aliases for testing
|
||||
// Helper function to parse aliases for testing.
|
||||
func mustParseAlias(s string) Alias {
|
||||
alias, err := parseAlias(s)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return alias
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -18,9 +18,11 @@ func splitDestinationAndPort(input string) (string, string, error) {
|
|||
if lastColonIndex == -1 {
|
||||
return "", "", errors.New("input must contain a colon character separating destination and port")
|
||||
}
|
||||
|
||||
if lastColonIndex == 0 {
|
||||
return "", "", errors.New("input cannot start with a colon character")
|
||||
}
|
||||
|
||||
if lastColonIndex == len(input)-1 {
|
||||
return "", "", errors.New("input cannot end with a colon character")
|
||||
}
|
||||
|
|
@ -45,6 +47,7 @@ func parsePortRange(portDef string) ([]tailcfg.PortRange, error) {
|
|||
for part := range parts {
|
||||
if strings.Contains(part, "-") {
|
||||
rangeParts := strings.Split(part, "-")
|
||||
|
||||
rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool {
|
||||
return e == ""
|
||||
})
|
||||
|
|
|
|||
|
|
@ -58,9 +58,11 @@ func TestParsePort(t *testing.T) {
|
|||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if result != test.expected {
|
||||
t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected)
|
||||
}
|
||||
|
|
@ -92,9 +94,11 @@ func TestParsePortRange(t *testing.T) {
|
|||
if err != nil && err.Error() != test.err {
|
||||
t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err)
|
||||
}
|
||||
|
||||
if err == nil && test.err != "" {
|
||||
t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(result, test.expected); diff != "" {
|
||||
t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -152,6 +152,7 @@ func (m *mapSession) serveLongPoll() {
|
|||
// This is not my favourite solution, but it kind of works in our eventually consistent world.
|
||||
ticker := time.NewTicker(time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
disconnected := true
|
||||
// Wait up to 10 seconds for the node to reconnect.
|
||||
// 10 seconds was arbitrary chosen as a reasonable time to reconnect.
|
||||
|
|
@ -160,6 +161,7 @@ func (m *mapSession) serveLongPoll() {
|
|||
disconnected = false
|
||||
break
|
||||
}
|
||||
|
||||
<-ticker.C
|
||||
}
|
||||
|
||||
|
|
@ -215,8 +217,10 @@ func (m *mapSession) serveLongPoll() {
|
|||
if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil {
|
||||
m.errf(err, "failed to add node to batcher")
|
||||
log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher")
|
||||
|
||||
m.h.Change(mapReqChange)
|
||||
|
|
|
|||
|
|
@ -107,9 +107,11 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
|||
Msg("Current primary no longer available")
|
||||
}
|
||||
}
|
||||
|
||||
if len(nodes) >= 1 {
|
||||
pr.primaries[prefix] = nodes[0]
|
||||
changed = true
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Str("prefix", prefix.String()).
|
||||
|
|
@ -126,6 +128,7 @@ func (pr *PrimaryRoutes) updatePrimaryLocked() bool {
|
|||
Str("prefix", prefix.String()).
|
||||
Msg("Cleaning up primary route that no longer has available nodes")
|
||||
delete(pr.primaries, prefix)
|
||||
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
|
|
@ -161,14 +164,18 @@ func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix)
|
|||
// If no routes are being set, remove the node from the routes map.
|
||||
if len(prefixes) == 0 {
|
||||
wasPresent := false
|
||||
|
||||
if _, ok := pr.routes[node]; ok {
|
||||
delete(pr.routes, node)
|
||||
|
||||
wasPresent = true
|
||||
|
||||
log.Debug().
|
||||
Caller().
|
||||
Uint64("node.id", node.Uint64()).
|
||||
Msg("Removed node from primary routes (no prefixes)")
|
||||
}
|
||||
|
||||
changed := pr.updatePrimaryLocked()
|
||||
log.Debug().
|
||||
Caller().
|
||||
|
|
@ -254,12 +261,14 @@ func (pr *PrimaryRoutes) stringLocked() string {
|
|||
|
||||
ids := types.NodeIDs(xmaps.Keys(pr.routes))
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, id := range ids {
|
||||
prefixes := pr.routes[id]
|
||||
fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", "))
|
||||
}
|
||||
|
||||
fmt.Fprintln(&sb, "\n\nCurrent primary routes:")
|
||||
|
||||
for route, nodeID := range pr.primaries {
|
||||
fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -130,6 +130,7 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
pr.SetRoutes(1) // Deregister by setting no routes
|
||||
|
||||
return pr.SetRoutes(1, mp("192.168.3.0/24"))
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
|
|
@ -153,8 +154,9 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
{
|
||||
name: "multiple-nodes-register-same-route",
|
||||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true
|
||||
|
||||
return pr.SetRoutes(3, mp("192.168.1.0/24")) // false
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
|
|
@ -182,7 +184,8 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
pr.SetRoutes(1, mp("192.168.1.0/24")) // false
|
||||
pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary
|
||||
pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
|
||||
return pr.SetRoutes(1) // true, 2 primary
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
2: {
|
||||
|
|
@ -393,6 +396,7 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
operations: func(pr *PrimaryRoutes) bool {
|
||||
pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0"))
|
||||
pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0"))
|
||||
|
||||
return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0"))
|
||||
},
|
||||
expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{
|
||||
|
|
@ -413,15 +417,20 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
operations: func(pr *PrimaryRoutes) bool {
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
|
||||
var change1, change2 bool
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
change1 = pr.SetRoutes(1, mp("192.168.1.0/24"))
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
change2 = pr.SetRoutes(2, mp("192.168.2.0/24"))
|
||||
}()
|
||||
|
||||
wg.Wait()
|
||||
|
||||
return change1 || change2
|
||||
|
|
@ -449,17 +458,21 @@ func TestPrimaryRoutes(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
pr := New()
|
||||
|
||||
change := tt.operations(pr)
|
||||
if change != tt.expectedChange {
|
||||
t.Errorf("change = %v, want %v", change, tt.expectedChange)
|
||||
}
|
||||
|
||||
comps := append(util.Comparers, cmpopts.EquateEmpty())
|
||||
if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" {
|
||||
t.Errorf("routes mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" {
|
||||
t.Errorf("primaries mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" {
|
||||
t.Errorf("isPrimary mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ func (s *State) DebugOverview() string {
|
|||
ephemeralCount := 0
|
||||
|
||||
now := time.Now()
|
||||
|
||||
for _, node := range allNodes.All() {
|
||||
if node.Valid() {
|
||||
userName := node.Owner().Name()
|
||||
|
|
@ -103,17 +104,21 @@ func (s *State) DebugOverview() string {
|
|||
|
||||
// User statistics
|
||||
sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users)))
|
||||
|
||||
for userName, nodeCount := range userNodeCounts {
|
||||
sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Policy information
|
||||
sb.WriteString("Policy:\n")
|
||||
sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode))
|
||||
|
||||
if s.cfg.Policy.Mode == types.PolicyModeFile {
|
||||
sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// DERP information
|
||||
|
|
@ -123,6 +128,7 @@ func (s *State) DebugOverview() string {
|
|||
} else {
|
||||
sb.WriteString("DERP: not configured\n")
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Route information
|
||||
|
|
@ -130,6 +136,7 @@ func (s *State) DebugOverview() string {
|
|||
if s.primaryRoutes.String() == "" {
|
||||
routeCount = 0
|
||||
}
|
||||
|
||||
sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount))
|
||||
sb.WriteString("\n")
|
||||
|
||||
|
|
@ -165,10 +172,12 @@ func (s *State) DebugDERPMap() string {
|
|||
for _, node := range region.Nodes {
|
||||
sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n",
|
||||
node.Name, node.HostName, node.DERPPort))
|
||||
|
||||
if node.STUNPort != 0 {
|
||||
sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort))
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
}
|
||||
|
||||
|
|
@ -319,6 +328,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
|
|||
if s.primaryRoutes.String() == "" {
|
||||
routeCount = 0
|
||||
}
|
||||
|
||||
info.PrimaryRoutes = routeCount
|
||||
|
||||
return info
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
|||
|
||||
// Create NodeStore
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -43,20 +44,26 @@ func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) {
|
|||
// 6. If DELETE came after UPDATE, the returned node should be invalid
|
||||
|
||||
done := make(chan bool, 2)
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
|
||||
var (
|
||||
updatedNode types.NodeView
|
||||
updateOk bool
|
||||
)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
|
||||
go func() {
|
||||
updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) {
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
store.DeleteNode(node.ID)
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
|
@ -90,6 +97,7 @@ func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) {
|
|||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -147,6 +155,7 @@ func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) {
|
|||
node := createTestNode(3, 1, "test-user", "test-node-3")
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -203,6 +212,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -213,8 +223,11 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
// 1. UpdateNode (from UpdateNodeFromMapRequest during polling)
|
||||
// 2. DeleteNode (from handleLogout when client sends logout request)
|
||||
|
||||
var updatedNode types.NodeView
|
||||
var updateOk bool
|
||||
var (
|
||||
updatedNode types.NodeView
|
||||
updateOk bool
|
||||
)
|
||||
|
||||
done := make(chan bool, 2)
|
||||
|
||||
// Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest)
|
||||
|
|
@ -222,12 +235,14 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) {
|
||||
n.LastSeen = new(time.Now())
|
||||
})
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
// Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node)
|
||||
go func() {
|
||||
store.DeleteNode(ephemeralNode.ID)
|
||||
|
||||
done <- true
|
||||
}()
|
||||
|
||||
|
|
@ -266,7 +281,7 @@ func TestEphemeralNodeLogoutRaceCondition(t *testing.T) {
|
|||
// 5. UpdateNode and DeleteNode batch together
|
||||
// 6. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 7. persistNodeToDB is called with the stale valid node
|
||||
// 8. Node gets re-inserted into database instead of staying deleted
|
||||
// 8. Node gets re-inserted into database instead of staying deleted.
|
||||
func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
||||
ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
|
|
@ -278,6 +293,7 @@ func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) {
|
|||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -348,6 +364,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
|||
|
||||
// Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -398,7 +415,7 @@ func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) {
|
|||
// 3. UpdateNode and DeleteNode batch together
|
||||
// 4. UpdateNode returns a valid node (from before delete in batch)
|
||||
// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist
|
||||
// 6. persistNodeToDB must detect the node is deleted and refuse to persist.
|
||||
func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
||||
ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7")
|
||||
ephemeralNode.AuthKey = &types.PreAuthKey{
|
||||
|
|
@ -408,6 +425,7 @@ func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) {
|
|||
}
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ func netInfoFromMapRequest(
|
|||
Uint64("node.id", nodeID.Uint64()).
|
||||
Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP).
|
||||
Msg("using NetInfo from previous Hostinfo in MapRequest")
|
||||
|
||||
return currentHostinfo.NetInfo
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -136,7 +136,7 @@ func TestNetInfoPreservationInRegistrationFlow(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
// Simple helper function for tests
|
||||
// Simple helper function for tests.
|
||||
func createTestNodeSimple(id types.NodeID) *types.Node {
|
||||
user := types.User{
|
||||
Name: "test-user",
|
||||
|
|
|
|||
|
|
@ -97,6 +97,7 @@ func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batc
|
|||
for _, n := range allNodes {
|
||||
nodes[n.ID] = *n
|
||||
}
|
||||
|
||||
snap := snapshotFromNodes(nodes, peersFunc)
|
||||
|
||||
store := &NodeStore{
|
||||
|
|
@ -165,11 +166,14 @@ func (s *NodeStore) PutNode(n types.Node) types.NodeView {
|
|||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
|
||||
nodeStoreOperations.WithLabelValues("put").Inc()
|
||||
|
||||
return resultNode
|
||||
|
|
@ -205,11 +209,14 @@ func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)
|
|||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
resultNode := <-work.nodeResult
|
||||
|
||||
nodeStoreOperations.WithLabelValues("update").Inc()
|
||||
|
||||
// Return the node and whether it exists (is valid)
|
||||
|
|
@ -229,7 +236,9 @@ func (s *NodeStore) DeleteNode(id types.NodeID) {
|
|||
}
|
||||
|
||||
nodeStoreQueueDepth.Inc()
|
||||
|
||||
s.writeQueue <- work
|
||||
|
||||
<-work.result
|
||||
nodeStoreQueueDepth.Dec()
|
||||
|
||||
|
|
@ -262,8 +271,10 @@ func (s *NodeStore) processWrite() {
|
|||
if len(batch) != 0 {
|
||||
s.applyBatch(batch)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
batch = append(batch, w)
|
||||
if len(batch) >= s.batchSize {
|
||||
s.applyBatch(batch)
|
||||
|
|
@ -321,6 +332,7 @@ func (s *NodeStore) applyBatch(batch []work) {
|
|||
w.updateFn(&n)
|
||||
nodes[w.nodeID] = n
|
||||
}
|
||||
|
||||
if w.nodeResult != nil {
|
||||
nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w)
|
||||
}
|
||||
|
|
@ -349,12 +361,14 @@ func (s *NodeStore) applyBatch(batch []work) {
|
|||
nodeView := node.View()
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- nodeView
|
||||
|
||||
close(w.nodeResult)
|
||||
}
|
||||
} else {
|
||||
// Node was deleted or doesn't exist
|
||||
for _, w := range workItems {
|
||||
w.nodeResult <- types.NodeView{} // Send invalid view
|
||||
|
||||
close(w.nodeResult)
|
||||
}
|
||||
}
|
||||
|
|
@ -400,6 +414,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
|||
peersByNode: func() map[types.NodeID][]types.NodeView {
|
||||
peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration)
|
||||
defer peersTimer.ObserveDuration()
|
||||
|
||||
return peersFunc(allNodes)
|
||||
}(),
|
||||
nodesByUser: make(map[types.UserID][]types.NodeView),
|
||||
|
|
@ -417,6 +432,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
|
|||
if newSnap.nodesByMachineKey[n.MachineKey] == nil {
|
||||
newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView)
|
||||
}
|
||||
|
||||
newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView
|
||||
}
|
||||
|
||||
|
|
@ -511,10 +527,12 @@ func (s *NodeStore) DebugString() string {
|
|||
|
||||
// User distribution (shows internal UserID tracking, not display owner)
|
||||
sb.WriteString("Nodes by Internal User ID:\n")
|
||||
|
||||
for userID, nodes := range snapshot.nodesByUser {
|
||||
if len(nodes) > 0 {
|
||||
userName := "unknown"
|
||||
taggedCount := 0
|
||||
|
||||
if len(nodes) > 0 && nodes[0].Valid() {
|
||||
userName = nodes[0].User().Name()
|
||||
// Count tagged nodes (which have UserID set but are owned by "tagged-devices")
|
||||
|
|
@ -532,23 +550,29 @@ func (s *NodeStore) DebugString() string {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Peer relationships summary
|
||||
sb.WriteString("Peer Relationships:\n")
|
||||
|
||||
totalPeers := 0
|
||||
|
||||
for nodeID, peers := range snapshot.peersByNode {
|
||||
peerCount := len(peers)
|
||||
|
||||
totalPeers += peerCount
|
||||
if node, exists := snapshot.nodesByID[nodeID]; exists {
|
||||
sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n",
|
||||
nodeID, node.Hostname, peerCount))
|
||||
}
|
||||
}
|
||||
|
||||
if len(snapshot.peersByNode) > 0 {
|
||||
avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode))
|
||||
sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers))
|
||||
}
|
||||
|
||||
sb.WriteString("\n")
|
||||
|
||||
// Node key index
|
||||
|
|
@ -591,6 +615,7 @@ func (s *NodeStore) RebuildPeerMaps() {
|
|||
}
|
||||
|
||||
s.writeQueue <- w
|
||||
|
||||
<-result
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ func TestSnapshotFromNodes(t *testing.T) {
|
|||
nodes := map[types.NodeID]types.Node{
|
||||
1: createTestNode(1, 1, "user1", "node1"),
|
||||
}
|
||||
|
||||
return nodes, allowAllPeersFunc
|
||||
},
|
||||
validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) {
|
||||
|
|
@ -192,11 +193,13 @@ func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
|
||||
for _, n := range nodes {
|
||||
if n.ID() != node.ID() {
|
||||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
|
|
@ -207,6 +210,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
|
||||
for _, node := range nodes {
|
||||
var peers []types.NodeView
|
||||
|
||||
nodeIsOdd := node.ID()%2 == 1
|
||||
|
||||
for _, n := range nodes {
|
||||
|
|
@ -221,6 +225,7 @@ func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView
|
|||
peers = append(peers, n)
|
||||
}
|
||||
}
|
||||
|
||||
ret[node.ID()] = peers
|
||||
}
|
||||
|
||||
|
|
@ -454,10 +459,13 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
// Add nodes in sequence
|
||||
n1 := store.PutNode(createTestNode(1, 1, "user1", "node1"))
|
||||
assert.True(t, n1.Valid())
|
||||
|
||||
n2 := store.PutNode(createTestNode(2, 2, "user2", "node2"))
|
||||
assert.True(t, n2.Valid())
|
||||
|
||||
n3 := store.PutNode(createTestNode(3, 3, "user3", "node3"))
|
||||
assert.True(t, n3.Valid())
|
||||
|
||||
n4 := store.PutNode(createTestNode(4, 4, "user4", "node4"))
|
||||
assert.True(t, n4.Valid())
|
||||
|
||||
|
|
@ -525,16 +533,20 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2 types.NodeView
|
||||
var newNode3 types.NodeView
|
||||
var ok1, ok2 bool
|
||||
var (
|
||||
resultNode1, resultNode2 types.NodeView
|
||||
newNode3 types.NodeView
|
||||
ok1, ok2 bool
|
||||
)
|
||||
|
||||
// These should all be processed in the same batch
|
||||
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "batch-updated-node1"
|
||||
n.GivenName = "batch-given-1"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
|
|
@ -543,12 +555,14 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
n.Hostname = "batch-updated-node2"
|
||||
n.GivenName = "batch-given-2"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
node3 := createTestNode(3, 1, "user1", "node3")
|
||||
newNode3 = store.PutNode(node3)
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
|
|
@ -601,20 +615,23 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
// This test verifies that when multiple updates to the same node
|
||||
// are batched together, each returned node reflects ALL changes
|
||||
// in the batch, not just the individual update's changes.
|
||||
|
||||
done1 := make(chan struct{})
|
||||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
var (
|
||||
resultNode1, resultNode2, resultNode3 types.NodeView
|
||||
ok1, ok2, ok3 bool
|
||||
)
|
||||
|
||||
// These updates all modify node 1 and should be batched together
|
||||
// The final state should have all three modifications applied
|
||||
|
||||
go func() {
|
||||
resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "multi-update-hostname"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
|
|
@ -622,6 +639,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "multi-update-givenname"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
|
|
@ -629,6 +647,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Tags = []string{"tag1", "tag2"}
|
||||
})
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
|
|
@ -722,14 +741,18 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
done2 := make(chan struct{})
|
||||
done3 := make(chan struct{})
|
||||
|
||||
var result1, result2, result3 types.NodeView
|
||||
var ok1, ok2, ok3 bool
|
||||
var (
|
||||
result1, result2, result3 types.NodeView
|
||||
ok1, ok2, ok3 bool
|
||||
)
|
||||
|
||||
// Start concurrent updates
|
||||
|
||||
go func() {
|
||||
result1, ok1 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Hostname = "concurrent-db-hostname"
|
||||
})
|
||||
|
||||
close(done1)
|
||||
}()
|
||||
|
||||
|
|
@ -737,6 +760,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
result2, ok2 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.GivenName = "concurrent-db-given"
|
||||
})
|
||||
|
||||
close(done2)
|
||||
}()
|
||||
|
||||
|
|
@ -744,6 +768,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
|
||||
n.Tags = []string{"concurrent-tag"}
|
||||
})
|
||||
|
||||
close(done3)
|
||||
}()
|
||||
|
||||
|
|
@ -827,6 +852,7 @@ func TestNodeStoreOperations(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
store := tt.setupFunc(t)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -846,10 +872,11 @@ type testStep struct {
|
|||
|
||||
// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests ---
|
||||
|
||||
// Helper for concurrent test nodes
|
||||
// Helper for concurrent test nodes.
|
||||
func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
|
||||
machineKey := key.NewMachine()
|
||||
nodeKey := key.NewNode()
|
||||
|
||||
return types.Node{
|
||||
ID: id,
|
||||
Hostname: hostname,
|
||||
|
|
@ -862,72 +889,90 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Concurrency: concurrent PutNode operations ---
|
||||
// --- Concurrency: concurrent PutNode operations ---.
|
||||
func TestNodeStoreConcurrentPutNode(t *testing.T) {
|
||||
const concurrentOps = 20
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
results := make(chan bool, concurrentOps)
|
||||
for i := range concurrentOps {
|
||||
wg.Add(1)
|
||||
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node")
|
||||
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Batching: concurrent ops fit in one batch ---
|
||||
// --- Batching: concurrent ops fit in one batch ---.
|
||||
func TestNodeStoreBatchingEfficiency(t *testing.T) {
|
||||
const batchSize = 10
|
||||
|
||||
const ops = 15 // more than batchSize
|
||||
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
results := make(chan bool, ops)
|
||||
for i := range ops {
|
||||
wg.Add(1)
|
||||
|
||||
go func(nodeID int) {
|
||||
defer wg.Done()
|
||||
|
||||
node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node")
|
||||
|
||||
resultNode := store.PutNode(node)
|
||||
results <- resultNode.Valid()
|
||||
}(i + 1)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(results)
|
||||
|
||||
successCount := 0
|
||||
|
||||
for success := range results {
|
||||
if success {
|
||||
successCount++
|
||||
}
|
||||
}
|
||||
|
||||
require.Equal(t, ops, successCount, "All batch PutNode operations should succeed")
|
||||
}
|
||||
|
||||
// --- Race conditions: many goroutines on same node ---
|
||||
// --- Race conditions: many goroutines on same node ---.
|
||||
func TestNodeStoreRaceConditions(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -936,13 +981,18 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
|||
resultNode := store.PutNode(node)
|
||||
require.True(t, resultNode.Valid())
|
||||
|
||||
const numGoroutines = 30
|
||||
const opsPerGoroutine = 10
|
||||
const (
|
||||
numGoroutines = 30
|
||||
opsPerGoroutine = 10
|
||||
)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
errors := make(chan error, numGoroutines*opsPerGoroutine)
|
||||
|
||||
for i := range numGoroutines {
|
||||
wg.Add(1)
|
||||
|
||||
go func(gid int) {
|
||||
defer wg.Done()
|
||||
|
||||
|
|
@ -962,6 +1012,7 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
|||
}
|
||||
case 2:
|
||||
newNode := createConcurrentTestNode(nodeID, "race-put")
|
||||
|
||||
resultNode := store.PutNode(newNode)
|
||||
if !resultNode.Valid() {
|
||||
errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j)
|
||||
|
|
@ -970,23 +1021,28 @@ func TestNodeStoreRaceConditions(t *testing.T) {
|
|||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
close(errors)
|
||||
|
||||
errorCount := 0
|
||||
|
||||
for err := range errors {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
|
||||
if errorCount > 0 {
|
||||
t.Fatalf("Race condition test failed with %d errors", errorCount)
|
||||
}
|
||||
}
|
||||
|
||||
// --- Resource cleanup: goroutine leak detection ---
|
||||
// --- Resource cleanup: goroutine leak detection ---.
|
||||
func TestNodeStoreResourceCleanup(t *testing.T) {
|
||||
// initialGoroutines := runtime.NumGoroutine()
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -1009,10 +1065,12 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
|||
})
|
||||
retrieved, found := store.GetNode(nodeID)
|
||||
assert.True(t, found && retrieved.Valid())
|
||||
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
runtime.GC()
|
||||
|
||||
// Wait for goroutines to settle and check for leaks
|
||||
|
|
@ -1023,9 +1081,10 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
|
|||
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
|
||||
}
|
||||
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---
|
||||
// --- Timeout/deadlock: operations complete within reasonable time ---.
|
||||
func TestNodeStoreOperationTimeout(t *testing.T) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -1033,36 +1092,47 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
const ops = 30
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
putResults := make([]error, ops)
|
||||
updateResults := make([]error, ops)
|
||||
|
||||
// Launch all PutNode operations concurrently
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
|
||||
startPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id)
|
||||
node := createConcurrentTestNode(id, "timeout-node")
|
||||
resultNode := store.PutNode(node)
|
||||
endPut := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut))
|
||||
|
||||
if !resultNode.Valid() {
|
||||
putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id)
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
// Launch all UpdateNode operations concurrently
|
||||
wg = sync.WaitGroup{}
|
||||
|
||||
for i := 1; i <= ops; i++ {
|
||||
nodeID := types.NodeID(i)
|
||||
|
||||
wg.Add(1)
|
||||
|
||||
go func(idx int, id types.NodeID) {
|
||||
defer wg.Done()
|
||||
|
||||
startUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id)
|
||||
resultNode, ok := store.UpdateNode(id, func(n *types.Node) {
|
||||
|
|
@ -1070,31 +1140,40 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||
})
|
||||
endUpdate := time.Now()
|
||||
fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate))
|
||||
|
||||
if !ok || !resultNode.Valid() {
|
||||
updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id)
|
||||
}
|
||||
}(i, nodeID)
|
||||
}
|
||||
|
||||
done := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
wg.Wait()
|
||||
close(done)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-done:
|
||||
errorCount := 0
|
||||
|
||||
for _, err := range putResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
for _, err := range updateResults {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
|
||||
errorCount++
|
||||
}
|
||||
}
|
||||
|
||||
if errorCount == 0 {
|
||||
t.Log("All concurrent operations completed successfully within timeout")
|
||||
} else {
|
||||
|
|
@ -1106,13 +1185,15 @@ func TestNodeStoreOperationTimeout(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Edge case: update non-existent node ---
|
||||
// --- Edge case: update non-existent node ---.
|
||||
func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
||||
for i := range 10 {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
store.Start()
|
||||
|
||||
nonExistentID := types.NodeID(999 + i)
|
||||
updateCallCount := 0
|
||||
|
||||
fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID)
|
||||
resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) {
|
||||
updateCallCount++
|
||||
|
|
@ -1126,9 +1207,10 @@ func TestNodeStoreUpdateNonExistentNode(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
// --- Allocation benchmark ---
|
||||
// --- Allocation benchmark ---.
|
||||
func BenchmarkNodeStoreAllocations(b *testing.B) {
|
||||
store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout)
|
||||
|
||||
store.Start()
|
||||
defer store.Stop()
|
||||
|
||||
|
|
@ -1140,6 +1222,7 @@ func BenchmarkNodeStoreAllocations(b *testing.B) {
|
|||
n.Hostname = "bench-updated"
|
||||
})
|
||||
store.GetNode(nodeID)
|
||||
|
||||
if i%10 == 9 {
|
||||
store.DeleteNode(nodeID)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -93,6 +93,7 @@ func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath s
|
|||
mux := tsql.NewMux()
|
||||
tsweb.Debugger(mux)
|
||||
go http.Serve(lst, mux)
|
||||
|
||||
logf("TailSQL started")
|
||||
<-ctx.Done()
|
||||
logf("TailSQL shutting down...")
|
||||
|
|
|
|||
|
|
@ -177,6 +177,7 @@ func RegistrationIDFromString(str string) (RegistrationID, error) {
|
|||
if len(str) != RegistrationIDLength {
|
||||
return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength)
|
||||
}
|
||||
|
||||
return RegistrationID(str), nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -301,6 +301,7 @@ func validatePKCEMethod(method string) error {
|
|||
if method != PKCEMethodPlain && method != PKCEMethodS256 {
|
||||
return errInvalidPKCEMethod
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
@ -1082,6 +1083,7 @@ func LoadServerConfig() (*Config, error) {
|
|||
if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 {
|
||||
return workers
|
||||
}
|
||||
|
||||
return DefaultBatcherWorkers()
|
||||
}(),
|
||||
RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"),
|
||||
|
|
@ -1117,6 +1119,7 @@ func isSafeServerURL(serverURL, baseDomain string) error {
|
|||
}
|
||||
|
||||
s := len(serverDomainParts)
|
||||
|
||||
b := len(baseDomainParts)
|
||||
for i := range baseDomainParts {
|
||||
if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] {
|
||||
|
|
|
|||
|
|
@ -363,6 +363,7 @@ noise:
|
|||
|
||||
// Populate a custom config file
|
||||
configFilePath := filepath.Join(tmpDir, "config.yaml")
|
||||
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
|
|
@ -398,10 +399,12 @@ server_url: http://127.0.0.1:8080
|
|||
tls_letsencrypt_hostname: example.com
|
||||
tls_letsencrypt_challenge_type: TLS-ALPN-01
|
||||
`)
|
||||
|
||||
err = os.WriteFile(configFilePath, configYaml, 0o600)
|
||||
if err != nil {
|
||||
t.Fatalf("Couldn't write file %s", configFilePath)
|
||||
}
|
||||
|
||||
err = LoadConfig(tmpDir, false)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
|
@ -463,6 +466,7 @@ func TestSafeServerURL(t *testing.T) {
|
|||
|
||||
return
|
||||
}
|
||||
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -156,6 +156,7 @@ func (node *Node) GivenNameHasBeenChanged() bool {
|
|||
// Strip invalid DNS characters for givenName comparison
|
||||
normalised := strings.ToLower(node.Hostname)
|
||||
normalised = invalidDNSRegex.ReplaceAllString(normalised, "")
|
||||
|
||||
return node.GivenName == normalised
|
||||
}
|
||||
|
||||
|
|
@ -464,7 +465,7 @@ func (node *Node) IsSubnetRouter() bool {
|
|||
return len(node.SubnetRoutes()) > 0
|
||||
}
|
||||
|
||||
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes
|
||||
// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes.
|
||||
func (node *Node) AllApprovedRoutes() []netip.Prefix {
|
||||
return append(node.SubnetRoutes(), node.ExitRoutes()...)
|
||||
}
|
||||
|
|
@ -579,6 +580,7 @@ func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) {
|
|||
Str("rejected_hostname", hostInfo.Hostname).
|
||||
Err(err).
|
||||
Msg("Rejecting invalid hostname update from hostinfo")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
|
|
@ -670,6 +672,7 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
|
|||
func (nodes Nodes) DebugString() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("Nodes:\n")
|
||||
|
||||
for _, node := range nodes {
|
||||
sb.WriteString(node.DebugString())
|
||||
sb.WriteString("\n")
|
||||
|
|
|
|||
|
|
@ -128,6 +128,7 @@ func (pak *PreAuthKey) Validate() error {
|
|||
if pak.Expiration != nil {
|
||||
return *pak.Expiration
|
||||
}
|
||||
|
||||
return time.Time{}
|
||||
}()).
|
||||
Time("now", time.Now()).
|
||||
|
|
|
|||
|
|
@ -40,9 +40,11 @@ var TaggedDevices = User{
|
|||
func (u Users) String() string {
|
||||
var sb strings.Builder
|
||||
sb.WriteString("[ ")
|
||||
|
||||
for _, user := range u {
|
||||
fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name)
|
||||
}
|
||||
|
||||
sb.WriteString(" ]")
|
||||
|
||||
return sb.String()
|
||||
|
|
@ -89,6 +91,7 @@ func (u *User) StringID() string {
|
|||
if u == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strconv.FormatUint(uint64(u.ID), 10)
|
||||
}
|
||||
|
||||
|
|
@ -203,6 +206,7 @@ type FlexibleBoolean bool
|
|||
|
||||
func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
|
||||
var val any
|
||||
|
||||
err := json.Unmarshal(data, &val)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not unmarshal data: %w", err)
|
||||
|
|
@ -216,6 +220,7 @@ func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error {
|
|||
if err != nil {
|
||||
return fmt.Errorf("could not parse %s as boolean: %w", v, err)
|
||||
}
|
||||
|
||||
*bit = FlexibleBoolean(pv)
|
||||
|
||||
default:
|
||||
|
|
@ -253,9 +258,11 @@ func (c *OIDCClaims) Identifier() string {
|
|||
if c.Iss == "" && c.Sub == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
if c.Iss == "" {
|
||||
return CleanIdentifier(c.Sub)
|
||||
}
|
||||
|
||||
if c.Sub == "" {
|
||||
return CleanIdentifier(c.Iss)
|
||||
}
|
||||
|
|
@ -340,6 +347,7 @@ func CleanIdentifier(identifier string) string {
|
|||
cleanParts = append(cleanParts, trimmed)
|
||||
}
|
||||
}
|
||||
|
||||
if len(cleanParts) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
|
@ -382,6 +390,7 @@ func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) {
|
|||
if claims.Iss == "" && !strings.HasPrefix(identifier, "/") {
|
||||
identifier = "/" + identifier
|
||||
}
|
||||
|
||||
u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true}
|
||||
u.DisplayName = claims.Name
|
||||
u.ProfilePicURL = claims.ProfilePictureURL
|
||||
|
|
|
|||
|
|
@ -70,6 +70,7 @@ func TestUnmarshallOIDCClaims(t *testing.T) {
|
|||
t.Errorf("UnmarshallOIDCClaims() error = %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -190,6 +191,7 @@ func TestOIDCClaimsIdentifier(t *testing.T) {
|
|||
}
|
||||
result := claims.Identifier()
|
||||
assert.Equal(t, tt.expected, result)
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("Identifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -282,6 +284,7 @@ func TestCleanIdentifier(t *testing.T) {
|
|||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := CleanIdentifier(tt.identifier)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
|
||||
if diff := cmp.Diff(tt.expected, result); diff != "" {
|
||||
t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
@ -487,6 +490,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
|
|||
var user User
|
||||
|
||||
user.FromClaim(&got, tt.emailVerifiedRequired)
|
||||
|
||||
if diff := cmp.Diff(user, tt.want); diff != "" {
|
||||
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -90,6 +90,7 @@ func TestNormaliseHostname(t *testing.T) {
|
|||
t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if !tt.wantErr && got != tt.want {
|
||||
t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want)
|
||||
}
|
||||
|
|
@ -172,6 +173,7 @@ func TestValidateHostname(t *testing.T) {
|
|||
t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
|
||||
if tt.wantErr && tt.errorContains != "" {
|
||||
if err == nil || !strings.Contains(err.Error(), tt.errorContains) {
|
||||
t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains)
|
||||
|
|
|
|||
|
|
@ -15,10 +15,12 @@ func YesNo(msg string) bool {
|
|||
|
||||
var resp string
|
||||
fmt.Scanln(&resp)
|
||||
|
||||
resp = strings.ToLower(resp)
|
||||
switch resp {
|
||||
case "y", "yes", "sure":
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ func TestYesNo(t *testing.T) {
|
|||
// Write test input
|
||||
go func() {
|
||||
defer w.Close()
|
||||
|
||||
w.WriteString(tt.input)
|
||||
}()
|
||||
|
||||
|
|
@ -95,6 +96,7 @@ func TestYesNo(t *testing.T) {
|
|||
// Restore stdin and stderr
|
||||
os.Stdin = oldStdin
|
||||
os.Stderr = oldStderr
|
||||
|
||||
stderrW.Close()
|
||||
|
||||
// Check the result
|
||||
|
|
@ -108,6 +110,7 @@ func TestYesNo(t *testing.T) {
|
|||
stderrR.Close()
|
||||
|
||||
expectedPrompt := "Test question [y/n] "
|
||||
|
||||
actualPrompt := stderrBuf.String()
|
||||
if actualPrompt != expectedPrompt {
|
||||
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
|
||||
|
|
@ -130,6 +133,7 @@ func TestYesNoPromptMessage(t *testing.T) {
|
|||
// Write test input
|
||||
go func() {
|
||||
defer w.Close()
|
||||
|
||||
w.WriteString("n\n")
|
||||
}()
|
||||
|
||||
|
|
@ -140,6 +144,7 @@ func TestYesNoPromptMessage(t *testing.T) {
|
|||
// Restore stdin and stderr
|
||||
os.Stdin = oldStdin
|
||||
os.Stderr = oldStderr
|
||||
|
||||
stderrW.Close()
|
||||
|
||||
// Check that the custom message was included in the prompt
|
||||
|
|
@ -148,6 +153,7 @@ func TestYesNoPromptMessage(t *testing.T) {
|
|||
stderrR.Close()
|
||||
|
||||
expectedPrompt := customMessage + " [y/n] "
|
||||
|
||||
actualPrompt := stderrBuf.String()
|
||||
if actualPrompt != expectedPrompt {
|
||||
t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt)
|
||||
|
|
@ -186,6 +192,7 @@ func TestYesNoCaseInsensitive(t *testing.T) {
|
|||
// Write test input
|
||||
go func() {
|
||||
defer w.Close()
|
||||
|
||||
w.WriteString(tc.input)
|
||||
}()
|
||||
|
||||
|
|
@ -195,6 +202,7 @@ func TestYesNoCaseInsensitive(t *testing.T) {
|
|||
// Restore stdin and stderr
|
||||
os.Stdin = oldStdin
|
||||
os.Stderr = oldStderr
|
||||
|
||||
stderrW.Close()
|
||||
|
||||
// Drain stderr
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@ func GenerateRandomStringURLSafe(n int) (string, error) {
|
|||
b, err := GenerateRandomBytes(n)
|
||||
|
||||
uenc := base64.RawURLEncoding.EncodeToString(b)
|
||||
|
||||
return uenc[:n], err
|
||||
}
|
||||
|
||||
|
|
@ -99,6 +100,7 @@ func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string {
|
|||
DstIPs: %v
|
||||
}
|
||||
`, rule.SrcIPs, rule.DstPorts))
|
||||
|
||||
if index < len(rules)-1 {
|
||||
sb.WriteString(", ")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool {
|
|||
// It returns an error if not exactly one URL is found.
|
||||
func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
|
||||
lines := strings.Split(output, "\n")
|
||||
|
||||
var urlStr string
|
||||
|
||||
for _, line := range lines {
|
||||
|
|
@ -38,6 +39,7 @@ func ParseLoginURLFromCLILogin(output string) (*url.URL, error) {
|
|||
if urlStr != "" {
|
||||
return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line)
|
||||
}
|
||||
|
||||
urlStr = line
|
||||
}
|
||||
}
|
||||
|
|
@ -94,6 +96,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||
|
||||
// Parse the header line - handle both 'traceroute' and 'tracert' (Windows)
|
||||
headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`)
|
||||
|
||||
headerMatches := headerRegex.FindStringSubmatch(lines[0])
|
||||
if len(headerMatches) < 2 {
|
||||
return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0])
|
||||
|
|
@ -105,6 +108,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||
if ipStr == "" {
|
||||
ipStr = headerMatches[3]
|
||||
}
|
||||
|
||||
ip, err := netip.ParseAddr(ipStr)
|
||||
if err != nil {
|
||||
return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err)
|
||||
|
|
@ -144,13 +148,17 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||
}
|
||||
|
||||
remainder := strings.TrimSpace(matches[2])
|
||||
var hopHostname string
|
||||
var hopIP netip.Addr
|
||||
var latencies []time.Duration
|
||||
|
||||
var (
|
||||
hopHostname string
|
||||
hopIP netip.Addr
|
||||
latencies []time.Duration
|
||||
)
|
||||
|
||||
// Check for Windows tracert format which has latencies before hostname
|
||||
// Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]"
|
||||
latencyFirst := false
|
||||
|
||||
if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") {
|
||||
// Check if latencies appear before any hostname/IP
|
||||
firstSpace := strings.Index(remainder, " ")
|
||||
|
|
@ -171,12 +179,14 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||
}
|
||||
// Extract and remove the latency from the beginning
|
||||
latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<")
|
||||
|
||||
ms, err := strconv.ParseFloat(latStr, 64)
|
||||
if err == nil {
|
||||
// Round to nearest microsecond to avoid floating point precision issues
|
||||
duration := time.Duration(ms * float64(time.Millisecond))
|
||||
latencies = append(latencies, duration.Round(time.Microsecond))
|
||||
}
|
||||
|
||||
remainder = strings.TrimSpace(remainder[latMatch[1]:])
|
||||
}
|
||||
}
|
||||
|
|
@ -205,6 +215,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||
if ip, err := netip.ParseAddr(parts[0]); err == nil {
|
||||
hopIP = ip
|
||||
}
|
||||
|
||||
remainder = strings.TrimSpace(strings.Join(parts[1:], " "))
|
||||
}
|
||||
}
|
||||
|
|
@ -216,6 +227,7 @@ func ParseTraceroute(output string) (Traceroute, error) {
|
|||
if len(match) > 1 {
|
||||
// Remove '<' prefix if present (e.g., "<1 ms")
|
||||
latStr := strings.TrimPrefix(match[1], "<")
|
||||
|
||||
ms, err := strconv.ParseFloat(latStr, 64)
|
||||
if err == nil {
|
||||
// Round to nearest microsecond to avoid floating point precision issues
|
||||
|
|
@ -280,11 +292,13 @@ func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) stri
|
|||
if key == "" {
|
||||
return "unknown-node"
|
||||
}
|
||||
|
||||
keyPrefix := key
|
||||
if len(key) > 8 {
|
||||
keyPrefix = key[:8]
|
||||
}
|
||||
return fmt.Sprintf("node-%s", keyPrefix)
|
||||
|
||||
return "node-" + keyPrefix
|
||||
}
|
||||
|
||||
lowercased := strings.ToLower(hostinfo.Hostname)
|
||||
|
|
|
|||
|
|
@ -180,6 +180,7 @@ Success.`,
|
|||
if err != nil {
|
||||
t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if gotURL.String() != tt.wantURL {
|
||||
t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL)
|
||||
}
|
||||
|
|
@ -1066,6 +1067,7 @@ func TestEnsureHostname(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
||||
// For invalid hostnames, we just check the prefix since the random part varies
|
||||
if strings.HasPrefix(tt.want, "invalid-") {
|
||||
|
|
@ -1103,9 +1105,11 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
|
||||
if hi.Hostname != "test-node" {
|
||||
t.Errorf("hostname = %v, want test-node", hi.Hostname)
|
||||
}
|
||||
|
||||
if hi.OS != "linux" {
|
||||
t.Errorf("OS = %v, want linux", hi.OS)
|
||||
}
|
||||
|
|
@ -1147,6 +1151,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
|
||||
if hi.Hostname != "node-nkey1234" {
|
||||
t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname)
|
||||
}
|
||||
|
|
@ -1162,6 +1167,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
|
||||
if hi.Hostname != "unknown-node" {
|
||||
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
|
||||
}
|
||||
|
|
@ -1179,6 +1185,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
|
||||
if hi.Hostname != "unknown-node" {
|
||||
t.Errorf("hostname = %v, want unknown-node", hi.Hostname)
|
||||
}
|
||||
|
|
@ -1200,18 +1207,23 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
|
||||
if hi.Hostname != "test" {
|
||||
t.Errorf("hostname = %v, want test", hi.Hostname)
|
||||
}
|
||||
|
||||
if hi.OS != "windows" {
|
||||
t.Errorf("OS = %v, want windows", hi.OS)
|
||||
}
|
||||
|
||||
if hi.OSVersion != "10.0.19044" {
|
||||
t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion)
|
||||
}
|
||||
|
||||
if hi.DeviceModel != "test-device" {
|
||||
t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel)
|
||||
}
|
||||
|
||||
if hi.BackendLogID != "log123" {
|
||||
t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID)
|
||||
}
|
||||
|
|
@ -1229,6 +1241,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
if hi == nil {
|
||||
t.Error("hostinfo should not be nil")
|
||||
}
|
||||
|
||||
if len(hi.Hostname) != 63 {
|
||||
t.Errorf("hostname length = %v, want 63", len(hi.Hostname))
|
||||
}
|
||||
|
|
@ -1239,6 +1252,7 @@ func TestEnsureHostnameWithHostinfo(t *testing.T) {
|
|||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey)
|
||||
// For invalid hostnames, we just check the prefix since the random part varies
|
||||
if strings.HasPrefix(tt.wantHostname, "invalid-") {
|
||||
|
|
@ -1265,6 +1279,7 @@ func TestEnsureHostname_DNSLabelLimit(t *testing.T) {
|
|||
for i, hostname := range testCases {
|
||||
t.Run(cmp.Diff("", ""), func(t *testing.T) {
|
||||
hostinfo := &tailcfg.Hostinfo{Hostname: hostname}
|
||||
|
||||
result := EnsureHostname(hostinfo, "mkey", "nkey")
|
||||
if len(result) > 63 {
|
||||
t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result))
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -46,6 +47,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
|
||||
// Create an API key using the CLI
|
||||
var validAPIKey string
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
apiKeyOutput, err := headscale.Execute(
|
||||
[]string{
|
||||
|
|
@ -63,7 +65,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
|
||||
// Get the API endpoint
|
||||
endpoint := headscale.GetEndpoint()
|
||||
apiURL := fmt.Sprintf("%s/api/v1/user", endpoint)
|
||||
apiURL := endpoint + "/api/v1/user"
|
||||
|
||||
// Create HTTP client
|
||||
client := &http.Client{
|
||||
|
|
@ -81,6 +83,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
|
@ -99,6 +102,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
// Should NOT contain user data after "Unauthorized"
|
||||
// This is the security bypass - if users array is present, auth was bypassed
|
||||
var jsonCheck map[string]any
|
||||
|
||||
jsonErr := json.Unmarshal(body, &jsonCheck)
|
||||
|
||||
// If we can unmarshal JSON and it contains "users", that's the bypass
|
||||
|
|
@ -132,6 +136,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
|
@ -165,6 +170,7 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
|
@ -193,10 +199,11 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
// Expected: Should return 200 with user data (this is the authorized case)
|
||||
req, err := http.NewRequest("GET", apiURL, nil)
|
||||
require.NoError(t, err)
|
||||
req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey))
|
||||
req.Header.Set("Authorization", "Bearer "+validAPIKey)
|
||||
|
||||
resp, err := client.Do(req)
|
||||
require.NoError(t, err)
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
|
|
@ -208,16 +215,19 @@ func TestAPIAuthenticationBypass(t *testing.T) {
|
|||
|
||||
// Should be able to parse as protobuf JSON
|
||||
var response v1.ListUsersResponse
|
||||
|
||||
err = protojson.Unmarshal(body, &response)
|
||||
assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key")
|
||||
|
||||
// Should contain our test users
|
||||
users := response.GetUsers()
|
||||
assert.Len(t, users, 3, "Should have 3 users")
|
||||
|
||||
userNames := make([]string, len(users))
|
||||
for i, u := range users {
|
||||
userNames[i] = u.GetName()
|
||||
}
|
||||
|
||||
assert.Contains(t, userNames, "user1")
|
||||
assert.Contains(t, userNames, "user2")
|
||||
assert.Contains(t, userNames, "user3")
|
||||
|
|
@ -234,6 +244,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -254,10 +265,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
validAPIKey := strings.TrimSpace(apiKeyOutput)
|
||||
|
||||
endpoint := headscale.GetEndpoint()
|
||||
apiURL := fmt.Sprintf("%s/api/v1/user", endpoint)
|
||||
apiURL := endpoint + "/api/v1/user"
|
||||
|
||||
t.Run("Curl_NoAuth", func(t *testing.T) {
|
||||
// Execute curl from inside the headscale container without auth
|
||||
|
|
@ -274,16 +286,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
|
||||
// Parse the output
|
||||
lines := strings.Split(curlOutput, "\n")
|
||||
var httpCode string
|
||||
var responseBody string
|
||||
|
||||
var (
|
||||
httpCode string
|
||||
responseBody string
|
||||
)
|
||||
|
||||
var responseBodySb295 strings.Builder
|
||||
for _, line := range lines {
|
||||
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
|
||||
httpCode = after
|
||||
} else {
|
||||
responseBody += line
|
||||
responseBodySb295.WriteString(line)
|
||||
}
|
||||
}
|
||||
responseBody += responseBodySb295.String()
|
||||
|
||||
// Should return 401
|
||||
assert.Equal(t, "401", httpCode,
|
||||
|
|
@ -320,16 +337,21 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
lines := strings.Split(curlOutput, "\n")
|
||||
var httpCode string
|
||||
var responseBody string
|
||||
|
||||
var (
|
||||
httpCode string
|
||||
responseBody string
|
||||
)
|
||||
|
||||
var responseBodySb344 strings.Builder
|
||||
for _, line := range lines {
|
||||
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
|
||||
httpCode = after
|
||||
} else {
|
||||
responseBody += line
|
||||
responseBodySb344.WriteString(line)
|
||||
}
|
||||
}
|
||||
responseBody += responseBodySb344.String()
|
||||
|
||||
assert.Equal(t, "401", httpCode)
|
||||
assert.Contains(t, responseBody, "Unauthorized")
|
||||
|
|
@ -346,7 +368,7 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
"curl",
|
||||
"-s",
|
||||
"-H",
|
||||
fmt.Sprintf("Authorization: Bearer %s", validAPIKey),
|
||||
"Authorization: Bearer " + validAPIKey,
|
||||
"-w",
|
||||
"\nHTTP_CODE:%{http_code}",
|
||||
apiURL,
|
||||
|
|
@ -355,8 +377,11 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
require.NoError(t, err)
|
||||
|
||||
lines := strings.Split(curlOutput, "\n")
|
||||
var httpCode string
|
||||
var responseBody strings.Builder
|
||||
|
||||
var (
|
||||
httpCode string
|
||||
responseBody strings.Builder
|
||||
)
|
||||
|
||||
for _, line := range lines {
|
||||
if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok {
|
||||
|
|
@ -372,8 +397,10 @@ func TestAPIAuthenticationBypassCurl(t *testing.T) {
|
|||
|
||||
// Should contain user data
|
||||
var response v1.ListUsersResponse
|
||||
|
||||
err = protojson.Unmarshal([]byte(responseBody.String()), &response)
|
||||
assert.NoError(t, err, "Response should be valid protobuf JSON")
|
||||
|
||||
users := response.GetUsers()
|
||||
assert.Len(t, users, 2, "Should have 2 users")
|
||||
})
|
||||
|
|
@ -391,6 +418,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -420,11 +448,12 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
|
|||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
validAPIKey := strings.TrimSpace(apiKeyOutput)
|
||||
|
||||
// Get the gRPC endpoint
|
||||
// For gRPC, we need to use the hostname and port 50443
|
||||
grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname())
|
||||
grpcAddress := headscale.GetHostname() + ":50443"
|
||||
|
||||
t.Run("gRPC_NoAPIKey", func(t *testing.T) {
|
||||
// Test 1: Try to use CLI without API key (should fail)
|
||||
|
|
@ -487,6 +516,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
|
|||
// CLI outputs the users array directly, not wrapped in ListUsersResponse
|
||||
// Parse as JSON array (CLI uses json.Marshal, not protojson)
|
||||
var users []*v1.User
|
||||
|
||||
err = json.Unmarshal([]byte(output), &users)
|
||||
assert.NoError(t, err, "Response should be valid JSON array")
|
||||
assert.Len(t, users, 2, "Should have 2 users")
|
||||
|
|
@ -495,6 +525,7 @@ func TestGRPCAuthenticationBypass(t *testing.T) {
|
|||
for i, u := range users {
|
||||
userNames[i] = u.GetName()
|
||||
}
|
||||
|
||||
assert.Contains(t, userNames, "grpcuser1")
|
||||
assert.Contains(t, userNames, "grpcuser2")
|
||||
})
|
||||
|
|
@ -513,6 +544,7 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -540,9 +572,10 @@ func TestCLIWithConfigAuthenticationBypass(t *testing.T) {
|
|||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
validAPIKey := strings.TrimSpace(apiKeyOutput)
|
||||
|
||||
grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname())
|
||||
grpcAddress := headscale.GetHostname() + ":50443"
|
||||
|
||||
// Create a config file for testing
|
||||
configWithoutKey := fmt.Sprintf(`
|
||||
|
|
@ -643,6 +676,7 @@ cli:
|
|||
// CLI outputs the users array directly, not wrapped in ListUsersResponse
|
||||
// Parse as JSON array (CLI uses json.Marshal, not protojson)
|
||||
var users []*v1.User
|
||||
|
||||
err = json.Unmarshal([]byte(output), &users)
|
||||
assert.NoError(t, err, "Response should be valid JSON array")
|
||||
assert.Len(t, users, 2, "Should have 2 users")
|
||||
|
|
@ -651,6 +685,7 @@ cli:
|
|||
for i, u := range users {
|
||||
userNames[i] = u.GetName()
|
||||
}
|
||||
|
||||
assert.Contains(t, userNames, "cliuser1")
|
||||
assert.Contains(t, userNames, "cliuser2")
|
||||
})
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -68,18 +69,24 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||
// assertClientsState(t, allClients)
|
||||
|
||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||
|
||||
for _, client := range allClients {
|
||||
ips, err := client.IPs()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
clientIPs[client] = ips
|
||||
}
|
||||
|
||||
var listNodes []*v1.Node
|
||||
var nodeCountBeforeLogout int
|
||||
var (
|
||||
listNodes []*v1.Node
|
||||
nodeCountBeforeLogout int
|
||||
)
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, len(allClients))
|
||||
|
|
@ -110,6 +117,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after logout")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
|
||||
|
|
@ -147,6 +155,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||
t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after relogin")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
|
||||
|
|
@ -200,6 +209,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
|
|||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, nodeCountBeforeLogout)
|
||||
|
|
@ -254,10 +264,14 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
|||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
var nodeCountBeforeLogout int
|
||||
var (
|
||||
listNodes []*v1.Node
|
||||
nodeCountBeforeLogout int
|
||||
)
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, len(allClients))
|
||||
|
|
@ -300,9 +314,11 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
|||
}
|
||||
|
||||
var user1Nodes []*v1.Node
|
||||
|
||||
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
user1Nodes, err = headscale.ListNodes("user1")
|
||||
assert.NoError(ct, err, "Failed to list nodes for user1 after relogin")
|
||||
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes))
|
||||
|
|
@ -322,15 +338,18 @@ func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) {
|
|||
// When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created
|
||||
// for the new user. The original nodes remain with the original user.
|
||||
var user2Nodes []*v1.Node
|
||||
|
||||
t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
user2Nodes, err = headscale.ListNodes("user2")
|
||||
assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin")
|
||||
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)")
|
||||
|
||||
t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat))
|
||||
|
||||
for _, client := range allClients {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
|
|
@ -351,6 +370,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -376,11 +396,13 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
|||
// assertClientsState(t, allClients)
|
||||
|
||||
clientIPs := make(map[TailscaleClient][]netip.Addr)
|
||||
|
||||
for _, client := range allClients {
|
||||
ips, err := client.IPs()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err)
|
||||
}
|
||||
|
||||
clientIPs[client] = ips
|
||||
}
|
||||
|
||||
|
|
@ -394,10 +416,14 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
|
|||
requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second)
|
||||
requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
var nodeCountBeforeLogout int
|
||||
var (
|
||||
listNodes []*v1.Node
|
||||
nodeCountBeforeLogout int
|
||||
)
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, listNodes, len(allClients))
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -176,6 +177,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
syncCompleteTime := time.Now()
|
||||
err = scenario.WaitForTailscaleSync()
|
||||
requireNoErrSync(t, err)
|
||||
|
||||
loginDuration := time.Since(syncCompleteTime)
|
||||
t.Logf("Login and sync completed in %v", loginDuration)
|
||||
|
||||
|
|
@ -207,6 +209,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
// Check each client's status individually to provide better diagnostics
|
||||
expiredCount := 0
|
||||
|
||||
for _, client := range allClients {
|
||||
status, err := client.Status()
|
||||
if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) {
|
||||
|
|
@ -356,6 +359,7 @@ func TestOIDC024UserCreation(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -413,6 +417,7 @@ func TestOIDCAuthenticationWithPKCE(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -470,6 +475,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
oidcMockUser("user1", true),
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -508,6 +514,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users during initial validation")
|
||||
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
|
|
@ -528,9 +535,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
|
||||
|
||||
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
|
||||
|
||||
var listNodes []*v1.Node
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during initial validation")
|
||||
assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes))
|
||||
|
|
@ -538,14 +548,19 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
|
||||
// Collect expected node IDs for validation after user1 initial login
|
||||
expectedNodes := make([]types.NodeID, 0, 1)
|
||||
|
||||
var nodeID uint64
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status := ts.MustStatus()
|
||||
assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status")
|
||||
|
||||
var err error
|
||||
|
||||
nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
assert.NoError(ct, err, "Failed to parse node ID from status")
|
||||
}, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login")
|
||||
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
|
||||
// Validate initial connection state for user1
|
||||
|
|
@ -583,6 +598,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users after user2 login")
|
||||
assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers))
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
|
|
@ -638,10 +654,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
|
||||
// Security validation: Only user2's node should be active after user switch
|
||||
var activeUser2NodeID types.NodeID
|
||||
|
||||
for _, node := range listNodesAfterNewUserLogin {
|
||||
if node.GetUser().GetId() == 2 { // user2
|
||||
activeUser2NodeID = types.NodeID(node.GetId())
|
||||
t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName())
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -655,6 +673,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
// Check user2 node is online
|
||||
if node, exists := nodeStore[activeUser2NodeID]; exists {
|
||||
assert.NotNil(c, node.IsOnline, "User2 node should have online status")
|
||||
|
||||
if node.IsOnline != nil {
|
||||
assert.True(c, *node.IsOnline, "User2 node should be online after login")
|
||||
}
|
||||
|
|
@ -747,6 +766,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users during final validation")
|
||||
assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers))
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
|
|
@ -816,10 +836,12 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
|
||||
// Security validation: Only user1's node should be active after relogin
|
||||
var activeUser1NodeID types.NodeID
|
||||
|
||||
for _, node := range listNodesAfterLoggingBackIn {
|
||||
if node.GetUser().GetId() == 1 { // user1
|
||||
activeUser1NodeID = types.NodeID(node.GetId())
|
||||
t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName())
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
|
@ -833,6 +855,7 @@ func TestOIDCReloginSameNodeNewUser(t *testing.T) {
|
|||
// Check user1 node is online
|
||||
if node, exists := nodeStore[activeUser1NodeID]; exists {
|
||||
assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin")
|
||||
|
||||
if node.IsOnline != nil {
|
||||
assert.True(c, *node.IsOnline, "User1 node should be online after relogin")
|
||||
}
|
||||
|
|
@ -907,6 +930,7 @@ func TestOIDCFollowUpUrl(t *testing.T) {
|
|||
time.Sleep(2 * time.Minute)
|
||||
|
||||
var newUrl *url.URL
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
st, err := ts.Status()
|
||||
assert.NoError(c, err)
|
||||
|
|
@ -1103,6 +1127,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
oidcMockUser("user1", true), // Relogin with same user
|
||||
},
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -1142,6 +1167,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users during initial validation")
|
||||
assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers))
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
|
|
@ -1162,9 +1188,12 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
}, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login")
|
||||
|
||||
t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat))
|
||||
|
||||
var initialNodes []*v1.Node
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
initialNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes during initial validation")
|
||||
assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes))
|
||||
|
|
@ -1172,14 +1201,19 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
|
||||
// Collect expected node IDs for validation after user1 initial login
|
||||
expectedNodes := make([]types.NodeID, 0, 1)
|
||||
|
||||
var nodeID uint64
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status := ts.MustStatus()
|
||||
assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status")
|
||||
|
||||
var err error
|
||||
|
||||
nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
assert.NoError(ct, err, "Failed to parse node ID from status")
|
||||
}, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login")
|
||||
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
|
||||
// Validate initial connection state for user1
|
||||
|
|
@ -1236,6 +1270,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
listUsers, err := headscale.ListUsers()
|
||||
assert.NoError(ct, err, "Failed to list users during final validation")
|
||||
assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers))
|
||||
|
||||
wantUsers := []*v1.User{
|
||||
{
|
||||
Id: 1,
|
||||
|
|
@ -1256,6 +1291,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
}, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle")
|
||||
|
||||
var finalNodes []*v1.Node
|
||||
|
||||
t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
finalNodes, err = headscale.ListNodes()
|
||||
|
|
@ -1279,6 +1315,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
|
||||
// Security validation: user1's node should be active after relogin
|
||||
activeUser1NodeID := types.NodeID(finalNodes[0].GetId())
|
||||
|
||||
t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat))
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
|
|
@ -1287,6 +1324,7 @@ func TestOIDCReloginSameNodeSameUser(t *testing.T) {
|
|||
// Check user1 node is online
|
||||
if node, exists := nodeStore[activeUser1NodeID]; exists {
|
||||
assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin")
|
||||
|
||||
if node.IsOnline != nil {
|
||||
assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin")
|
||||
}
|
||||
|
|
@ -1356,6 +1394,7 @@ func TestOIDCExpiryAfterRestart(t *testing.T) {
|
|||
|
||||
// Verify initial expiry is set
|
||||
var initialExpiry time.Time
|
||||
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
nodes, err := headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
|
|
|
|||
|
|
@ -67,6 +67,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -106,13 +107,16 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
|
|||
validateInitialConnection(t, headscale, expectedNodes)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
|
||||
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after web authentication")
|
||||
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication")
|
||||
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
|
|
@ -152,6 +156,7 @@ func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) {
|
|||
t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after web flow logout")
|
||||
assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes))
|
||||
|
|
@ -226,6 +231,7 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -256,13 +262,16 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
|
|||
validateInitialConnection(t, headscale, expectedNodes)
|
||||
|
||||
var listNodes []*v1.Node
|
||||
|
||||
t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
listNodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err, "Failed to list nodes after initial web authentication")
|
||||
assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication")
|
||||
|
||||
nodeCountBeforeLogout := len(listNodes)
|
||||
t.Logf("node count before logout: %d", nodeCountBeforeLogout)
|
||||
|
||||
|
|
@ -313,9 +322,11 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
|
|||
t.Logf("all clients logged back in as user1")
|
||||
|
||||
var user1Nodes []*v1.Node
|
||||
|
||||
t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
user1Nodes, err = headscale.ListNodes("user1")
|
||||
assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin")
|
||||
assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes))
|
||||
|
|
@ -333,15 +344,18 @@ func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) {
|
|||
// Validate that user2's old nodes still exist in database (but are expired/offline)
|
||||
// When CLI registration creates new nodes for user1, user2's old nodes remain
|
||||
var user2Nodes []*v1.Node
|
||||
|
||||
t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat))
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
user2Nodes, err = headscale.ListNodes("user2")
|
||||
assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1")
|
||||
assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes))
|
||||
}, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1")
|
||||
|
||||
t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat))
|
||||
|
||||
for _, client := range allClients {
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
status, err := client.Status()
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
|
|||
// Generate random hostname for the headscale instance
|
||||
hash, err := util.GenerateRandomStringDNSSafe(6)
|
||||
require.NoError(t, err)
|
||||
|
||||
testName := "derpverify"
|
||||
hostname := fmt.Sprintf("hs-%s-%s", testName, hash)
|
||||
|
||||
|
|
@ -40,6 +41,7 @@ func TestDERPVerifyEndpoint(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -107,6 +109,7 @@ func DERPVerify(
|
|||
if err := c.Connect(t.Context()); err != nil {
|
||||
result = fmt.Errorf("client Connect: %w", err)
|
||||
}
|
||||
|
||||
if m, err := c.Recv(); err != nil {
|
||||
result = fmt.Errorf("client first Recv: %w", err)
|
||||
} else if v, ok := m.(derp.ServerInfoMessage); !ok {
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) {
|
|||
if opts.Labels == nil {
|
||||
opts.Labels = make(map[string]string)
|
||||
}
|
||||
|
||||
opts.Labels["hi.run-id"] = runID
|
||||
opts.Labels["hi.test-type"] = testType
|
||||
}
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ type buffer struct {
|
|||
func (b *buffer) Write(p []byte) (n int, err error) {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
return b.store.Write(p)
|
||||
}
|
||||
|
||||
|
|
@ -49,6 +50,7 @@ func (b *buffer) Write(p []byte) (n int, err error) {
|
|||
func (b *buffer) String() string {
|
||||
b.mutex.Lock()
|
||||
defer b.mutex.Unlock()
|
||||
|
||||
return b.store.String()
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -47,6 +47,7 @@ func SaveLog(
|
|||
}
|
||||
|
||||
var stdout, stderr bytes.Buffer
|
||||
|
||||
err = WriteLog(pool, resource, &stdout, &stderr)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Ne
|
|||
if err != nil {
|
||||
return nil, fmt.Errorf("looking up network names: %w", err)
|
||||
}
|
||||
|
||||
if len(networks) == 0 {
|
||||
if _, err := pool.CreateNetwork(name); err == nil {
|
||||
// Create does not give us an updated version of the resource, so we need to
|
||||
|
|
@ -90,6 +91,7 @@ func RandomFreeHostPort() (int, error) {
|
|||
// CleanUnreferencedNetworks removes networks that are not referenced by any containers.
|
||||
func CleanUnreferencedNetworks(pool *dockertest.Pool) error {
|
||||
filter := "name=hs-"
|
||||
|
||||
networks, err := pool.NetworksByName(filter)
|
||||
if err != nil {
|
||||
return fmt.Errorf("getting networks by filter %q: %w", filter, err)
|
||||
|
|
@ -122,6 +124,7 @@ func CleanImagesInCI(pool *dockertest.Pool) error {
|
|||
}
|
||||
|
||||
removedCount := 0
|
||||
|
||||
for _, image := range images {
|
||||
// Only remove dangling (untagged) images to avoid forcing rebuilds
|
||||
// Dangling images have no RepoTags or only have "<none>:<none>"
|
||||
|
|
|
|||
|
|
@ -159,10 +159,12 @@ func New(
|
|||
} else {
|
||||
hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash)
|
||||
}
|
||||
|
||||
tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err)
|
||||
}
|
||||
|
||||
dsic := &DERPServerInContainer{
|
||||
version: version,
|
||||
hostname: hostname,
|
||||
|
|
@ -185,6 +187,7 @@ func New(
|
|||
fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort)
|
||||
fmt.Fprintf(&cmdArgs, " --stun=true")
|
||||
fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort)
|
||||
|
||||
if dsic.withVerifyClientURL != "" {
|
||||
fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL)
|
||||
}
|
||||
|
|
@ -214,11 +217,13 @@ func New(
|
|||
}
|
||||
|
||||
var container *dockertest.Resource
|
||||
|
||||
buildOptions := &dockertest.BuildOptions{
|
||||
Dockerfile: "Dockerfile.derper",
|
||||
ContextDir: dockerContextPath,
|
||||
BuildArgs: []docker.BuildArg{},
|
||||
}
|
||||
|
||||
switch version {
|
||||
case "head":
|
||||
buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{
|
||||
|
|
@ -249,6 +254,7 @@ func New(
|
|||
err,
|
||||
)
|
||||
}
|
||||
|
||||
log.Printf("Created %s container\n", hostname)
|
||||
|
||||
dsic.container = container
|
||||
|
|
@ -259,12 +265,14 @@ func New(
|
|||
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dsic.tlsCert) != 0 {
|
||||
err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(dsic.tlsKey) != 0 {
|
||||
err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey)
|
||||
if err != nil {
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ package integration
|
|||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
|
|
@ -47,7 +48,7 @@ const (
|
|||
TimestampFormatRunID = "20060102-150405"
|
||||
)
|
||||
|
||||
// NodeSystemStatus represents the status of a node across different systems
|
||||
// NodeSystemStatus represents the status of a node across different systems.
|
||||
type NodeSystemStatus struct {
|
||||
Batcher bool
|
||||
BatcherConnCount int
|
||||
|
|
@ -104,7 +105,7 @@ func requireNoErrLogout(t *testing.T, err error) {
|
|||
require.NoError(t, err, "failed to log out tailscale nodes")
|
||||
}
|
||||
|
||||
// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes
|
||||
// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes.
|
||||
func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -113,8 +114,10 @@ func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.Nod
|
|||
status := client.MustStatus()
|
||||
nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedNodes = append(expectedNodes, types.NodeID(nodeID))
|
||||
}
|
||||
|
||||
return expectedNodes
|
||||
}
|
||||
|
||||
|
|
@ -148,15 +151,17 @@ func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNode
|
|||
}
|
||||
|
||||
// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems
|
||||
// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems
|
||||
// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems.
|
||||
func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
|
||||
t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message)
|
||||
|
||||
if expectedOnline {
|
||||
|
|
@ -171,15 +176,17 @@ func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNode
|
|||
t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message)
|
||||
}
|
||||
|
||||
// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state
|
||||
// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state.
|
||||
func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
var prevReport string
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// Get batcher state
|
||||
debugInfo, err := headscale.DebugBatcher()
|
||||
assert.NoError(c, err, "Failed to get batcher debug info")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -187,6 +194,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
// Get map responses
|
||||
mapResponses, err := headscale.GetAllMapReponses()
|
||||
assert.NoError(c, err, "Failed to get map responses")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -194,6 +202,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -264,6 +273,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
if id == nodeID {
|
||||
continue // Skip self-references
|
||||
}
|
||||
|
||||
expectedPeerMaps++
|
||||
|
||||
if online, exists := peerMap[nodeID]; exists && online {
|
||||
|
|
@ -278,6 +288,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check")
|
||||
|
||||
// Update status with map response data
|
||||
|
|
@ -301,10 +312,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
|
||||
// Verify all systems show nodes in expected state and report failures
|
||||
allMatch := true
|
||||
|
||||
var failureReport strings.Builder
|
||||
|
||||
ids := types.NodeIDs(maps.Keys(nodeStatus))
|
||||
slices.Sort(ids)
|
||||
|
||||
for _, nodeID := range ids {
|
||||
status := nodeStatus[nodeID]
|
||||
systemsMatch := (status.Batcher == expectedOnline) &&
|
||||
|
|
@ -313,10 +326,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
|
||||
if !systemsMatch {
|
||||
allMatch = false
|
||||
|
||||
stateStr := "offline"
|
||||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
|
||||
failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat)))
|
||||
failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline))
|
||||
failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount))
|
||||
|
|
@ -331,6 +346,7 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
t.Logf("Previous report:\n%s", prevReport)
|
||||
t.Logf("Current report:\n%s", failureReport.String())
|
||||
t.Logf("Report diff:\n%s", diff)
|
||||
|
||||
prevReport = failureReport.String()
|
||||
}
|
||||
|
||||
|
|
@ -344,11 +360,12 @@ func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlSer
|
|||
if expectedOnline {
|
||||
stateStr = "online"
|
||||
}
|
||||
|
||||
assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr))
|
||||
}, timeout, 2*time.Second, message)
|
||||
}
|
||||
|
||||
// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components
|
||||
// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components.
|
||||
func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) {
|
||||
t.Helper()
|
||||
|
||||
|
|
@ -357,18 +374,22 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
|
|||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
debugInfo, err := headscale.DebugBatcher()
|
||||
assert.NoError(c, err, "Failed to get batcher debug info")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
allBatcherOffline := true
|
||||
|
||||
for _, nodeID := range expectedNodes {
|
||||
nodeIDStr := fmt.Sprintf("%d", nodeID)
|
||||
if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected {
|
||||
allBatcherOffline = false
|
||||
|
||||
assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID)
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher")
|
||||
}, 15*time.Second, 1*time.Second, "batcher disconnection validation")
|
||||
|
||||
|
|
@ -377,20 +398,24 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
|
|||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
allNodeStoreOffline := true
|
||||
|
||||
for _, nodeID := range expectedNodes {
|
||||
if node, exists := nodeStore[nodeID]; exists {
|
||||
isOnline := node.IsOnline != nil && *node.IsOnline
|
||||
if isOnline {
|
||||
allNodeStoreOffline = false
|
||||
|
||||
assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore")
|
||||
}, 20*time.Second, 1*time.Second, "nodestore offline validation")
|
||||
|
||||
|
|
@ -399,6 +424,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
|
|||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
mapResponses, err := headscale.GetAllMapReponses()
|
||||
assert.NoError(c, err, "Failed to get map responses")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -411,6 +437,7 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
|
|||
for nodeID := range onlineMap {
|
||||
if slices.Contains(expectedNodes, nodeID) {
|
||||
allMapResponsesOffline = false
|
||||
|
||||
assert.False(c, true, "Node %d should not appear in map responses", nodeID)
|
||||
}
|
||||
}
|
||||
|
|
@ -421,13 +448,16 @@ func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expec
|
|||
if id == nodeID {
|
||||
continue // Skip self-references
|
||||
}
|
||||
|
||||
if online, exists := peerMap[nodeID]; exists && online {
|
||||
allMapResponsesOffline = false
|
||||
|
||||
assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses")
|
||||
}, 60*time.Second, 2*time.Second, "map response propagation validation")
|
||||
|
||||
|
|
@ -447,6 +477,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe
|
|||
// Get nodestore state
|
||||
nodeStore, err := headscale.DebugNodeStore()
|
||||
assert.NoError(c, err, "Failed to get nodestore debug info")
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -461,12 +492,14 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe
|
|||
for _, nodeID := range expectedNodes {
|
||||
node, exists := nodeStore[nodeID]
|
||||
assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID)
|
||||
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
// Validate that the node has Hostinfo
|
||||
assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname)
|
||||
|
||||
if node.Hostinfo == nil {
|
||||
t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
|
||||
continue
|
||||
|
|
@ -474,6 +507,7 @@ func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expe
|
|||
|
||||
// Validate that the node has NetInfo
|
||||
assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname)
|
||||
|
||||
if node.Hostinfo.NetInfo == nil {
|
||||
t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat))
|
||||
continue
|
||||
|
|
@ -524,6 +558,7 @@ func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) {
|
|||
// Returns the total number of successful ping operations.
|
||||
func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int {
|
||||
t.Helper()
|
||||
|
||||
success := 0
|
||||
|
||||
for _, client := range clients {
|
||||
|
|
@ -545,6 +580,7 @@ func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts
|
|||
// for validating NAT traversal and relay functionality. Returns success count.
|
||||
func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int {
|
||||
t.Helper()
|
||||
|
||||
success := 0
|
||||
|
||||
for _, client := range clients {
|
||||
|
|
@ -602,9 +638,12 @@ func assertClientsState(t *testing.T, clients []TailscaleClient) {
|
|||
|
||||
for _, client := range clients {
|
||||
wg.Add(1)
|
||||
|
||||
c := client // Avoid loop pointer
|
||||
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
|
||||
assertValidStatus(t, c)
|
||||
assertValidNetcheck(t, c)
|
||||
assertValidNetmap(t, c)
|
||||
|
|
@ -635,6 +674,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
|||
assert.NoError(c, err, "getting netmap for %q", client.Hostname())
|
||||
|
||||
assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname())
|
||||
|
||||
if hi := netmap.SelfNode.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services())
|
||||
}
|
||||
|
|
@ -653,6 +693,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
|||
assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP())
|
||||
|
||||
assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname())
|
||||
|
||||
if hi := peer.Hostinfo(); hi.Valid() {
|
||||
assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services())
|
||||
|
||||
|
|
@ -681,6 +722,7 @@ func assertValidNetmap(t *testing.T, client TailscaleClient) {
|
|||
// and network map presence. This test is not suitable for ACL/partial connection tests.
|
||||
func assertValidStatus(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
status, err := client.Status(true)
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
|
|
@ -738,6 +780,7 @@ func assertValidStatus(t *testing.T, client TailscaleClient) {
|
|||
// which is essential for NAT traversal and connectivity in restricted networks.
|
||||
func assertValidNetcheck(t *testing.T, client TailscaleClient) {
|
||||
t.Helper()
|
||||
|
||||
report, err := client.Netcheck()
|
||||
if err != nil {
|
||||
t.Fatalf("getting status for %q: %s", client.Hostname(), err)
|
||||
|
|
@ -792,6 +835,7 @@ func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool {
|
|||
t.Helper()
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
err := client.WriteLogs(buf, buf)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err)
|
||||
|
|
@ -815,6 +859,7 @@ func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error)
|
|||
scanner := bufio.NewScanner(in)
|
||||
{
|
||||
const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB
|
||||
|
||||
buff := make([]byte, logBufferInitialSize)
|
||||
scanner.Buffer(buff, len(buff))
|
||||
scanner.Split(bufio.ScanLines)
|
||||
|
|
@ -941,17 +986,20 @@ func GetUserByName(headscale ControlServer, username string) (*v1.User, error) {
|
|||
func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) {
|
||||
for _, client := range updated {
|
||||
isOriginal := false
|
||||
|
||||
for _, origClient := range original {
|
||||
if client.Hostname() == origClient.Hostname() {
|
||||
isOriginal = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !isOriginal {
|
||||
return client, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("no new client found")
|
||||
|
||||
return nil, errors.New("no new client found")
|
||||
}
|
||||
|
||||
// AddAndLoginClient adds a new tailscale client to a user and logs it in.
|
||||
|
|
@ -959,7 +1007,7 @@ func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error)
|
|||
// 1. Creating a new node
|
||||
// 2. Finding the new node in the client list
|
||||
// 3. Getting the user to create a preauth key
|
||||
// 4. Logging in the new node
|
||||
// 4. Logging in the new node.
|
||||
func (s *Scenario) AddAndLoginClient(
|
||||
t *testing.T,
|
||||
username string,
|
||||
|
|
@ -1037,5 +1085,6 @@ func (s *Scenario) MustAddAndLoginClient(
|
|||
|
||||
client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...)
|
||||
require.NoError(t, err)
|
||||
|
||||
return client
|
||||
}
|
||||
|
|
|
|||
|
|
@ -725,12 +725,14 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
|
|||
|
||||
// Find the top-level directory to strip
|
||||
var topLevelDir string
|
||||
|
||||
firstPass := tar.NewReader(bytes.NewReader(tarData))
|
||||
for {
|
||||
header, err := firstPass.Next()
|
||||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
|
@ -747,6 +749,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
|
|||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
|
@ -794,6 +797,7 @@ func extractTarToDirectory(tarData []byte, targetDir string) error {
|
|||
outFile.Close()
|
||||
return fmt.Errorf("failed to copy file contents: %w", err)
|
||||
}
|
||||
|
||||
outFile.Close()
|
||||
|
||||
// Set file permissions
|
||||
|
|
@ -844,10 +848,12 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||
|
||||
// Check if the database file exists and has a schema
|
||||
dbPath := "/tmp/integration_test_db.sqlite3"
|
||||
|
||||
fileInfo, err := t.Execute([]string{"ls", "-la", dbPath})
|
||||
if err != nil {
|
||||
return fmt.Errorf("database file does not exist at %s: %w", dbPath, err)
|
||||
}
|
||||
|
||||
log.Printf("Database file info: %s", fileInfo)
|
||||
|
||||
// Check if the database has any tables (schema)
|
||||
|
|
@ -872,6 +878,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||
if err == io.EOF {
|
||||
break
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read tar header: %w", err)
|
||||
}
|
||||
|
|
@ -886,6 +893,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||
// Extract the first regular file we find
|
||||
if header.Typeflag == tar.TypeReg {
|
||||
dbPath := path.Join(savePath, t.hostname+".db")
|
||||
|
||||
outFile, err := os.Create(dbPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create database file: %w", err)
|
||||
|
|
@ -893,6 +901,7 @@ func (t *HeadscaleInContainer) SaveDatabase(savePath string) error {
|
|||
|
||||
written, err := io.Copy(outFile, tarReader)
|
||||
outFile.Close()
|
||||
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to copy database file: %w", err)
|
||||
}
|
||||
|
|
@ -1059,6 +1068,7 @@ func (t *HeadscaleInContainer) CreateUser(
|
|||
}
|
||||
|
||||
var u v1.User
|
||||
|
||||
err = json.Unmarshal([]byte(result), &u)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal user: %w", err)
|
||||
|
|
@ -1195,6 +1205,7 @@ func (t *HeadscaleInContainer) ListNodes(
|
|||
users ...string,
|
||||
) ([]*v1.Node, error) {
|
||||
var ret []*v1.Node
|
||||
|
||||
execUnmarshal := func(command []string) error {
|
||||
result, _, err := dockertestutil.ExecuteCommand(
|
||||
t.container,
|
||||
|
|
@ -1206,6 +1217,7 @@ func (t *HeadscaleInContainer) ListNodes(
|
|||
}
|
||||
|
||||
var nodes []*v1.Node
|
||||
|
||||
err = json.Unmarshal([]byte(result), &nodes)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unmarshal nodes: %w", err)
|
||||
|
|
@ -1245,7 +1257,7 @@ func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error {
|
|||
"nodes",
|
||||
"delete",
|
||||
"--identifier",
|
||||
fmt.Sprintf("%d", nodeID),
|
||||
strconv.FormatUint(nodeID, 10),
|
||||
"--output",
|
||||
"json",
|
||||
"--force",
|
||||
|
|
@ -1309,6 +1321,7 @@ func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) {
|
|||
}
|
||||
|
||||
var users []*v1.User
|
||||
|
||||
err = json.Unmarshal([]byte(result), &users)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal nodes: %w", err)
|
||||
|
|
@ -1439,6 +1452,7 @@ func (h *HeadscaleInContainer) PID() (int, error) {
|
|||
if pidInt == 1 {
|
||||
continue
|
||||
}
|
||||
|
||||
pids = append(pids, pidInt)
|
||||
}
|
||||
|
||||
|
|
@ -1494,6 +1508,7 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
|
|||
}
|
||||
|
||||
var node *v1.Node
|
||||
|
||||
err = json.Unmarshal([]byte(result), &node)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err)
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ func PeerSyncTimeout() time.Duration {
|
|||
if util.IsCI() {
|
||||
return 120 * time.Second
|
||||
}
|
||||
|
||||
return 60 * time.Second
|
||||
}
|
||||
|
||||
|
|
@ -205,6 +206,7 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type
|
|||
res := make(map[types.NodeID]map[types.NodeID]bool)
|
||||
for nid, mrs := range all {
|
||||
res[nid] = make(map[types.NodeID]bool)
|
||||
|
||||
for _, mr := range mrs {
|
||||
for _, peer := range mr.Peers {
|
||||
if peer.Online != nil {
|
||||
|
|
@ -225,5 +227,6 @@ func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[type
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
||||
|
|
|
|||
|
|
@ -48,6 +48,7 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -90,6 +91,7 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
// Wait for route advertisements to propagate to NodeStore
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
|
|
@ -126,6 +128,7 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
// Wait for route approvals to propagate to NodeStore
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
|
||||
|
|
@ -148,9 +151,11 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
assert.NotNil(c, peerStatus.AllowedIPs)
|
||||
|
||||
if peerStatus.AllowedIPs != nil {
|
||||
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])})
|
||||
}
|
||||
}
|
||||
|
|
@ -171,6 +176,7 @@ func TestEnablingRoutes(t *testing.T) {
|
|||
// Wait for route state changes to propagate to nodes
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
|
||||
|
|
@ -270,6 +276,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
|
||||
prefp, err := scenario.SubnetOfNetwork("usernet1")
|
||||
require.NoError(t, err)
|
||||
|
||||
pref := *prefp
|
||||
t.Logf("usernet1 prefix: %s", pref.String())
|
||||
|
||||
|
|
@ -289,6 +296,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
slices.SortStableFunc(allClients, func(a, b TailscaleClient) int {
|
||||
statusA := a.MustStatus()
|
||||
statusB := b.MustStatus()
|
||||
|
||||
return cmp.Compare(statusA.Self.ID, statusB.Self.ID)
|
||||
})
|
||||
|
||||
|
|
@ -308,6 +316,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String())
|
||||
t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String())
|
||||
t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time")
|
||||
|
||||
for _, client := range allClients[:3] {
|
||||
command := []string{
|
||||
"tailscale",
|
||||
|
|
@ -323,6 +332,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
|
||||
// Wait for route configuration changes after advertising routes
|
||||
var nodes []*v1.Node
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
|
|
@ -362,10 +372,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) {
|
||||
if t.Failed() {
|
||||
t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat))
|
||||
|
||||
status, err := client.Status()
|
||||
if err == nil {
|
||||
printCurrentRouteMap(t, xmaps.Values(status.Peer)...)
|
||||
}
|
||||
|
||||
t.FailNow()
|
||||
}
|
||||
}
|
||||
|
|
@ -384,6 +396,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String())
|
||||
t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes")
|
||||
t.Logf(" Expected: Client can access webservice through router 1 only")
|
||||
|
||||
_, err = headscale.ApproveRoutes(
|
||||
MustFindNode(subRouter1.Hostname(), nodes).GetId(),
|
||||
[]netip.Prefix{pref},
|
||||
|
|
@ -454,10 +467,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter1.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1")
|
||||
|
||||
|
|
@ -481,6 +496,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)")
|
||||
t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)")
|
||||
t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over")
|
||||
|
||||
_, err = headscale.ApproveRoutes(
|
||||
MustFindNode(subRouter2.Hostname(), nodes).GetId(),
|
||||
[]netip.Prefix{pref},
|
||||
|
|
@ -492,6 +508,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 6)
|
||||
|
||||
if len(nodes) >= 3 {
|
||||
requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1)
|
||||
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
|
||||
|
|
@ -567,10 +584,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter1.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode")
|
||||
|
||||
|
|
@ -596,6 +615,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)")
|
||||
t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY")
|
||||
t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers")
|
||||
|
||||
_, err = headscale.ApproveRoutes(
|
||||
MustFindNode(subRouter3.Hostname(), nodes).GetId(),
|
||||
[]netip.Prefix{pref},
|
||||
|
|
@ -670,12 +690,14 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.NotEmpty(c, ips, "subRouter1 should have IP addresses")
|
||||
|
||||
var expectedIP netip.Addr
|
||||
|
||||
for _, ip := range ips {
|
||||
if ip.Is4() {
|
||||
expectedIP = ip
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address")
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, expectedIP)
|
||||
|
|
@ -752,10 +774,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter2.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover")
|
||||
|
||||
|
|
@ -823,10 +847,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter3.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover")
|
||||
|
||||
|
|
@ -851,6 +877,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)")
|
||||
t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)")
|
||||
t.Logf(" Expected: HA is restored with 2 routers available")
|
||||
|
||||
err = subRouter1.Up()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -900,10 +927,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter3.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery")
|
||||
|
||||
|
|
@ -930,6 +959,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname())
|
||||
t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname())
|
||||
t.Logf(" Expected: Full HA restored with all 3 routers online")
|
||||
|
||||
err = subRouter2.Up()
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
@ -980,10 +1010,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter3.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery")
|
||||
|
||||
|
|
@ -1065,10 +1097,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter1.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable")
|
||||
|
||||
|
|
@ -1151,10 +1185,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter2.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable")
|
||||
|
||||
|
|
@ -1180,6 +1216,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname())
|
||||
t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname())
|
||||
t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY")
|
||||
|
||||
r1Node := MustFindNode(subRouter1.Hostname(), nodes)
|
||||
_, err = headscale.ApproveRoutes(
|
||||
r1Node.GetId(),
|
||||
|
|
@ -1235,10 +1272,12 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := subRouter2.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable")
|
||||
|
||||
|
|
@ -1264,6 +1303,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
|
|||
t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname())
|
||||
t.Logf(" Expected: Routers 1 & 3 are both STANDBY")
|
||||
t.Logf(" Expected: Full HA restored with all 3 routers available")
|
||||
|
||||
r3Node := MustFindNode(subRouter3.Hostname(), nodes)
|
||||
_, err = headscale.ApproveRoutes(
|
||||
r3Node.GetId(),
|
||||
|
|
@ -1313,6 +1353,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -1360,6 +1401,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
slices.SortStableFunc(allClients, func(a, b TailscaleClient) int {
|
||||
statusA := a.MustStatus()
|
||||
statusB := b.MustStatus()
|
||||
|
||||
return cmp.Compare(statusA.Self.ID, statusB.Self.ID)
|
||||
})
|
||||
|
||||
|
|
@ -1389,15 +1431,20 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
|
||||
// Wait for route advertisements to propagate to the server
|
||||
var nodes []*v1.Node
|
||||
|
||||
require.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
||||
// Find the node that should have the route by checking node IDs
|
||||
var routeNode *v1.Node
|
||||
var otherNode *v1.Node
|
||||
var (
|
||||
routeNode *v1.Node
|
||||
otherNode *v1.Node
|
||||
)
|
||||
|
||||
for _, node := range nodes {
|
||||
nodeIDStr := strconv.FormatUint(node.GetId(), 10)
|
||||
if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute {
|
||||
|
|
@ -1460,6 +1507,7 @@ func TestSubnetRouteACL(t *testing.T) {
|
|||
srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey]
|
||||
|
||||
assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist")
|
||||
|
||||
if srs1PeerStatus == nil {
|
||||
return
|
||||
}
|
||||
|
|
@ -1570,6 +1618,7 @@ func TestEnablingExitRoutes(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario")
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -1591,8 +1640,10 @@ func TestEnablingExitRoutes(t *testing.T) {
|
|||
requireNoErrSync(t, err)
|
||||
|
||||
var nodes []*v1.Node
|
||||
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
|
@ -1650,6 +1701,7 @@ func TestEnablingExitRoutes(t *testing.T) {
|
|||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
assert.NotNil(c, peerStatus.AllowedIPs)
|
||||
|
||||
if peerStatus.AllowedIPs != nil {
|
||||
assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4)
|
||||
assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4())
|
||||
|
|
@ -1680,6 +1732,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -1710,10 +1763,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||
if s.User[s.Self.UserID].LoginName == "user1@test.no" {
|
||||
user1c = c
|
||||
}
|
||||
|
||||
if s.User[s.Self.UserID].LoginName == "user2@test.no" {
|
||||
user2c = c
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, user1c)
|
||||
require.NotNil(t, user2c)
|
||||
|
||||
|
|
@ -1730,6 +1785,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||
// Wait for route advertisements to propagate to NodeStore
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, nodes, 2)
|
||||
|
|
@ -1760,6 +1816,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||
// Wait for route state changes to propagate to nodes
|
||||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
|
@ -1777,6 +1834,7 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref})
|
||||
}
|
||||
}, 10*time.Second, 500*time.Millisecond, "routes should be visible to client")
|
||||
|
|
@ -1803,10 +1861,12 @@ func TestSubnetRouterMultiNetwork(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := user2c.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := user1c.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for user1c") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router")
|
||||
}
|
||||
|
|
@ -1827,6 +1887,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -1854,10 +1915,12 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
|||
if s.User[s.Self.UserID].LoginName == "user1@test.no" {
|
||||
user1c = c
|
||||
}
|
||||
|
||||
if s.User[s.Self.UserID].LoginName == "user2@test.no" {
|
||||
user2c = c
|
||||
}
|
||||
}
|
||||
|
||||
require.NotNil(t, user1c)
|
||||
require.NotNil(t, user2c)
|
||||
|
||||
|
|
@ -1874,6 +1937,7 @@ func TestSubnetRouterMultiNetworkExitNode(t *testing.T) {
|
|||
// Wait for route advertisements to propagate to NodeStore
|
||||
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
|
||||
var err error
|
||||
|
||||
nodes, err = headscale.ListNodes()
|
||||
assert.NoError(ct, err)
|
||||
assert.Len(ct, nodes, 2)
|
||||
|
|
@ -1956,6 +2020,7 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node {
|
|||
return node
|
||||
}
|
||||
}
|
||||
|
||||
panic("node not found")
|
||||
}
|
||||
|
||||
|
|
@ -2239,10 +2304,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(tt.spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
var nodes []*v1.Node
|
||||
|
||||
opts := []hsic.Option{
|
||||
hsic.WithTestName("autoapprovemulti"),
|
||||
hsic.WithEmbeddedDERPServerOnly(),
|
||||
|
|
@ -2298,6 +2365,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
// Add the Docker network route to the auto-approvers
|
||||
// Keep existing auto-approvers (like bigRoute) in place
|
||||
var approvers policyv2.AutoApprovers
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(tt.approver, "tag:"):
|
||||
approvers = append(approvers, tagApprover(tt.approver))
|
||||
|
|
@ -2366,6 +2434,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
} else {
|
||||
pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
|
||||
|
|
@ -2404,6 +2473,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
slices.SortStableFunc(allClients, func(a, b TailscaleClient) int {
|
||||
statusA := a.MustStatus()
|
||||
statusB := b.MustStatus()
|
||||
|
||||
return cmp.Compare(statusA.Self.ID, statusB.Self.ID)
|
||||
})
|
||||
|
||||
|
|
@ -2456,11 +2526,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers()))
|
||||
|
||||
routerPeerFound := false
|
||||
|
||||
for _, peerKey := range status.Peers() {
|
||||
peerStatus := status.Peer[peerKey]
|
||||
|
||||
if peerStatus.ID == routerUsernet1ID.StableID() {
|
||||
routerPeerFound = true
|
||||
|
||||
t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v",
|
||||
peerStatus.HostName,
|
||||
peerStatus.ID,
|
||||
|
|
@ -2468,9 +2540,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
peerStatus.PrimaryRoutes)
|
||||
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
|
||||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
|
||||
} else {
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
|
||||
|
|
@ -2507,10 +2581,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := routerUsernet1.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router")
|
||||
|
||||
|
|
@ -2547,9 +2623,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
|
||||
if peerStatus.ID == routerUsernet1ID.StableID() {
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
|
||||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
|
||||
} else {
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
|
||||
|
|
@ -2569,10 +2647,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := routerUsernet1.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change")
|
||||
|
||||
|
|
@ -2606,6 +2686,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
// Add the route back to the auto approver in the policy, the route should
|
||||
// now become available again.
|
||||
var newApprovers policyv2.AutoApprovers
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(tt.approver, "tag:"):
|
||||
newApprovers = append(newApprovers, tagApprover(tt.approver))
|
||||
|
|
@ -2639,9 +2720,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
|
||||
if peerStatus.ID == routerUsernet1ID.StableID() {
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
|
||||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
|
||||
} else {
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
|
||||
|
|
@ -2661,10 +2744,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := client.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := routerUsernet1.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval")
|
||||
|
||||
|
|
@ -2700,11 +2785,13 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
|
||||
} else if peerStatus.ID == "2" {
|
||||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute})
|
||||
} else {
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
|
||||
|
|
@ -2742,9 +2829,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
|
||||
if peerStatus.ID == routerUsernet1ID.StableID() {
|
||||
assert.NotNil(c, peerStatus.PrimaryRoutes)
|
||||
|
||||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
|
||||
} else {
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
|
||||
|
|
@ -2782,6 +2871,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
|
|||
if peerStatus.PrimaryRoutes != nil {
|
||||
assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route)
|
||||
}
|
||||
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route})
|
||||
} else if peerStatus.ID == "3" {
|
||||
requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()})
|
||||
|
|
@ -2816,10 +2906,12 @@ func SortPeerStatus(a, b *ipnstate.PeerStatus) int {
|
|||
func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) {
|
||||
t.Logf("== Current routing map ==")
|
||||
slices.SortFunc(routers, SortPeerStatus)
|
||||
|
||||
for _, router := range routers {
|
||||
got := filterNonRoutes(router)
|
||||
t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID)
|
||||
t.Logf(" AllowedIPs: %v", got)
|
||||
|
||||
if router.PrimaryRoutes != nil {
|
||||
t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice())
|
||||
}
|
||||
|
|
@ -2832,6 +2924,7 @@ func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix {
|
|||
if tsaddr.IsExitRoute(p) {
|
||||
return true
|
||||
}
|
||||
|
||||
return !slices.ContainsFunc(status.TailscaleIPs, p.Contains)
|
||||
})
|
||||
}
|
||||
|
|
@ -2883,6 +2976,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
|||
}
|
||||
|
||||
scenario, err := NewScenario(spec)
|
||||
|
||||
require.NoErrorf(t, err, "failed to create scenario: %s", err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -3023,6 +3117,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
// List nodes and verify the router has 3 available routes
|
||||
var err error
|
||||
|
||||
nodes, err := headscale.NodesByUser()
|
||||
assert.NoError(c, err)
|
||||
assert.Len(c, nodes, 2)
|
||||
|
|
@ -3058,10 +3153,12 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
|
|||
assert.EventuallyWithT(t, func(c *assert.CollectT) {
|
||||
tr, err := nodeClient.Traceroute(webip)
|
||||
assert.NoError(c, err)
|
||||
|
||||
ip, err := routerClient.IPv4()
|
||||
if !assert.NoError(c, err, "failed to get IPv4 for routerClient") {
|
||||
return
|
||||
}
|
||||
|
||||
assertTracerouteViaIPWithCollect(c, tr, ip)
|
||||
}, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router")
|
||||
}
|
||||
|
|
|
|||
|
|
@ -191,9 +191,11 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
|
|||
}
|
||||
|
||||
var userToNetwork map[string]*dockertest.Network
|
||||
|
||||
if spec.Networks != nil || len(spec.Networks) != 0 {
|
||||
for name, users := range s.spec.Networks {
|
||||
networkName := testHashPrefix + "-" + name
|
||||
|
||||
network, err := s.AddNetwork(networkName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -203,6 +205,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
|
|||
if n2, ok := userToNetwork[user]; ok {
|
||||
return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name)
|
||||
}
|
||||
|
||||
mak.Set(&userToNetwork, user, network)
|
||||
}
|
||||
}
|
||||
|
|
@ -219,6 +222,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc))
|
||||
}
|
||||
}
|
||||
|
|
@ -230,6 +234,7 @@ func NewScenario(spec ScenarioSpec) (*Scenario, error) {
|
|||
if spec.OIDCAccessTTL != 0 {
|
||||
ttl = spec.OIDCAccessTTL
|
||||
}
|
||||
|
||||
err = s.runMockOIDC(ttl, spec.OIDCUsers)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -268,6 +273,7 @@ func (s *Scenario) Networks() []*dockertest.Network {
|
|||
if len(s.networks) == 0 {
|
||||
panic("Scenario.Networks called with empty network list")
|
||||
}
|
||||
|
||||
return xmaps.Values(s.networks)
|
||||
}
|
||||
|
||||
|
|
@ -337,6 +343,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
|||
for userName, user := range s.users {
|
||||
for _, client := range user.Clients {
|
||||
log.Printf("removing client %s in user %s", client.Hostname(), userName)
|
||||
|
||||
stdoutPath, stderrPath, err := client.Shutdown()
|
||||
if err != nil {
|
||||
log.Printf("failed to tear down client: %s", err)
|
||||
|
|
@ -353,6 +360,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
|
||||
for _, derp := range s.derpServers {
|
||||
|
|
@ -373,6 +381,7 @@ func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) {
|
|||
|
||||
if s.mockOIDC.r != nil {
|
||||
s.mockOIDC.r.Close()
|
||||
|
||||
if err := s.mockOIDC.r.Close(); err != nil {
|
||||
log.Printf("failed to tear down oidc server: %s", err)
|
||||
}
|
||||
|
|
@ -552,6 +561,7 @@ func (s *Scenario) CreateTailscaleNode(
|
|||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
opts = append(opts,
|
||||
tsic.WithCACert(cert),
|
||||
tsic.WithHeadscaleName(hostname),
|
||||
|
|
@ -591,6 +601,7 @@ func (s *Scenario) CreateTailscaleNodesInUser(
|
|||
) error {
|
||||
if user, ok := s.users[userStr]; ok {
|
||||
var versions []string
|
||||
|
||||
for i := range count {
|
||||
version := requestedVersion
|
||||
if requestedVersion == "all" {
|
||||
|
|
@ -749,10 +760,12 @@ func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Durat
|
|||
for _, client := range user.Clients {
|
||||
c := client
|
||||
expectedCount := expectedPeers
|
||||
|
||||
user.syncWaitGroup.Go(func() error {
|
||||
return c.WaitForPeers(expectedCount, timeout, retryInterval)
|
||||
})
|
||||
}
|
||||
|
||||
if err := user.syncWaitGroup.Wait(); err != nil {
|
||||
allErrors = append(allErrors, err)
|
||||
}
|
||||
|
|
@ -871,6 +884,7 @@ func (s *Scenario) createHeadscaleEnvWithTags(
|
|||
} else {
|
||||
key, err = s.CreatePreAuthKey(u.GetId(), true, false)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
@ -887,9 +901,11 @@ func (s *Scenario) createHeadscaleEnvWithTags(
|
|||
|
||||
func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error {
|
||||
log.Printf("running tailscale up for user %s", userStr)
|
||||
|
||||
if user, ok := s.users[userStr]; ok {
|
||||
for _, client := range user.Clients {
|
||||
tsc := client
|
||||
|
||||
user.joinWaitGroup.Go(func() error {
|
||||
loginURL, err := tsc.LoginWithURL(loginServer)
|
||||
if err != nil {
|
||||
|
|
@ -945,6 +961,7 @@ func newDebugJar() (*debugJar, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &debugJar{
|
||||
inner: jar,
|
||||
store: make(map[string]map[string]map[string]*http.Cookie),
|
||||
|
|
@ -961,20 +978,25 @@ func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) {
|
|||
if c == nil || c.Name == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
domain := c.Domain
|
||||
if domain == "" {
|
||||
domain = u.Hostname()
|
||||
}
|
||||
|
||||
path := c.Path
|
||||
if path == "" {
|
||||
path = "/"
|
||||
}
|
||||
|
||||
if _, ok := j.store[domain]; !ok {
|
||||
j.store[domain] = make(map[string]map[string]*http.Cookie)
|
||||
}
|
||||
|
||||
if _, ok := j.store[domain][path]; !ok {
|
||||
j.store[domain][path] = make(map[string]*http.Cookie)
|
||||
}
|
||||
|
||||
j.store[domain][path][c.Name] = copyCookie(c)
|
||||
}
|
||||
}
|
||||
|
|
@ -989,8 +1011,10 @@ func (j *debugJar) Dump(w io.Writer) {
|
|||
|
||||
for domain, paths := range j.store {
|
||||
fmt.Fprintf(w, "Domain: %s\n", domain)
|
||||
|
||||
for path, byName := range paths {
|
||||
fmt.Fprintf(w, " Path: %s\n", path)
|
||||
|
||||
for _, c := range byName {
|
||||
fmt.Fprintf(
|
||||
w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n",
|
||||
|
|
@ -1054,7 +1078,9 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f
|
|||
}
|
||||
|
||||
log.Printf("%s logging in with url: %s", hostname, loginURL.String())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err)
|
||||
|
|
@ -1066,6 +1092,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f
|
|||
return http.ErrUseLastResponse
|
||||
}
|
||||
}
|
||||
|
||||
defer func() {
|
||||
hc.CheckRedirect = originalRedirect
|
||||
}()
|
||||
|
|
@ -1080,6 +1107,7 @@ func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, f
|
|||
if err != nil {
|
||||
return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err)
|
||||
}
|
||||
|
||||
body := string(bodyBytes)
|
||||
|
||||
var redirectURL *url.URL
|
||||
|
|
@ -1126,6 +1154,7 @@ func (s *Scenario) runHeadscaleRegister(userStr string, body string) error {
|
|||
if len(keySep) != 2 {
|
||||
return errParseAuthPage
|
||||
}
|
||||
|
||||
key := keySep[1]
|
||||
key = strings.SplitN(key, " ", 2)[0]
|
||||
log.Printf("registering node %s", key)
|
||||
|
|
@ -1154,6 +1183,7 @@ func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error
|
|||
noTls := &http.Transport{
|
||||
TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint
|
||||
}
|
||||
|
||||
resp, err := noTls.RoundTrip(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
|
@ -1361,6 +1391,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
|
|||
if err != nil {
|
||||
log.Fatalf("could not find an open port: %s", err)
|
||||
}
|
||||
|
||||
portNotation := fmt.Sprintf("%d/tcp", port)
|
||||
|
||||
hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
|
||||
|
|
@ -1421,6 +1452,7 @@ func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUse
|
|||
ipAddr := s.mockOIDC.r.GetIPInNetwork(network)
|
||||
|
||||
log.Println("Waiting for headscale mock oidc to be ready for tests")
|
||||
|
||||
hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port))
|
||||
|
||||
if err := s.pool.Retry(func() error {
|
||||
|
|
@ -1468,7 +1500,6 @@ func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) {
|
|||
// log.Fatalf("could not find an open port: %s", err)
|
||||
// }
|
||||
// portNotation := fmt.Sprintf("%d/tcp", port)
|
||||
|
||||
hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength)
|
||||
|
||||
hostname := "hs-webservice-" + hash
|
||||
|
|
|
|||
|
|
@ -35,6 +35,7 @@ func TestHeadscale(t *testing.T) {
|
|||
user := "test-space"
|
||||
|
||||
scenario, err := NewScenario(ScenarioSpec{})
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
@ -83,6 +84,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) {
|
|||
count := 1
|
||||
|
||||
scenario, err := NewScenario(ScenarioSpec{})
|
||||
|
||||
require.NoError(t, err)
|
||||
defer scenario.ShutdownAssertNoPanics(t)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue