Merge branch 'main' into fix/dns-override-local-2899

This commit is contained in:
Rogan Lynch 2025-12-18 07:11:12 -08:00 committed by GitHub
commit 05e991180f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
91 changed files with 8807 additions and 3028 deletions

16
.editorconfig Normal file
View file

@ -0,0 +1,16 @@
root = true
[*]
charset = utf-8
end_of_line = lf
indent_size = 2
indent_style = space
insert_final_newline = true
trim_trailing_whitespace = true
max_line_length = 120
[*.go]
indent_style = tab
[Makefile]
indent_style = tab

View file

@ -15,7 +15,7 @@ jobs:
runs-on: ubuntu-latest
permissions: write-all
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
@ -29,13 +29,12 @@ jobs:
- '**/*.go'
- 'integration_test/'
- 'config-example.yaml'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
@ -55,7 +54,7 @@ jobs:
exit $BUILD_STATUS
- name: Nix gosum diverging
uses: actions/github-script@60a0d83039c74a4aee543508d2ffcb1c3799cdea # v7.0.1
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
if: failure() && steps.build.outcome == 'failure'
with:
github-token: ${{secrets.GITHUB_TOKEN}}
@ -67,7 +66,7 @@ jobs:
body: 'Nix build failed with wrong gosum, please update "vendorSha256" (${{ steps.build.outputs.OLD_HASH }}) for the "headscale" package in flake.nix with the new SHA: ${{ steps.build.outputs.NEW_HASH }}'
})
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
if: steps.changed-files.outputs.files == 'true'
with:
name: headscale-linux
@ -82,22 +81,20 @@ jobs:
- "GOARCH=arm64 GOOS=darwin"
- "GOARCH=amd64 GOOS=darwin"
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
- name: Run go cross compile
env:
CGO_ENABLED: 0
run:
env ${{ matrix.env }} nix develop --command -- go build -o "headscale"
run: env ${{ matrix.env }} nix develop --command -- go build -o "headscale"
./cmd/headscale
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: "headscale-${{ matrix.env }}"
path: "headscale"

View file

@ -16,7 +16,7 @@ jobs:
check-generated:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
@ -31,7 +31,7 @@ jobs:
- '**/*.proto'
- 'buf.gen.yaml'
- 'tools/**'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'

View file

@ -10,7 +10,7 @@ jobs:
check-tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
@ -24,13 +24,12 @@ jobs:
- '**/*.go'
- 'integration_test/'
- 'config-example.yaml'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}

View file

@ -21,15 +21,15 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 0
- name: Install python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
with:
python-version: 3.x
- name: Setup cache
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
uses: actions/cache@a7833574556fa59680c1b7cb190c1735db73ebf0 # v5.0.0
with:
key: ${{ github.ref }}
path: .cache

View file

@ -11,13 +11,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- name: Install python
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0
with:
python-version: 3.x
- name: Setup cache
uses: actions/cache@5a3ec84eff668545956fd18022155c47e93e2684 # v4.2.3
uses: actions/cache@a7833574556fa59680c1b7cb190c1735db73ebf0 # v5.0.0
with:
key: ${{ github.ref }}
path: .cache

View file

@ -10,6 +10,55 @@ import (
"strings"
)
// testsToSplit defines tests that should be split into multiple CI jobs.
// Key is the test function name, value is a list of subtest prefixes.
// Each prefix becomes a separate CI job as "TestName/prefix".
//
// Example: TestAutoApproveMultiNetwork has subtests like:
// - TestAutoApproveMultiNetwork/authkey-tag-advertiseduringup-false-pol-database
// - TestAutoApproveMultiNetwork/webauth-user-advertiseduringup-true-pol-file
//
// Splitting by approver type (tag, user, group) creates 6 CI jobs with 4 tests each:
// - TestAutoApproveMultiNetwork/authkey-tag.* (4 tests)
// - TestAutoApproveMultiNetwork/authkey-user.* (4 tests)
// - TestAutoApproveMultiNetwork/authkey-group.* (4 tests)
// - TestAutoApproveMultiNetwork/webauth-tag.* (4 tests)
// - TestAutoApproveMultiNetwork/webauth-user.* (4 tests)
// - TestAutoApproveMultiNetwork/webauth-group.* (4 tests)
//
// This reduces load per CI job (4 tests instead of 12) to avoid infrastructure
// flakiness when running many sequential Docker-based integration tests.
var testsToSplit = map[string][]string{
"TestAutoApproveMultiNetwork": {
"authkey-tag",
"authkey-user",
"authkey-group",
"webauth-tag",
"webauth-user",
"webauth-group",
},
}
// expandTests takes a list of test names and expands any that need splitting
// into multiple subtest patterns.
func expandTests(tests []string) []string {
var expanded []string
for _, test := range tests {
if prefixes, ok := testsToSplit[test]; ok {
// This test should be split into multiple jobs.
// We append ".*" to each prefix because the CI runner wraps patterns
// with ^...$ anchors. Without ".*", a pattern like "authkey$" wouldn't
// match "authkey-tag-advertiseduringup-false-pol-database".
for _, prefix := range prefixes {
expanded = append(expanded, fmt.Sprintf("%s/%s.*", test, prefix))
}
} else {
expanded = append(expanded, test)
}
}
return expanded
}
func findTests() []string {
rgBin, err := exec.LookPath("rg")
if err != nil {
@ -66,8 +115,11 @@ func updateYAML(tests []string, jobName string, testPath string) {
func main() {
tests := findTests()
quotedTests := make([]string, len(tests))
for i, test := range tests {
// Expand tests that should be split into multiple jobs
expandedTests := expandTests(tests)
quotedTests := make([]string, len(expandedTests))
for i, test := range expandedTests {
quotedTests[i] = fmt.Sprintf("\"%s\"", test)
}

View file

@ -11,13 +11,13 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
# [Required] Access token with `workflow` scope.
token: ${{ secrets.WORKFLOW_SECRET }}
- name: Run GitHub Actions Version Updater
uses: saadmk11/github-actions-version-updater@64be81ba69383f81f2be476703ea6570c4c8686e # v0.8.1
uses: saadmk11/github-actions-version-updater@d8781caf11d11168579c8e5e94f62b068038f442 # v0.9.0
with:
# [Required] Access token with `workflow` scope.
token: ${{ secrets.WORKFLOW_SECRET }}

View file

@ -28,23 +28,12 @@ jobs:
# that triggered the build.
HAS_TAILSCALE_SECRET: ${{ secrets.TS_OAUTH_CLIENT_ID }}
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
id: changed-files
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
with:
filters: |
files:
- '*.nix'
- 'go.*'
- '**/*.go'
- 'integration_test/'
- 'config-example.yaml'
- name: Tailscale
if: ${{ env.HAS_TAILSCALE_SECRET }}
uses: tailscale/github-action@6986d2c82a91fbac2949fe01f5bab95cf21b5102 # v3.2.2
uses: tailscale/github-action@a392da0a182bba0e9613b6243ebd69529b1878aa # v4.1.0
with:
oauth-client-id: ${{ secrets.TS_OAUTH_CLIENT_ID }}
oauth-secret: ${{ secrets.TS_OAUTH_SECRET }}
@ -52,31 +41,72 @@ jobs:
- name: Setup SSH server for Actor
if: ${{ env.HAS_TAILSCALE_SECRET }}
uses: alexellis/setup-sshd-actor@master
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
- name: Download headscale image
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
name: headscale-image
path: /tmp/artifacts
- name: Download tailscale HEAD image
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
with:
name: tailscale-head-image
path: /tmp/artifacts
- name: Download hi binary
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
with:
name: hi-binary
path: /tmp/artifacts
- name: Download Go cache
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
with:
name: go-cache
path: /tmp/artifacts
- name: Download postgres image
if: ${{ inputs.postgres_flag == '--postgres=1' }}
uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0
with:
name: postgres-image
path: /tmp/artifacts
- name: Load Docker images, Go cache, and prepare binary
run: |
gunzip -c /tmp/artifacts/headscale-image.tar.gz | docker load
gunzip -c /tmp/artifacts/tailscale-head-image.tar.gz | docker load
if [ -f /tmp/artifacts/postgres-image.tar.gz ]; then
gunzip -c /tmp/artifacts/postgres-image.tar.gz | docker load
fi
chmod +x /tmp/artifacts/hi
docker images
# Extract Go cache to host directories for bind mounting
mkdir -p /tmp/go-cache
tar -xzf /tmp/artifacts/go-cache.tar.gz -C /tmp/go-cache
ls -la /tmp/go-cache/ /tmp/go-cache/.cache/
- name: Run Integration Test
if: always() && steps.changed-files.outputs.files == 'true'
run:
nix develop --command -- hi run --stats --ts-memory-limit=300 --hs-memory-limit=1500 "^${{ inputs.test }}$" \
env:
HEADSCALE_INTEGRATION_HEADSCALE_IMAGE: headscale:${{ github.sha }}
HEADSCALE_INTEGRATION_TAILSCALE_IMAGE: tailscale-head:${{ github.sha }}
HEADSCALE_INTEGRATION_POSTGRES_IMAGE: ${{ inputs.postgres_flag == '--postgres=1' && format('postgres:{0}', github.sha) || '' }}
HEADSCALE_INTEGRATION_GO_CACHE: /tmp/go-cache/go
HEADSCALE_INTEGRATION_GO_BUILD_CACHE: /tmp/go-cache/.cache/go-build
run: /tmp/artifacts/hi run --stats --ts-memory-limit=300 --hs-memory-limit=1500 "^${{ inputs.test }}$" \
--timeout=120m \
${{ inputs.postgres_flag }}
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
if: always() && steps.changed-files.outputs.files == 'true'
# Sanitize test name for artifact upload (replace invalid characters: " : < > | * ? \ / with -)
- name: Sanitize test name for artifacts
if: always()
id: sanitize
run: echo "name=${TEST_NAME//[\":<>|*?\\\/]/-}" >> $GITHUB_OUTPUT
env:
TEST_NAME: ${{ inputs.test }}
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
if: always()
with:
name: ${{ inputs.database_name }}-${{ inputs.test }}-logs
name: ${{ inputs.database_name }}-${{ steps.sanitize.outputs.name }}-logs
path: "control_logs/*/*.log"
- uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2
if: always() && steps.changed-files.outputs.files == 'true'
- uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
if: always()
with:
name: ${{ inputs.database_name }}-${{ inputs.test }}-archives
path: "control_logs/*/*.tar"
name: ${{ inputs.database_name }}-${{ steps.sanitize.outputs.name }}-artifacts
path: control_logs/
- name: Setup a blocking tmux session
if: ${{ env.HAS_TAILSCALE_SECRET }}
uses: alexellis/block-with-tmux-action@master

View file

@ -10,7 +10,7 @@ jobs:
golangci-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
@ -24,13 +24,12 @@ jobs:
- '**/*.go'
- 'integration_test/'
- 'config-example.yaml'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
@ -46,7 +45,7 @@ jobs:
prettier-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
@ -65,13 +64,12 @@ jobs:
- '**/*.css'
- '**/*.scss'
- '**/*.html'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
@ -83,12 +81,11 @@ jobs:
proto-lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}

View file

@ -19,7 +19,7 @@ jobs:
contents: read
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
@ -38,14 +38,13 @@ jobs:
- 'cmd/**'
- 'hscontrol/**'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.nix == 'true' || steps.changed-files.outputs.go == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.nix == 'true' || steps.changed-files.outputs.go == 'true'
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}

View file

@ -13,28 +13,27 @@ jobs:
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 0
- name: Login to DockerHub
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
- name: Login to GHCR
uses: docker/login-action@74a5d142397b4f367a81961eba4e8cd7edddf772 # v3.4.0
uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0
with:
registry: ghcr.io
username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }}
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}

View file

@ -12,16 +12,14 @@ jobs:
issues: write
pull-requests: write
steps:
- uses: actions/stale@5bef64f19d7facfb25b37b414482c7164d639639 # v9.1.0
- uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1
with:
days-before-issue-stale: 90
days-before-issue-close: 7
stale-issue-label: "stale"
stale-issue-message:
"This issue is stale because it has been open for 90 days with no
stale-issue-message: "This issue is stale because it has been open for 90 days with no
activity."
close-issue-message:
"This issue was closed because it has been inactive for 14 days
close-issue-message: "This issue was closed because it has been inactive for 14 days
since being marked as stale."
days-before-pr-stale: -1
days-before-pr-close: -1

View file

@ -7,7 +7,117 @@ concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
# build: Builds binaries and Docker images once, uploads as artifacts for reuse.
# build-postgres: Pulls postgres image separately to avoid Docker Hub rate limits.
# sqlite: Runs all integration tests with SQLite backend.
# postgres: Runs a subset of tests with PostgreSQL to verify database compatibility.
build:
runs-on: ubuntu-latest
outputs:
files-changed: ${{ steps.changed-files.outputs.files }}
steps:
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
- name: Get changed files
id: changed-files
uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
with:
filters: |
files:
- '*.nix'
- 'go.*'
- '**/*.go'
- 'integration/**'
- 'config-example.yaml'
- '.github/workflows/test-integration.yaml'
- '.github/workflows/integration-test-template.yml'
- 'Dockerfile.*'
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}
- name: Build binaries and warm Go cache
if: steps.changed-files.outputs.files == 'true'
run: |
# Build all Go binaries in one nix shell to maximize cache reuse
nix develop --command -- bash -c '
go build -o hi ./cmd/hi
CGO_ENABLED=0 GOOS=linux go build -o headscale ./cmd/headscale
# Build integration test binary to warm the cache with all dependencies
go test -c ./integration -o /dev/null 2>/dev/null || true
'
- name: Upload hi binary
if: steps.changed-files.outputs.files == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: hi-binary
path: hi
retention-days: 10
- name: Package Go cache
if: steps.changed-files.outputs.files == 'true'
run: |
# Package Go module cache and build cache
tar -czf go-cache.tar.gz -C ~ go .cache/go-build
- name: Upload Go cache
if: steps.changed-files.outputs.files == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: go-cache
path: go-cache.tar.gz
retention-days: 10
- name: Build headscale image
if: steps.changed-files.outputs.files == 'true'
run: |
docker build \
--file Dockerfile.integration-ci \
--tag headscale:${{ github.sha }} \
.
docker save headscale:${{ github.sha }} | gzip > headscale-image.tar.gz
- name: Build tailscale HEAD image
if: steps.changed-files.outputs.files == 'true'
run: |
docker build \
--file Dockerfile.tailscale-HEAD \
--tag tailscale-head:${{ github.sha }} \
.
docker save tailscale-head:${{ github.sha }} | gzip > tailscale-head-image.tar.gz
- name: Upload headscale image
if: steps.changed-files.outputs.files == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: headscale-image
path: headscale-image.tar.gz
retention-days: 10
- name: Upload tailscale HEAD image
if: steps.changed-files.outputs.files == 'true'
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: tailscale-head-image
path: tailscale-head-image.tar.gz
retention-days: 10
build-postgres:
runs-on: ubuntu-latest
needs: build
if: needs.build.outputs.files-changed == 'true'
steps:
- name: Pull and save postgres image
run: |
docker pull postgres:latest
docker tag postgres:latest postgres:${{ github.sha }}
docker save postgres:${{ github.sha }} | gzip > postgres-image.tar.gz
- name: Upload postgres image
uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0
with:
name: postgres-image
path: postgres-image.tar.gz
retention-days: 10
sqlite:
needs: build
if: needs.build.outputs.files-changed == 'true'
strategy:
fail-fast: false
matrix:
@ -25,6 +135,8 @@ jobs:
- TestACLAutogroupTagged
- TestACLAutogroupSelf
- TestACLPolicyPropagationOverTime
- TestACLTagPropagation
- TestACLTagPropagationPortSpecific
- TestAPIAuthenticationBypass
- TestAPIAuthenticationBypassCurl
- TestGRPCAuthenticationBypass
@ -54,10 +166,6 @@ jobs:
- TestPreAuthKeyCommandReusableEphemeral
- TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand
- TestNodeTagCommand
- TestTaggedNodeRegistration
- TestTagPersistenceAcrossRestart
- TestNodeAdvertiseTagCommand
- TestNodeCommand
- TestNodeExpireCommand
- TestNodeRenameCommand
@ -87,7 +195,12 @@ jobs:
- TestEnablingExitRoutes
- TestSubnetRouterMultiNetwork
- TestSubnetRouterMultiNetworkExitNode
- TestAutoApproveMultiNetwork
- TestAutoApproveMultiNetwork/authkey-tag.*
- TestAutoApproveMultiNetwork/authkey-user.*
- TestAutoApproveMultiNetwork/authkey-group.*
- TestAutoApproveMultiNetwork/webauth-tag.*
- TestAutoApproveMultiNetwork/webauth-user.*
- TestAutoApproveMultiNetwork/webauth-group.*
- TestSubnetRouteACLFiltering
- TestHeadscale
- TestTailscaleNodesJoiningHeadcale
@ -97,12 +210,42 @@ jobs:
- TestSSHIsBlockedInACL
- TestSSHUserOnlyIsolation
- TestSSHAutogroupSelf
- TestTagsAuthKeyWithTagRequestDifferentTag
- TestTagsAuthKeyWithTagNoAdvertiseFlag
- TestTagsAuthKeyWithTagCannotAddViaCLI
- TestTagsAuthKeyWithTagCannotChangeViaCLI
- TestTagsAuthKeyWithTagAdminOverrideReauthPreserves
- TestTagsAuthKeyWithTagCLICannotModifyAdminTags
- TestTagsAuthKeyWithoutTagCannotRequestTags
- TestTagsAuthKeyWithoutTagRegisterNoTags
- TestTagsAuthKeyWithoutTagCannotAddViaCLI
- TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithReset
- TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithEmptyAdvertise
- TestTagsAuthKeyWithoutTagCLICannotReduceAdminMultiTag
- TestTagsUserLoginOwnedTagAtRegistration
- TestTagsUserLoginNonExistentTagAtRegistration
- TestTagsUserLoginUnownedTagAtRegistration
- TestTagsUserLoginAddTagViaCLIReauth
- TestTagsUserLoginRemoveTagViaCLIReauth
- TestTagsUserLoginCLINoOpAfterAdminAssignment
- TestTagsUserLoginCLICannotRemoveAdminTags
- TestTagsAuthKeyWithTagRequestNonExistentTag
- TestTagsAuthKeyWithTagRequestUnownedTag
- TestTagsAuthKeyWithoutTagRequestNonExistentTag
- TestTagsAuthKeyWithoutTagRequestUnownedTag
- TestTagsAdminAPICannotSetNonExistentTag
- TestTagsAdminAPICanSetUnownedTag
- TestTagsAdminAPICannotRemoveAllTags
- TestTagsAdminAPICannotSetInvalidFormat
uses: ./.github/workflows/integration-test-template.yml
secrets: inherit
with:
test: ${{ matrix.test }}
postgres_flag: "--postgres=0"
database_name: "sqlite"
postgres:
needs: [build, build-postgres]
if: needs.build.outputs.files-changed == 'true'
strategy:
fail-fast: false
matrix:
@ -113,6 +256,7 @@ jobs:
- TestPingAllByIPManyUpDown
- TestSubnetRouterMultiNetwork
uses: ./.github/workflows/integration-test-template.yml
secrets: inherit
with:
test: ${{ matrix.test }}
postgres_flag: "--postgres=1"

View file

@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1
with:
fetch-depth: 2
@ -27,13 +27,12 @@ jobs:
- 'integration_test/'
- 'config-example.yaml'
- uses: nixbuild/nix-quick-install-action@889f3180bb5f064ee9e3201428d04ae9e41d54ad # v31
- uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34
if: steps.changed-files.outputs.files == 'true'
- uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3
if: steps.changed-files.outputs.files == 'true'
with:
primary-key:
nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix',
'**/flake.lock') }}
restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }}

View file

@ -7,6 +7,7 @@ linters:
- depguard
- dupl
- exhaustruct
- funcorder
- funlen
- gochecknoglobals
- gochecknoinits
@ -28,6 +29,15 @@ linters:
- wrapcheck
- wsl
settings:
forbidigo:
forbid:
# Forbid time.Sleep everywhere with context-appropriate alternatives
- pattern: 'time\.Sleep'
msg: >-
time.Sleep is forbidden.
In tests: use assert.EventuallyWithT for polling/waiting patterns.
In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives.
analyze-types: true
gocritic:
disabled-checks:
- appendAssign

View file

@ -125,7 +125,7 @@ kos:
# bare tells KO to only use the repository
# for tagging and naming the container.
bare: true
base_image: gcr.io/distroless/base-debian12
base_image: gcr.io/distroless/base-debian13
build: headscale
main: ./cmd/headscale
env:
@ -154,7 +154,7 @@ kos:
- headscale/headscale
bare: true
base_image: gcr.io/distroless/base-debian12:debug
base_image: gcr.io/distroless/base-debian13:debug
build: headscale
main: ./cmd/headscale
env:

View file

@ -139,19 +139,19 @@ git commit -m "feat: add new feature"
# Fix the issues and try committing again
```
### Manual golangci-lint (Optional)
### Manual golangci-lint
While golangci-lint runs automatically via prek, you can also run it manually:
```bash
# Use the same logic as the pre-commit hook (recommended)
./.golangci-lint-hook.sh
# Or manually specify a base reference
# If you have upstream remote configured (recommended)
golangci-lint run --new-from-rev=upstream/main --timeout=5m --fix
# If you only have origin remote
golangci-lint run --new-from-rev=main --timeout=5m --fix
```
The `.golangci-lint-hook.sh` script automatically finds where your branch diverged from the main branch by checking `upstream/main`, `origin/main`, or `main` in that order.
**Important**: Always use `--new-from-rev` to only lint changed files. This prevents formatting the entire repository and keeps changes focused on your actual modifications.
### Skipping Hooks (Not Recommended)
@ -411,7 +411,7 @@ go run ./cmd/hi run "TestPattern*"
- Only ONE test can run at a time (Docker port conflicts)
- Tests generate ~100MB of logs per run in `control_logs/`
- Clean environment before each test: `rm -rf control_logs/202507* && docker system prune -f`
- Clean environment before each test: `sudo rm -rf control_logs/202* && docker system prune -f`
### Test Artifacts Location

View file

@ -1,41 +1,49 @@
# CHANGELOG
## Next
## 0.28.0 (202x-xx-xx)
**Minimum supported Tailscale client version: v1.74.0**
### Web registration templates redesign
### Tags as identity
The OIDC callback and device registration web pages have been updated to use the
Material for MkDocs design system from the official documentation. The templates
now use consistent typography, spacing, and colours across all registration
flows. External links are properly secured with noreferrer/noopener attributes.
Tags are now implemented following the Tailscale model where tags and user ownership are mutually exclusive. Devices can be either
user-owned (authenticated via web/OIDC) or tagged (authenticated via tagged PreAuthKeys). Tagged devices receive their identity from
tags rather than users, making them suitable for servers and infrastructure. Applying a tag to a device removes user-based
ownership. See the [Tailscale tags documentation](https://tailscale.com/kb/1068/tags) for details on how tags work.
User-owned nodes can now request tags during registration using `--advertise-tags`. Tags are validated against the `tagOwners` policy
and applied at registration time. Tags can be managed via the CLI or API after registration.
### Smarter map updates
The map update system has been rewritten to send smaller, partial updates instead of full network maps whenever possible. This reduces bandwidth usage and improves performance, especially for large networks. The system now properly tracks peer
changes and can send removal notifications when nodes are removed due to policy changes.
[#2856](https://github.com/juanfont/headscale/pull/2856) [#2961](https://github.com/juanfont/headscale/pull/2961)
### Pre-authentication key security improvements
Pre-authentication keys now use bcrypt hashing for improved security
[#2853](https://github.com/juanfont/headscale/pull/2853). Keys are stored as a
prefix and bcrypt hash instead of plaintext. The full key is only displayed once
at creation time. When listing keys, only the prefix is shown (e.g.,
`hskey-auth-{prefix}-***`). All new keys use the format
`hskey-auth-{prefix}-{secret}`. Legacy plaintext keys continue to work for
backwards compatibility.
Pre-authentication keys now use bcrypt hashing for improved security [#2853](https://github.com/juanfont/headscale/pull/2853). Keys
are stored as a prefix and bcrypt hash instead of plaintext. The full key is only displayed once at creation time. When listing keys,
only the prefix is shown (e.g., `hskey-auth-{prefix}-***`). All new keys use the format `hskey-auth-{prefix}-{secret}`. Legacy plaintext keys in the format `{secret}` will continue to work for backwards compatibility.
### Tags
### Web registration templates redesign
Tags are now implemented following the Tailscale model where tags and user ownership are mutually exclusive. Devices can be either user-owned (authenticated via web/OIDC) or tagged (authenticated via tagged PreAuthKeys). Tagged devices receive their identity from tags rather than users, making them suitable for servers and infrastructure. Applying a tag to a device removes user-based authentication. See the [Tailscale tags documentation](https://tailscale.com/kb/1068/tags) for details on how tags work.
The OIDC callback and device registration web pages have been updated to use the Material for MkDocs design system from the official
documentation. The templates now use consistent typography, spacing, and colours across all registration flows.
### Database migration support removed for pre-0.25.0 databases
Headscale no longer supports direct upgrades from databases created before
version 0.25.0. Users on older versions must upgrade sequentially through each
stable release, selecting the latest patch version available for each minor
release.
Headscale no longer supports direct upgrades from databases created before version 0.25.0. Users on older versions must upgrade
sequentially through each stable release, selecting the latest patch version available for each minor release.
### BREAKING
- **Tags**: The gRPC `SetTags` endpoint now allows converting user-owned nodes to tagged nodes by setting tags. Once a node is tagged, it cannot be converted back to a user-owned node.
- **Tags**: The gRPC `SetTags` endpoint now allows converting user-owned nodes to tagged nodes by setting tags. Once a node is tagged, it cannot be converted back to a user-owned node. [#2885](https://github.com/juanfont/headscale/pull/2885)
- **Tags**: Tags are now resolved from the node's stored Tags field only [#2931](https://github.com/juanfont/headscale/pull/2931)
- `--advertise-tags` is processed during registration, not on every policy evaluation
- PreAuthKey tagged devices ignore `--advertise-tags` from clients
- User-owned nodes can use `--advertise-tags` if authorized by `tagOwners` policy
- Tags can be managed via CLI (`headscale nodes tag`) or the SetTags API after registration
- Database migration support removed for pre-0.25.0 databases [#2883](https://github.com/juanfont/headscale/pull/2883)
- If you are running a version older than 0.25.0, you must upgrade to 0.25.1 first, then upgrade to this release
- See the [upgrade path documentation](https://headscale.net/stable/about/faq/#what-is-the-recommended-update-path-can-i-skip-multiple-versions-while-updating) for detailed guidance
@ -47,28 +55,28 @@ release.
### Changes
- Smarter change notifications send partial map updates and node removals instead of full maps [#2961](https://github.com/juanfont/headscale/pull/2961)
- Send lightweight endpoint and DERP region updates instead of full maps [#2856](https://github.com/juanfont/headscale/pull/2856)
- Add `oidc.email_verified_required` config option to control email verification requirement [#2860](https://github.com/juanfont/headscale/pull/2860)
- When `true` (default), only verified emails can authenticate via OIDC with `allowed_domains` or `allowed_users`
- When `false`, unverified emails are allowed for OIDC authentication
- Add NixOS module in repository for faster iteration [#2857](https://github.com/juanfont/headscale/pull/2857)
- Add favicon to webpages [#2858](https://github.com/juanfont/headscale/pull/2858)
- Redesign OIDC callback and registration web templates [#2832](https://github.com/juanfont/headscale/pull/2832)
- Reclaim IPs from the IP allocator when nodes are deleted [#2831](https://github.com/juanfont/headscale/pull/2831)
- Add bcrypt hashing for pre-authentication keys [#2853](https://github.com/juanfont/headscale/pull/2853)
- Add structured prefix format for API keys (`hskey-api-{prefix}-{secret}`) [#2853](https://github.com/juanfont/headscale/pull/2853)
- Add registration keys for web authentication tracking (`hskey-reg-{random}`) [#2853](https://github.com/juanfont/headscale/pull/2853)
- Send lightweight endpoint and DERP region updates instead of full maps [#2856](https://github.com/juanfont/headscale/pull/2856)
- Detect when only node endpoints or DERP region changed and send
PeerChangedPatch responses instead of full map updates, reducing bandwidth
and improving performance
## 0.27.2 (2025-xx-xx)
### Changes
- Fix ACL policy not applied to new OIDC nodes until client restart
[#2890](https://github.com/juanfont/headscale/pull/2890)
- Fix autogroup:self preventing visibility of nodes matched by other ACL rules
[#2882](https://github.com/juanfont/headscale/pull/2882)
- Fix nodes being rejected after pre-authentication key expiration
[#2917](https://github.com/juanfont/headscale/pull/2917)
- Add prefix to API keys (`hskey-api-{prefix}-{secret}`) [#2853](https://github.com/juanfont/headscale/pull/2853)
- Add prefix to registration keys for web authentication tracking (`hskey-reg-{random}`) [#2853](https://github.com/juanfont/headscale/pull/2853)
- Tags can now be tagOwner of other tags [#2930](https://github.com/juanfont/headscale/pull/2930)
- Add `taildrop.enabled` configuration option to enable/disable Taildrop file sharing [#2955](https://github.com/juanfont/headscale/pull/2955)
- Allow disabling the metrics server by setting empty `metrics_listen_addr` [#2914](https://github.com/juanfont/headscale/pull/2914)
- Log ACME/autocert errors for easier debugging [#2933](https://github.com/juanfont/headscale/pull/2933)
- Improve CLI list output formatting [#2951](https://github.com/juanfont/headscale/pull/2951)
- Use Debian 13 distroless base images for containers [#2944](https://github.com/juanfont/headscale/pull/2944)
- Fix ACL policy not applied to new OIDC nodes until client restart [#2890](https://github.com/juanfont/headscale/pull/2890)
- Fix autogroup:self preventing visibility of nodes matched by other ACL rules [#2882](https://github.com/juanfont/headscale/pull/2882)
- Fix nodes being rejected after pre-authentication key expiration [#2917](https://github.com/juanfont/headscale/pull/2917)
- Fix list-routes command respecting identifier filter with JSON output [#2927](https://github.com/juanfont/headscale/pull/2927)
## 0.27.1 (2025-11-11)

View file

@ -2,27 +2,43 @@
# and are in no way endorsed by Headscale's maintainers as an
# official nor supported release or distribution.
FROM docker.io/golang:1.25-trixie
FROM docker.io/golang:1.25-trixie AS builder
ARG VERSION=dev
ENV GOPATH /go
WORKDIR /go/src/headscale
RUN apt-get --update install --no-install-recommends --yes less jq sqlite3 dnsutils \
&& apt-get dist-clean
RUN mkdir -p /var/run/headscale
# Install delve debugger
# Install delve debugger first - rarely changes, good cache candidate
RUN go install github.com/go-delve/delve/cmd/dlv@latest
# Download dependencies - only invalidated when go.mod/go.sum change
COPY go.mod go.sum /go/src/headscale/
RUN go mod download
# Copy source and build - invalidated on any source change
COPY . .
# Build debug binary with debug symbols for delve
RUN CGO_ENABLED=0 GOOS=linux go build -gcflags="all=-N -l" -o /go/bin/headscale ./cmd/headscale
# Runtime stage
FROM debian:trixie-slim
RUN apt-get --update install --no-install-recommends --yes \
bash ca-certificates curl dnsutils findutils iproute2 jq less procps python3 sqlite3 \
&& apt-get dist-clean
RUN mkdir -p /var/run/headscale
# Copy binaries from builder
COPY --from=builder /go/bin/headscale /usr/local/bin/headscale
COPY --from=builder /go/bin/dlv /usr/local/bin/dlv
# Copy source code for delve source-level debugging
COPY --from=builder /go/src/headscale /go/src/headscale
WORKDIR /go/src/headscale
# Need to reset the entrypoint or everything will run as a busybox script
ENTRYPOINT []
EXPOSE 8080/tcp 40000/tcp
CMD ["/go/bin/dlv", "--listen=0.0.0.0:40000", "--headless=true", "--api-version=2", "--accept-multiclient", "exec", "/go/bin/headscale", "--"]
CMD ["dlv", "--listen=0.0.0.0:40000", "--headless=true", "--api-version=2", "--accept-multiclient", "exec", "/usr/local/bin/headscale", "--"]

17
Dockerfile.integration-ci Normal file
View file

@ -0,0 +1,17 @@
# Minimal CI image - expects pre-built headscale binary in build context
# For local development with delve debugging, use Dockerfile.integration instead
FROM debian:trixie-slim
RUN apt-get --update install --no-install-recommends --yes \
bash ca-certificates curl dnsutils findutils iproute2 jq less procps python3 sqlite3 \
&& apt-get dist-clean
RUN mkdir -p /var/run/headscale
# Copy pre-built headscale binary from build context
COPY headscale /usr/local/bin/headscale
ENTRYPOINT []
EXPOSE 8080/tcp
CMD ["/usr/local/bin/headscale"]

View file

@ -37,7 +37,9 @@ RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\
-v ./cmd/tailscale ./cmd/tailscaled ./cmd/containerboot
FROM alpine:3.22
RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl
# Upstream: ca-certificates ip6tables iptables iproute2
# Tests: curl python3 (traceroute via BusyBox)
RUN apk add --no-cache ca-certificates curl ip6tables iptables iproute2 python3
COPY --from=build-env /go/bin/* /usr/local/bin/
# For compat with the previous run.sh, although ideally you should be

View file

@ -10,10 +10,6 @@ import (
"google.golang.org/grpc/status"
)
const (
errPreAuthKeyMalformed = Error("key is malformed. expected 64 hex characters with `nodekey` prefix")
)
// Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors
type Error string

View file

@ -220,10 +220,6 @@ var listNodeRoutesCmd = &cobra.Command{
)
}
if output != "" {
SuccessOutput(response.GetNodes(), "", output)
}
nodes := response.GetNodes()
if identifier != 0 {
for _, node := range response.GetNodes() {
@ -238,6 +234,11 @@ var listNodeRoutesCmd = &cobra.Command{
return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0)
})
if output != "" {
SuccessOutput(nodes, "", output)
return
}
tableData, err := nodeRoutesToPtables(nodes)
if err != nil {
ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output)
@ -561,23 +562,26 @@ func nodesToPtables(
var forcedTags string
for _, tag := range node.GetForcedTags() {
forcedTags += "," + tag
forcedTags += "\n" + tag
}
forcedTags = strings.TrimLeft(forcedTags, ",")
forcedTags = strings.TrimLeft(forcedTags, "\n")
var invalidTags string
for _, tag := range node.GetInvalidTags() {
if !slices.Contains(node.GetForcedTags(), tag) {
invalidTags += "," + pterm.LightRed(tag)
invalidTags += "\n" + pterm.LightRed(tag)
}
}
invalidTags = strings.TrimLeft(invalidTags, ",")
invalidTags = strings.TrimLeft(invalidTags, "\n")
var validTags string
for _, tag := range node.GetValidTags() {
if !slices.Contains(node.GetForcedTags(), tag) {
validTags += "," + pterm.LightGreen(tag)
validTags += "\n" + pterm.LightGreen(tag)
}
}
validTags = strings.TrimLeft(validTags, ",")
validTags = strings.TrimLeft(validTags, "\n")
var user string
if currentUser == "" || (currentUser == node.GetUser().GetName()) {
@ -639,9 +643,9 @@ func nodeRoutesToPtables(
nodeData := []string{
strconv.FormatUint(node.GetId(), util.Base10),
node.GetGivenName(),
strings.Join(node.GetApprovedRoutes(), ", "),
strings.Join(node.GetAvailableRoutes(), ", "),
strings.Join(node.GetSubnetRoutes(), ", "),
strings.Join(node.GetApprovedRoutes(), "\n"),
strings.Join(node.GetAvailableRoutes(), "\n"),
strings.Join(node.GetSubnetRoutes(), "\n"),
}
tableData = append(
tableData,

View file

@ -107,10 +107,10 @@ var listPreAuthKeys = &cobra.Command{
aclTags := ""
for _, tag := range key.GetAclTags() {
aclTags += "," + tag
aclTags += "\n" + tag
}
aclTags = strings.TrimLeft(aclTags, ",")
aclTags = strings.TrimLeft(aclTags, "\n")
tableData = append(tableData, []string{
strconv.FormatUint(key.GetId(), 10),

View file

@ -9,6 +9,7 @@ import (
"strings"
"time"
"github.com/cenkalti/backoff/v5"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/api/types/filters"
"github.com/docker/docker/api/types/image"
@ -86,30 +87,28 @@ func killTestContainers(ctx context.Context) error {
return nil
}
const (
containerRemoveInitialInterval = 100 * time.Millisecond
containerRemoveMaxElapsedTime = 2 * time.Second
)
// removeContainerWithRetry attempts to remove a container with exponential backoff retry logic.
func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool {
maxRetries := 3
baseDelay := 100 * time.Millisecond
expBackoff := backoff.NewExponentialBackOff()
expBackoff.InitialInterval = containerRemoveInitialInterval
for attempt := range maxRetries {
_, err := backoff.Retry(ctx, func() (struct{}, error) {
err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{
Force: true,
})
if err == nil {
return true
if err != nil {
return struct{}{}, err
}
// If this is the last attempt, don't wait
if attempt == maxRetries-1 {
break
}
return struct{}{}, nil
}, backoff.WithBackOff(expBackoff), backoff.WithMaxElapsedTime(containerRemoveMaxElapsedTime))
// Wait with exponential backoff
delay := baseDelay * time.Duration(1<<attempt)
time.Sleep(delay)
}
return false
return err == nil
}
// pruneDockerNetworks removes unused Docker networks.

View file

@ -301,6 +301,11 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
"HEADSCALE_INTEGRATION_RUN_ID=" + runID,
}
// Pass through CI environment variable for CI detection
if ci := os.Getenv("CI"); ci != "" {
env = append(env, "CI="+ci)
}
// Pass through all HEADSCALE_INTEGRATION_* environment variables
for _, e := range os.Environ() {
if strings.HasPrefix(e, "HEADSCALE_INTEGRATION_") {
@ -313,6 +318,10 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
env = append(env, e)
}
}
// Set GOCACHE to a known location (used by both bind mount and volume cases)
env = append(env, "GOCACHE=/cache/go-build")
containerConfig := &container.Config{
Image: "golang:" + config.GoVersion,
Cmd: goTestCmd,
@ -332,20 +341,43 @@ func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunC
log.Printf("Using Docker socket: %s", dockerSocketPath)
}
binds := []string{
fmt.Sprintf("%s:%s", projectRoot, projectRoot),
dockerSocketPath + ":/var/run/docker.sock",
logsDir + ":/tmp/control",
}
// Use bind mounts for Go cache if provided via environment variables,
// otherwise fall back to Docker volumes for local development
var mounts []mount.Mount
goCache := os.Getenv("HEADSCALE_INTEGRATION_GO_CACHE")
goBuildCache := os.Getenv("HEADSCALE_INTEGRATION_GO_BUILD_CACHE")
if goCache != "" {
binds = append(binds, goCache+":/go")
} else {
mounts = append(mounts, mount.Mount{
Type: mount.TypeVolume,
Source: "hs-integration-go-cache",
Target: "/go",
})
}
if goBuildCache != "" {
binds = append(binds, goBuildCache+":/cache/go-build")
} else {
mounts = append(mounts, mount.Mount{
Type: mount.TypeVolume,
Source: "hs-integration-go-build-cache",
Target: "/cache/go-build",
})
}
hostConfig := &container.HostConfig{
AutoRemove: false, // We'll remove manually for better control
Binds: []string{
fmt.Sprintf("%s:%s", projectRoot, projectRoot),
dockerSocketPath + ":/var/run/docker.sock",
logsDir + ":/tmp/control",
},
Mounts: []mount.Mount{
{
Type: mount.TypeVolume,
Source: "hs-integration-go-cache",
Target: "/go",
},
},
Binds: binds,
Mounts: mounts,
}
return cli.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName)
@ -812,63 +844,3 @@ func extractContainerFiles(ctx context.Context, cli *client.Client, containerID,
// This function is kept for potential future use or other file types
return nil
}
// logExtractionError logs extraction errors with appropriate level based on error type.
func logExtractionError(artifactType, containerName string, err error, verbose bool) {
if errors.Is(err, ErrFileNotFoundInTar) {
// File not found is expected and only logged in verbose mode
if verbose {
log.Printf("No %s found in container %s", artifactType, containerName)
}
} else {
// Other errors are actual failures and should be logged as warnings
log.Printf("Warning: failed to extract %s from %s: %v", artifactType, containerName, err)
}
}
// extractSingleFile copies a single file from a container.
func extractSingleFile(ctx context.Context, cli *client.Client, containerID, sourcePath, fileName, logsDir string, verbose bool) error {
tarReader, _, err := cli.CopyFromContainer(ctx, containerID, sourcePath)
if err != nil {
return fmt.Errorf("failed to copy %s from container: %w", sourcePath, err)
}
defer tarReader.Close()
// Extract the single file from the tar
filePath := filepath.Join(logsDir, fileName)
if err := extractFileFromTar(tarReader, filepath.Base(sourcePath), filePath); err != nil {
return fmt.Errorf("failed to extract file from tar: %w", err)
}
if verbose {
log.Printf("Extracted %s from %s", fileName, containerID[:12])
}
return nil
}
// extractDirectory copies a directory from a container and extracts its contents.
func extractDirectory(ctx context.Context, cli *client.Client, containerID, sourcePath, dirName, logsDir string, verbose bool) error {
tarReader, _, err := cli.CopyFromContainer(ctx, containerID, sourcePath)
if err != nil {
return fmt.Errorf("failed to copy %s from container: %w", sourcePath, err)
}
defer tarReader.Close()
// Create target directory
targetDir := filepath.Join(logsDir, dirName)
if err := os.MkdirAll(targetDir, 0o755); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetDir, err)
}
// Extract the directory from the tar
if err := extractDirectoryFromTar(tarReader, targetDir); err != nil {
return fmt.Errorf("failed to extract directory from tar: %w", err)
}
if verbose {
log.Printf("Extracted %s/ from %s", dirName, containerID[:12])
}
return nil
}

View file

@ -1,105 +0,0 @@
package main
import (
"archive/tar"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"strings"
)
// ErrFileNotFoundInTar indicates a file was not found in the tar archive.
var ErrFileNotFoundInTar = errors.New("file not found in tar")
// extractFileFromTar extracts a single file from a tar reader.
func extractFileFromTar(tarReader io.Reader, fileName, outputPath string) error {
tr := tar.NewReader(tarReader)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
// Check if this is the file we're looking for
if filepath.Base(header.Name) == fileName {
if header.Typeflag == tar.TypeReg {
// Create the output file
outFile, err := os.Create(outputPath)
if err != nil {
return fmt.Errorf("failed to create output file: %w", err)
}
defer outFile.Close()
// Copy file contents
if _, err := io.Copy(outFile, tr); err != nil {
return fmt.Errorf("failed to copy file contents: %w", err)
}
return nil
}
}
}
return fmt.Errorf("%w: %s", ErrFileNotFoundInTar, fileName)
}
// extractDirectoryFromTar extracts all files from a tar reader to a target directory.
func extractDirectoryFromTar(tarReader io.Reader, targetDir string) error {
tr := tar.NewReader(tarReader)
for {
header, err := tr.Next()
if err == io.EOF {
break
}
if err != nil {
return fmt.Errorf("failed to read tar header: %w", err)
}
// Clean the path to prevent directory traversal
cleanName := filepath.Clean(header.Name)
if strings.Contains(cleanName, "..") {
continue // Skip potentially dangerous paths
}
targetPath := filepath.Join(targetDir, cleanName)
switch header.Typeflag {
case tar.TypeDir:
// Create directory
if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to create directory %s: %w", targetPath, err)
}
case tar.TypeReg:
// Ensure parent directories exist
if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil {
return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err)
}
// Create file
outFile, err := os.Create(targetPath)
if err != nil {
return fmt.Errorf("failed to create file %s: %w", targetPath, err)
}
if _, err := io.Copy(outFile, tr); err != nil {
outFile.Close()
return fmt.Errorf("failed to copy file contents: %w", err)
}
outFile.Close()
// Set file permissions
if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil {
return fmt.Errorf("failed to set file permissions: %w", err)
}
}
}
return nil
}

View file

@ -421,6 +421,15 @@ logtail:
# default static port 41641. This option is intended as a workaround for some buggy
# firewall devices. See https://tailscale.com/kb/1181/firewalls/ for more information.
randomize_client_port: false
# Taildrop configuration
# Taildrop is the file sharing feature of Tailscale, allowing nodes to send files to each other.
# https://tailscale.com/kb/1106/taildrop/
taildrop:
# Enable or disable Taildrop for all nodes.
# When enabled, nodes can send files to other nodes owned by the same user.
# Tagged devices and cross-user transfers are not permitted by Tailscale clients.
enabled: true
# Advanced performance tuning parameters.
# The defaults are carefully chosen and should rarely need adjustment.
# Only modify these if you have identified a specific performance issue.

View file

@ -7,10 +7,15 @@
This page collects third-party tools, client libraries, and scripts related to headscale.
| Name | Repository Link | Description |
| --------------------- | --------------------------------------------------------------- | -------------------------------------------------------------------- |
| tailscale-manager | [Github](https://github.com/singlestore-labs/tailscale-manager) | Dynamically manage Tailscale route advertisements |
| headscalebacktosqlite | [Github](https://github.com/bigbozza/headscalebacktosqlite) | Migrate headscale from PostgreSQL back to SQLite |
| headscale-pf | [Github](https://github.com/YouSysAdmin/headscale-pf) | Populates user groups based on user groups in Jumpcloud or Authentik |
| headscale-client-go | [Github](https://github.com/hibare/headscale-client-go) | A Go client implementation for the Headscale HTTP API. |
| headscale-zabbix | [Github](https://github.com/dblanque/headscale-zabbix) | A Zabbix Monitoring Template for the Headscale Service. |
- [tailscale-manager](https://github.com/singlestore-labs/tailscale-manager) - Dynamically manage Tailscale route
advertisements
- [headscalebacktosqlite](https://github.com/bigbozza/headscalebacktosqlite) - Migrate headscale from PostgreSQL back to
SQLite
- [headscale-pf](https://github.com/YouSysAdmin/headscale-pf) - Populates user groups based on user groups in Jumpcloud
or Authentik
- [headscale-client-go](https://github.com/hibare/headscale-client-go) - A Go client implementation for the Headscale
HTTP API.
- [headscale-zabbix](https://github.com/dblanque/headscale-zabbix) - A Zabbix Monitoring Template for the Headscale
Service.
- [tailscale-exporter](https://github.com/adinhodovic/tailscale-exporter) - A Prometheus exporter for Headscale that
provides network-level metrics using the Headscale API.

View file

@ -7,14 +7,17 @@
Headscale doesn't provide a built-in web interface but users may pick one from the available options.
| Name | Repository Link | Description |
| ---------------------- | ----------------------------------------------------------- | -------------------------------------------------------------------------------------------- |
| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server |
| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend environment required |
| Headplane | [GitHub](https://github.com/tale/headplane) | An advanced Tailscale inspired frontend for headscale |
| headscale-admin | [Github](https://github.com/GoodiesHQ/headscale-admin) | Headscale-Admin is meant to be a simple, modern web interface for headscale |
| ouroboros | [Github](https://github.com/yellowsink/ouroboros) | Ouroboros is designed for users to manage their own devices, rather than for admins |
| unraid-headscale-admin | [Github](https://github.com/ich777/unraid-headscale-admin) | A simple headscale admin UI for Unraid, it offers Local (`docker exec`) and API Mode |
| headscale-console | [Github](https://github.com/rickli-cloud/headscale-console) | WebAssembly-based client supporting SSH, VNC and RDP with optional self-service capabilities |
- [headscale-ui](https://github.com/gurucomputing/headscale-ui) - A web frontend for the headscale Tailscale-compatible
coordination server
- [HeadscaleUi](https://github.com/simcu/headscale-ui) - A static headscale admin ui, no backend environment required
- [Headplane](https://github.com/tale/headplane) - An advanced Tailscale inspired frontend for headscale
- [headscale-admin](https://github.com/GoodiesHQ/headscale-admin) - Headscale-Admin is meant to be a simple, modern web
interface for headscale
- [ouroboros](https://github.com/yellowsink/ouroboros) - Ouroboros is designed for users to manage their own devices,
rather than for admins
- [unraid-headscale-admin](https://github.com/ich777/unraid-headscale-admin) - A simple headscale admin UI for Unraid,
it offers Local (`docker exec`) and API Mode
- [headscale-console](https://github.com/rickli-cloud/headscale-console) - WebAssembly-based client supporting SSH, VNC
and RDP with optional self-service capabilities
You can ask for support on our [Discord server](https://discord.gg/c84AZQhmpx) in the "web-interfaces" channel.

View file

@ -166,9 +166,6 @@
buf
clang-tools # clang-format
protobuf-language-server
# Add hi to make it even easier to use ci runner.
hi
]
++ lib.optional pkgs.stdenv.isLinux [ traceroute ];

6
go.sum
View file

@ -124,8 +124,6 @@ github.com/creachadair/command v0.2.0 h1:qTA9cMMhZePAxFoNdnk6F6nn94s1qPndIg9hJbq
github.com/creachadair/command v0.2.0/go.mod h1:j+Ar+uYnFsHpkMeV9kGj6lJ45y9u2xqtg8FYy6cm+0o=
github.com/creachadair/flax v0.0.5 h1:zt+CRuXQASxwQ68e9GHAOnEgAU29nF0zYMHOCrL5wzE=
github.com/creachadair/flax v0.0.5/go.mod h1:F1PML0JZLXSNDMNiRGK2yjm5f+L9QCHchyHBldFymj8=
github.com/creachadair/mds v0.25.2 h1:xc0S0AfDq5GX9KUR5sLvi5XjA61/P6S5e0xFs1vA18Q=
github.com/creachadair/mds v0.25.2/go.mod h1:+s4CFteFRj4eq2KcGHW8Wei3u9NyzSPzNV32EvjyK/Q=
github.com/creachadair/mds v0.25.10 h1:9k9JB35D1xhOCFl0liBhagBBp8fWWkKZrA7UXsfoHtA=
github.com/creachadair/mds v0.25.10/go.mod h1:4hatI3hRM+qhzuAmqPRFvaBM8mONkS7nsLxkcuTYUIs=
github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc=
@ -278,8 +276,6 @@ github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfC
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/jsimonetti/rtnetlink v1.4.1 h1:JfD4jthWBqZMEffc5RjgmlzpYttAVw1sdnmiNaPO3hE=
github.com/jsimonetti/rtnetlink v1.4.1/go.mod h1:xJjT7t59UIZ62GLZbv6PLLo8VFrostJMPBAheR6OM8w=
github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo=
github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ=
github.com/klauspost/compress v1.18.1 h1:bcSGx7UbpBqMChDtsF28Lw6v/G94LPrrbMbdC3JH2co=
github.com/klauspost/compress v1.18.1/go.mod h1:ZQFFVG+MdnR0P+l6wpXgIL4NTtwiKIdBnrBd8Nrxr+0=
github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
@ -463,8 +459,6 @@ github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+y
github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc=
github.com/tailscale/setec v0.0.0-20250305161714-445cadbbca3d h1:mnqtPWYyvNiPU9l9tzO2YbHXU/xV664XthZYA26lOiE=
github.com/tailscale/setec v0.0.0-20250305161714-445cadbbca3d/go.mod h1:9BzmlFc3OLqLzLTF/5AY+BMs+clxMqyhSGzgXIm8mNI=
github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694 h1:95eIP97c88cqAFU/8nURjgI9xxPbD+Ci6mY/a79BI/w=
github.com/tailscale/squibble v0.0.0-20250108170732-a4ca58afa694/go.mod h1:veguaG8tVg1H/JG5RfpoUW41I+O8ClPElo/fTYr8mMk=
github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993 h1:FyiiAvDAxpB0DrW2GW3KOVfi3YFOtsQUEeFWbf55JJU=
github.com/tailscale/squibble v0.0.0-20251030164342-4d5df9caa993/go.mod h1:xJkMmR3t+thnUQhA3Q4m2VSlS5pcOq+CIjmU/xfKKx4=
github.com/tailscale/tailsql v0.0.0-20250421235516-02f85f087b97 h1:JJkDnrAhHvOCttk8z9xeZzcDlzzkRA7+Duxj9cwOyxk=

View file

@ -5,6 +5,7 @@ import (
"crypto/tls"
"errors"
"fmt"
"io"
"net"
"net/http"
_ "net/http/pprof" // nolint
@ -270,7 +271,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
return
case <-expireTicker.C:
var expiredNodeChanges []change.ChangeSet
var expiredNodeChanges []change.Change
var changed bool
lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck)
@ -304,7 +305,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
}
h.state.SetDERPMap(derpMap)
h.Change(change.DERPSet)
h.Change(change.DERPMap())
case records, ok := <-extraRecordsUpdate:
if !ok {
@ -312,7 +313,7 @@ func (h *Headscale) scheduledTasks(ctx context.Context) {
}
h.cfg.TailcfgDNSConfig.ExtraRecords = records
h.Change(change.ExtraRecordsSet)
h.Change(change.ExtraRecords())
}
}
}
@ -729,16 +730,27 @@ func (h *Headscale) Serve() error {
log.Info().
Msgf("listening and serving HTTP on: %s", h.cfg.Addr)
debugHTTPListener, err := net.Listen("tcp", h.cfg.MetricsAddr)
if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err)
// Only start debug/metrics server if address is configured
var debugHTTPServer *http.Server
var debugHTTPListener net.Listener
if h.cfg.MetricsAddr != "" {
debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr)
if err != nil {
return fmt.Errorf("failed to bind to TCP address: %w", err)
}
debugHTTPServer = h.debugHTTPServer()
errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })
log.Info().
Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)
} else {
log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)")
}
debugHTTPServer := h.debugHTTPServer()
errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) })
log.Info().
Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr)
var tailsqlContext context.Context
if tailsqlEnabled {
@ -794,16 +806,25 @@ func (h *Headscale) Serve() error {
h.ephemeralGC.Close()
// Gracefully shut down servers
ctx, cancel := context.WithTimeout(
context.Background(),
shutdownCtx, cancel := context.WithTimeout(
context.WithoutCancel(ctx),
types.HTTPShutdownTimeout,
)
info("shutting down debug http server")
if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("failed to shutdown prometheus http")
defer cancel()
if debugHTTPServer != nil {
info("shutting down debug http server")
err := debugHTTPServer.Shutdown(shutdownCtx)
if err != nil {
log.Error().Err(err).Msg("failed to shutdown prometheus http")
}
}
info("shutting down main http server")
if err := httpServer.Shutdown(ctx); err != nil {
err := httpServer.Shutdown(shutdownCtx)
if err != nil {
log.Error().Err(err).Msg("failed to shutdown http")
}
@ -829,7 +850,10 @@ func (h *Headscale) Serve() error {
// Close network listeners
info("closing network listeners")
debugHTTPListener.Close()
if debugHTTPListener != nil {
debugHTTPListener.Close()
}
httpListener.Close()
grpcGatewayConn.Close()
@ -847,9 +871,6 @@ func (h *Headscale) Serve() error {
log.Info().
Msg("Headscale stopped")
// And we're done:
cancel()
return
}
}
@ -877,6 +898,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
Cache: autocert.DirCache(h.cfg.TLS.LetsEncrypt.CacheDir),
Client: &acme.Client{
DirectoryURL: h.cfg.ACMEURL,
HTTPClient: &http.Client{
Transport: &acmeLogger{
rt: http.DefaultTransport,
},
},
},
Email: h.cfg.ACMEEmail,
}
@ -935,18 +961,6 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) {
}
}
func notFoundHandler(
writer http.ResponseWriter,
req *http.Request,
) {
log.Trace().
Interface("header", req.Header).
Interface("proto", req.Proto).
Interface("url", req.URL).
Msg("Request did not match")
writer.WriteHeader(http.StatusNotFound)
}
func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
dir := filepath.Dir(path)
err := util.EnsureDir(dir)
@ -994,6 +1008,31 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) {
// Change is used to send changes to nodes.
// All change should be enqueued here and empty will be automatically
// ignored.
func (h *Headscale) Change(cs ...change.ChangeSet) {
func (h *Headscale) Change(cs ...change.Change) {
h.mapBatcher.AddWork(cs...)
}
// Provide some middleware that can inspect the ACME/autocert https calls
// and log when things are failing.
type acmeLogger struct {
rt http.RoundTripper
}
// RoundTrip will log when ACME/autocert failures happen either when err != nil OR
// when http status codes indicate a failure has occurred.
func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) {
resp, err := l.rt.RoundTrip(req)
if err != nil {
log.Error().Err(err).Str("url", req.URL.String()).Msg("ACME request failed")
return nil, err
}
if resp.StatusCode >= http.StatusBadRequest {
defer resp.Body.Close()
body, _ := io.ReadAll(resp.Body)
log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("ACME request returned error")
}
return resp, nil
}

View file

@ -9,12 +9,8 @@
--md-primary-fg-color: #4051b5;
--md-accent-fg-color: #526cfe;
--md-typeset-a-color: var(--md-primary-fg-color);
--md-text-font:
"Roboto", -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue",
Arial, sans-serif;
--md-code-font:
"Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New",
monospace;
--md-text-font: "Roboto", -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif;
--md-code-font: "Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace;
}
/* Base Typography */

View file

@ -668,9 +668,10 @@ func TestAuthenticationFlows(t *testing.T) {
}
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Simulate successful registration
// Simulate successful registration - send to buffered channel
// The channel is buffered (size 1), so this can complete immediately
// and handleRegister will receive the value when it starts waiting
go func() {
time.Sleep(20 * time.Millisecond)
user := app.state.CreateUserForTest("followup-user")
node := app.state.CreateNodeForTest(user, "followup-success-node")
registered <- node
@ -927,6 +928,82 @@ func TestAuthenticationFlows(t *testing.T) {
},
},
// === ADVERTISE-TAGS (RequestTags) SCENARIOS ===
// Tests for client-provided tags via --advertise-tags flag
// TEST: PreAuthKey registration rejects client-provided RequestTags
// WHAT: Tests that PreAuthKey registrations cannot use client-provided tags
// INPUT: PreAuthKey registration with RequestTags in Hostinfo
// EXPECTED: Registration fails with "requested tags [...] are invalid or not permitted" error
// WHY: PreAuthKey nodes get their tags from the key itself, not from client requests
{
name: "preauth_key_rejects_request_tags",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
t.Helper()
user := app.state.CreateUserForTest("pak-requesttags-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
return tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: authKey,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "pak-requesttags-node",
RequestTags: []string{"tag:unauthorized"},
},
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: machineKey1.Public,
wantError: true,
},
// TEST: Tagged PreAuthKey ignores client-provided RequestTags
// WHAT: Tests that tagged PreAuthKey uses key tags, not client RequestTags
// INPUT: Tagged PreAuthKey registration with different RequestTags
// EXPECTED: Registration fails because RequestTags are rejected for PreAuthKey
// WHY: Tags-as-identity: PreAuthKey tags are authoritative, client cannot override
{
name: "tagged_preauth_key_rejects_client_request_tags",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
t.Helper()
user := app.state.CreateUserForTest("tagged-pak-clienttags-user")
keyTags := []string{"tag:authorized"}
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, keyTags)
if err != nil {
return "", err
}
return pak.Key, nil
},
request: func(authKey string) tailcfg.RegisterRequest {
return tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: authKey,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-pak-clienttags-node",
RequestTags: []string{"tag:client-wants-this"}, // Should be rejected
},
Expiry: time.Now().Add(24 * time.Hour),
}
},
machineKey: machineKey1.Public,
wantError: true, // RequestTags rejected for PreAuthKey registrations
},
// === RE-AUTHENTICATION SCENARIOS ===
// TEST: Existing node re-authenticates with new pre-auth key
// WHAT: Tests that existing node can re-authenticate using new pre-auth key
@ -1202,8 +1279,9 @@ func TestAuthenticationFlows(t *testing.T) {
OS: "unknown-os",
OSVersion: "999.999.999",
DeviceModel: "test-device-model",
RequestTags: []string{"invalid:tag", "another!tag"},
Services: []tailcfg.Service{{Proto: "tcp", Port: 65535}},
// Note: RequestTags are not included for PreAuthKey registrations
// since tags come from the key itself, not client requests.
Services: []tailcfg.Service{{Proto: "tcp", Port: 65535}},
},
Expiry: time.Now().Add(24 * time.Hour),
}
@ -1247,8 +1325,8 @@ func TestAuthenticationFlows(t *testing.T) {
app.state.SetRegistrationCacheEntry(regID, nodeToRegister)
// Simulate registration that returns nil (cache expired during auth)
// The channel is buffered (size 1), so this can complete immediately
go func() {
time.Sleep(20 * time.Millisecond)
registered <- nil // Nil indicates cache expiry
}()
@ -1315,9 +1393,13 @@ func TestAuthenticationFlows(t *testing.T) {
// === AUTH PROVIDER EDGE CASES ===
// TEST: Interactive workflow preserves custom hostinfo
// WHAT: Tests that custom hostinfo fields are preserved through interactive flow
// INPUT: Interactive registration with detailed hostinfo (OS, version, model, etc.)
// INPUT: Interactive registration with detailed hostinfo (OS, version, model)
// EXPECTED: Node registers with all hostinfo fields preserved
// WHY: Ensures interactive flow doesn't lose custom hostinfo data
// NOTE: RequestTags are NOT tested here because tag authorization via
// advertise-tags requires the user to have existing nodes (for IP-based
// ownership verification). New users registering their first node cannot
// claim tags via RequestTags - they must use a tagged PreAuthKey instead.
{
name: "interactive_workflow_with_custom_hostinfo",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
@ -1331,7 +1413,6 @@ func TestAuthenticationFlows(t *testing.T) {
OS: "linux",
OSVersion: "20.04",
DeviceModel: "server",
RequestTags: []string{"tag:server"},
},
Expiry: time.Now().Add(24 * time.Hour),
}
@ -1353,7 +1434,6 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Equal(t, "linux", node.Hostinfo().OS())
assert.Equal(t, "20.04", node.Hostinfo().OSVersion())
assert.Equal(t, "server", node.Hostinfo().DeviceModel())
assert.Contains(t, node.Hostinfo().RequestTags().AsSlice(), "tag:server")
}
},
},
@ -2001,11 +2081,8 @@ func TestAuthenticationFlows(t *testing.T) {
}(i)
}
// All should wait since no auth completion happened
// After a short delay, they should timeout or be waiting
time.Sleep(100 * time.Millisecond)
// Now complete the authentication to signal one of them
// Complete the authentication to signal the waiting goroutines
// The goroutines will receive from the buffered channel when ready
registrationID, err := extractRegistrationIDFromAuthURL(authURL)
require.NoError(t, err)
@ -2329,10 +2406,8 @@ func TestAuthenticationFlows(t *testing.T) {
responseChan <- resp
}()
// Give followup time to start waiting
time.Sleep(50 * time.Millisecond)
// Complete authentication for second registration
// The goroutine will receive the node from the buffered channel
_, _, err = app.state.HandleNodeFromAuthPath(
regID2,
types.UserID(user.ID),
@ -2525,10 +2600,7 @@ func runInteractiveWorkflowTest(t *testing.T, tt struct {
responseChan <- resp
}()
// Give the followup request time to start waiting
time.Sleep(50 * time.Millisecond)
// Now complete the authentication - this will signal the waiting followup request
// Complete the authentication - the goroutine will receive from the buffered channel
user := app.state.CreateUserForTest("interactive-test-user")
_, _, err = app.state.HandleNodeFromAuthPath(
registrationID,
@ -3234,7 +3306,7 @@ func TestIssue2830_ExistingNodeReregistersWithExpiredKey(t *testing.T) {
// Create a valid key (will expire it later)
expiry := time.Now().Add(1 * time.Hour)
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiry, nil)
require.NoError(t, err)
machineKey := key.NewMachine()
@ -3423,3 +3495,49 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
nodesAfterAttack := app.state.ListNodesByUser(types.UserID(user.ID))
require.Equal(t, 1, nodesAfterAttack.Len(), "Should still have exactly one node (attack prevented)")
}
// TestWebAuthRejectsUnauthorizedRequestTags tests that web auth registrations
// validate RequestTags against policy and reject unauthorized tags.
func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create a user that will authenticate via web auth
user := app.state.CreateUserForTest("webauth-tags-user")
machineKey := key.NewMachine()
nodeKey := key.NewNode()
// Simulate a registration cache entry (as would be created during web auth)
registrationID := types.MustRegistrationID()
regEntry := types.NewRegisterNode(types.Node{
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "webauth-tags-node",
Hostinfo: &tailcfg.Hostinfo{
Hostname: "webauth-tags-node",
RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy
},
})
app.state.SetRegistrationCacheEntry(registrationID, regEntry)
// Complete the web auth - should fail because tag is unauthorized
_, _, err := app.state.HandleNodeFromAuthPath(
registrationID,
types.UserID(user.ID),
nil, // no expiry
"webauth",
)
// Expect error due to unauthorized tags
require.Error(t, err, "HandleNodeFromAuthPath should reject unauthorized RequestTags")
require.Contains(t, err.Error(), "requested tags",
"Error should indicate requested tags are invalid or not permitted")
require.Contains(t, err.Error(), "tag:unauthorized",
"Error should mention the rejected tag")
// Verify no node was created
_, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.False(t, found, "Node should not be created when tags are unauthorized")
}

View file

@ -12,6 +12,14 @@ import (
"tailscale.com/util/set"
)
const (
// minVersionParts is the minimum number of version parts needed for major.minor.
minVersionParts = 2
// legacyDERPCapVer is the capability version when LegacyDERP can be cleaned up.
legacyDERPCapVer = 111
)
// CanOldCodeBeCleanedUp is intended to be called on startup to see if
// there are old code that can ble cleaned up, entries should contain
// a CapVer where something can be cleaned up and a panic if it can.
@ -19,7 +27,7 @@ import (
//
// All uses of Capability version checks should be listed here.
func CanOldCodeBeCleanedUp() {
if MinSupportedCapabilityVersion >= 111 {
if MinSupportedCapabilityVersion >= legacyDERPCapVer {
panic("LegacyDERP can be cleaned up in tail.go")
}
}
@ -44,12 +52,25 @@ func TailscaleVersion(ver tailcfg.CapabilityVersion) string {
}
// CapabilityVersion returns the CapabilityVersion for the given Tailscale version.
// It accepts both full versions (v1.90.1) and minor versions (v1.90).
func CapabilityVersion(ver string) tailcfg.CapabilityVersion {
if !strings.HasPrefix(ver, "v") {
ver = "v" + ver
}
return tailscaleToCapVer[ver]
// Try direct lookup first (works for minor versions like v1.90)
if cv, ok := tailscaleToCapVer[ver]; ok {
return cv
}
// Try extracting minor version from full version (v1.90.1 -> v1.90)
parts := strings.Split(strings.TrimPrefix(ver, "v"), ".")
if len(parts) >= minVersionParts {
minor := "v" + parts[0] + "." + parts[1]
return tailscaleToCapVer[minor]
}
return 0
}
// TailscaleLatest returns the n latest Tailscale versions.

View file

@ -5,54 +5,79 @@ package capver
import "tailscale.com/tailcfg"
var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{
"v1.70.0": 102,
"v1.72.0": 104,
"v1.72.1": 104,
"v1.74.0": 106,
"v1.74.1": 106,
"v1.76.0": 106,
"v1.76.1": 106,
"v1.76.6": 106,
"v1.78.0": 109,
"v1.78.1": 109,
"v1.80.0": 113,
"v1.80.1": 113,
"v1.80.2": 113,
"v1.80.3": 113,
"v1.82.0": 115,
"v1.82.5": 115,
"v1.84.0": 116,
"v1.84.1": 116,
"v1.84.2": 116,
"v1.86.0": 122,
"v1.86.2": 123,
"v1.88.1": 125,
"v1.88.3": 125,
"v1.90.1": 130,
"v1.90.2": 130,
"v1.90.3": 130,
"v1.90.4": 130,
"v1.90.6": 130,
"v1.90.8": 130,
"v1.90.9": 130,
"v1.24": 32,
"v1.26": 32,
"v1.28": 32,
"v1.30": 41,
"v1.32": 46,
"v1.34": 51,
"v1.36": 56,
"v1.38": 58,
"v1.40": 61,
"v1.42": 62,
"v1.44": 63,
"v1.46": 65,
"v1.48": 68,
"v1.50": 74,
"v1.52": 79,
"v1.54": 79,
"v1.56": 82,
"v1.58": 85,
"v1.60": 87,
"v1.62": 88,
"v1.64": 90,
"v1.66": 95,
"v1.68": 97,
"v1.70": 102,
"v1.72": 104,
"v1.74": 106,
"v1.76": 106,
"v1.78": 109,
"v1.80": 113,
"v1.82": 115,
"v1.84": 116,
"v1.86": 123,
"v1.88": 125,
"v1.90": 130,
"v1.92": 131,
}
var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{
102: "v1.70.0",
104: "v1.72.0",
106: "v1.74.0",
109: "v1.78.0",
113: "v1.80.0",
115: "v1.82.0",
116: "v1.84.0",
122: "v1.86.0",
123: "v1.86.2",
125: "v1.88.1",
130: "v1.90.1",
32: "v1.24",
41: "v1.30",
46: "v1.32",
51: "v1.34",
56: "v1.36",
58: "v1.38",
61: "v1.40",
62: "v1.42",
63: "v1.44",
65: "v1.46",
68: "v1.48",
74: "v1.50",
79: "v1.52",
82: "v1.56",
85: "v1.58",
87: "v1.60",
88: "v1.62",
90: "v1.64",
95: "v1.66",
97: "v1.68",
102: "v1.70",
104: "v1.72",
106: "v1.74",
109: "v1.78",
113: "v1.80",
115: "v1.82",
116: "v1.84",
123: "v1.86",
125: "v1.88",
130: "v1.90",
131: "v1.92",
}
// SupportedMajorMinorVersions is the number of major.minor Tailscale versions supported.
const SupportedMajorMinorVersions = 9
const SupportedMajorMinorVersions = 10
// MinSupportedCapabilityVersion represents the minimum capability version
// supported by this Headscale instance (latest 10 minor versions)

View file

@ -9,9 +9,9 @@ var tailscaleLatestMajorMinorTests = []struct {
stripV bool
expected []string
}{
{3, false, []string{"v1.86", "v1.88", "v1.90"}},
{2, true, []string{"1.88", "1.90"}},
{9, true, []string{
{3, false, []string{"v1.88", "v1.90", "v1.92"}},
{2, true, []string{"1.90", "1.92"}},
{10, true, []string{
"1.74",
"1.76",
"1.78",
@ -21,6 +21,7 @@ var tailscaleLatestMajorMinorTests = []struct {
"1.86",
"1.88",
"1.90",
"1.92",
}},
{0, false, nil},
}
@ -29,11 +30,11 @@ var capVerMinimumTailscaleVersionTests = []struct {
input tailcfg.CapabilityVersion
expected string
}{
{106, "v1.74.0"},
{102, "v1.70.0"},
{104, "v1.72.0"},
{109, "v1.78.0"},
{113, "v1.80.0"},
{106, "v1.74"},
{32, "v1.24"},
{41, "v1.30"},
{46, "v1.32"},
{51, "v1.34"},
{9001, ""}, // Test case for a version higher than any in the map
{60, ""}, // Test case for a version lower than any in the map
}

View file

@ -1,9 +1,9 @@
package db
import (
"math/rand"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
@ -68,31 +68,18 @@ func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) {
gc.Cancel(nodeID)
}
// Create a channel to signal when we're done with cleanup checks
cleanupDone := make(chan struct{})
// Close GC
gc.Close()
// Close GC and check for leaks in a separate goroutine
go func() {
// Close GC
gc.Close()
// Give any potential leaked goroutines a chance to exit
// Still need a small sleep here as we're checking for absence of goroutines
time.Sleep(oneHundred)
// Check for leaked goroutines
// Wait for goroutines to clean up and verify no leaks
assert.EventuallyWithT(t, func(c *assert.CollectT) {
finalGoroutines := runtime.NumGoroutine()
t.Logf("Final number of goroutines: %d", finalGoroutines)
// NB: We have to allow for a small number of extra goroutines because of test itself
assert.LessOrEqual(t, finalGoroutines, initialGoroutines+5,
assert.LessOrEqual(c, finalGoroutines, initialGoroutines+5,
"There are significantly more goroutines after GC usage, which suggests a leak")
}, time.Second, 10*time.Millisecond, "goroutines should clean up after GC close")
close(cleanupDone)
}()
// Wait for cleanup to complete
<-cleanupDone
t.Logf("Final number of goroutines: %d", runtime.NumGoroutine())
}
// TestEphemeralGarbageCollectorReschedule is a test for the rescheduling of nodes in EphemeralGarbageCollector().
@ -103,10 +90,14 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
}
// Start GC
@ -125,10 +116,15 @@ func TestEphemeralGarbageCollectorReschedule(t *testing.T) {
// Reschedule the same node with a shorter expiry
gc.Schedule(nodeID, shortExpiry)
// Wait for deletion
time.Sleep(shortExpiry * 2)
// Wait for deletion notification with timeout
select {
case deletedNodeID := <-deletionNotifier:
assert.Equal(t, nodeID, deletedNodeID, "The correct node should be deleted")
case <-time.After(time.Second):
t.Fatal("Timed out waiting for node deletion")
}
// Verify that the node was deleted once
// Verify that the node was deleted exactly once
deleteMutex.Lock()
assert.Len(t, deletedIDs, 1, "Node should be deleted exactly once")
assert.Equal(t, nodeID, deletedIDs[0], "The correct node should be deleted")
@ -203,18 +199,24 @@ func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
var deletedIDs []types.NodeID
var deleteMutex sync.Mutex
deletionNotifier := make(chan types.NodeID, 1)
deleteFunc := func(nodeID types.NodeID) {
deleteMutex.Lock()
deletedIDs = append(deletedIDs, nodeID)
deleteMutex.Unlock()
deletionNotifier <- nodeID
}
// Start the GC
gc := NewEphemeralGarbageCollector(deleteFunc)
go gc.Start()
const longExpiry = 1 * time.Hour
const shortExpiry = fifty
const (
longExpiry = 1 * time.Hour
shortWait = fifty * 2
)
// Schedule node deletion with a long expiry
gc.Schedule(types.NodeID(1), longExpiry)
@ -222,8 +224,13 @@ func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) {
// Close the GC before the timer
gc.Close()
// Wait a short time
time.Sleep(shortExpiry * 2)
// Verify that no deletion occurred within a reasonable time
select {
case <-deletionNotifier:
t.Fatal("Node was deleted after GC was closed, which should not happen")
case <-time.After(shortWait):
// Expected: no deletion should occur
}
// Verify that no deletion occurred
deleteMutex.Lock()
@ -265,29 +272,17 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
// Close GC right away
gc.Close()
// Use a channel to signal when we should check for goroutine count
gcClosedCheck := make(chan struct{})
go func() {
// Give the GC time to fully close and clean up resources
// This is still time-based but only affects when we check the goroutine count,
// not the actual test logic
time.Sleep(oneHundred)
close(gcClosedCheck)
}()
// Now try to schedule node for deletion with a very short expiry
// If the Schedule operation incorrectly creates a timer, it would fire quickly
nodeID := types.NodeID(1)
gc.Schedule(nodeID, 1*time.Millisecond)
// Set up a timeout channel for our test
timeout := time.After(fiveHundred)
// Check if any node was deleted (which shouldn't happen)
// Use timeout to wait for potential deletion
select {
case <-nodeDeleted:
t.Fatal("Node was deleted after GC was closed, which should not happen")
case <-timeout:
case <-time.After(fiveHundred):
// This is the expected path - no deletion should occur
}
@ -298,13 +293,14 @@ func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) {
assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close")
// Check for goroutine leaks after GC is fully closed
<-gcClosedCheck
finalGoroutines := runtime.NumGoroutine()
t.Logf("Final number of goroutines: %d", finalGoroutines)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
finalGoroutines := runtime.NumGoroutine()
// Allow for small fluctuations in goroutine count for testing routines etc
assert.LessOrEqual(c, finalGoroutines, initialGoroutines+2,
"There should be no significant goroutine leaks when Schedule is called after Close")
}, time.Second, 10*time.Millisecond, "goroutines should clean up after GC close")
// Allow for small fluctuations in goroutine count for testing routines etc
assert.LessOrEqual(t, finalGoroutines, initialGoroutines+2,
"There should be no significant goroutine leaks when Schedule is called after Close")
t.Logf("Final number of goroutines: %d", runtime.NumGoroutine())
}
// TestEphemeralGarbageCollectorConcurrentScheduleAndClose tests the behavior of the garbage collector
@ -331,7 +327,8 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
// Number of concurrent scheduling goroutines
const numSchedulers = 10
const nodesPerScheduler = 50
const schedulingDuration = fiveHundred
const closeAfterNodes = 25 // Close GC after this many nodes per scheduler
// Use WaitGroup to wait for all scheduling goroutines to finish
var wg sync.WaitGroup
@ -340,6 +337,9 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
// Create a stopper channel to signal scheduling goroutines to stop
stopScheduling := make(chan struct{})
// Track how many nodes have been scheduled
var scheduledCount int64
// Launch goroutines that continuously schedule nodes
for schedulerIndex := range numSchedulers {
go func(schedulerID int) {
@ -355,18 +355,23 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
default:
nodeID := types.NodeID(baseNodeID + j + 1)
gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test
atomic.AddInt64(&scheduledCount, 1)
// Random (short) sleep to introduce randomness/variability
time.Sleep(time.Duration(rand.Intn(5)) * time.Millisecond)
// Yield to other goroutines to introduce variability
runtime.Gosched()
}
}
}(schedulerIndex)
}
// After a short delay, close the garbage collector while schedulers are still running
// Close the garbage collector after some nodes have been scheduled
go func() {
defer wg.Done()
time.Sleep(schedulingDuration / 2)
// Wait until enough nodes have been scheduled
for atomic.LoadInt64(&scheduledCount) < int64(numSchedulers*closeAfterNodes) {
runtime.Gosched()
}
// Close GC
gc.Close()
@ -378,14 +383,13 @@ func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) {
// Wait for all goroutines to complete
wg.Wait()
// Wait a bit longer to allow any leaked goroutines to do their work
time.Sleep(oneHundred)
// Check for leaks using EventuallyWithT
assert.EventuallyWithT(t, func(c *assert.CollectT) {
finalGoroutines := runtime.NumGoroutine()
// Allow for a reasonable small variable routine count due to testing
assert.LessOrEqual(c, finalGoroutines, initialGoroutines+5,
"There should be no significant goroutine leaks during concurrent Schedule and Close operations")
}, time.Second, 10*time.Millisecond, "goroutines should clean up")
// Check for leaks
finalGoroutines := runtime.NumGoroutine()
t.Logf("Final number of goroutines: %d", finalGoroutines)
// Allow for a reasonable small variable routine count due to testing
assert.LessOrEqual(t, finalGoroutines, initialGoroutines+5,
"There should be no significant goroutine leaks during concurrent Schedule and Close operations")
t.Logf("Final number of goroutines: %d", runtime.NumGoroutine())
}

View file

@ -6,7 +6,9 @@ import (
"math/big"
"net/netip"
"regexp"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
@ -445,7 +447,7 @@ func TestAutoApproveRoutes(t *testing.T) {
RoutableIPs: tt.routes,
},
Tags: []string{"tag:exit"},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
err = adb.DB.Save(&nodeTagged).Error
@ -507,23 +509,48 @@ func TestEphemeralGarbageCollectorOrder(t *testing.T) {
got := []types.NodeID{}
var mu sync.Mutex
deletionCount := make(chan struct{}, 10)
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
mu.Lock()
defer mu.Unlock()
got = append(got, ni)
deletionCount <- struct{}{}
})
go e.Start()
go e.Schedule(1, 1*time.Second)
go e.Schedule(2, 2*time.Second)
go e.Schedule(3, 3*time.Second)
go e.Schedule(4, 4*time.Second)
// Use shorter timeouts for faster tests
go e.Schedule(1, 50*time.Millisecond)
go e.Schedule(2, 100*time.Millisecond)
go e.Schedule(3, 150*time.Millisecond)
go e.Schedule(4, 200*time.Millisecond)
time.Sleep(time.Second)
// Wait for first deletion (node 1 at 50ms)
select {
case <-deletionCount:
case <-time.After(time.Second):
t.Fatal("timeout waiting for first deletion")
}
// Cancel nodes 2 and 4
go e.Cancel(2)
go e.Cancel(4)
time.Sleep(6 * time.Second)
// Wait for node 3 to be deleted (at 150ms)
select {
case <-deletionCount:
case <-time.After(time.Second):
t.Fatal("timeout waiting for second deletion")
}
// Give a bit more time for any unexpected deletions
select {
case <-deletionCount:
// Unexpected - more deletions than expected
case <-time.After(300 * time.Millisecond):
// Expected - no more deletions
}
e.Close()
@ -541,20 +568,30 @@ func TestEphemeralGarbageCollectorLoads(t *testing.T) {
want := 1000
var deletedCount int64
e := NewEphemeralGarbageCollector(func(ni types.NodeID) {
mu.Lock()
defer mu.Unlock()
time.Sleep(time.Duration(generateRandomNumber(t, 3)) * time.Millisecond)
// Yield to other goroutines to introduce variability
runtime.Gosched()
got = append(got, ni)
atomic.AddInt64(&deletedCount, 1)
})
go e.Start()
// Use shorter expiry for faster tests
for i := range want {
go e.Schedule(types.NodeID(i), 1*time.Second)
go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec // test code, no overflow risk
}
time.Sleep(10 * time.Second)
// Wait for all deletions to complete
assert.EventuallyWithT(t, func(c *assert.CollectT) {
count := atomic.LoadInt64(&deletedCount)
assert.Equal(c, int64(want), count, "all nodes should be deleted")
}, 10*time.Second, 50*time.Millisecond, "waiting for all deletions")
e.Close()

View file

@ -364,7 +364,13 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) {
return
}
log.Error().Caller().Err(err).Msgf("STUN ReadFrom")
time.Sleep(time.Second)
// Rate limit error logging - wait before retrying, but respect context cancellation
select {
case <-ctx.Done():
return
case <-time.After(time.Second):
}
continue
}

View file

@ -16,7 +16,6 @@ import (
"time"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
@ -59,14 +58,9 @@ func (api headscaleV1APIServer) CreateUser(
return nil, status.Errorf(codes.Internal, "failed to create user: %s", err)
}
c := change.UserAdded(types.UserID(user.ID))
// TODO(kradalby): Both of these might be policy changes, find a better way to merge.
if !policyChanged.Empty() {
c.Change = change.Policy
}
api.h.Change(c)
// CreateUser returns a policy change response if the user creation affected policy.
// This triggers a full policy re-evaluation for all connected nodes.
api.h.Change(policyChanged)
return &v1.CreateUserResponse{User: user.Proto()}, nil
}
@ -110,7 +104,8 @@ func (api headscaleV1APIServer) DeleteUser(
return nil, err
}
api.h.Change(change.UserRemoved(types.UserID(user.ID)))
// User deletion may affect policy, trigger a full policy re-evaluation.
api.h.Change(change.UserRemoved())
return &v1.DeleteUserResponse{}, nil
}
@ -556,13 +551,7 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N
resp.User = types.TaggedDevices.Proto()
}
var tags []string
for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
resp.ValidTags = lo.Uniq(append(tags, node.Tags().AsSlice()...))
resp.ValidTags = node.Tags().AsSlice()
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp

View file

@ -105,7 +105,7 @@ func TestSetTags_Conversion(t *testing.T) {
tags: []string{"tag:server"},
wantErr: true,
wantCode: codes.InvalidArgument,
wantErrMessage: "invalid or unauthorized tags",
wantErrMessage: "requested tags",
},
{
// Conversion is allowed, but tag authorization fails without tagOwners
@ -114,7 +114,7 @@ func TestSetTags_Conversion(t *testing.T) {
tags: []string{"tag:server", "tag:database"},
wantErr: true,
wantCode: codes.InvalidArgument,
wantErrMessage: "invalid or unauthorized tags",
wantErrMessage: "requested tags",
},
{
name: "reject non-existent node",

View file

@ -11,7 +11,6 @@ import (
"strings"
"time"
"github.com/chasefleming/elem-go/styles"
"github.com/gorilla/mux"
"github.com/juanfont/headscale/hscontrol/assets"
"github.com/juanfont/headscale/hscontrol/templates"
@ -228,13 +227,6 @@ func (h *Headscale) VersionHandler(
}
}
var codeStyleRegisterWebAPI = styles.Props{
styles.Display: "block",
styles.Padding: "20px",
styles.Border: "1px solid #bbb",
styles.BackgroundColor: "#eee",
}
type AuthProviderWeb struct {
serverURL string
}

View file

@ -13,18 +13,13 @@ import (
"github.com/puzpuzpuz/xsync/v4"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var (
mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type and change type",
}, []string{"response_type", "change_type"})
errNodeNotFoundInNodeStore = errors.New("node not found in NodeStore")
)
var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: "headscale",
Name: "mapresponse_generated_total",
Help: "total count of mapresponses generated by response type",
}, []string{"response_type"})
type batcherFunc func(cfg *types.Config, state *state.State) Batcher
@ -36,8 +31,8 @@ type Batcher interface {
RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool
IsConnected(id types.NodeID) bool
ConnectedMap() *xsync.Map[types.NodeID, bool]
AddWork(c ...change.ChangeSet)
MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error)
AddWork(r ...change.Change)
MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error)
DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error)
}
@ -51,7 +46,7 @@ func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeB
workCh: make(chan work, workers*200),
nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](),
connected: xsync.NewMap[types.NodeID, *time.Time](),
pendingChanges: xsync.NewMap[types.NodeID, []change.ChangeSet](),
pendingChanges: xsync.NewMap[types.NodeID, []change.Change](),
}
}
@ -69,15 +64,21 @@ type nodeConnection interface {
nodeID() types.NodeID
version() tailcfg.CapabilityVersion
send(data *tailcfg.MapResponse) error
// computePeerDiff returns peers that were previously sent but are no longer in the current list.
computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID)
// updateSentPeers updates the tracking of which peers have been sent to this node.
updateSentPeers(resp *tailcfg.MapResponse)
}
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID that is based on the provided [change.ChangeSet].
func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion, mapper *mapper, c change.ChangeSet) (*tailcfg.MapResponse, error) {
if c.Empty() {
return nil, nil
// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change].
func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) {
nodeID := nc.nodeID()
version := nc.version()
if r.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send
}
// Validate inputs before processing
if nodeID == 0 {
return nil, fmt.Errorf("invalid nodeID: %d", nodeID)
}
@ -86,141 +87,58 @@ func generateMapResponse(nodeID types.NodeID, version tailcfg.CapabilityVersion,
return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID)
}
// Handle self-only responses
if r.IsSelfOnly() && r.TargetNode != nodeID {
return nil, nil //nolint:nilnil // No response needed for other nodes when self-only
}
var (
mapResp *tailcfg.MapResponse
err error
responseType string
mapResp *tailcfg.MapResponse
err error
)
// Record metric when function exits
defer func() {
if err == nil && mapResp != nil && responseType != "" {
mapResponseGenerated.WithLabelValues(responseType, c.Change.String()).Inc()
}
}()
// Track metric using categorized type, not free-form reason
mapResponseGenerated.WithLabelValues(r.Type()).Inc()
switch c.Change {
case change.DERP:
responseType = "derp"
mapResp, err = mapper.derpMapResponse(nodeID)
// Check if this requires runtime peer visibility computation (e.g., policy changes)
if r.RequiresRuntimePeerComputation {
currentPeers := mapper.state.ListPeers(nodeID)
case change.NodeCameOnline, change.NodeWentOffline:
if c.IsSubnetRouter {
// TODO(kradalby): This can potentially be a peer update of the old and new subnet router.
responseType = "full"
mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
// Trust the change type for online/offline status to avoid race conditions
// between NodeStore updates and change processing
responseType = string(patchResponseDebug)
onlineStatus := c.Change == change.NodeCameOnline
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
Online: ptr.To(onlineStatus),
},
})
currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len())
for _, peer := range currentPeers.All() {
currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID())
}
case change.NodeNewOrUpdate:
// If the node is the one being updated, we send a self update that preserves peer information
// to ensure the node sees changes to its own properties (e.g., hostname/DNS name changes)
// without losing its view of peer status during rapid reconnection cycles
if c.IsSelfUpdate(nodeID) {
responseType = "self"
mapResp, err = mapper.selfMapResponse(nodeID, version)
} else {
responseType = "change"
mapResp, err = mapper.peerChangeResponse(nodeID, version, c.NodeID)
}
case change.NodeRemove:
responseType = "remove"
mapResp, err = mapper.peerRemovedResponse(nodeID, c.NodeID)
case change.NodeKeyExpiry:
// If the node is the one whose key is expiring, we send a "full" self update
// as nodes will ignore patch updates about themselves (?).
if c.IsSelfUpdate(nodeID) {
responseType = "self"
mapResp, err = mapper.selfMapResponse(nodeID, version)
// mapResp, err = mapper.fullMapResponse(nodeID, version)
} else {
responseType = "patch"
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{
{
NodeID: c.NodeID.NodeID(),
KeyExpiry: c.NodeExpiry,
},
})
}
case change.NodeEndpoint, change.NodeDERP:
// Endpoint or DERP changes can be sent as lightweight patches.
// Query the NodeStore for the current peer state to construct the PeerChange.
// Even if only endpoint or only DERP changed, we include both in the patch
// since they're often updated together and it's minimal overhead.
responseType = "patch"
peer, found := mapper.state.GetNodeByID(c.NodeID)
if !found {
return nil, fmt.Errorf("%w: %d", errNodeNotFoundInNodeStore, c.NodeID)
}
peerChange := &tailcfg.PeerChange{
NodeID: c.NodeID.NodeID(),
Endpoints: peer.Endpoints().AsSlice(),
DERPRegion: 0, // Will be set below if available
}
// Extract DERP region from Hostinfo if available
if hi := peer.AsStruct().Hostinfo; hi != nil && hi.NetInfo != nil {
peerChange.DERPRegion = hi.NetInfo.PreferredDERP
}
mapResp, err = mapper.peerChangedPatchResponse(nodeID, []*tailcfg.PeerChange{peerChange})
default:
// The following will always hit this:
// change.Full, change.Policy
responseType = "full"
mapResp, err = mapper.fullMapResponse(nodeID, version)
removedPeers := nc.computePeerDiff(currentPeerIDs)
mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers)
} else {
mapResp, err = mapper.buildFromChange(nodeID, version, &r)
}
if err != nil {
return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err)
}
// TODO(kradalby): Is this necessary?
// Validate the generated map response - only check for nil response
// Note: mapResp.Node can be nil for peer updates, which is valid
if mapResp == nil && c.Change != change.DERP && c.Change != change.NodeRemove {
return nil, fmt.Errorf("generated nil map response for nodeID %d change %s", nodeID, c.Change.String())
}
return mapResp, nil
}
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.ChangeSet].
func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) error {
// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change].
func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error {
if nc == nil {
return errors.New("nodeConnection is nil")
}
nodeID := nc.nodeID()
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("change.type", c.Change.String()).Msg("Node change processing started because change notification received")
log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received")
var data *tailcfg.MapResponse
var err error
data, err = generateMapResponse(nodeID, nc.version(), mapper, c)
data, err := generateMapResponse(nc, mapper, r)
if err != nil {
return fmt.Errorf("generating map response for node %d: %w", nodeID, err)
}
if data == nil {
// No data to send is valid for some change types
// No data to send is valid for some response types
return nil
}
@ -230,6 +148,9 @@ func handleNodeChange(nc nodeConnection, mapper *mapper, c change.ChangeSet) err
return fmt.Errorf("sending map response to node %d: %w", nodeID, err)
}
// Update peer tracking after successful send
nc.updateSentPeers(data)
return nil
}
@ -241,7 +162,7 @@ type workResult struct {
// work represents a unit of work to be processed by workers.
type work struct {
c change.ChangeSet
c change.Change
nodeID types.NodeID
resultCh chan<- workResult // optional channel for synchronous operations
}

View file

@ -1,8 +1,8 @@
package mapper
import (
"context"
"crypto/rand"
"errors"
"fmt"
"sync"
"sync/atomic"
@ -16,6 +16,8 @@ import (
"tailscale.com/types/ptr"
)
var errConnectionClosed = errors.New("connection channel already closed")
// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention.
type LockFreeBatcher struct {
tick *time.Ticker
@ -26,16 +28,16 @@ type LockFreeBatcher struct {
connected *xsync.Map[types.NodeID, *time.Time]
// Work queue channel
workCh chan work
ctx context.Context
cancel context.CancelFunc
workCh chan work
workChOnce sync.Once // Ensures workCh is only closed once
done chan struct{}
doneOnce sync.Once // Ensures done is only closed once
// Batching state
pendingChanges *xsync.Map[types.NodeID, []change.ChangeSet]
pendingChanges *xsync.Map[types.NodeID, []change.Change]
// Metrics
totalNodes atomic.Int64
totalUpdates atomic.Int64
workQueuedCount atomic.Int64
workProcessed atomic.Int64
workErrors atomic.Int64
@ -140,28 +142,27 @@ func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRespo
}
// AddWork queues a change to be processed by the batcher.
func (b *LockFreeBatcher) AddWork(c ...change.ChangeSet) {
b.addWork(c...)
func (b *LockFreeBatcher) AddWork(r ...change.Change) {
b.addWork(r...)
}
func (b *LockFreeBatcher) Start() {
b.ctx, b.cancel = context.WithCancel(context.Background())
b.done = make(chan struct{})
go b.doWork()
}
func (b *LockFreeBatcher) Close() {
if b.cancel != nil {
b.cancel()
b.cancel = nil
}
// Signal shutdown to all goroutines, only once
b.doneOnce.Do(func() {
if b.done != nil {
close(b.done)
}
})
// Only close workCh once
select {
case <-b.workCh:
// Channel is already closed
default:
// Only close workCh once using sync.Once to prevent races
b.workChOnce.Do(func() {
close(b.workCh)
}
})
// Close the underlying channels supplying the data to the clients.
b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool {
@ -187,8 +188,8 @@ func (b *LockFreeBatcher) doWork() {
case <-cleanupTicker.C:
// Clean up nodes that have been offline for too long
b.cleanupOfflineNodes()
case <-b.ctx.Done():
log.Info().Msg("batcher context done, stopping to feed workers")
case <-b.done:
log.Info().Msg("batcher done channel closed, stopping to feed workers")
return
}
}
@ -213,15 +214,19 @@ func (b *LockFreeBatcher) worker(workerID int) {
var result workResult
if nc, exists := b.nodes.Load(w.nodeID); exists {
var err error
result.mapResponse, err = generateMapResponse(nc.nodeID(), nc.version(), b.mapper, w.c)
result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c)
result.err = err
if result.err != nil {
b.workErrors.Add(1)
log.Error().Err(result.err).
Int("worker.id", workerID).
Uint64("node.id", w.nodeID.Uint64()).
Str("change", w.c.Change.String()).
Str("reason", w.c.Reason).
Msg("failed to generate map response for synchronous work")
} else if result.mapResponse != nil {
// Update peer tracking for synchronous responses too
nc.updateSentPeers(result.mapResponse)
}
} else {
result.err = fmt.Errorf("node %d not found", w.nodeID)
@ -236,7 +241,7 @@ func (b *LockFreeBatcher) worker(workerID int) {
// Send result
select {
case w.resultCh <- result:
case <-b.ctx.Done():
case <-b.done:
return
}
@ -254,20 +259,20 @@ func (b *LockFreeBatcher) worker(workerID int) {
b.workErrors.Add(1)
log.Error().Err(err).
Int("worker.id", workerID).
Uint64("node.id", w.c.NodeID.Uint64()).
Str("change", w.c.Change.String()).
Uint64("node.id", w.nodeID.Uint64()).
Str("reason", w.c.Reason).
Msg("failed to apply change")
}
}
case <-b.ctx.Done():
log.Debug().Int("workder.id", workerID).Msg("batcher context is done, exiting worker")
case <-b.done:
log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker")
return
}
}
}
func (b *LockFreeBatcher) addWork(c ...change.ChangeSet) {
b.addToBatch(c...)
func (b *LockFreeBatcher) addWork(r ...change.Change) {
b.addToBatch(r...)
}
// queueWork safely queues work.
@ -277,44 +282,78 @@ func (b *LockFreeBatcher) queueWork(w work) {
select {
case b.workCh <- w:
// Successfully queued
case <-b.ctx.Done():
case <-b.done:
// Batcher is shutting down
return
}
}
// addToBatch adds a change to the pending batch.
func (b *LockFreeBatcher) addToBatch(c ...change.ChangeSet) {
// addToBatch adds changes to the pending batch.
func (b *LockFreeBatcher) addToBatch(changes ...change.Change) {
// Clean up any nodes being permanently removed from the system.
//
// This handles the case where a node is deleted from state but the batcher
// still has it registered. By cleaning up here, we prevent "node not found"
// errors when workers try to generate map responses for deleted nodes.
//
// Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually
// deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy
// changes that affect peer visibility do NOT use this field - they set
// RequiresRuntimePeerComputation=true and compute removed peers at runtime,
// putting them in tailcfg.MapResponse.PeersRemoved (a different struct).
// Therefore, this cleanup only removes nodes that are truly being deleted,
// not nodes that are still connected but have lost visibility of certain peers.
//
// See: https://github.com/juanfont/headscale/issues/2924
for _, ch := range changes {
for _, removedID := range ch.PeersRemoved {
if _, existed := b.nodes.LoadAndDelete(removedID); existed {
b.totalNodes.Add(-1)
log.Debug().
Uint64("node.id", removedID.Uint64()).
Msg("Removed deleted node from batcher")
}
b.connected.Delete(removedID)
b.pendingChanges.Delete(removedID)
}
}
// Short circuit if any of the changes is a full update, which
// means we can skip sending individual changes.
if change.HasFull(c) {
if change.HasFull(changes) {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
b.pendingChanges.Store(nodeID, []change.ChangeSet{{Change: change.Full}})
b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()})
return true
})
return
}
all, self := change.SplitAllAndSelf(c)
for _, changeSet := range self {
changes, _ := b.pendingChanges.LoadOrStore(changeSet.NodeID, []change.ChangeSet{})
changes = append(changes, changeSet)
b.pendingChanges.Store(changeSet.NodeID, changes)
return
}
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
rel := change.RemoveUpdatesForSelf(nodeID, all)
broadcast, targeted := change.SplitTargetedAndBroadcast(changes)
changes, _ := b.pendingChanges.LoadOrStore(nodeID, []change.ChangeSet{})
changes = append(changes, rel...)
b.pendingChanges.Store(nodeID, changes)
// Handle targeted changes - send only to the specific node
for _, ch := range targeted {
pending, _ := b.pendingChanges.LoadOrStore(ch.TargetNode, []change.Change{})
pending = append(pending, ch)
b.pendingChanges.Store(ch.TargetNode, pending)
}
return true
})
// Handle broadcast changes - send to all nodes, filtering as needed
if len(broadcast) > 0 {
b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool {
filtered := change.FilterForNode(nodeID, broadcast)
if len(filtered) > 0 {
pending, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{})
pending = append(pending, filtered...)
b.pendingChanges.Store(nodeID, pending)
}
return true
})
}
}
// processBatchedChanges processes all pending batched changes.
@ -324,14 +363,14 @@ func (b *LockFreeBatcher) processBatchedChanges() {
}
// Process all pending changes
b.pendingChanges.Range(func(nodeID types.NodeID, changes []change.ChangeSet) bool {
if len(changes) == 0 {
b.pendingChanges.Range(func(nodeID types.NodeID, pending []change.Change) bool {
if len(pending) == 0 {
return true
}
// Send all batched changes for this node
for _, c := range changes {
b.queueWork(work{c: c, nodeID: nodeID, resultCh: nil})
for _, ch := range pending {
b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil})
}
// Clear the pending changes for this node
@ -434,17 +473,17 @@ func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] {
// MapResponseFromChange queues work to generate a map response and waits for the result.
// This allows synchronous map generation using the same worker pool.
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, c change.ChangeSet) (*tailcfg.MapResponse, error) {
func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) {
resultCh := make(chan workResult, 1)
// Queue the work with a result channel using the safe queueing method
b.queueWork(work{c: c, nodeID: id, resultCh: resultCh})
b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh})
// Wait for the result
select {
case result := <-resultCh:
return result.mapResponse, result.err
case <-b.ctx.Done():
case <-b.done:
return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id)
}
}
@ -456,6 +495,7 @@ type connectionEntry struct {
version tailcfg.CapabilityVersion
created time.Time
lastUsed atomic.Int64 // Unix timestamp of last successful send
closed atomic.Bool // Indicates if this connection has been closed
}
// multiChannelNodeConn manages multiple concurrent connections for a single node.
@ -467,6 +507,12 @@ type multiChannelNodeConn struct {
connections []*connectionEntry
updateCount atomic.Int64
// lastSentPeers tracks which peers were last sent to this node.
// This enables computing diffs for policy changes instead of sending
// full peer lists (which clients interpret as "no change" when empty).
// Using xsync.Map for lock-free concurrent access.
lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}]
}
// generateConnectionID generates a unique connection identifier.
@ -479,8 +525,9 @@ func generateConnectionID() string {
// newMultiChannelNodeConn creates a new multi-channel node connection.
func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn {
return &multiChannelNodeConn{
id: id,
mapper: mapper,
id: id,
mapper: mapper,
lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](),
}
}
@ -489,6 +536,9 @@ func (mc *multiChannelNodeConn) close() {
defer mc.mutex.Unlock()
for _, conn := range mc.connections {
// Mark as closed before closing the channel to prevent
// send on closed channel panics from concurrent workers
conn.closed.Store(true)
close(conn.c)
}
}
@ -621,6 +671,12 @@ func (entry *connectionEntry) send(data *tailcfg.MapResponse) error {
return nil
}
// Check if the connection has been closed to prevent send on closed channel panic.
// This can happen during shutdown when Close() is called while workers are still processing.
if entry.closed.Load() {
return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed)
}
// Use a short timeout to detect stale connections where the client isn't reading the channel.
// This is critical for detecting Docker containers that are forcefully terminated
// but still have channels that appear open.
@ -654,9 +710,59 @@ func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion {
return mc.connections[0].version
}
// updateSentPeers updates the tracked peer state based on a sent MapResponse.
// This must be called after successfully sending a response to keep track of
// what the client knows about, enabling accurate diffs for future updates.
func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) {
if resp == nil {
return
}
// Full peer list replaces tracked state entirely
if resp.Peers != nil {
mc.lastSentPeers.Clear()
for _, peer := range resp.Peers {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
}
// Incremental additions
for _, peer := range resp.PeersChanged {
mc.lastSentPeers.Store(peer.ID, struct{}{})
}
// Incremental removals
for _, id := range resp.PeersRemoved {
mc.lastSentPeers.Delete(id)
}
}
// computePeerDiff compares the current peer list against what was last sent
// and returns the peers that were removed (in lastSentPeers but not in current).
func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID {
currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers))
for _, id := range currentPeers {
currentSet[id] = struct{}{}
}
var removed []tailcfg.NodeID
// Find removed: in lastSentPeers but not in current
mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool {
if _, exists := currentSet[id]; !exists {
removed = append(removed, id)
}
return true
})
return removed
}
// change applies a change to all active connections for the node.
func (mc *multiChannelNodeConn) change(c change.ChangeSet) error {
return handleNodeChange(mc, mc.mapper, c)
func (mc *multiChannelNodeConn) change(r change.Change) error {
return handleNodeChange(mc, mc.mapper, r)
}
// DebugNodeInfo contains debug information about a node's connections.
@ -715,3 +821,9 @@ func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo {
func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
return b.mapper.debugMapResponses()
}
// WorkErrors returns the count of work errors encountered.
// This is primarily useful for testing and debugging.
func (b *LockFreeBatcher) WorkErrors() int64 {
return b.workErrors.Load()
}

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"sync"
"sync/atomic"
@ -58,7 +59,7 @@ func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapRespo
return fmt.Errorf("%w: %d", errNodeNotFoundAfterAdd, id)
}
t.AddWork(change.NodeOnline(node))
t.AddWork(change.NodeOnlineFor(node))
return nil
}
@ -75,7 +76,7 @@ func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapRe
// Do this BEFORE removing from batcher so the change can be processed
node, ok := t.state.GetNodeByID(id)
if ok {
t.AddWork(change.NodeOffline(node))
t.AddWork(change.NodeOfflineFor(node))
}
// Finally remove from the real batcher
@ -146,12 +147,12 @@ type node struct {
n *types.Node
ch chan *tailcfg.MapResponse
// Update tracking
// Update tracking (all accessed atomically for thread safety)
updateCount int64
patchCount int64
fullCount int64
maxPeersCount int
lastPeerCount int
maxPeersCount atomic.Int64
lastPeerCount atomic.Int64
stop chan struct{}
stopped chan struct{}
}
@ -421,18 +422,32 @@ func (n *node) start() {
// Track update types
if info.IsFull {
atomic.AddInt64(&n.fullCount, 1)
n.lastPeerCount = info.PeerCount
// Update max peers seen
if info.PeerCount > n.maxPeersCount {
n.maxPeersCount = info.PeerCount
n.lastPeerCount.Store(int64(info.PeerCount))
// Update max peers seen using compare-and-swap for thread safety
for {
current := n.maxPeersCount.Load()
if int64(info.PeerCount) <= current {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) {
break
}
}
}
if info.IsPatch {
atomic.AddInt64(&n.patchCount, 1)
// For patches, we track how many patch items
if info.PatchCount > n.maxPeersCount {
n.maxPeersCount = info.PatchCount
// For patches, we track how many patch items using compare-and-swap
for {
current := n.maxPeersCount.Load()
if int64(info.PatchCount) <= current {
break
}
if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) {
break
}
}
}
}
@ -464,8 +479,8 @@ func (n *node) cleanup() NodeStats {
TotalUpdates: atomic.LoadInt64(&n.updateCount),
PatchUpdates: atomic.LoadInt64(&n.patchCount),
FullUpdates: atomic.LoadInt64(&n.fullCount),
MaxPeersSeen: n.maxPeersCount,
LastPeerCount: n.lastPeerCount,
MaxPeersSeen: int(n.maxPeersCount.Load()),
LastPeerCount: int(n.lastPeerCount.Load()),
}
}
@ -502,8 +517,10 @@ func TestEnhancedNodeTracking(t *testing.T) {
// Send the data to the node's channel
testNode.ch <- &resp
// Give it time to process
time.Sleep(100 * time.Millisecond)
// Wait for tracking goroutine to process the update
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.GreaterOrEqual(c, atomic.LoadInt64(&testNode.updateCount), int64(1), "should have processed the update")
}, time.Second, 10*time.Millisecond, "waiting for update to be processed")
// Check stats
stats := testNode.cleanup()
@ -533,17 +550,21 @@ func TestEnhancedTrackingWithBatcher(t *testing.T) {
// Connect the node to the batcher
batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100))
time.Sleep(100 * time.Millisecond) // Let connection settle
// Generate some work
batcher.AddWork(change.FullSet)
time.Sleep(100 * time.Millisecond) // Let work be processed
// Wait for connection to be established
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.True(c, batcher.IsConnected(testNode.n.ID), "node should be connected")
}, time.Second, 10*time.Millisecond, "waiting for node connection")
batcher.AddWork(change.PolicySet)
time.Sleep(100 * time.Millisecond)
// Generate work and wait for updates to be processed
batcher.AddWork(change.FullUpdate())
batcher.AddWork(change.PolicyChange())
batcher.AddWork(change.DERPMap())
batcher.AddWork(change.DERPSet)
time.Sleep(100 * time.Millisecond)
// Wait for updates to be processed (at least 1 update received)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.GreaterOrEqual(c, atomic.LoadInt64(&testNode.updateCount), int64(1), "should have received updates")
}, time.Second, 10*time.Millisecond, "waiting for updates to be processed")
// Check stats
stats := testNode.cleanup()
@ -627,8 +648,8 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
allNodes[i].start()
}
// Give time for tracking goroutines to start
time.Sleep(100 * time.Millisecond)
// Yield to allow tracking goroutines to start
runtime.Gosched()
startTime := time.Now()
@ -640,31 +661,26 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
// Issue full update after each join to ensure connectivity
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
// Add tiny delay for large node counts to prevent overwhelming
// Yield to scheduler for large node counts to prevent overwhelming the work queue
if tc.nodeCount > 100 && i%50 == 49 {
time.Sleep(10 * time.Millisecond)
runtime.Gosched()
}
}
joinTime := time.Since(startTime)
t.Logf("All nodes joined in %v, waiting for full connectivity...", joinTime)
// Wait for all updates to propagate - no timeout, continue until all nodes achieve connectivity
checkInterval := 5 * time.Second
// Wait for all updates to propagate until all nodes achieve connectivity
expectedPeers := tc.nodeCount - 1 // Each node should see all others except itself
for {
time.Sleep(checkInterval)
// Check if all nodes have seen the expected number of peers
assert.EventuallyWithT(t, func(c *assert.CollectT) {
connectedCount := 0
for i := range allNodes {
node := &allNodes[i]
// Check current stats without stopping the tracking
currentMaxPeers := node.maxPeersCount
currentMaxPeers := int(node.maxPeersCount.Load())
if currentMaxPeers >= expectedPeers {
connectedCount++
}
@ -674,12 +690,10 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
t.Logf("Progress: %d/%d nodes (%.1f%%) have seen %d+ peers",
connectedCount, len(allNodes), progress, expectedPeers)
if connectedCount == len(allNodes) {
t.Logf("✅ All nodes achieved full connectivity!")
break
}
}
assert.Equal(c, len(allNodes), connectedCount, "all nodes should achieve full connectivity")
}, 5*time.Minute, 5*time.Second, "waiting for full connectivity")
t.Logf("✅ All nodes achieved full connectivity!")
totalTime := time.Since(startTime)
// Disconnect all nodes
@ -688,8 +702,12 @@ func TestBatcherScalabilityAllToAll(t *testing.T) {
batcher.RemoveNode(node.n.ID, node.ch)
}
// Give time for final updates to process
time.Sleep(500 * time.Millisecond)
// Wait for all nodes to be disconnected
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range allNodes {
assert.False(c, batcher.IsConnected(allNodes[i].n.ID), "node should be disconnected")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to disconnect")
// Collect final statistics
totalUpdates := int64(0)
@ -814,7 +832,7 @@ func TestBatcherBasicOperations(t *testing.T) {
}
// Test work processing with DERP change
batcher.AddWork(change.DERPChange())
batcher.AddWork(change.DERPMap())
// Wait for update and validate content
select {
@ -941,31 +959,31 @@ func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout ti
// }{
// {
// name: "DERP change",
// changeSet: change.DERPSet,
// changeSet: change.DERPMapResponse(),
// expectData: true,
// description: "DERP changes should generate map updates",
// },
// {
// name: "Node key expiry",
// changeSet: change.KeyExpiry(testNodes[1].n.ID),
// changeSet: change.KeyExpiryFor(testNodes[1].n.ID),
// expectData: true,
// description: "Node key expiry with real node data",
// },
// {
// name: "Node new registration",
// changeSet: change.NodeAdded(testNodes[1].n.ID),
// changeSet: change.NodeAddedResponse(testNodes[1].n.ID),
// expectData: true,
// description: "New node registration with real data",
// },
// {
// name: "Full update",
// changeSet: change.FullSet,
// changeSet: change.FullUpdateResponse(),
// expectData: true,
// description: "Full updates with real node data",
// },
// {
// name: "Policy change",
// changeSet: change.PolicySet,
// changeSet: change.PolicyChangeResponse(),
// expectData: true,
// description: "Policy updates with real node data",
// },
@ -1039,13 +1057,13 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
var receivedUpdates []*tailcfg.MapResponse
// Add multiple changes rapidly to test batching
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(testNodes[1].n.ID, testExpiry))
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.KeyExpiryFor(testNodes[1].n.ID, testExpiry))
batcher.AddWork(change.DERPMap())
batcher.AddWork(change.NodeAdded(testNodes[1].n.ID))
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
// Collect updates with timeout
updateCount := 0
@ -1069,8 +1087,8 @@ func TestBatcherWorkQueueBatching(t *testing.T) {
t.Logf("Update %d: nil update", updateCount)
}
case <-timeout:
// Expected: 5 changes should generate 6 updates (no batching in current implementation)
expectedUpdates := 6
// Expected: 5 explicit changes + 1 initial from AddNode + 1 NodeOnline from wrapper = 7 updates
expectedUpdates := 7
t.Logf("Received %d updates from %d changes (expected %d)",
updateCount, 5, expectedUpdates)
@ -1142,21 +1160,22 @@ func XTestBatcherChannelClosingRace(t *testing.T) {
// Add real work during connection chaos
if i%10 == 0 {
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
}
// Rapid second connection - should replace ch1
ch2 := make(chan *tailcfg.MapResponse, 1)
wg.Go(func() {
time.Sleep(1 * time.Microsecond)
runtime.Gosched() // Yield to introduce timing variability
batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100))
})
// Remove second connection
wg.Go(func() {
time.Sleep(2 * time.Microsecond)
runtime.Gosched() // Yield to introduce timing variability
runtime.Gosched() // Extra yield to offset from AddNode
batcher.RemoveNode(testNode.n.ID, ch2)
})
@ -1241,7 +1260,7 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
// Add node and immediately queue real work
batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100))
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
// Consumer goroutine to validate data and detect channel issues
go func() {
@ -1283,15 +1302,17 @@ func TestBatcherWorkerChannelSafety(t *testing.T) {
if i%10 == 0 {
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(testNode.n.ID, testExpiry))
batcher.AddWork(change.KeyExpiryFor(testNode.n.ID, testExpiry))
}
// Rapid removal creates race between worker and removal
time.Sleep(time.Duration(i%3) * 100 * time.Microsecond)
for range i % 3 {
runtime.Gosched() // Introduce timing variability
}
batcher.RemoveNode(testNode.n.ID, ch)
// Give workers time to process and close channels
time.Sleep(5 * time.Millisecond)
// Yield to allow workers to process and close channels
runtime.Gosched()
}()
}
@ -1471,7 +1492,9 @@ func TestBatcherConcurrentClients(t *testing.T) {
wg.Done()
}()
time.Sleep(time.Duration(i%5) * time.Millisecond)
for range i % 5 {
runtime.Gosched() // Introduce timing variability
}
churningChannelsMutex.Lock()
ch, exists := churningChannels[nodeID]
@ -1487,12 +1510,12 @@ func TestBatcherConcurrentClients(t *testing.T) {
// Generate various types of work during racing
if i%3 == 0 {
// DERP changes
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
}
if i%5 == 0 {
// Full updates using real node data
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
}
if i%7 == 0 && len(allNodes) > 0 {
@ -1500,11 +1523,11 @@ func TestBatcherConcurrentClients(t *testing.T) {
node := allNodes[i%len(allNodes)]
// Use a valid expiry time for testing since test nodes don't have expiry set
testExpiry := time.Now().Add(24 * time.Hour)
batcher.AddWork(change.KeyExpiry(node.n.ID, testExpiry))
batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry))
}
// Small delay to allow some batching
time.Sleep(2 * time.Millisecond)
// Yield to allow some batching
runtime.Gosched()
}
wg.Wait()
@ -1519,8 +1542,8 @@ func TestBatcherConcurrentClients(t *testing.T) {
return
}
// Allow final updates to be processed
time.Sleep(100 * time.Millisecond)
// Yield to allow any in-flight updates to complete
runtime.Gosched()
// Validate results
panicMutex.Lock()
@ -1730,8 +1753,8 @@ func XTestBatcherScalability(t *testing.T) {
testNodes[i].start()
}
// Give time for all tracking goroutines to start
time.Sleep(100 * time.Millisecond)
// Yield to allow tracking goroutines to start
runtime.Gosched()
// Connect all nodes first so they can see each other as peers
connectedNodes := make(map[types.NodeID]bool)
@ -1748,10 +1771,21 @@ func XTestBatcherScalability(t *testing.T) {
connectedNodesMutex.Unlock()
}
// Give more time for all connections to be established
time.Sleep(500 * time.Millisecond)
batcher.AddWork(change.FullSet)
time.Sleep(500 * time.Millisecond) // Allow initial update to propagate
// Wait for all connections to be established
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range testNodes {
assert.True(c, batcher.IsConnected(testNodes[i].n.ID), "node should be connected")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect")
batcher.AddWork(change.FullUpdate())
// Wait for initial update to propagate
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range testNodes {
assert.GreaterOrEqual(c, atomic.LoadInt64(&testNodes[i].updateCount), int64(1), "should have received initial update")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for initial update")
go func() {
defer close(done)
@ -1769,9 +1803,9 @@ func XTestBatcherScalability(t *testing.T) {
if cycle%10 == 0 {
t.Logf("Cycle %d/%d completed", cycle, tc.cycles)
}
// Add delays for mixed chaos
// Yield for mixed chaos to introduce timing variability
if tc.chaosType == "mixed" && cycle%10 == 0 {
time.Sleep(time.Duration(cycle%2) * time.Microsecond)
runtime.Gosched()
}
// For chaos testing, only disconnect/reconnect a subset of nodes
@ -1835,9 +1869,12 @@ func XTestBatcherScalability(t *testing.T) {
wg.Done()
}()
// Small delay before reconnecting
time.Sleep(time.Duration(index%3) * time.Millisecond)
batcher.AddNode(
// Yield before reconnecting to introduce timing variability
for range index % 3 {
runtime.Gosched()
}
_ = batcher.AddNode(
nodeID,
channel,
tailcfg.CapabilityVersion(100),
@ -1850,7 +1887,7 @@ func XTestBatcherScalability(t *testing.T) {
// Add work to create load
if index%5 == 0 {
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
}
}(
node.n.ID,
@ -1877,11 +1914,11 @@ func XTestBatcherScalability(t *testing.T) {
// Generate different types of work to ensure updates are sent
switch index % 4 {
case 0:
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
case 1:
batcher.AddWork(change.PolicySet)
batcher.AddWork(change.PolicyChange())
case 2:
batcher.AddWork(change.DERPSet)
batcher.AddWork(change.DERPMap())
default:
// Pick a random node and generate a node change
if len(testNodes) > 0 {
@ -1890,7 +1927,7 @@ func XTestBatcherScalability(t *testing.T) {
change.NodeAdded(testNodes[nodeIdx].n.ID),
)
} else {
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
}
}
}(i)
@ -1941,9 +1978,17 @@ func XTestBatcherScalability(t *testing.T) {
}
}
// Give time for batcher workers to process all the work and send updates
// BEFORE disconnecting nodes
time.Sleep(1 * time.Second)
// Wait for batcher workers to process all work and send updates
// before disconnecting nodes
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Check that at least some updates were processed
var totalUpdates int64
for i := range testNodes {
totalUpdates += atomic.LoadInt64(&testNodes[i].updateCount)
}
assert.Positive(c, totalUpdates, "should have processed some updates")
}, 5*time.Second, 50*time.Millisecond, "waiting for updates to be processed")
// Now disconnect all nodes from batcher to stop new updates
for i := range testNodes {
@ -1951,8 +1996,12 @@ func XTestBatcherScalability(t *testing.T) {
batcher.RemoveNode(node.n.ID, node.ch)
}
// Give time for enhanced tracking goroutines to process any remaining data in channels
time.Sleep(200 * time.Millisecond)
// Wait for nodes to be disconnected
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range testNodes {
assert.False(c, batcher.IsConnected(testNodes[i].n.ID), "node should be disconnected")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to disconnect")
// Cleanup nodes and get their final stats
totalUpdates := int64(0)
@ -2089,17 +2138,24 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
t.Logf("Created %d nodes in database", len(allNodes))
// Connect nodes one at a time to avoid overwhelming the work queue
// Connect nodes one at a time and wait for each to be connected
for i, node := range allNodes {
batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100))
t.Logf("Connected node %d (ID: %d)", i, node.n.ID)
// Small delay between connections to allow NodeCameOnline processing
time.Sleep(50 * time.Millisecond)
// Wait for node to be connected
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.True(c, batcher.IsConnected(node.n.ID), "node should be connected")
}, time.Second, 10*time.Millisecond, "waiting for node connection")
}
// Give additional time for all NodeCameOnline events to be processed
// Wait for all NodeCameOnline events to be processed
t.Logf("Waiting for NodeCameOnline events to settle...")
time.Sleep(500 * time.Millisecond)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range allNodes {
assert.True(c, batcher.IsConnected(allNodes[i].n.ID), "all nodes should be connected")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect")
// Check how many peers each node should see
for i, node := range allNodes {
@ -2109,11 +2165,23 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
// Send a full update - this should generate full peer lists
t.Logf("Sending FullSet update...")
batcher.AddWork(change.FullSet)
batcher.AddWork(change.FullUpdate())
// Give much more time for workers to process the FullSet work items
// Wait for FullSet work items to be processed
t.Logf("Waiting for FullSet to be processed...")
time.Sleep(1 * time.Second)
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Check that some data is available in at least one channel
found := false
for i := range allNodes {
if len(allNodes[i].ch) > 0 {
found = true
break
}
}
assert.True(c, found, "no updates received yet")
}, 5*time.Second, 50*time.Millisecond, "waiting for FullSet updates")
// Check what each node receives - read multiple updates
totalUpdates := 0
@ -2193,7 +2261,7 @@ func TestBatcherFullPeerUpdates(t *testing.T) {
t.Logf("Total updates received across all nodes: %d", totalUpdates)
if !foundFullUpdate {
t.Errorf("CRITICAL: No FULL updates received despite sending change.FullSet!")
t.Errorf("CRITICAL: No FULL updates received despite sending change.FullUpdateResponse()!")
t.Errorf(
"This confirms the bug - FullSet updates are not generating full peer responses",
)
@ -2226,7 +2294,12 @@ func TestBatcherRapidReconnection(t *testing.T) {
}
}
time.Sleep(100 * time.Millisecond) // Let connections settle
// Wait for all connections to settle
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range allNodes {
assert.True(c, batcher.IsConnected(allNodes[i].n.ID), "node should be connected")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle")
// Phase 2: Rapid disconnect ALL nodes (simulating nodes going down)
t.Logf("Phase 2: Rapid disconnect all nodes...")
@ -2246,7 +2319,12 @@ func TestBatcherRapidReconnection(t *testing.T) {
}
}
time.Sleep(100 * time.Millisecond) // Let reconnections settle
// Wait for all reconnections to settle
assert.EventuallyWithT(t, func(c *assert.CollectT) {
for i := range allNodes {
assert.True(c, batcher.IsConnected(allNodes[i].n.ID), "node should be reconnected")
}
}, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle")
// Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR
t.Logf("Phase 4: Checking debug status...")
@ -2294,7 +2372,7 @@ func TestBatcherRapidReconnection(t *testing.T) {
t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...")
// Send a change that should reach all nodes
batcher.AddWork(change.DERPChange())
batcher.AddWork(change.DERPMap())
receivedCount := 0
timeout := time.After(500 * time.Millisecond)
@ -2347,7 +2425,11 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Fatalf("Failed to add node2: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Wait for initial connections
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.True(c, batcher.IsConnected(node1.n.ID), "node1 should be connected")
assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected")
}, time.Second, 10*time.Millisecond, "waiting for initial connections")
// Phase 2: Add second connection for node1 (multi-connection scenario)
t.Logf("Phase 2: Adding second connection for node 1...")
@ -2357,7 +2439,8 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Fatalf("Failed to add second connection for node1: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Yield to allow connection to be processed
runtime.Gosched()
// Phase 3: Add third connection for node1
t.Logf("Phase 3: Adding third connection for node 1...")
@ -2367,7 +2450,8 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Fatalf("Failed to add third connection for node1: %v", err)
}
time.Sleep(50 * time.Millisecond)
// Yield to allow connection to be processed
runtime.Gosched()
// Phase 4: Verify debug status shows correct connection count
t.Logf("Phase 4: Verifying debug status shows multiple connections...")
@ -2424,15 +2508,14 @@ func TestBatcherMultiConnection(t *testing.T) {
clearChannel(node2.ch)
// Send a change notification from node2 (so node1 should receive it on all connections)
testChangeSet := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
testChangeSet := change.NodeAdded(node2.n.ID)
batcher.AddWork(testChangeSet)
time.Sleep(100 * time.Millisecond) // Let updates propagate
// Wait for updates to propagate to at least one channel
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Positive(c, len(node1.ch)+len(secondChannel)+len(thirdChannel), "should have received updates")
}, 5*time.Second, 50*time.Millisecond, "waiting for updates to propagate")
// Verify all three connections for node1 receive the update
connection1Received := false
@ -2479,7 +2562,8 @@ func TestBatcherMultiConnection(t *testing.T) {
t.Errorf("Failed to remove second connection for node1")
}
time.Sleep(50 * time.Millisecond)
// Yield to allow removal to be processed
runtime.Gosched()
// Verify debug status shows 2 connections now
if debugBatcher, ok := batcher.(interface {
@ -2503,14 +2587,14 @@ func TestBatcherMultiConnection(t *testing.T) {
clearChannel(node1.ch)
clearChannel(thirdChannel)
testChangeSet2 := change.ChangeSet{
NodeID: node2.n.ID,
Change: change.NodeNewOrUpdate,
SelfUpdateOnly: false,
}
testChangeSet2 := change.NodeAdded(node2.n.ID)
batcher.AddWork(testChangeSet2)
time.Sleep(100 * time.Millisecond)
// Wait for updates to propagate to remaining channels
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.Positive(c, len(node1.ch)+len(thirdChannel), "should have received updates")
}, 5*time.Second, 50*time.Millisecond, "waiting for updates to propagate")
// Verify remaining connections still receive updates
remaining1Received := false
@ -2537,7 +2621,11 @@ func TestBatcherMultiConnection(t *testing.T) {
remaining1Received, remaining3Received)
}
// Verify second channel no longer receives updates (should be closed/removed)
// Drain secondChannel of any messages received before removal
// (the test wrapper sends NodeOffline before removal, which may have reached this channel)
clearChannel(secondChannel)
// Verify second channel no longer receives new updates after being removed
select {
case <-secondChannel:
t.Errorf("Removed connection still received update - this should not happen")
@ -2547,3 +2635,140 @@ func TestBatcherMultiConnection(t *testing.T) {
})
}
}
// TestNodeDeletedWhileChangesPending reproduces issue #2924 where deleting a node
// from state while there are pending changes for that node in the batcher causes
// "node not found" errors. The race condition occurs when:
// 1. Node is connected and changes are queued for it
// 2. Node is deleted from state (NodeStore) but not from batcher
// 3. Batcher worker tries to generate map response for deleted node
// 4. Mapper fails to find node in state, causing repeated "node not found" errors.
func TestNodeDeletedWhileChangesPending(t *testing.T) {
for _, batcherFunc := range allBatcherFunctions {
t.Run(batcherFunc.name, func(t *testing.T) {
// Create test environment with 3 nodes
testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, NORMAL_BUFFER_SIZE)
defer cleanup()
batcher := testData.Batcher
st := testData.State
node1 := &testData.Nodes[0]
node2 := &testData.Nodes[1]
node3 := &testData.Nodes[2]
t.Logf("Testing issue #2924: Node1=%d, Node2=%d, Node3=%d",
node1.n.ID, node2.n.ID, node3.n.ID)
// Helper to drain channels
drainCh := func(ch chan *tailcfg.MapResponse) {
for {
select {
case <-ch:
// drain
default:
return
}
}
}
// Start update consumers for all nodes
node1.start()
node2.start()
node3.start()
defer node1.cleanup()
defer node2.cleanup()
defer node3.cleanup()
// Connect all nodes to the batcher
require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)))
require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)))
require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100)))
// Wait for all nodes to be connected
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.True(c, batcher.IsConnected(node1.n.ID), "node1 should be connected")
assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected")
assert.True(c, batcher.IsConnected(node3.n.ID), "node3 should be connected")
}, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect")
// Get initial work errors count
var initialWorkErrors int64
if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok {
initialWorkErrors = lfb.WorkErrors()
t.Logf("Initial work errors: %d", initialWorkErrors)
}
// Clear channels to prepare for the test
drainCh(node1.ch)
drainCh(node2.ch)
drainCh(node3.ch)
// Get node view for deletion
nodeToDelete, ok := st.GetNodeByID(node3.n.ID)
require.True(t, ok, "node3 should exist in state")
// Delete the node from state - this returns a NodeRemoved change
// In production, this change is sent to batcher via app.Change()
nodeChange, err := st.DeleteNode(nodeToDelete)
require.NoError(t, err, "should be able to delete node from state")
t.Logf("Deleted node %d from state, change: %s", node3.n.ID, nodeChange.Reason)
// Verify node is deleted from state
_, exists := st.GetNodeByID(node3.n.ID)
require.False(t, exists, "node3 should be deleted from state")
// Send the NodeRemoved change to batcher (this is what app.Change() does)
// With the fix, this should clean up node3 from batcher's internal state
batcher.AddWork(nodeChange)
// Wait for the batcher to process the removal and clean up the node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
assert.False(c, batcher.IsConnected(node3.n.ID), "node3 should be disconnected from batcher")
}, 5*time.Second, 50*time.Millisecond, "waiting for node removal to be processed")
t.Logf("Node %d connected in batcher after NodeRemoved: %v", node3.n.ID, batcher.IsConnected(node3.n.ID))
// Now queue changes that would have caused errors before the fix
// With the fix, these should NOT cause "node not found" errors
// because node3 was cleaned up when NodeRemoved was processed
batcher.AddWork(change.FullUpdate())
batcher.AddWork(change.PolicyChange())
// Wait for work to be processed and verify no errors occurred
// With the fix, no new errors should occur because the deleted node
// was cleaned up from batcher state when NodeRemoved was processed
assert.EventuallyWithT(t, func(c *assert.CollectT) {
var finalWorkErrors int64
if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok {
finalWorkErrors = lfb.WorkErrors()
}
newErrors := finalWorkErrors - initialWorkErrors
assert.Zero(c, newErrors, "Fix for #2924: should have no work errors after node deletion")
}, 5*time.Second, 100*time.Millisecond, "waiting for work processing to complete without errors")
// Verify remaining nodes still work correctly
drainCh(node1.ch)
drainCh(node2.ch)
batcher.AddWork(change.NodeAdded(node1.n.ID))
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// Node 1 and 2 should receive updates
stats1 := NodeStats{TotalUpdates: atomic.LoadInt64(&node1.updateCount)}
stats2 := NodeStats{TotalUpdates: atomic.LoadInt64(&node2.updateCount)}
assert.Positive(c, stats1.TotalUpdates, "node1 should have received updates")
assert.Positive(c, stats2.TotalUpdates, "node2 should have received updates")
}, 5*time.Second, 100*time.Millisecond, "waiting for remaining nodes to receive updates")
})
}
}
// unwrapBatcher extracts the underlying batcher from wrapper types.
func unwrapBatcher(b Batcher) Batcher {
if wrapper, ok := b.(*testBatcherWrapper); ok {
return unwrapBatcher(wrapper.Batcher)
}
return b
}

View file

@ -29,10 +29,8 @@ type debugType string
const (
fullResponseDebug debugType = "full"
selfResponseDebug debugType = "self"
patchResponseDebug debugType = "patch"
removeResponseDebug debugType = "remove"
changeResponseDebug debugType = "change"
derpResponseDebug debugType = "derp"
policyResponseDebug debugType = "policy"
)
// NewMapResponseBuilder creates a new builder with basic fields set.
@ -76,8 +74,9 @@ func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder {
}
_, matchers := b.mapper.state.Filter()
tailnode, err := tailNode(
nv, b.capVer, b.mapper.state,
tailnode, err := nv.TailNode(
b.capVer,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(nv, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},
@ -251,8 +250,8 @@ func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) (
changedViews = peers
}
tailPeers, err := tailNodes(
changedViews, b.capVer, b.mapper.state,
tailPeers, err := types.TailNodes(
changedViews, b.capVer,
func(id types.NodeID) []netip.Prefix {
return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers)
},

View file

@ -4,7 +4,6 @@ import (
"encoding/json"
"fmt"
"io/fs"
"net/netip"
"net/url"
"os"
"path"
@ -15,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/state"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/types/change"
"github.com/rs/zerolog/log"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
@ -180,52 +180,108 @@ func (m *mapper) selfMapResponse(
return ma, err
}
func (m *mapper) derpMapResponse(
nodeID types.NodeID,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(derpResponseDebug).
WithDERPMap().
Build()
}
// PeerChangedPatchResponse creates a patch MapResponse with
// incoming update from a state change.
func (m *mapper) peerChangedPatchResponse(
nodeID types.NodeID,
changed []*tailcfg.PeerChange,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(patchResponseDebug).
WithPeerChangedPatch(changed).
Build()
}
// peerChangeResponse returns a MapResponse with changed or added nodes.
func (m *mapper) peerChangeResponse(
// policyChangeResponse creates a MapResponse for policy changes.
// It sends:
// - PeersRemoved for peers that are no longer visible after the policy change
// - PeersChanged for remaining peers (their AllowedIPs may have changed due to policy)
// - Updated PacketFilters
// - Updated SSHPolicy (SSH rules may reference users/groups that changed)
// This avoids the issue where an empty Peers slice is interpreted by Tailscale
// clients as "no change" rather than "no peers".
func (m *mapper) policyChangeResponse(
nodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
changedNodeID types.NodeID,
removedPeers []tailcfg.NodeID,
currentPeers views.Slice[types.NodeView],
) (*tailcfg.MapResponse, error) {
peers := m.state.ListPeers(nodeID, changedNodeID)
return m.NewMapResponseBuilder(nodeID).
WithDebugType(changeResponseDebug).
builder := m.NewMapResponseBuilder(nodeID).
WithDebugType(policyResponseDebug).
WithCapabilityVersion(capVer).
WithUserProfiles(peers).
WithPeerChanges(peers).
Build()
WithPacketFilters().
WithSSHPolicy()
if len(removedPeers) > 0 {
// Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved
removedIDs := make([]types.NodeID, len(removedPeers))
for i, id := range removedPeers {
removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent
}
builder.WithPeersRemoved(removedIDs...)
}
// Send remaining peers in PeersChanged - their AllowedIPs may have
// changed due to the policy update (e.g., different routes allowed).
if currentPeers.Len() > 0 {
builder.WithPeerChanges(currentPeers)
}
return builder.Build()
}
// peerRemovedResponse creates a MapResponse indicating that a peer has been removed.
func (m *mapper) peerRemovedResponse(
// buildFromChange builds a MapResponse from a change.Change specification.
// This provides fine-grained control over what gets included in the response.
func (m *mapper) buildFromChange(
nodeID types.NodeID,
removedNodeID types.NodeID,
capVer tailcfg.CapabilityVersion,
resp *change.Change,
) (*tailcfg.MapResponse, error) {
return m.NewMapResponseBuilder(nodeID).
WithDebugType(removeResponseDebug).
WithPeersRemoved(removedNodeID).
Build()
if resp.IsEmpty() {
return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error
}
// If this is a self-update (the changed node is the receiving node),
// send a self-update response to ensure the node sees its own changes.
if resp.OriginNode != 0 && resp.OriginNode == nodeID {
return m.selfMapResponse(nodeID, capVer)
}
builder := m.NewMapResponseBuilder(nodeID).
WithCapabilityVersion(capVer).
WithDebugType(changeResponseDebug)
if resp.IncludeSelf {
builder.WithSelfNode()
}
if resp.IncludeDERPMap {
builder.WithDERPMap()
}
if resp.IncludeDNS {
builder.WithDNSConfig()
}
if resp.IncludeDomain {
builder.WithDomain()
}
if resp.IncludePolicy {
builder.WithPacketFilters()
builder.WithSSHPolicy()
}
if resp.SendAllPeers {
peers := m.state.ListPeers(nodeID)
builder.WithUserProfiles(peers)
builder.WithPeers(peers)
} else {
if len(resp.PeersChanged) > 0 {
peers := m.state.ListPeers(nodeID, resp.PeersChanged...)
builder.WithUserProfiles(peers)
builder.WithPeerChanges(peers)
}
if len(resp.PeersRemoved) > 0 {
builder.WithPeersRemoved(resp.PeersRemoved...)
}
}
if len(resp.PeerPatches) > 0 {
builder.WithPeerChangedPatch(resp.PeerPatches)
}
return builder.Build()
}
func writeDebugMapResponse(
@ -259,11 +315,6 @@ func writeDebugMapResponse(
}
}
// routeFilterFunc is a function that takes a node ID and returns a list of
// netip.Prefixes that are allowed for that node. It is used to filter routes
// from the primary route manager to the node.
type routeFilterFunc func(id types.NodeID) []netip.Prefix
func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) {
if debugDumpMapResponsePath == "" {
return nil, nil

View file

@ -1,146 +0,0 @@
package mapper
import (
"fmt"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/samber/lo"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/views"
)
// NodeCanHaveTagChecker is an interface for checking if a node can have a tag.
type NodeCanHaveTagChecker interface {
NodeCanHaveTag(node types.NodeView, tag string) bool
}
func tailNodes(
nodes views.Slice[types.NodeView],
capVer tailcfg.CapabilityVersion,
checker NodeCanHaveTagChecker,
primaryRouteFunc routeFilterFunc,
cfg *types.Config,
) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, 0, nodes.Len())
for _, node := range nodes.All() {
tNode, err := tailNode(
node,
capVer,
checker,
primaryRouteFunc,
cfg,
)
if err != nil {
return nil, err
}
tNodes = append(tNodes, tNode)
}
return tNodes, nil
}
// tailNode converts a Node into a Tailscale Node.
func tailNode(
node types.NodeView,
capVer tailcfg.CapabilityVersion,
checker NodeCanHaveTagChecker,
primaryRouteFunc routeFilterFunc,
cfg *types.Config,
) (*tailcfg.Node, error) {
addrs := node.Prefixes()
var derp int
// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077
// and should be removed after 111 is the minimum capver.
var legacyDERP string
if node.Hostinfo().Valid() && node.Hostinfo().NetInfo().Valid() {
legacyDERP = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo().NetInfo().PreferredDERP())
derp = node.Hostinfo().NetInfo().PreferredDERP()
} else {
legacyDERP = "127.3.3.40:0" // Zero means disconnected or unknown.
}
var keyExpiry time.Time
if node.Expiry().Valid() {
keyExpiry = node.Expiry().Get()
} else {
keyExpiry = time.Time{}
}
hostname, err := node.GetFQDN(cfg.BaseDomain)
if err != nil {
return nil, err
}
var tags []string
for _, tag := range node.RequestTagsSlice().All() {
if checker.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
for _, tag := range node.Tags().All() {
tags = append(tags, tag)
}
tags = lo.Uniq(tags)
routes := primaryRouteFunc(node.ID())
allowed := append(addrs, routes...)
allowed = append(allowed, node.ExitRoutes()...)
tsaddr.SortPrefixes(allowed)
tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID()), // this is the actual ID
StableID: node.ID().StableID(),
Name: hostname,
Cap: capVer,
User: node.TailscaleUserID(),
Key: node.NodeKey(),
KeyExpiry: keyExpiry.UTC(),
Machine: node.MachineKey(),
DiscoKey: node.DiscoKey(),
Addresses: addrs,
PrimaryRoutes: routes,
AllowedIPs: allowed,
Endpoints: node.Endpoints().AsSlice(),
HomeDERP: derp,
LegacyDERPString: legacyDERP,
Hostinfo: node.Hostinfo(),
Created: node.CreatedAt().UTC(),
Online: node.IsOnline().Clone(),
Tags: tags,
MachineAuthorized: !node.IsExpired(),
Expired: node.IsExpired(),
}
tNode.CapMap = tailcfg.NodeCapMap{
tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{},
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
}
if cfg.RandomizeClientPort {
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
}
// Set LastSeen only for offline nodes to avoid confusing Tailscale clients
// during rapid reconnection cycles. Online nodes should not have LastSeen set
// as this can make clients interpret them as "not online" despite Online=true.
if node.LastSeen().Valid() && node.IsOnline().Valid() && !node.IsOnline().Get() {
lastSeen := node.LastSeen().Get()
tNode.LastSeen = &lastSeen
}
return &tNode, nil
}

View file

@ -8,10 +8,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/routes"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/require"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
@ -71,7 +69,6 @@ func TestTailNode(t *testing.T) {
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
@ -186,7 +183,6 @@ func TestTailNode(t *testing.T) {
HomeDERP: 0,
LegacyDERPString: "127.3.3.40:0",
Hostinfo: hiview(tailcfg.Hostinfo{}),
Tags: []string{},
MachineAuthorized: true,
CapMap: tailcfg.NodeCapMap{
@ -204,23 +200,20 @@ func TestTailNode(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
polMan, err := policy.NewPolicyManager(tt.pol, []types.User{}, types.Nodes{tt.node}.ViewSlice())
require.NoError(t, err)
primary := routes.New()
cfg := &types.Config{
BaseDomain: tt.baseDomain,
TailcfgDNSConfig: tt.dnsConfig,
RandomizeClientPort: false,
Taildrop: types.TaildropConfig{Enabled: true},
}
_ = primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...)
// This is a hack to avoid having a second node to test the primary route.
// This should be baked into the test case proper if it is extended in the future.
_ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24"))
got, err := tailNode(
tt.node.View(),
got, err := tt.node.View().TailNode(
0,
polMan,
func(id types.NodeID) []netip.Prefix {
return primary.PrimaryRoutes(id)
},
@ -228,13 +221,13 @@ func TestTailNode(t *testing.T) {
)
if (err != nil) != tt.wantErr {
t.Errorf("tailNode() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("TailNode() error = %v, wantErr %v", err, tt.wantErr)
return
}
if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("tailNode() unexpected result (-want +got):\n%s", diff)
t.Errorf("TailNode() unexpected result (-want +got):\n%s", diff)
}
})
}
@ -274,17 +267,13 @@ func TestNodeExpiry(t *testing.T) {
GivenName: "test",
Expiry: tt.exp,
}
polMan, err := policy.NewPolicyManager(nil, nil, types.Nodes{}.ViewSlice())
require.NoError(t, err)
tn, err := tailNode(
node.View(),
tn, err := node.View().TailNode(
0,
polMan,
func(id types.NodeID) []netip.Prefix {
return []netip.Prefix{}
},
&types.Config{},
&types.Config{Taildrop: types.TaildropConfig{Enabled: true}},
)
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)

View file

@ -1,47 +0,0 @@
package mapper
import "tailscale.com/tailcfg"
// mergePatch takes the current patch and a newer patch
// and override any field that has changed.
func mergePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
}

View file

@ -32,31 +32,16 @@ var (
Name: "mapresponse_sent_total",
Help: "total count of mapresponses sent to clients",
}, []string{"status", "type"})
mapResponseUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_updates_received_total",
Help: "total count of mapresponse updates received on update channel",
}, []string{"type"})
mapResponseEndpointUpdates = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_endpoint_updates_total",
Help: "total count of endpoint updates received",
}, []string{"status"})
mapResponseReadOnly = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_readonly_requests_total",
Help: "total count of readonly requests received",
}, []string{"status"})
mapResponseEnded = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_ended_total",
Help: "total count of new mapsessions ended",
}, []string{"reason"})
mapResponseClosed = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "mapresponse_closed_total",
Help: "total count of calls to mapresponse close",
}, []string{"return"})
httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{
Namespace: prometheusNamespace,
Name: "http_duration_seconds",

View file

@ -29,9 +29,6 @@ const (
// of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise.
// The early payload is optional. Some servers may not send it... But we do!
earlyPayloadMagic = "\xff\xff\xffTS"
// EarlyNoise was added in protocol version 49.
earlyNoiseCapabilityVersion = 49
)
type noiseServer struct {

View file

@ -41,6 +41,7 @@ var (
errOIDCAllowedUsers = errors.New(
"authenticated principal does not match any allowed user",
)
errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email")
)
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
@ -173,11 +174,6 @@ func (a *AuthProviderOIDC) RegisterHandler(
http.Redirect(writer, req, authURL, http.StatusFound)
}
type oidcCallbackTemplateConfig struct {
User string
Verb string
}
// OIDCCallbackHandler handles the callback from the OIDC endpoint
// Retrieves the nkey from the state cache and adds the node to the users email user
// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities
@ -269,17 +265,8 @@ func (a *AuthProviderOIDC) OIDCCallbackHandler(
// The user claims are now updated from the userinfo endpoint so we can verify the user
// against allowed emails, email domains, and groups.
if err := validateOIDCAllowedDomains(a.cfg.AllowedDomains, &claims); err != nil {
httpError(writer, err)
return
}
if err := validateOIDCAllowedGroups(a.cfg.AllowedGroups, &claims); err != nil {
httpError(writer, err)
return
}
if err := validateOIDCAllowedUsers(a.cfg.AllowedUsers, &claims); err != nil {
err = doOIDCAuthorization(a.cfg, &claims)
if err != nil {
httpError(writer, err)
return
}
@ -439,17 +426,13 @@ func validateOIDCAllowedGroups(
allowedGroups []string,
claims *types.OIDCClaims,
) error {
if len(allowedGroups) > 0 {
for _, group := range allowedGroups {
if slices.Contains(claims.Groups, group) {
return nil
}
for _, group := range allowedGroups {
if slices.Contains(claims.Groups, group) {
return nil
}
return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups)
}
return nil
return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups)
}
// validateOIDCAllowedUsers checks that if AllowedUsers is provided,
@ -458,14 +441,62 @@ func validateOIDCAllowedUsers(
allowedUsers []string,
claims *types.OIDCClaims,
) error {
if len(allowedUsers) > 0 &&
!slices.Contains(allowedUsers, claims.Email) {
if !slices.Contains(allowedUsers, claims.Email) {
return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers)
}
return nil
}
// doOIDCAuthorization applies authorization tests to claims.
//
// The following tests are always applied:
//
// - validateOIDCAllowedGroups
//
// The following tests are applied if cfg.EmailVerifiedRequired=false
// or claims.email_verified=true:
//
// - validateOIDCAllowedDomains
// - validateOIDCAllowedUsers
//
// NOTE that, contrary to the function name, validateOIDCAllowedUsers
// only checks the email address -- not the username.
func doOIDCAuthorization(
cfg *types.OIDCConfig,
claims *types.OIDCClaims,
) error {
if len(cfg.AllowedGroups) > 0 {
err := validateOIDCAllowedGroups(cfg.AllowedGroups, claims)
if err != nil {
return err
}
}
trustEmail := !cfg.EmailVerifiedRequired || bool(claims.EmailVerified)
hasEmailTests := len(cfg.AllowedDomains) > 0 || len(cfg.AllowedUsers) > 0
if !trustEmail && hasEmailTests {
return NewHTTPError(http.StatusUnauthorized, "unverified email", errOIDCUnverifiedEmail)
}
if len(cfg.AllowedDomains) > 0 {
err := validateOIDCAllowedDomains(cfg.AllowedDomains, claims)
if err != nil {
return err
}
}
if len(cfg.AllowedUsers) > 0 {
err := validateOIDCAllowedUsers(cfg.AllowedUsers, claims)
if err != nil {
return err
}
}
return nil
}
// getRegistrationIDFromState retrieves the registration ID from the state.
func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID {
regInfo, ok := a.registrationCache.Get(state)
@ -478,14 +509,16 @@ func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.Regis
func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
claims *types.OIDCClaims,
) (*types.User, change.ChangeSet, error) {
var user *types.User
var err error
var newUser bool
var c change.ChangeSet
) (*types.User, change.Change, error) {
var (
user *types.User
err error
newUser bool
c change.Change
)
user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier())
if err != nil && !errors.Is(err, db.ErrUserNotFound) {
return nil, change.EmptySet, fmt.Errorf("creating or updating user: %w", err)
return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err)
}
// if the user is still not found, create a new empty user.
@ -496,12 +529,12 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
user = &types.User{}
}
user.FromClaim(claims)
user.FromClaim(claims, a.cfg.EmailVerifiedRequired)
if newUser {
user, c, err = a.h.state.CreateUser(*user)
if err != nil {
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
}
} else {
_, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error {
@ -509,7 +542,7 @@ func (a *AuthProviderOIDC) createOrUpdateUserFromClaim(
return nil
})
if err != nil {
return nil, change.EmptySet, fmt.Errorf("updating user: %w", err)
return nil, change.Change{}, fmt.Errorf("updating user: %w", err)
}
}
@ -550,7 +583,7 @@ func (a *AuthProviderOIDC) handleRegistration(
// Send both changes. Empty changes are ignored by Change().
a.h.Change(nodeChange, routesChange)
return !nodeChange.Empty(), nil
return !nodeChange.IsEmpty(), nil
}
func renderOIDCCallbackTemplate(

173
hscontrol/oidc_test.go Normal file
View file

@ -0,0 +1,173 @@
package hscontrol
import (
"testing"
"github.com/juanfont/headscale/hscontrol/types"
)
func TestDoOIDCAuthorization(t *testing.T) {
testCases := []struct {
name string
cfg *types.OIDCConfig
claims *types.OIDCClaims
wantErr bool
}{
{
name: "verified email domain",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{"test.com"},
AllowedUsers: []string{},
AllowedGroups: []string{},
},
claims: &types.OIDCClaims{
Email: "user@test.com",
EmailVerified: true,
},
},
{
name: "verified email user",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{},
AllowedUsers: []string{"user@test.com"},
AllowedGroups: []string{},
},
claims: &types.OIDCClaims{
Email: "user@test.com",
EmailVerified: true,
},
},
{
name: "unverified email domain",
wantErr: true,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{"test.com"},
AllowedUsers: []string{},
AllowedGroups: []string{},
},
claims: &types.OIDCClaims{
Email: "user@test.com",
EmailVerified: false,
},
},
{
name: "group member",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{},
AllowedUsers: []string{},
AllowedGroups: []string{"test"},
},
claims: &types.OIDCClaims{Groups: []string{"test"}},
},
{
name: "non group member",
wantErr: true,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{},
AllowedUsers: []string{},
AllowedGroups: []string{"nope"},
},
claims: &types.OIDCClaims{Groups: []string{"testo"}},
},
{
name: "group member but bad domain",
wantErr: true,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{"user@good.com"},
AllowedUsers: []string{},
AllowedGroups: []string{"test group"},
},
claims: &types.OIDCClaims{Groups: []string{"test group"}, Email: "bad@bad.com", EmailVerified: true},
},
{
name: "all checks pass",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{"test.com"},
AllowedUsers: []string{"user@test.com"},
AllowedGroups: []string{"test group"},
},
claims: &types.OIDCClaims{Groups: []string{"test group"}, Email: "user@test.com", EmailVerified: true},
},
{
name: "all checks pass with unverified email",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: false,
AllowedDomains: []string{"test.com"},
AllowedUsers: []string{"user@test.com"},
AllowedGroups: []string{"test group"},
},
claims: &types.OIDCClaims{Groups: []string{"test group"}, Email: "user@test.com", EmailVerified: false},
},
{
name: "fail on unverified email",
wantErr: true,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{"test.com"},
AllowedUsers: []string{"user@test.com"},
AllowedGroups: []string{"test group"},
},
claims: &types.OIDCClaims{Groups: []string{"test group"}, Email: "user@test.com", EmailVerified: false},
},
{
name: "unverified email user only",
wantErr: true,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{},
AllowedUsers: []string{"user@test.com"},
AllowedGroups: []string{},
},
claims: &types.OIDCClaims{
Email: "user@test.com",
EmailVerified: false,
},
},
{
name: "no filters configured",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{},
AllowedUsers: []string{},
AllowedGroups: []string{},
},
claims: &types.OIDCClaims{
Email: "anyone@anywhere.com",
EmailVerified: false,
},
},
{
name: "multiple allowed groups second matches",
wantErr: false,
cfg: &types.OIDCConfig{
EmailVerifiedRequired: true,
AllowedDomains: []string{},
AllowedUsers: []string{},
AllowedGroups: []string{"group1", "group2", "group3"},
},
claims: &types.OIDCClaims{Groups: []string{"group2"}},
},
}
for _, tC := range testCases {
t.Run(tC.name, func(t *testing.T) {
err := doOIDCAuthorization(tC.cfg, tC.claims)
if ((err != nil) && !tC.wantErr) || ((err == nil) && tC.wantErr) {
t.Errorf("bad authorization: %s > want=%v | got=%v", tC.name, tC.wantErr, err)
}
})
}
}

View file

@ -26,6 +26,9 @@ type PolicyManager interface {
// NodeCanHaveTag reports whether the given node can have the given tag.
NodeCanHaveTag(types.NodeView, string) bool
// TagExists reports whether the given tag is defined in the policy.
TagExists(tag string) bool
// NodeCanApproveRoute reports whether the given node can approve the given route.
NodeCanApproveRoute(types.NodeView, netip.Prefix) bool

View file

@ -748,6 +748,32 @@ func TestNodeCanApproveRoute(t *testing.T) {
}`,
canApprove: true,
},
{
// Tags-as-identity: Tagged nodes are identified by their tags, not by the
// user who created them. Group membership of the creator is irrelevant.
// A tagged node can only be auto-approved via tag-based autoApprovers,
// not group-based ones (even if the creator is in the group).
name: "tagged-node-with-group-autoapprover-not-approved",
node: taggedNode, // Has tag:router, owned by user3
route: p("10.30.0.0/16"),
policy: `{
"tagOwners": {
"tag:router": ["user3@"]
},
"groups": {
"group:ops": ["user3@"]
},
"acls": [
{"action": "accept", "src": ["*"], "dst": ["*:*"]}
],
"autoApprovers": {
"routes": {
"10.30.0.0/16": ["group:ops"]
}
}
}`,
canApprove: false, // Tagged nodes don't inherit group membership for auto-approval
},
{
name: "small-subnet-with-exitnode-only-approval",
node: normalNode,

View file

@ -1,7 +1,9 @@
package v2
import (
"cmp"
"encoding/json"
"errors"
"fmt"
"net/netip"
"slices"
@ -19,6 +21,9 @@ import (
"tailscale.com/util/deephash"
)
// ErrInvalidTagOwner is returned when a tag owner is not an Alias type.
var ErrInvalidTagOwner = errors.New("tag owner is not an Alias")
type PolicyManager struct {
mu sync.Mutex
pol *Policy
@ -493,8 +498,7 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.mu.Lock()
defer pm.mu.Unlock()
oldNodeCount := pm.nodes.Len()
newNodeCount := nodes.Len()
policyChanged := pm.nodesHavePolicyAffectingChanges(nodes)
// Invalidate cache entries for nodes that changed.
// For autogroup:self: invalidate all nodes belonging to affected users (peer changes).
@ -503,19 +507,17 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
pm.nodes = nodes
nodesChanged := oldNodeCount != newNodeCount
// When nodes are added/removed, we must recompile filters because:
// When policy-affecting node properties change, we must recompile filters because:
// 1. User/group aliases (like "user1@") resolve to node IPs
// 2. Filter compilation needs nodes to generate rules
// 3. Without nodes, filters compile to empty (0 rules)
// 2. Tag aliases (like "tag:server") match nodes based on their tags
// 3. Filter compilation needs nodes to generate rules
//
// For autogroup:self: return true when nodes change even if the global filter
// hash didn't change. The global filter is empty for autogroup:self (each node
// has its own filter), so the hash never changes. But peer relationships DO
// change when nodes are added/removed, so we must signal this to trigger updates.
// For global policies: the filter must be recompiled to include the new nodes.
if nodesChanged {
if policyChanged {
// Recompile filter with the new node list
needsUpdate, err := pm.updateLocked()
if err != nil {
@ -536,23 +538,132 @@ func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, erro
return false, nil
}
func (pm *PolicyManager) nodesHavePolicyAffectingChanges(newNodes views.Slice[types.NodeView]) bool {
if pm.nodes.Len() != newNodes.Len() {
return true
}
oldNodes := make(map[types.NodeID]types.NodeView, pm.nodes.Len())
for _, node := range pm.nodes.All() {
oldNodes[node.ID()] = node
}
for _, newNode := range newNodes.All() {
oldNode, exists := oldNodes[newNode.ID()]
if !exists {
return true
}
if newNode.HasPolicyChange(oldNode) {
return true
}
}
return false
}
// NodeCanHaveTag checks if a node can have the specified tag during client-initiated
// registration or reauth flows (e.g., tailscale up --advertise-tags).
//
// This function is NOT used by the admin API's SetNodeTags - admins can set any
// existing tag on any node by calling State.SetNodeTags directly, which bypasses
// this authorization check.
func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool {
if pm == nil {
if pm == nil || pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
// Check if tag exists in policy
owners, exists := pm.pol.TagOwners[Tag(tag)]
if !exists {
return false
}
// Check if node's owner can assign this tag via the pre-resolved tagOwnerMap.
// The tagOwnerMap contains IP sets built from resolving TagOwners entries
// (usernames/groups) to their nodes' IPs, so checking if the node's IP
// is in the set answers "does this node's owner own this tag?"
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
if slices.ContainsFunc(node.IPs(), ips.Contains) {
return true
}
}
// For new nodes being registered, their IP may not yet be in the tagOwnerMap.
// Fall back to checking the node's user directly against the TagOwners.
// This handles the case where a user registers a new node with --advertise-tags.
if node.User().Valid() {
for _, owner := range owners {
if pm.userMatchesOwner(node.User(), owner) {
return true
}
}
}
return false
}
// userMatchesOwner checks if a user matches a tag owner entry.
// This is used as a fallback when the node's IP is not in the tagOwnerMap.
func (pm *PolicyManager) userMatchesOwner(user types.UserView, owner Owner) bool {
switch o := owner.(type) {
case *Username:
if o == nil {
return false
}
// Resolve the username to find the user it refers to
resolvedUser, err := o.resolveUser(pm.users)
if err != nil {
return false
}
return user.ID() == resolvedUser.ID
case *Group:
if o == nil || pm.pol == nil {
return false
}
// Resolve the group to get usernames
usernames, ok := pm.pol.Groups[*o]
if !ok {
return false
}
// Check if the user matches any username in the group
for _, uname := range usernames {
resolvedUser, err := uname.resolveUser(pm.users)
if err != nil {
continue
}
if user.ID() == resolvedUser.ID {
return true
}
}
return false
default:
return false
}
}
// TagExists reports whether the given tag is defined in the policy.
func (pm *PolicyManager) TagExists(tag string) bool {
if pm == nil || pm.pol == nil {
return false
}
pm.mu.Lock()
defer pm.mu.Unlock()
_, exists := pm.pol.TagOwners[Tag(tag)]
return exists
}
func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool {
if pm == nil {
return false
@ -834,3 +945,126 @@ func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.
}
}
}
// flattenTags flattens the TagOwners by resolving nested tags and detecting cycles.
// It will return a Owners list where all the Tag types have been resolved to their underlying Owners.
func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) {
if visiting[tag] {
cycleStart := 0
for i, t := range chain {
if t == tag {
cycleStart = i
break
}
}
cycleTags := make([]string, len(chain[cycleStart:]))
for i, t := range chain[cycleStart:] {
cycleTags[i] = string(t)
}
slices.Sort(cycleTags)
return nil, fmt.Errorf("%w: %s", ErrCircularReference, strings.Join(cycleTags, " -> "))
}
visiting[tag] = true
chain = append(chain, tag)
defer delete(visiting, tag)
var result Owners
for _, owner := range tagOwners[tag] {
switch o := owner.(type) {
case *Tag:
if _, ok := tagOwners[*o]; !ok {
return nil, fmt.Errorf("tag %q %w %q", tag, ErrUndefinedTagReference, *o)
}
nested, err := flattenTags(tagOwners, *o, visiting, chain)
if err != nil {
return nil, err
}
result = append(result, nested...)
default:
result = append(result, owner)
}
}
return result, nil
}
// flattenTagOwners flattens all TagOwners by resolving nested tags and detecting cycles.
// It will return a new TagOwners map where all the Tag types have been resolved to their underlying Owners.
func flattenTagOwners(tagOwners TagOwners) (TagOwners, error) {
ret := make(TagOwners)
for tag := range tagOwners {
flattened, err := flattenTags(tagOwners, tag, make(map[Tag]bool), nil)
if err != nil {
return nil, err
}
slices.SortFunc(flattened, func(a, b Owner) int {
return cmp.Compare(a.String(), b.String())
})
ret[tag] = slices.CompactFunc(flattened, func(a, b Owner) bool {
return a.String() == b.String()
})
}
return ret, nil
}
// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet.
// The resulting map can be used to quickly look up the IPSet for a given Tag.
// It is intended for internal use in a PolicyManager.
func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) {
if p == nil {
return make(map[Tag]*netipx.IPSet), nil
}
if len(p.TagOwners) == 0 {
return make(map[Tag]*netipx.IPSet), nil
}
ret := make(map[Tag]*netipx.IPSet)
tagOwners, err := flattenTagOwners(p.TagOwners)
if err != nil {
return nil, err
}
for tag, owners := range tagOwners {
var ips netipx.IPSetBuilder
for _, owner := range owners {
switch o := owner.(type) {
case *Tag:
// After flattening, Tag types should not appear in the owners list.
// If they do, skip them as they represent already-resolved references.
case Alias:
// If it does not resolve, that means the tag is not associated with any IP addresses.
resolved, _ := o.Resolve(p, users, nodes)
ips.AddSet(resolved)
default:
// Should never happen - after flattening, all owners should be Alias types
return nil, fmt.Errorf("%w: %v", ErrInvalidTagOwner, owner)
}
}
ipSet, err := ips.IPSet()
if err != nil {
return nil, err
}
ret[tag] = ipSet
}
return ret, nil
}

View file

@ -464,14 +464,14 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
// test-2 has a router device with tag:node-router
test2RouterNode := &types.Node{
ID: 2,
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
ID: 2,
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
}
nodes := types.Nodes{test1Node, test2RouterNode}
@ -537,8 +537,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
Hostname: "test-1-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[0],
UserID: users[0].ID,
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -547,8 +547,8 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
Hostname: "test-2-device",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
UserID: users[1].ID,
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -606,3 +606,126 @@ func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) {
require.NoError(t, err)
require.False(t, policyChanged2, "SetPolicy should return false when policy content hasn't changed")
}
// TestTagPropagationToPeerMap tests that when a node's tags change,
// the peer map is correctly updated. This is a regression test for
// https://github.com/juanfont/headscale/issues/2389
func TestTagPropagationToPeerMap(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"},
{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"},
}
// Policy: user2 can access tag:web nodes
policy := `{
"tagOwners": {
"tag:web": ["user1@headscale.net"],
"tag:internal": ["user1@headscale.net"]
},
"acls": [
{
"action": "accept",
"src": ["user2@headscale.net"],
"dst": ["user2@headscale.net:*"]
},
{
"action": "accept",
"src": ["user2@headscale.net"],
"dst": ["tag:web:*"]
},
{
"action": "accept",
"src": ["tag:web"],
"dst": ["user2@headscale.net:*"]
}
]
}`
// user1's node starts with tag:web and tag:internal
user1Node := &types.Node{
ID: 1,
Hostname: "user1-node",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
Tags: []string{"tag:web", "tag:internal"},
}
// user2's node (no tags)
user2Node := &types.Node{
ID: 2,
Hostname: "user2-node",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
}
initialNodes := types.Nodes{user1Node, user2Node}
pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice())
require.NoError(t, err)
// Initial state: user2 should see user1 as a peer (user1 has tag:web)
initialPeerMap := pm.BuildPeerMap(initialNodes.ViewSlice())
// Check user2's peers - should include user1
user2Peers := initialPeerMap[user2Node.ID]
require.Len(t, user2Peers, 1, "user2 should have 1 peer initially (user1 with tag:web)")
require.Equal(t, user1Node.ID, user2Peers[0].ID(), "user2's peer should be user1")
// Check user1's peers - should include user2 (bidirectional ACL)
user1Peers := initialPeerMap[user1Node.ID]
require.Len(t, user1Peers, 1, "user1 should have 1 peer initially (user2)")
require.Equal(t, user2Node.ID, user1Peers[0].ID(), "user1's peer should be user2")
// Now change user1's tags: remove tag:web, keep only tag:internal
user1NodeUpdated := &types.Node{
ID: 1,
Hostname: "user1-node",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
Tags: []string{"tag:internal"}, // tag:web removed!
}
updatedNodes := types.Nodes{user1NodeUpdated, user2Node}
// SetNodes should detect the tag change
changed, err := pm.SetNodes(updatedNodes.ViewSlice())
require.NoError(t, err)
require.True(t, changed, "SetNodes should return true when tags change")
// After tag change: user2 should NOT see user1 as a peer anymore
// (no ACL allows user2 to access tag:internal)
updatedPeerMap := pm.BuildPeerMap(updatedNodes.ViewSlice())
// Check user2's peers - should be empty now
user2PeersAfter := updatedPeerMap[user2Node.ID]
require.Empty(t, user2PeersAfter, "user2 should have no peers after tag:web is removed from user1")
// Check user1's peers - should also be empty
user1PeersAfter := updatedPeerMap[user1Node.ID]
require.Empty(t, user1PeersAfter, "user1 should have no peers after tag:web is removed")
// Also verify MatchersForNode returns non-empty matchers and ReduceNodes filters correctly
// This simulates what buildTailPeers does in the mapper
matchersForUser2, err := pm.MatchersForNode(user2Node.View())
require.NoError(t, err)
require.NotEmpty(t, matchersForUser2, "MatchersForNode should return non-empty matchers (at least self-access rule)")
// Test ReduceNodes logic with the updated nodes and matchers
// This is what buildTailPeers does - it takes peers from ListPeers (which might include user1)
// and filters them using ReduceNodes with the updated matchers
// Inline the ReduceNodes logic to avoid import cycle
user2View := user2Node.View()
user1UpdatedView := user1NodeUpdated.View()
// Check if user2 can access user1 OR user1 can access user2
canAccess := user2View.CanAccess(matchersForUser2, user1UpdatedView) ||
user1UpdatedView.CanAccess(matchersForUser2, user2View)
require.False(t, canAccess, "user2 should NOT be able to access user1 after tag:web is removed (ReduceNodes should filter out)")
}

View file

@ -9,7 +9,6 @@ import (
"strings"
"github.com/go-json-experiment/json"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/prometheus/common/model"
@ -34,6 +33,10 @@ const Wildcard = Asterix(0)
var ErrAutogroupSelfRequiresPerNodeResolution = errors.New("autogroup:self requires per-node resolution and cannot be resolved in this context")
var ErrCircularReference = errors.New("circular reference detected")
var ErrUndefinedTagReference = errors.New("references undefined tag")
type Asterix int
func (a Asterix) Validate() error {
@ -201,7 +204,7 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
}
for _, node := range nodes.All() {
// Skip tagged nodes
// Skip tagged nodes - they are identified by tags, not users
if node.IsTagged() {
continue
}
@ -303,35 +306,11 @@ func (t *Tag) UnmarshalJSON(b []byte) error {
func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) {
var ips netipx.IPSetBuilder
// TODO(kradalby): This is currently resolved twice, and should be resolved once.
// It is added temporary until we sort out the story on how and when we resolve tags
// from the three places they can be "approved":
// - As part of a PreAuthKey (handled in HasTag)
// - As part of ForcedTags (set via CLI) (handled in HasTag)
// - As part of HostInfo.RequestTags and approved by policy (this is happening here)
// Part of #2417
tagMap, err := resolveTagOwners(p, users, nodes)
if err != nil {
return nil, err
}
for _, node := range nodes.All() {
// Check if node has this tag
if node.HasTag(string(t)) {
node.AppendToIPSet(&ips)
}
// TODO(kradalby): remove as part of #2417, see comment above
if tagMap != nil {
if tagips, ok := tagMap[t]; ok && node.InIPSet(tagips) && node.Hostinfo().Valid() {
for _, tag := range node.RequestTagsSlice().All() {
if tag == string(t) {
node.AppendToIPSet(&ips)
break
}
}
}
}
}
return ips.IPSet()
@ -341,6 +320,10 @@ func (t Tag) CanBeAutoApprover() bool {
return true
}
func (t Tag) CanBeTagOwner() bool {
return true
}
func (t Tag) String() string {
return string(t)
}
@ -537,61 +520,26 @@ func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[type
return util.TheInternet(), nil
case AutoGroupMember:
// autogroup:member represents all untagged devices in the tailnet.
tagMap, err := resolveTagOwners(p, users, nodes)
if err != nil {
return nil, err
}
for _, node := range nodes.All() {
// Skip if node is tagged
if node.IsTagged() {
continue
}
// Skip if node has any allowed requested tags
hasAllowedTag := false
if node.RequestTagsSlice().Len() != 0 {
for _, tag := range node.RequestTagsSlice().All() {
if _, ok := tagMap[Tag(tag)]; ok {
hasAllowedTag = true
break
}
}
}
if hasAllowedTag {
continue
}
// Node is a member if it has no forced tags and no allowed requested tags
// Node is a member if it is not tagged
node.AppendToIPSet(&build)
}
return build.IPSet()
case AutoGroupTagged:
// autogroup:tagged represents all devices with a tag in the tailnet.
tagMap, err := resolveTagOwners(p, users, nodes)
if err != nil {
return nil, err
}
for _, node := range nodes.All() {
// Include if node is tagged
if node.IsTagged() {
node.AppendToIPSet(&build)
if !node.IsTagged() {
continue
}
// Include if node has any allowed requested tags
if node.RequestTagsSlice().Len() != 0 {
for _, tag := range node.RequestTagsSlice().All() {
if _, ok := tagMap[Tag(tag)]; ok {
node.AppendToIPSet(&build)
break
}
}
}
node.AppendToIPSet(&build)
}
return build.IPSet()
@ -915,6 +863,7 @@ func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error {
type Owner interface {
CanBeTagOwner() bool
UnmarshalJSON([]byte) error
String() string
}
// OwnerEnc is used to deserialize a Owner.
@ -963,6 +912,8 @@ func (o Owners) MarshalJSON() ([]byte, error) {
owners[i] = string(*v)
case *Group:
owners[i] = string(*v)
case *Tag:
owners[i] = string(*v)
default:
return nil, fmt.Errorf("unknown owner type: %T", v)
}
@ -977,6 +928,8 @@ func parseOwner(s string) (Owner, error) {
return ptr.To(Username(s)), nil
case isGroup(s):
return ptr.To(Group(s)), nil
case isTag(s):
return ptr.To(Tag(s)), nil
}
return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types:
@ -1134,6 +1087,8 @@ func (to TagOwners) MarshalJSON() ([]byte, error) {
ownerStrs[i] = string(*v)
case *Group:
ownerStrs[i] = string(*v)
case *Tag:
ownerStrs[i] = string(*v)
default:
return nil, fmt.Errorf("unknown owner type: %T", v)
}
@ -1162,41 +1117,6 @@ func (to TagOwners) Contains(tagOwner *Tag) error {
return fmt.Errorf(`Tag %q is not defined in the Policy, please define or remove the reference to it`, tagOwner)
}
// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet.
// The resulting map can be used to quickly look up the IPSet for a given Tag.
// It is intended for internal use in a PolicyManager.
func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) {
if p == nil {
return nil, nil
}
ret := make(map[Tag]*netipx.IPSet)
for tag, owners := range p.TagOwners {
var ips netipx.IPSetBuilder
for _, owner := range owners {
o, ok := owner.(Alias)
if !ok {
// Should never happen
return nil, fmt.Errorf("owner %v is not an Alias", owner)
}
// If it does not resolve, that means the tag is not associated with any IP addresses.
resolved, _ := o.Resolve(p, users, nodes)
ips.AddSet(resolved)
}
ipSet, err := ips.IPSet()
if err != nil {
return nil, err
}
ret[tag] = ipSet
}
return ret, nil
}
type AutoApproverPolicy struct {
Routes map[netip.Prefix]AutoApprovers `json:"routes,omitempty"`
ExitNode AutoApprovers `json:"exitNode,omitempty"`
@ -1844,10 +1764,23 @@ func (p *Policy) validate() error {
if err := p.Groups.Contains(g); err != nil {
errs = append(errs, err)
}
case *Tag:
t := tagOwner
err := p.TagOwners.Contains(t)
if err != nil {
errs = append(errs, err)
}
}
}
}
// Validate tag ownership chains for circular references and undefined tags.
_, err := flattenTagOwners(p.TagOwners)
if err != nil {
errs = append(errs, err)
}
for _, approvers := range p.AutoApprovers.Routes {
for _, approver := range approvers {
switch approver := approver.(type) {

View file

@ -1470,6 +1470,57 @@ func TestUnmarshalPolicy(t *testing.T) {
},
},
},
{
name: "tags-can-own-other-tags",
input: `
{
"tagOwners": {
"tag:bigbrother": [],
"tag:smallbrother": ["tag:bigbrother"],
},
"acls": [
{
"action": "accept",
"proto": "tcp",
"src": ["*"],
"dst": ["tag:smallbrother:9000"]
}
]
}
`,
want: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): {},
Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))},
},
ACLs: []ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: Aliases{
Wildcard,
},
Destinations: []AliasWithPorts{
{
Alias: ptr.To(Tag("tag:smallbrother")),
Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}},
},
},
},
},
},
},
{
name: "tag-owner-references-undefined-tag",
input: `
{
"tagOwners": {
"tag:child": ["tag:nonexistent"],
},
}
`,
wantErr: `tag "tag:child" references undefined tag "tag:nonexistent"`,
},
}
cmps := append(util.Comparers,
@ -1596,7 +1647,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
IPv4: ap("100.100.101.2"),
},
// not matching because it's tagged (tags copied from AuthKey)
{
@ -1628,7 +1679,7 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(groupuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
IPv4: ap("100.100.101.5"),
},
// not matching because it's tagged (tags copied from AuthKey)
{
@ -1665,7 +1716,7 @@ func TestResolvePolicy(t *testing.T) {
// Not matching forced tags
{
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.10"),
IPv4: ap("100.100.101.10"),
},
// not matching pak tag
{
@ -1677,7 +1728,7 @@ func TestResolvePolicy(t *testing.T) {
// Not matching forced tags
{
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
IPv4: ap("100.100.101.234"),
},
// matching tag (tags copied from AuthKey during registration)
{
@ -1689,6 +1740,52 @@ func TestResolvePolicy(t *testing.T) {
pol: &Policy{},
want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")},
},
{
name: "tag-owned-by-tag-call-child",
toResolve: tp("tag:smallbrother"),
pol: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): {},
Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))},
},
},
nodes: types.Nodes{
// Should not match as we resolve the "child" tag.
{
Tags: []string{"tag:bigbrother"},
IPv4: ap("100.100.101.234"),
},
// Should match.
{
Tags: []string{"tag:smallbrother"},
IPv4: ap("100.100.101.239"),
},
},
want: []netip.Prefix{mp("100.100.101.239/32")},
},
{
name: "tag-owned-by-tag-call-parent",
toResolve: tp("tag:bigbrother"),
pol: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): {},
Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))},
},
},
nodes: types.Nodes{
// Should match - we are resolving "tag:bigbrother" which this node has.
{
Tags: []string{"tag:bigbrother"},
IPv4: ap("100.100.101.234"),
},
// Should not match - this node has "tag:smallbrother", not the tag we're resolving.
{
Tags: []string{"tag:smallbrother"},
IPv4: ap("100.100.101.239"),
},
},
want: []netip.Prefix{mp("100.100.101.234/32")},
},
{
name: "empty-policy",
toResolve: pp("100.100.101.101/32"),
@ -1747,7 +1844,7 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
{
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
IPv4: ap("100.100.101.234"),
},
},
},
@ -1765,124 +1862,108 @@ func TestResolvePolicy(t *testing.T) {
name: "autogroup-member-comprehensive",
toResolve: ptr.To(AutoGroup(AutoGroupMember)),
nodes: types.Nodes{
// Node with no tags (should be included)
// Node with no tags (should be included - is a member)
{
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
// Node with forced tags (should be excluded)
// Node with single tag (should be excluded - tagged nodes are not members)
{
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be excluded)
// Node with multiple tags, all defined in policy (should be excluded)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
Tags: []string{"tag:test", "tag:other"},
IPv4: ap("100.100.101.3"),
},
// Node with non-allowed requested tag (should be included)
// Node with tag not defined in policy (should be excluded - still tagged)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
Tags: []string{"tag:undefined"},
IPv4: ap("100.100.101.4"),
},
// Node with multiple requested tags, one allowed (should be excluded)
// Node with mixed tags - some defined, some not (should be excluded)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
Tags: []string{"tag:test", "tag:undefined"},
IPv4: ap("100.100.101.5"),
},
// Node with multiple requested tags, none allowed (should be included)
// Another untagged node from different user (should be included)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
User: ptr.To(testuser2),
IPv4: ap("100.100.101.6"),
},
},
pol: &Policy{
TagOwners: TagOwners{
Tag("tag:test"): Owners{ptr.To(Username("testuser@"))},
Tag("tag:test"): Owners{ptr.To(Username("testuser@"))},
Tag("tag:other"): Owners{ptr.To(Username("testuser@"))},
},
},
want: []netip.Prefix{
mp("100.100.101.1/32"), // No tags
mp("100.100.101.4/32"), // Non-allowed requested tag
mp("100.100.101.6/32"), // Multiple non-allowed requested tags
mp("100.100.101.1/32"), // No tags - is a member
mp("100.100.101.6/32"), // No tags, different user - is a member
},
},
{
name: "autogroup-tagged",
toResolve: ptr.To(AutoGroup(AutoGroupTagged)),
nodes: types.Nodes{
// Node with no tags (should be excluded)
// Node with no tags (should be excluded - not tagged)
{
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
// Node with forced tag (should be included)
// Node with single tag defined in policy (should be included)
{
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be included)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
IPv4: ap("100.100.101.3"),
},
// Node with non-allowed requested tag (should be excluded)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
IPv4: ap("100.100.101.4"),
},
// Node with multiple requested tags, one allowed (should be included)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
IPv4: ap("100.100.101.5"),
},
// Node with multiple requested tags, none allowed (should be excluded)
{
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
IPv4: ap("100.100.101.6"),
},
// Node with multiple forced tags (should be included)
// Node with multiple tags, all defined in policy (should be included)
{
User: ptr.To(testuser),
Tags: []string{"tag:test", "tag:other"},
IPv4: ap("100.100.101.7"),
IPv4: ap("100.100.101.3"),
},
// Node with tag not defined in policy (should be included - still tagged)
{
User: ptr.To(testuser),
Tags: []string{"tag:undefined"},
IPv4: ap("100.100.101.4"),
},
// Node with mixed tags - some defined, some not (should be included)
{
User: ptr.To(testuser),
Tags: []string{"tag:test", "tag:undefined"},
IPv4: ap("100.100.101.5"),
},
// Another untagged node from different user (should be excluded)
{
User: ptr.To(testuser2),
IPv4: ap("100.100.101.6"),
},
// Tagged node from different user (should be included)
{
User: ptr.To(testuser2),
Tags: []string{"tag:server"},
IPv4: ap("100.100.101.7"),
},
},
pol: &Policy{
TagOwners: TagOwners{
Tag("tag:test"): Owners{ptr.To(Username("testuser@"))},
Tag("tag:test"): Owners{ptr.To(Username("testuser@"))},
Tag("tag:other"): Owners{ptr.To(Username("testuser@"))},
Tag("tag:server"): Owners{ptr.To(Username("testuser2@"))},
},
},
want: []netip.Prefix{
mp("100.100.101.2/31"), // Forced tag and allowed requested tag consecutive IPs are put in 31 prefix
mp("100.100.101.5/32"), // Multiple requested tags, one allowed
mp("100.100.101.7/32"), // Multiple forced tags
mp("100.100.101.2/31"), // .2, .3 consecutive tagged nodes
mp("100.100.101.4/31"), // .4, .5 consecutive tagged nodes
mp("100.100.101.7/32"), // Tagged node from different user
},
},
{
@ -1900,13 +1981,11 @@ func TestResolvePolicy(t *testing.T) {
{
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.3"),
IPv4: ap("100.100.101.3"),
},
{
User: ptr.To(testuser2),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.4"),
},
},
@ -1976,11 +2055,11 @@ func TestResolveAutoApprovers(t *testing.T) {
User: &users[2],
},
{
IPv4: ap("100.64.0.4"),
IPv4: ap("100.64.0.4"),
Tags: []string{"tag:testtag"},
},
{
IPv4: ap("100.64.0.5"),
IPv4: ap("100.64.0.5"),
Tags: []string{"tag:exittest"},
},
}
@ -2474,6 +2553,20 @@ func TestResolveTagOwners(t *testing.T) {
},
wantErr: false,
},
{
name: "tag-owns-tag",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:bigbrother"): Owners{ptr.To(Username("user1@"))},
Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))},
},
},
want: map[Tag]*netipx.IPSet{
Tag("tag:bigbrother"): mustIPSet("100.64.0.1/32"),
Tag("tag:smallbrother"): mustIPSet("100.64.0.1/32"),
},
wantErr: false,
},
}
cmps := append(util.Comparers, cmp.Comparer(ipSetComparer))
@ -2627,6 +2720,127 @@ func TestNodeCanHaveTag(t *testing.T) {
tag: "tag:dev", // This tag is not defined in tagOwners
want: false,
},
// Test cases for nodes without IPs (new registration scenario)
// These test the user-based fallback in NodeCanHaveTag
{
name: "node-without-ip-user-owns-tag",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:test"): Owners{ptr.To(Username("user1@"))},
},
},
node: &types.Node{
// No IPv4 or IPv6 - simulates new node registration
User: &users[0],
UserID: ptr.To(users[0].ID),
},
tag: "tag:test",
want: true, // Should succeed via user-based fallback
},
{
name: "node-without-ip-user-does-not-own-tag",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:test"): Owners{ptr.To(Username("user2@"))},
},
},
node: &types.Node{
// No IPv4 or IPv6 - simulates new node registration
User: &users[0], // user1, but tag owned by user2
UserID: ptr.To(users[0].ID),
},
tag: "tag:test",
want: false, // user1 does not own tag:test
},
{
name: "node-without-ip-group-owns-tag",
policy: &Policy{
Groups: Groups{
"group:admins": Usernames{"user1@", "user2@"},
},
TagOwners: TagOwners{
Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))},
},
},
node: &types.Node{
// No IPv4 or IPv6 - simulates new node registration
User: &users[1], // user2 is in group:admins
UserID: ptr.To(users[1].ID),
},
tag: "tag:admin",
want: true, // Should succeed via group membership
},
{
name: "node-without-ip-not-in-group",
policy: &Policy{
Groups: Groups{
"group:admins": Usernames{"user1@"},
},
TagOwners: TagOwners{
Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))},
},
},
node: &types.Node{
// No IPv4 or IPv6 - simulates new node registration
User: &users[1], // user2 is NOT in group:admins
UserID: ptr.To(users[1].ID),
},
tag: "tag:admin",
want: false, // user2 is not in group:admins
},
{
name: "node-without-ip-no-user",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:test"): Owners{ptr.To(Username("user1@"))},
},
},
node: &types.Node{
// No IPv4, IPv6, or User - edge case
},
tag: "tag:test",
want: false, // No user means can't authorize via user-based fallback
},
{
name: "node-without-ip-mixed-owners-user-match",
policy: &Policy{
Groups: Groups{
"group:ops": Usernames{"user3@"},
},
TagOwners: TagOwners{
Tag("tag:server"): Owners{
ptr.To(Username("user1@")),
ptr.To(Group("group:ops")),
},
},
},
node: &types.Node{
User: &users[0], // user1 directly owns the tag
UserID: ptr.To(users[0].ID),
},
tag: "tag:server",
want: true,
},
{
name: "node-without-ip-mixed-owners-group-match",
policy: &Policy{
Groups: Groups{
"group:ops": Usernames{"user3@"},
},
TagOwners: TagOwners{
Tag("tag:server"): Owners{
ptr.To(Username("user1@")),
ptr.To(Group("group:ops")),
},
},
},
node: &types.Node{
User: &users[2], // user3 is in group:ops
UserID: ptr.To(users[2].ID),
},
tag: "tag:server",
want: true,
},
}
for _, tt := range tests {
@ -2649,6 +2863,106 @@ func TestNodeCanHaveTag(t *testing.T) {
}
}
func TestUserMatchesOwner(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
{Model: gorm.Model{ID: 3}, Name: "user3"},
}
tests := []struct {
name string
policy *Policy
user types.User
owner Owner
want bool
}{
{
name: "username-match",
policy: &Policy{},
user: users[0],
owner: ptr.To(Username("user1@")),
want: true,
},
{
name: "username-no-match",
policy: &Policy{},
user: users[0],
owner: ptr.To(Username("user2@")),
want: false,
},
{
name: "group-match",
policy: &Policy{
Groups: Groups{
"group:admins": Usernames{"user1@", "user2@"},
},
},
user: users[1], // user2 is in group:admins
owner: ptr.To(Group("group:admins")),
want: true,
},
{
name: "group-no-match",
policy: &Policy{
Groups: Groups{
"group:admins": Usernames{"user1@"},
},
},
user: users[1], // user2 is NOT in group:admins
owner: ptr.To(Group("group:admins")),
want: false,
},
{
name: "group-not-defined",
policy: &Policy{
Groups: Groups{},
},
user: users[0],
owner: ptr.To(Group("group:undefined")),
want: false,
},
{
name: "nil-username-owner",
policy: &Policy{},
user: users[0],
owner: (*Username)(nil),
want: false,
},
{
name: "nil-group-owner",
policy: &Policy{},
user: users[0],
owner: (*Group)(nil),
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create a minimal PolicyManager for testing
// We need nodes with IPs to initialize the tagOwnerMap
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: &users[0],
},
}
b, err := json.Marshal(tt.policy)
require.NoError(t, err)
pm, err := NewPolicyManager(b, users, nodes.ViewSlice())
require.NoError(t, err)
got := pm.userMatchesOwner(tt.user.View(), tt.owner)
if got != tt.want {
t.Errorf("userMatchesOwner() = %v, want %v", got, tt.want)
}
})
}
}
func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) {
tests := []struct {
name string
@ -2936,3 +3250,147 @@ func mustParseAlias(s string) Alias {
}
return alias
}
func TestFlattenTagOwners(t *testing.T) {
tests := []struct {
name string
input TagOwners
want TagOwners
wantErr string
}{
{
name: "tag-owns-tag",
input: TagOwners{
Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))},
},
want: TagOwners{
Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:smallbrother"): Owners{ptr.To(Group("group:user1"))},
},
wantErr: "",
},
{
name: "circular-reference",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))},
},
want: nil,
wantErr: "circular reference detected: tag:a -> tag:b",
},
{
name: "mixed-owners",
input: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))},
Tag("tag:y"): Owners{ptr.To(Username("user2@"))},
},
want: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))},
Tag("tag:y"): Owners{ptr.To(Username("user2@"))},
},
wantErr: "",
},
{
name: "mixed-dupe-owners",
input: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))},
Tag("tag:y"): Owners{ptr.To(Username("user1@"))},
},
want: TagOwners{
Tag("tag:x"): Owners{ptr.To(Username("user1@"))},
Tag("tag:y"): Owners{ptr.To(Username("user1@"))},
},
wantErr: "",
},
{
name: "no-tag-owners",
input: TagOwners{
Tag("tag:solo"): Owners{ptr.To(Username("user1@"))},
},
want: TagOwners{
Tag("tag:solo"): Owners{ptr.To(Username("user1@"))},
},
wantErr: "",
},
{
name: "tag-long-owner-chain",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))},
Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))},
Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))},
Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))},
Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))},
},
want: TagOwners{
Tag("tag:a"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:b"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:c"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:d"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:e"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:f"): Owners{ptr.To(Group("group:user1"))},
Tag("tag:g"): Owners{ptr.To(Group("group:user1"))},
},
wantErr: "",
},
{
name: "tag-long-circular-chain",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:g"))},
Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))},
Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))},
Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))},
Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))},
Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))},
},
wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g",
},
{
name: "undefined-tag-reference",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:nonexistent"))},
},
wantErr: `tag "tag:a" references undefined tag "tag:nonexistent"`,
},
{
name: "tag-with-empty-owners-is-valid",
input: TagOwners{
Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))},
Tag("tag:b"): Owners{}, // empty owners but exists
},
want: TagOwners{
Tag("tag:a"): nil,
Tag("tag:b"): nil,
},
wantErr: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := flattenTagOwners(tt.input)
if tt.wantErr != "" {
if err == nil {
t.Fatalf("flattenTagOwners() expected error %q, got nil", tt.wantErr)
}
if err.Error() != tt.wantErr {
t.Fatalf("flattenTagOwners() expected error %q, got %q", tt.wantErr, err.Error())
}
return
}
if err != nil {
t.Fatalf("flattenTagOwners() unexpected error: %v", err)
}
if diff := cmp.Diff(tt.want, got); diff != "" {
t.Errorf("flattenTagOwners() mismatch (-want +got):\n%s", diff)
}
})
}
}

View file

@ -319,41 +319,6 @@ var keepAlive = tailcfg.MapResponse{
KeepAlive: true,
}
func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcfg.PeerChange) {
trace := log.Trace().Caller().Uint64("node.id", uint64(peerChange.NodeID)).Str("hostname", hostname)
if peerChange.Key != nil {
trace = trace.Str("node.key", peerChange.Key.ShortString())
}
if peerChange.DiscoKey != nil {
trace = trace.Str("disco.key", peerChange.DiscoKey.ShortString())
}
if peerChange.Online != nil {
trace = trace.Bool("online", *peerChange.Online)
}
if peerChange.Endpoints != nil {
eps := make([]string, len(peerChange.Endpoints))
for idx, ep := range peerChange.Endpoints {
eps[idx] = ep.String()
}
trace = trace.Strs("endpoints", eps)
}
if hostinfoChange {
trace = trace.Bool("hostinfo_changed", hostinfoChange)
}
if peerChange.DERPRegion != 0 {
trace = trace.Int("derp_region", peerChange.DERPRegion)
}
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
}
// logf adds common mapSession context to a zerolog event.
func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event {
return event.

View file

@ -991,8 +991,13 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
store.Start()
defer store.Stop()
time.Sleep(50 * time.Millisecond)
afterStartGoroutines := runtime.NumGoroutine()
// Wait for store to be ready
var afterStartGoroutines int
assert.EventuallyWithT(t, func(c *assert.CollectT) {
afterStartGoroutines = runtime.NumGoroutine()
assert.Positive(c, afterStartGoroutines) // Just ensure we have a valid count
}, time.Second, 10*time.Millisecond, "store should be running")
const ops = 100
for i := range ops {
@ -1010,11 +1015,13 @@ func TestNodeStoreResourceCleanup(t *testing.T) {
}
}
runtime.GC()
time.Sleep(100 * time.Millisecond)
finalGoroutines := runtime.NumGoroutine()
if finalGoroutines > afterStartGoroutines+2 {
t.Errorf("Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines)
}
// Wait for goroutines to settle and check for leaks
assert.EventuallyWithT(t, func(c *assert.CollectT) {
finalGoroutines := runtime.NumGoroutine()
assert.LessOrEqual(c, finalGoroutines, afterStartGoroutines+2,
"Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines)
}, time.Second, 10*time.Millisecond, "goroutines should not leak")
}
// --- Timeout/deadlock: operations complete within reasonable time ---
@ -1145,3 +1152,92 @@ func TestNodeStoreAllocationStats(t *testing.T) {
allocs := res.AllocsPerOp()
t.Logf("NodeStore allocations per op: %.2f", float64(allocs))
}
// TestRebuildPeerMapsWithChangedPeersFunc tests that RebuildPeerMaps correctly
// rebuilds the peer map when the peersFunc behavior changes.
// This simulates what happens when SetNodeTags changes node tags and the
// PolicyManager's matchers are updated, requiring the peer map to be rebuilt.
func TestRebuildPeerMapsWithChangedPeersFunc(t *testing.T) {
// Create a peersFunc that can be controlled via a channel
// Initially it returns all nodes as peers, then we change it to return no peers
allowPeers := true
// This simulates how PolicyManager.BuildPeerMap works - it reads state
// that can change between calls
dynamicPeersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView {
ret := make(map[types.NodeID][]types.NodeView, len(nodes))
if allowPeers {
// Allow all peers
for _, node := range nodes {
var peers []types.NodeView
for _, n := range nodes {
if n.ID() != node.ID() {
peers = append(peers, n)
}
}
ret[node.ID()] = peers
}
} else {
// Allow no peers
for _, node := range nodes {
ret[node.ID()] = []types.NodeView{}
}
}
return ret
}
// Create nodes
node1 := createTestNode(1, 1, "user1", "node1")
node2 := createTestNode(2, 2, "user2", "node2")
initialNodes := types.Nodes{&node1, &node2}
// Create store with dynamic peersFunc
store := NewNodeStore(initialNodes, dynamicPeersFunc, TestBatchSize, TestBatchTimeout)
store.Start()
defer store.Stop()
// Initially, nodes should see each other as peers
snapshot := store.data.Load()
require.Len(t, snapshot.peersByNode[1], 1, "node1 should have 1 peer initially")
require.Len(t, snapshot.peersByNode[2], 1, "node2 should have 1 peer initially")
require.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID())
require.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID())
// Now "change the policy" by disabling peers
allowPeers = false
// Call RebuildPeerMaps to rebuild with the new behavior
store.RebuildPeerMaps()
// After rebuild, nodes should have no peers
snapshot = store.data.Load()
assert.Empty(t, snapshot.peersByNode[1], "node1 should have no peers after rebuild")
assert.Empty(t, snapshot.peersByNode[2], "node2 should have no peers after rebuild")
// Verify that ListPeers returns the correct result
peers1 := store.ListPeers(1)
peers2 := store.ListPeers(2)
assert.Equal(t, 0, peers1.Len(), "ListPeers for node1 should return empty")
assert.Equal(t, 0, peers2.Len(), "ListPeers for node2 should return empty")
// Now re-enable peers and rebuild again
allowPeers = true
store.RebuildPeerMaps()
// Nodes should see each other again
snapshot = store.data.Load()
require.Len(t, snapshot.peersByNode[1], 1, "node1 should have 1 peer after re-enabling")
require.Len(t, snapshot.peersByNode[2], 1, "node2 should have 1 peer after re-enabling")
peers1 = store.ListPeers(1)
peers2 = store.ListPeers(2)
assert.Equal(t, 1, peers1.Len(), "ListPeers for node1 should return 1")
assert.Equal(t, 1, peers2.Len(), "ListPeers for node2 should return 1")
}

View file

@ -12,6 +12,7 @@ import (
"net/netip"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
@ -56,6 +57,15 @@ var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode")
// ErrNodeNotFound is returned when a node cannot be found by its ID.
var ErrNodeNotFound = errors.New("node not found")
// ErrInvalidNodeView is returned when an invalid node view is provided.
var ErrInvalidNodeView = errors.New("invalid node view provided")
// ErrNodeNotInNodeStore is returned when a node no longer exists in the NodeStore.
var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore")
// ErrNodeNameNotUnique is returned when a node name is not unique.
var ErrNodeNameNotUnique = errors.New("node name is not unique")
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
@ -242,7 +252,7 @@ func (s *State) DERPMap() tailcfg.DERPMapView {
// ReloadPolicy reloads the access control policy and triggers auto-approval if changed.
// Returns true if the policy changed.
func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
func (s *State) ReloadPolicy() ([]change.Change, error) {
pol, err := policyBytes(s.db, s.cfg)
if err != nil {
return nil, fmt.Errorf("loading policy: %w", err)
@ -259,7 +269,7 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
// propagate correctly when switching between policy types.
s.nodeStore.RebuildPeerMaps()
cs := []change.ChangeSet{change.PolicyChange()}
cs := []change.Change{change.PolicyChange()}
// Always call autoApproveNodes during policy reload, regardless of whether
// the policy content has changed. This ensures that routes are re-evaluated
@ -288,16 +298,16 @@ func (s *State) ReloadPolicy() ([]change.ChangeSet, error) {
// CreateUser creates a new user and updates the policy manager.
// Returns the created user, change set, and any error.
func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, error) {
func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) {
if err := s.db.DB.Save(&user).Error; err != nil {
return nil, change.EmptySet, fmt.Errorf("creating user: %w", err)
return nil, change.Change{}, fmt.Errorf("creating user: %w", err)
}
// Check if policy manager needs updating
c, err := s.updatePolicyManagerUsers()
if err != nil {
// Log the error but don't fail the user creation
return &user, change.EmptySet, fmt.Errorf("failed to update policy manager after user creation: %w", err)
return &user, change.Change{}, fmt.Errorf("failed to update policy manager after user creation: %w", err)
}
// Even if the policy manager doesn't detect a filter change, SSH policies
@ -305,7 +315,7 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro
// nodes, we should send a policy change to ensure they get updated SSH policies.
// TODO(kradalby): detect this, or rebuild all SSH policies so we can determine
// this upstream.
if c.Empty() {
if c.IsEmpty() {
c = change.PolicyChange()
}
@ -316,7 +326,7 @@ func (s *State) CreateUser(user types.User) (*types.User, change.ChangeSet, erro
// UpdateUser modifies an existing user using the provided update function within a transaction.
// Returns the updated user, change set, and any error.
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.ChangeSet, error) {
func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.Change, error) {
user, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.User, error) {
user, err := hsdb.GetUserByID(tx, userID)
if err != nil {
@ -336,13 +346,13 @@ func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error
return user, nil
})
if err != nil {
return nil, change.EmptySet, err
return nil, change.Change{}, err
}
// Check if policy manager needs updating
c, err := s.updatePolicyManagerUsers()
if err != nil {
return user, change.EmptySet, fmt.Errorf("failed to update policy manager after user update: %w", err)
return user, change.Change{}, fmt.Errorf("failed to update policy manager after user update: %w", err)
}
// TODO(kradalby): We might want to update nodestore with the user data
@ -357,7 +367,7 @@ func (s *State) DeleteUser(userID types.UserID) error {
}
// RenameUser changes a user's name. The new name must be unique.
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.ChangeSet, error) {
func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.Change, error) {
return s.UpdateUser(userID, func(user *types.User) error {
user.Name = newName
return nil
@ -394,9 +404,9 @@ func (s *State) ListAllUsers() ([]types.User, error) {
// NodeStore and the database. It verifies the node still exists in NodeStore to prevent
// race conditions where a node might be deleted between UpdateNode returning and
// persistNodeToDB being called.
func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Change, error) {
if !node.Valid() {
return types.NodeView{}, change.EmptySet, fmt.Errorf("invalid node view provided")
return types.NodeView{}, change.Change{}, ErrInvalidNodeView
}
// Verify the node still exists in NodeStore before persisting to database.
@ -410,7 +420,8 @@ func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Cha
Str("node.name", node.Hostname()).
Bool("is_ephemeral", node.IsEphemeral()).
Msg("Node no longer exists in NodeStore, skipping database persist to prevent race condition")
return types.NodeView{}, change.EmptySet, fmt.Errorf("node %d no longer exists in NodeStore, skipping database persist", node.ID())
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, node.ID())
}
nodePtr := node.AsStruct()
@ -420,23 +431,23 @@ func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Cha
// See: https://github.com/juanfont/headscale/issues/2862
err := s.db.DB.Omit("expiry").Updates(nodePtr).Error
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("saving node: %w", err)
return types.NodeView{}, change.Change{}, fmt.Errorf("saving node: %w", err)
}
// Check if policy manager needs updating
c, err := s.updatePolicyManagerNodes()
if err != nil {
return nodePtr.View(), change.EmptySet, fmt.Errorf("failed to update policy manager after node save: %w", err)
return nodePtr.View(), change.Change{}, fmt.Errorf("failed to update policy manager after node save: %w", err)
}
if c.Empty() {
if c.IsEmpty() {
c = change.NodeAdded(node.ID())
}
return node, c, nil
}
func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet, error) {
func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.Change, error) {
// Update NodeStore first
nodePtr := node.AsStruct()
@ -448,12 +459,12 @@ func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.ChangeSet,
// DeleteNode permanently removes a node and cleans up associated resources.
// Returns whether policies changed and any error. This operation is irreversible.
func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
func (s *State) DeleteNode(node types.NodeView) (change.Change, error) {
s.nodeStore.DeleteNode(node.ID())
err := s.db.DeleteNode(node.AsStruct())
if err != nil {
return change.EmptySet, err
return change.Change{}, err
}
s.ipAlloc.FreeIPs(node.IPs())
@ -463,18 +474,20 @@ func (s *State) DeleteNode(node types.NodeView) (change.ChangeSet, error) {
// Check if policy manager needs updating after node deletion
policyChange, err := s.updatePolicyManagerNodes()
if err != nil {
return change.EmptySet, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
return change.Change{}, fmt.Errorf("failed to update policy manager after node deletion: %w", err)
}
if !policyChange.Empty() {
c = policyChange
if !policyChange.IsEmpty() {
// Merge policy change with NodeRemoved to preserve PeersRemoved info
// This ensures the batcher cleans up the deleted node from its state
c = c.Merge(policyChange)
}
return c, nil
}
// Connect marks a node as connected and updates its primary routes in the state.
func (s *State) Connect(id types.NodeID) []change.ChangeSet {
func (s *State) Connect(id types.NodeID) []change.Change {
// CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification
// This ensures that when the NodeCameOnline change is distributed and processed by other nodes,
// the NodeStore already reflects the correct online status for full map generation.
@ -487,7 +500,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet {
return nil
}
c := []change.ChangeSet{change.NodeOnline(node)}
c := []change.Change{change.NodeOnlineFor(node)}
log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected")
@ -504,7 +517,7 @@ func (s *State) Connect(id types.NodeID) []change.ChangeSet {
}
// Disconnect marks a node as disconnected and updates its primary routes in the state.
func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) {
func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) {
now := time.Now()
node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) {
@ -526,14 +539,15 @@ func (s *State) Disconnect(id types.NodeID) ([]change.ChangeSet, error) {
// Log error but don't fail the disconnection - NodeStore is already updated
// and we need to send change notifications to peers
log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database")
c = change.EmptySet
c = change.Change{}
}
// The node is disconnecting so make sure that none of the routes it
// announced are served to any nodes.
routeChange := s.primaryRoutes.SetRoutes(id)
cs := []change.ChangeSet{change.NodeOffline(node), c}
cs := []change.Change{change.NodeOfflineFor(node), c}
// If we have a policy change or route change, return that as it's more comprehensive
// Otherwise, return the NodeOffline change to ensure nodes are notified
@ -636,7 +650,7 @@ func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] {
}
// SetNodeExpiry updates the expiration time for a node.
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.ChangeSet, error) {
func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) {
// Update NodeStore before database to ensure consistency. The NodeStore update is
// blocking and will be the source of truth for the batcher. The database update must
// make the exact same change. If the database update fails, the NodeStore change will
@ -648,7 +662,7 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
})
if !ok {
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
}
return s.persistNodeToDB(n)
@ -657,24 +671,39 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
// SetNodeTags assigns tags to a node, making it a "tagged node".
// Once a node is tagged, it cannot be un-tagged (only tags can be changed).
// The UserID is preserved as "created by" information.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) {
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.Change, error) {
// CANNOT REMOVE ALL TAGS
if len(tags) == 0 {
return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags
return types.NodeView{}, change.Change{}, types.ErrCannotRemoveAllTags
}
// Get node for validation
existingNode, exists := s.nodeStore.GetNode(nodeID)
if !exists {
return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID)
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID)
}
// Validate tags against policy
validatedTags, err := s.validateAndNormalizeTags(existingNode.AsStruct(), tags)
if err != nil {
return types.NodeView{}, change.EmptySet, err
// Validate tags: must have correct format and exist in policy
validatedTags := make([]string, 0, len(tags))
invalidTags := make([]string, 0)
for _, tag := range tags {
if !strings.HasPrefix(tag, "tag:") || !s.polMan.TagExists(tag) {
invalidTags = append(invalidTags, tag)
continue
}
validatedTags = append(validatedTags, tag)
}
if len(invalidTags) > 0 {
return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, invalidTags)
}
slices.Sort(validatedTags)
validatedTags = slices.Compact(validatedTags)
// Log the operation
logTagOperation(existingNode, validatedTags)
@ -687,14 +716,14 @@ func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView,
})
if !ok {
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
}
return s.persistNodeToDB(n)
}
// SetApprovedRoutes sets the network routes that a node is approved to advertise.
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.ChangeSet, error) {
func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.Change, error) {
// TODO(kradalby): In principle we should call the AutoApprove logic here
// because even if the CLI removes an auto-approved route, it will be added
// back automatically.
@ -703,13 +732,13 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
})
if !ok {
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
}
// Persist the node changes to the database
nodeView, c, err := s.persistNodeToDB(n)
if err != nil {
return types.NodeView{}, change.EmptySet, err
return types.NodeView{}, change.Change{}, err
}
// Update primary routes table based on SubnetRoutes (intersection of announced and approved).
@ -727,9 +756,9 @@ func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (t
}
// RenameNode changes the display name of a node.
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.ChangeSet, error) {
func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) {
if err := util.ValidateHostname(newName); err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("renaming node: %w", err)
return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err)
}
// Check name uniqueness against NodeStore
@ -737,7 +766,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
for i := 0; i < allNodes.Len(); i++ {
node := allNodes.At(i)
if node.ID() != nodeID && node.AsStruct().GivenName == newName {
return types.NodeView{}, change.EmptySet, fmt.Errorf("name is not unique: %s", newName)
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName)
}
}
@ -749,7 +778,7 @@ func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView,
})
if !ok {
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", nodeID)
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID)
}
return s.persistNodeToDB(n)
@ -794,12 +823,12 @@ func (s *State) BackfillNodeIPs() ([]string, error) {
// ExpireExpiredNodes finds and processes expired nodes since the last check.
// Returns next check time, state update with expired nodes, and whether any were found.
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.ChangeSet, bool) {
func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Change, bool) {
// Why capture start time: We need to ensure we don't miss nodes that expire
// while this function is running by using a consistent timestamp for the next check
started := time.Now()
var updates []change.ChangeSet
var updates []change.Change
for _, node := range s.nodeStore.ListNodes().All() {
if !node.Valid() {
@ -809,7 +838,7 @@ func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Cha
// Why check After(lastCheck): We only want to notify about nodes that
// expired since the last check to avoid duplicate notifications
if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) {
updates = append(updates, change.KeyExpiry(node.ID(), node.Expiry().Get()))
updates = append(updates, change.KeyExpiryFor(node.ID(), node.Expiry().Get()))
}
}
@ -852,7 +881,7 @@ func (s *State) SetPolicy(pol []byte) (bool, error) {
// AutoApproveRoutes checks if a node's routes should be auto-approved.
// AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them.
func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) {
func (s *State) AutoApproveRoutes(nv types.NodeView) (change.Change, error) {
approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes())
if changed {
log.Debug().
@ -873,7 +902,7 @@ func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) {
Err(err).
Msg("Failed to persist auto-approved routes")
return change.EmptySet, err
return change.Change{}, err
}
log.Info().Uint64("node.id", nv.ID().Uint64()).Str("node.name", nv.Hostname()).Strs("routes.approved", util.PrefixesToString(approved)).Msg("Routes approved")
@ -881,7 +910,7 @@ func (s *State) AutoApproveRoutes(nv types.NodeView) (change.ChangeSet, error) {
return c, nil
}
return change.EmptySet, nil
return change.Change{}, nil
}
// GetPolicy retrieves the current policy from the database.
@ -895,14 +924,14 @@ func (s *State) SetPolicyInDB(data string) (*types.Policy, error) {
}
// SetNodeRoutes sets the primary routes for a node.
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.ChangeSet {
func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.Change {
if s.primaryRoutes.SetRoutes(nodeID, routes...) {
// Route changes affect packet filters for all nodes, so trigger a policy change
// to ensure filters are regenerated across the entire network
return change.PolicyChange()
}
return change.EmptySet
return change.Change{}
}
// GetNodePrimaryRoutes returns the primary routes for a node.
@ -1128,6 +1157,41 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
nodeToRegister.Tags = nil
}
// Reject advertise-tags for PreAuthKey registrations early, before any resource allocation.
// PreAuthKey nodes get their tags from the key itself, not from client requests.
if params.PreAuthKey != nil && params.Hostinfo != nil && len(params.Hostinfo.RequestTags) > 0 {
return types.NodeView{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, params.Hostinfo.RequestTags)
}
// Process RequestTags (from tailscale up --advertise-tags) ONLY for non-PreAuthKey registrations.
// Validate early before IP allocation to avoid resource leaks on failure.
if params.PreAuthKey == nil && params.Hostinfo != nil && len(params.Hostinfo.RequestTags) > 0 {
var approvedTags, rejectedTags []string
for _, tag := range params.Hostinfo.RequestTags {
if s.polMan.NodeCanHaveTag(nodeToRegister.View(), tag) {
approvedTags = append(approvedTags, tag)
} else {
rejectedTags = append(rejectedTags, tag)
}
}
// Reject registration if any requested tags are unauthorized
if len(rejectedTags) > 0 {
return types.NodeView{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, rejectedTags)
}
if len(approvedTags) > 0 {
nodeToRegister.Tags = approvedTags
slices.Sort(nodeToRegister.Tags)
nodeToRegister.Tags = slices.Compact(nodeToRegister.Tags)
log.Info().
Str("node.name", nodeToRegister.Hostname).
Strs("tags", nodeToRegister.Tags).
Msg("approved advertise-tags during registration")
}
}
// Validate before saving
err := validateNodeOwnership(&nodeToRegister)
if err != nil {
@ -1181,17 +1245,17 @@ func (s *State) HandleNodeFromAuthPath(
userID types.UserID,
expiry *time.Time,
registrationMethod string,
) (types.NodeView, change.ChangeSet, error) {
) (types.NodeView, change.Change, error) {
// Get the registration entry from cache
regEntry, ok := s.GetRegistrationCacheEntry(registrationID)
if !ok {
return types.NodeView{}, change.EmptySet, hsdb.ErrNodeNotFoundRegistrationCache
return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache
}
// Get the user
user, err := s.db.GetUserByID(userID)
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("failed to find user: %w", err)
return types.NodeView{}, change.Change{}, fmt.Errorf("failed to find user: %w", err)
}
// Ensure we have a valid hostname from the registration cache entry
@ -1255,7 +1319,7 @@ func (s *State) HandleNodeFromAuthPath(
})
if !ok {
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID())
}
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
@ -1267,7 +1331,7 @@ func (s *State) HandleNodeFromAuthPath(
return nil, nil
})
if err != nil {
return types.NodeView{}, change.EmptySet, err
return types.NodeView{}, change.Change{}, err
}
log.Trace().
@ -1325,7 +1389,7 @@ func (s *State) HandleNodeFromAuthPath(
ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}),
})
if err != nil {
return types.NodeView{}, change.EmptySet, err
return types.NodeView{}, change.Change{}, err
}
}
@ -1346,8 +1410,8 @@ func (s *State) HandleNodeFromAuthPath(
return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err)
}
var c change.ChangeSet
if !usersChange.Empty() || !nodesChange.Empty() {
var c change.Change
if !usersChange.IsEmpty() || !nodesChange.IsEmpty() {
c = change.PolicyChange()
} else {
c = change.NodeAdded(finalNode.ID())
@ -1360,10 +1424,10 @@ func (s *State) HandleNodeFromAuthPath(
func (s *State) HandleNodeFromPreAuthKey(
regReq tailcfg.RegisterRequest,
machineKey key.MachinePublic,
) (types.NodeView, change.ChangeSet, error) {
) (types.NodeView, change.Change, error) {
pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey)
if err != nil {
return types.NodeView{}, change.EmptySet, err
return types.NodeView{}, change.Change{}, err
}
// Check if node exists with same machine key before validating the key.
@ -1410,7 +1474,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// New node or NodeKey rotation: require valid auth key.
err = pak.Validate()
if err != nil {
return types.NodeView{}, change.EmptySet, err
return types.NodeView{}, change.Change{}, err
}
}
@ -1484,7 +1548,7 @@ func (s *State) HandleNodeFromPreAuthKey(
})
if !ok {
return types.NodeView{}, change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", existingNodeSameUser.ID())
return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID())
}
_, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) {
@ -1504,7 +1568,7 @@ func (s *State) HandleNodeFromPreAuthKey(
return nil, nil
})
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("writing node to database: %w", err)
return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err)
}
log.Trace().
@ -1556,7 +1620,7 @@ func (s *State) HandleNodeFromPreAuthKey(
ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}),
})
if err != nil {
return types.NodeView{}, change.EmptySet, fmt.Errorf("creating new node: %w", err)
return types.NodeView{}, change.Change{}, fmt.Errorf("creating new node: %w", err)
}
}
@ -1571,8 +1635,8 @@ func (s *State) HandleNodeFromPreAuthKey(
return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err)
}
var c change.ChangeSet
if !usersChange.Empty() || !nodesChange.Empty() {
var c change.Change
if !usersChange.IsEmpty() || !nodesChange.IsEmpty() {
c = change.PolicyChange()
} else {
c = change.NodeAdded(finalNode.ID())
@ -1587,17 +1651,17 @@ func (s *State) HandleNodeFromPreAuthKey(
// have the list already available so it could go much quicker. Alternatively
// the policy manager could have a remove or add list for users.
// updatePolicyManagerUsers refreshes the policy manager with current user data.
func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
func (s *State) updatePolicyManagerUsers() (change.Change, error) {
users, err := s.ListAllUsers()
if err != nil {
return change.EmptySet, fmt.Errorf("listing users for policy update: %w", err)
return change.Change{}, fmt.Errorf("listing users for policy update: %w", err)
}
log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected")
changed, err := s.polMan.SetUsers(users)
if err != nil {
return change.EmptySet, fmt.Errorf("updating policy manager users: %w", err)
return change.Change{}, fmt.Errorf("updating policy manager users: %w", err)
}
log.Debug().Caller().Bool("policy.changed", changed).Msg("Policy manager user update completed because SetUsers operation finished")
@ -1606,7 +1670,7 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
return change.PolicyChange(), nil
}
return change.EmptySet, nil
return change.Change{}, nil
}
// updatePolicyManagerNodes updates the policy manager with current nodes.
@ -1615,19 +1679,22 @@ func (s *State) updatePolicyManagerUsers() (change.ChangeSet, error) {
// have the list already available so it could go much quicker. Alternatively
// the policy manager could have a remove or add list for nodes.
// updatePolicyManagerNodes refreshes the policy manager with current node data.
func (s *State) updatePolicyManagerNodes() (change.ChangeSet, error) {
func (s *State) updatePolicyManagerNodes() (change.Change, error) {
nodes := s.ListNodes()
changed, err := s.polMan.SetNodes(nodes)
if err != nil {
return change.EmptySet, fmt.Errorf("updating policy manager nodes: %w", err)
return change.Change{}, fmt.Errorf("updating policy manager nodes: %w", err)
}
if changed {
// Rebuild peer maps because policy-affecting node changes (tags, user, IPs)
// affect ACL visibility. Without this, cached peer relationships use stale data.
s.nodeStore.RebuildPeerMaps()
return change.PolicyChange(), nil
}
return change.EmptySet, nil
return change.Change{}, nil
}
// PingDB checks if the database connection is healthy.
@ -1641,14 +1708,16 @@ func (s *State) PingDB(ctx context.Context) error {
// TODO(kradalby): This is kind of messy, maybe this is another +1
// for an event bus. See example comments here.
// autoApproveNodes automatically approves nodes based on policy rules.
func (s *State) autoApproveNodes() ([]change.ChangeSet, error) {
func (s *State) autoApproveNodes() ([]change.Change, error) {
nodes := s.ListNodes()
// Approve routes concurrently, this should make it likely
// that the writes end in the same batch in the nodestore write.
var errg errgroup.Group
var cs []change.ChangeSet
var mu sync.Mutex
var (
errg errgroup.Group
cs []change.Change
mu sync.Mutex
)
for _, nv := range nodes.All() {
errg.Go(func() error {
approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes())
@ -1689,7 +1758,7 @@ func (s *State) autoApproveNodes() ([]change.ChangeSet, error) {
// - node.PeerChangeFromMapRequest
// - node.ApplyPeerChange
// - logTracePeerChange in poll.go.
func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.ChangeSet, error) {
func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.Change, error) {
log.Trace().
Caller().
Uint64("node.id", id.Uint64()).
@ -1802,7 +1871,7 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
})
if !ok {
return change.EmptySet, fmt.Errorf("node not found in NodeStore: %d", id)
return change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, id)
}
if routeChange {
@ -1814,80 +1883,67 @@ func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest
// SetApprovedRoutes will update both database and PrimaryRoutes table
_, c, err := s.SetApprovedRoutes(id, autoApprovedRoutes)
if err != nil {
return change.EmptySet, fmt.Errorf("persisting auto-approved routes: %w", err)
return change.Change{}, fmt.Errorf("persisting auto-approved routes: %w", err)
}
// If SetApprovedRoutes resulted in a policy change, return it
if !c.Empty() {
if !c.IsEmpty() {
return c, nil
}
} // Continue with the rest of the processing using the updated node
nodeRouteChange := change.EmptySet
// Handle route changes after NodeStore update
// We need to update node routes if either:
// 1. The approved routes changed (routeChange is true), OR
// 2. The announced routes changed (even if approved routes stayed the same)
// This is because SubnetRoutes is the intersection of announced AND approved routes.
needsRouteUpdate := false
var routesChangedButNotApproved bool
if hostinfoChanged && needsRouteApproval && !routeChange {
if hi := req.Hostinfo; hi != nil {
routesChangedButNotApproved = true
}
}
if routesChangedButNotApproved {
needsRouteUpdate = true
log.Debug().
Caller().
Uint64("node.id", id.Uint64()).
Msg("updating routes because announced routes changed but approved routes did not")
}
if needsRouteUpdate {
// SetNodeRoutes sets the active/distributed routes, so we must use AllApprovedRoutes()
// which returns only the intersection of announced AND approved routes.
// Using AnnouncedRoutes() would bypass the security model and auto-approve everything.
log.Debug().
Caller().
Uint64("node.id", id.Uint64()).
Strs("announcedRoutes", util.PrefixesToString(updatedNode.AnnouncedRoutes())).
Strs("approvedRoutes", util.PrefixesToString(updatedNode.ApprovedRoutes().AsSlice())).
Strs("allApprovedRoutes", util.PrefixesToString(updatedNode.AllApprovedRoutes())).
Msg("updating node routes for distribution")
nodeRouteChange = s.SetNodeRoutes(id, updatedNode.AllApprovedRoutes()...)
}
// Handle route changes after NodeStore update.
// Update routes if announced routes changed (even if approved routes stayed the same)
// because SubnetRoutes is the intersection of announced AND approved routes.
nodeRouteChange := s.maybeUpdateNodeRoutes(id, updatedNode, hostinfoChanged, needsRouteApproval, routeChange, req.Hostinfo)
_, policyChange, err := s.persistNodeToDB(updatedNode)
if err != nil {
return change.EmptySet, fmt.Errorf("saving to database: %w", err)
return change.Change{}, fmt.Errorf("saving to database: %w", err)
}
if policyChange.IsFull() {
return policyChange, nil
}
if !nodeRouteChange.Empty() {
if !nodeRouteChange.IsEmpty() {
return nodeRouteChange, nil
}
// Determine the most specific change type based on what actually changed.
// This allows us to send lightweight patch updates instead of full map responses.
return buildMapRequestChangeResponse(id, updatedNode, hostinfoChanged, endpointChanged, derpChanged)
}
// buildMapRequestChangeResponse determines the appropriate response type for a MapRequest update.
// Hostinfo changes require a full update, while endpoint/DERP changes can use lightweight patches.
func buildMapRequestChangeResponse(
id types.NodeID,
node types.NodeView,
hostinfoChanged, endpointChanged, derpChanged bool,
) (change.Change, error) {
// Hostinfo changes require NodeAdded (full update) as they may affect many fields.
if hostinfoChanged {
return change.NodeAdded(id), nil
}
// Return specific change types for endpoint and/or DERP updates.
// The batcher will query NodeStore for current state and include both in PeerChange if both changed.
// Prioritize endpoint changes as they're more common and important for connectivity.
if endpointChanged {
return change.EndpointUpdate(id), nil
}
if endpointChanged || derpChanged {
patch := &tailcfg.PeerChange{NodeID: id.NodeID()}
if derpChanged {
return change.DERPUpdate(id), nil
if endpointChanged {
patch.Endpoints = node.Endpoints().AsSlice()
}
if derpChanged {
if hi := node.Hostinfo(); hi.Valid() {
if ni := hi.NetInfo(); ni.Valid() {
patch.DERPRegion = ni.PreferredDERP()
}
}
}
return change.EndpointOrDERPUpdate(id, patch), nil
}
return change.NodeAdded(id), nil
@ -1932,3 +1988,34 @@ func peerChangeEmpty(peerChange tailcfg.PeerChange) bool {
peerChange.LastSeen == nil &&
peerChange.KeyExpiry == nil
}
// maybeUpdateNodeRoutes updates node routes if announced routes changed but approved routes didn't.
// This is needed because SubnetRoutes is the intersection of announced AND approved routes.
func (s *State) maybeUpdateNodeRoutes(
id types.NodeID,
node types.NodeView,
hostinfoChanged, needsRouteApproval, routeChange bool,
hostinfo *tailcfg.Hostinfo,
) change.Change {
// Only update if announced routes changed without approval change
if !hostinfoChanged || !needsRouteApproval || routeChange || hostinfo == nil {
return change.Change{}
}
log.Debug().
Caller().
Uint64("node.id", id.Uint64()).
Msg("updating routes because announced routes changed but approved routes did not")
// SetNodeRoutes sets the active/distributed routes using AllApprovedRoutes()
// which returns only the intersection of announced AND approved routes.
log.Debug().
Caller().
Uint64("node.id", id.Uint64()).
Strs("announcedRoutes", util.PrefixesToString(node.AnnouncedRoutes())).
Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())).
Strs("allApprovedRoutes", util.PrefixesToString(node.AllApprovedRoutes())).
Msg("updating node routes for distribution")
return s.SetNodeRoutes(id, node.AllApprovedRoutes()...)
}

View file

@ -3,8 +3,6 @@ package state
import (
"errors"
"fmt"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
@ -17,8 +15,9 @@ var (
// ErrNodeHasNeitherUserNorTags is returned when a node has neither a user nor tags.
ErrNodeHasNeitherUserNorTags = errors.New("node has neither user nor tags - must be owned by user or tagged")
// ErrInvalidOrUnauthorizedTags is returned when tags are invalid or unauthorized.
ErrInvalidOrUnauthorizedTags = errors.New("invalid or unauthorized tags")
// ErrRequestedTagsInvalidOrNotPermitted is returned when requested tags are invalid or not permitted.
// This message format matches Tailscale SaaS: "requested tags [tag:xxx] are invalid or not permitted".
ErrRequestedTagsInvalidOrNotPermitted = errors.New("requested tags")
)
// validateNodeOwnership ensures proper node ownership model.
@ -44,44 +43,6 @@ func validateNodeOwnership(node *types.Node) error {
return nil
}
// validateAndNormalizeTags validates tags against policy and normalizes them.
// Returns validated and normalized tags, or an error if validation fails.
func (s *State) validateAndNormalizeTags(node *types.Node, requestedTags []string) ([]string, error) {
if len(requestedTags) == 0 {
return nil, nil
}
var (
validTags []string
invalidTags []string
)
for _, tag := range requestedTags {
// Validate format
if !strings.HasPrefix(tag, "tag:") {
invalidTags = append(invalidTags, tag)
continue
}
// Validate against policy
nodeView := node.View()
if s.polMan.NodeCanHaveTag(nodeView, tag) {
validTags = append(validTags, tag)
} else {
invalidTags = append(invalidTags, tag)
}
}
if len(invalidTags) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidOrUnauthorizedTags, invalidTags)
}
// Normalize: sort and deduplicate
slices.Sort(validTags)
return slices.Compact(validTags), nil
}
// logTagOperation logs tag assignment operations for audit purposes.
func logTagOperation(existingNode types.NodeView, newTags []string) {
if existingNode.IsTagged() {

View file

@ -1,241 +1,445 @@
//go:generate go tool stringer -type=Change
package change
import (
"errors"
"slices"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
)
type (
NodeID = types.NodeID
UserID = types.UserID
)
// Change declares what should be included in a MapResponse.
// The mapper uses this to build the response without guessing.
type Change struct {
// Reason is a human-readable description for logging/debugging.
Reason string
type Change int
// TargetNode, if set, means this response should only be sent to this node.
TargetNode types.NodeID
const (
ChangeUnknown Change = 0
// OriginNode is the node that triggered this change.
// Used for self-update detection and filtering.
OriginNode types.NodeID
// Deprecated: Use specific change instead
// Full is a legacy change to ensure places where we
// have not yet determined the specific update, can send.
Full Change = 9
// Content flags - what to include in the MapResponse.
IncludeSelf bool
IncludeDERPMap bool
IncludeDNS bool
IncludeDomain bool
IncludePolicy bool // PacketFilters and SSHPolicy - always sent together
// Server changes.
Policy Change = 11
DERP Change = 12
ExtraRecords Change = 13
// Peer changes.
PeersChanged []types.NodeID
PeersRemoved []types.NodeID
PeerPatches []*tailcfg.PeerChange
SendAllPeers bool
// Node changes.
NodeCameOnline Change = 21
NodeWentOffline Change = 22
NodeRemove Change = 23
NodeKeyExpiry Change = 24
NodeNewOrUpdate Change = 25
NodeEndpoint Change = 26
NodeDERP Change = 27
// RequiresRuntimePeerComputation indicates that peer visibility
// must be computed at runtime per-node. Used for policy changes
// where each node may have different peer visibility.
RequiresRuntimePeerComputation bool
}
// User changes.
UserNewOrUpdate Change = 51
UserRemove Change = 52
)
// boolFieldNames returns all boolean field names for exhaustive testing.
// When adding a new boolean field to Change, add it here.
// Tests use reflection to verify this matches the struct.
func (r Change) boolFieldNames() []string {
return []string{
"IncludeSelf",
"IncludeDERPMap",
"IncludeDNS",
"IncludeDomain",
"IncludePolicy",
"SendAllPeers",
"RequiresRuntimePeerComputation",
}
}
// AlsoSelf reports whether this change should also be sent to the node itself.
func (c Change) AlsoSelf() bool {
switch c {
case NodeRemove, NodeKeyExpiry, NodeNewOrUpdate:
return true
func (r Change) Merge(other Change) Change {
merged := r
merged.IncludeSelf = r.IncludeSelf || other.IncludeSelf
merged.IncludeDERPMap = r.IncludeDERPMap || other.IncludeDERPMap
merged.IncludeDNS = r.IncludeDNS || other.IncludeDNS
merged.IncludeDomain = r.IncludeDomain || other.IncludeDomain
merged.IncludePolicy = r.IncludePolicy || other.IncludePolicy
merged.SendAllPeers = r.SendAllPeers || other.SendAllPeers
merged.RequiresRuntimePeerComputation = r.RequiresRuntimePeerComputation || other.RequiresRuntimePeerComputation
merged.PeersChanged = uniqueNodeIDs(append(r.PeersChanged, other.PeersChanged...))
merged.PeersRemoved = uniqueNodeIDs(append(r.PeersRemoved, other.PeersRemoved...))
merged.PeerPatches = append(r.PeerPatches, other.PeerPatches...)
if r.Reason != "" && other.Reason != "" && r.Reason != other.Reason {
merged.Reason = r.Reason + "; " + other.Reason
} else if other.Reason != "" {
merged.Reason = other.Reason
}
return false
return merged
}
type ChangeSet struct {
Change Change
// SelfUpdateOnly indicates that this change should only be sent
// to the node itself, and not to other nodes.
// This is used for changes that are not relevant to other nodes.
// NodeID must be set if this is true.
SelfUpdateOnly bool
// NodeID if set, is the ID of the node that is being changed.
// It must be set if this is a node change.
NodeID types.NodeID
// UserID if set, is the ID of the user that is being changed.
// It must be set if this is a user change.
UserID types.UserID
// IsSubnetRouter indicates whether the node is a subnet router.
IsSubnetRouter bool
// NodeExpiry is set if the change is NodeKeyExpiry.
NodeExpiry *time.Time
}
func (c *ChangeSet) Validate() error {
if c.Change >= NodeCameOnline || c.Change <= NodeNewOrUpdate {
if c.NodeID == 0 {
return errors.New("ChangeSet.NodeID must be set for node updates")
}
func (r Change) IsEmpty() bool {
if r.IncludeSelf || r.IncludeDERPMap || r.IncludeDNS ||
r.IncludeDomain || r.IncludePolicy || r.SendAllPeers {
return false
}
if c.Change >= UserNewOrUpdate || c.Change <= UserRemove {
if c.UserID == 0 {
return errors.New("ChangeSet.UserID must be set for user updates")
}
if r.RequiresRuntimePeerComputation {
return false
}
return nil
return len(r.PeersChanged) == 0 &&
len(r.PeersRemoved) == 0 &&
len(r.PeerPatches) == 0
}
// Empty reports whether the ChangeSet is empty, meaning it does not
// represent any change.
func (c ChangeSet) Empty() bool {
return c.Change == ChangeUnknown && c.NodeID == 0 && c.UserID == 0
func (r Change) IsSelfOnly() bool {
if r.TargetNode == 0 || !r.IncludeSelf {
return false
}
if r.SendAllPeers || len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || len(r.PeerPatches) > 0 {
return false
}
return true
}
// IsFull reports whether the ChangeSet represents a full update.
func (c ChangeSet) IsFull() bool {
return c.Change == Full || c.Change == Policy
// IsTargetedToNode returns true if this response should only be sent to TargetNode.
func (r Change) IsTargetedToNode() bool {
return r.TargetNode != 0
}
func HasFull(cs []ChangeSet) bool {
for _, c := range cs {
if c.IsFull() {
// IsFull reports whether this is a full update response.
func (r Change) IsFull() bool {
return r.SendAllPeers && r.IncludeSelf && r.IncludeDERPMap &&
r.IncludeDNS && r.IncludeDomain && r.IncludePolicy
}
// Type returns a categorized type string for metrics.
// This provides a bounded set of values suitable for Prometheus labels,
// unlike Reason which is free-form text for logging.
func (r Change) Type() string {
if r.IsFull() {
return "full"
}
if r.IsSelfOnly() {
return "self"
}
if r.RequiresRuntimePeerComputation {
return "policy"
}
if len(r.PeerPatches) > 0 && len(r.PeersChanged) == 0 && len(r.PeersRemoved) == 0 && !r.SendAllPeers {
return "patch"
}
if len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || r.SendAllPeers {
return "peers"
}
if r.IncludeDERPMap || r.IncludeDNS || r.IncludeDomain || r.IncludePolicy {
return "config"
}
return "unknown"
}
// ShouldSendToNode determines if this response should be sent to nodeID.
// It handles self-only targeting and filtering out self-updates for non-origin nodes.
func (r Change) ShouldSendToNode(nodeID types.NodeID) bool {
// If targeted to a specific node, only send to that node
if r.TargetNode != 0 {
return r.TargetNode == nodeID
}
return true
}
// HasFull returns true if any response in the slice is a full update.
func HasFull(rs []Change) bool {
for _, r := range rs {
if r.IsFull() {
return true
}
}
return false
}
func SplitAllAndSelf(cs []ChangeSet) (all []ChangeSet, self []ChangeSet) {
for _, c := range cs {
if c.SelfUpdateOnly {
self = append(self, c)
// SplitTargetedAndBroadcast separates responses into targeted (to specific node) and broadcast.
func SplitTargetedAndBroadcast(rs []Change) ([]Change, []Change) {
var broadcast, targeted []Change
for _, r := range rs {
if r.IsTargetedToNode() {
targeted = append(targeted, r)
} else {
all = append(all, c)
broadcast = append(broadcast, r)
}
}
return all, self
return broadcast, targeted
}
func RemoveUpdatesForSelf(id types.NodeID, cs []ChangeSet) (ret []ChangeSet) {
for _, c := range cs {
if c.NodeID != id || c.Change.AlsoSelf() {
ret = append(ret, c)
// FilterForNode returns responses that should be sent to the given node.
func FilterForNode(nodeID types.NodeID, rs []Change) []Change {
var result []Change
for _, r := range rs {
if r.ShouldSendToNode(nodeID) {
result = append(result, r)
}
}
return ret
return result
}
// IsSelfUpdate reports whether this ChangeSet represents an update to the given node itself.
func (c ChangeSet) IsSelfUpdate(nodeID types.NodeID) bool {
return c.NodeID == nodeID
}
func (c ChangeSet) AlsoSelf() bool {
// If NodeID is 0, it means this ChangeSet is not related to a specific node,
// so we consider it as a change that should be sent to all nodes.
if c.NodeID == 0 {
return true
func uniqueNodeIDs(ids []types.NodeID) []types.NodeID {
if len(ids) == 0 {
return nil
}
return c.Change.AlsoSelf() || c.SelfUpdateOnly
slices.Sort(ids)
return slices.Compact(ids)
}
var (
EmptySet = ChangeSet{Change: ChangeUnknown}
FullSet = ChangeSet{Change: Full}
DERPSet = ChangeSet{Change: DERP}
PolicySet = ChangeSet{Change: Policy}
ExtraRecordsSet = ChangeSet{Change: ExtraRecords}
)
// Constructor functions
func FullSelf(id types.NodeID) ChangeSet {
return ChangeSet{
Change: Full,
SelfUpdateOnly: true,
NodeID: id,
func FullUpdate() Change {
return Change{
Reason: "full update",
IncludeSelf: true,
IncludeDERPMap: true,
IncludeDNS: true,
IncludeDomain: true,
IncludePolicy: true,
SendAllPeers: true,
}
}
func NodeAdded(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeNewOrUpdate,
NodeID: id,
// FullSelf returns a full update targeted at a specific node.
func FullSelf(nodeID types.NodeID) Change {
return Change{
Reason: "full self update",
TargetNode: nodeID,
IncludeSelf: true,
IncludeDERPMap: true,
IncludeDNS: true,
IncludeDomain: true,
IncludePolicy: true,
SendAllPeers: true,
}
}
func NodeRemoved(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeRemove,
NodeID: id,
func SelfUpdate(nodeID types.NodeID) Change {
return Change{
Reason: "self update",
TargetNode: nodeID,
IncludeSelf: true,
}
}
func NodeOnline(node types.NodeView) ChangeSet {
return ChangeSet{
Change: NodeCameOnline,
NodeID: node.ID(),
IsSubnetRouter: node.IsSubnetRouter(),
func PolicyOnly() Change {
return Change{
Reason: "policy update",
IncludePolicy: true,
}
}
func NodeOffline(node types.NodeView) ChangeSet {
return ChangeSet{
Change: NodeWentOffline,
NodeID: node.ID(),
IsSubnetRouter: node.IsSubnetRouter(),
func PolicyAndPeers(changedPeers ...types.NodeID) Change {
return Change{
Reason: "policy and peers update",
IncludePolicy: true,
PeersChanged: changedPeers,
}
}
func KeyExpiry(id types.NodeID, expiry time.Time) ChangeSet {
return ChangeSet{
Change: NodeKeyExpiry,
NodeID: id,
NodeExpiry: &expiry,
func VisibilityChange(reason string, added, removed []types.NodeID) Change {
return Change{
Reason: reason,
IncludePolicy: true,
PeersChanged: added,
PeersRemoved: removed,
}
}
func EndpointUpdate(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeEndpoint,
NodeID: id,
func PeersChanged(reason string, peerIDs ...types.NodeID) Change {
return Change{
Reason: reason,
PeersChanged: peerIDs,
}
}
func DERPUpdate(id types.NodeID) ChangeSet {
return ChangeSet{
Change: NodeDERP,
NodeID: id,
func PeersRemoved(peerIDs ...types.NodeID) Change {
return Change{
Reason: "peers removed",
PeersRemoved: peerIDs,
}
}
func UserAdded(id types.UserID) ChangeSet {
return ChangeSet{
Change: UserNewOrUpdate,
UserID: id,
func PeerPatched(reason string, patches ...*tailcfg.PeerChange) Change {
return Change{
Reason: reason,
PeerPatches: patches,
}
}
func UserRemoved(id types.UserID) ChangeSet {
return ChangeSet{
Change: UserRemove,
UserID: id,
func DERPMap() Change {
return Change{
Reason: "DERP map update",
IncludeDERPMap: true,
}
}
func PolicyChange() ChangeSet {
return ChangeSet{
Change: Policy,
// PolicyChange creates a response for policy changes.
// Policy changes require runtime peer visibility computation.
func PolicyChange() Change {
return Change{
Reason: "policy change",
IncludePolicy: true,
RequiresRuntimePeerComputation: true,
}
}
func DERPChange() ChangeSet {
return ChangeSet{
Change: DERP,
// DNSConfig creates a response for DNS configuration updates.
func DNSConfig() Change {
return Change{
Reason: "DNS config update",
IncludeDNS: true,
}
}
// NodeOnline creates a patch response for a node coming online.
func NodeOnline(nodeID types.NodeID) Change {
return Change{
Reason: "node online",
PeerPatches: []*tailcfg.PeerChange{
{
NodeID: nodeID.NodeID(),
Online: ptrTo(true),
},
},
}
}
// NodeOffline creates a patch response for a node going offline.
func NodeOffline(nodeID types.NodeID) Change {
return Change{
Reason: "node offline",
PeerPatches: []*tailcfg.PeerChange{
{
NodeID: nodeID.NodeID(),
Online: ptrTo(false),
},
},
}
}
// KeyExpiry creates a patch response for a node's key expiry change.
func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change {
return Change{
Reason: "key expiry",
PeerPatches: []*tailcfg.PeerChange{
{
NodeID: nodeID.NodeID(),
KeyExpiry: expiry,
},
},
}
}
// ptrTo returns a pointer to the given value.
func ptrTo[T any](v T) *T {
return &v
}
// High-level change constructors
// NodeAdded returns a Change for when a node is added or updated.
// The OriginNode field enables self-update detection by the mapper.
func NodeAdded(id types.NodeID) Change {
c := PeersChanged("node added", id)
c.OriginNode = id
return c
}
// NodeRemoved returns a Change for when a node is removed.
func NodeRemoved(id types.NodeID) Change {
return PeersRemoved(id)
}
// NodeOnlineFor returns a Change for when a node comes online.
// If the node is a subnet router, a full update is sent instead of a patch.
func NodeOnlineFor(node types.NodeView) Change {
if node.IsSubnetRouter() {
c := FullUpdate()
c.Reason = "subnet router online"
return c
}
return NodeOnline(node.ID())
}
// NodeOfflineFor returns a Change for when a node goes offline.
// If the node is a subnet router, a full update is sent instead of a patch.
func NodeOfflineFor(node types.NodeView) Change {
if node.IsSubnetRouter() {
c := FullUpdate()
c.Reason = "subnet router offline"
return c
}
return NodeOffline(node.ID())
}
// KeyExpiryFor returns a Change for when a node's key expiry changes.
// The OriginNode field enables self-update detection by the mapper.
func KeyExpiryFor(id types.NodeID, expiry time.Time) Change {
c := KeyExpiry(id, &expiry)
c.OriginNode = id
return c
}
// EndpointOrDERPUpdate returns a Change for when a node's endpoints or DERP region changes.
// The OriginNode field enables self-update detection by the mapper.
func EndpointOrDERPUpdate(id types.NodeID, patch *tailcfg.PeerChange) Change {
c := PeerPatched("endpoint/DERP update", patch)
c.OriginNode = id
return c
}
// UserAdded returns a Change for when a user is added or updated.
// A full update is sent to refresh user profiles on all nodes.
func UserAdded() Change {
c := FullUpdate()
c.Reason = "user added"
return c
}
// UserRemoved returns a Change for when a user is removed.
// A full update is sent to refresh user profiles on all nodes.
func UserRemoved() Change {
c := FullUpdate()
c.Reason = "user removed"
return c
}
// ExtraRecords returns a Change for when DNS extra records change.
func ExtraRecords() Change {
c := DNSConfig()
c.Reason = "extra records update"
return c
}

View file

@ -1,59 +0,0 @@
// Code generated by "stringer -type=Change"; DO NOT EDIT.
package change
import "strconv"
func _() {
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
var x [1]struct{}
_ = x[ChangeUnknown-0]
_ = x[Full-9]
_ = x[Policy-11]
_ = x[DERP-12]
_ = x[ExtraRecords-13]
_ = x[NodeCameOnline-21]
_ = x[NodeWentOffline-22]
_ = x[NodeRemove-23]
_ = x[NodeKeyExpiry-24]
_ = x[NodeNewOrUpdate-25]
_ = x[NodeEndpoint-26]
_ = x[NodeDERP-27]
_ = x[UserNewOrUpdate-51]
_ = x[UserRemove-52]
}
const (
_Change_name_0 = "ChangeUnknown"
_Change_name_1 = "Full"
_Change_name_2 = "PolicyDERPExtraRecords"
_Change_name_3 = "NodeCameOnlineNodeWentOfflineNodeRemoveNodeKeyExpiryNodeNewOrUpdateNodeEndpointNodeDERP"
_Change_name_4 = "UserNewOrUpdateUserRemove"
)
var (
_Change_index_2 = [...]uint8{0, 6, 10, 22}
_Change_index_3 = [...]uint8{0, 14, 29, 39, 52, 67, 79, 87}
_Change_index_4 = [...]uint8{0, 15, 25}
)
func (i Change) String() string {
switch {
case i == 0:
return _Change_name_0
case i == 9:
return _Change_name_1
case 11 <= i && i <= 13:
i -= 11
return _Change_name_2[_Change_index_2[i]:_Change_index_2[i+1]]
case 21 <= i && i <= 27:
i -= 21
return _Change_name_3[_Change_index_3[i]:_Change_index_3[i+1]]
case 51 <= i && i <= 52:
i -= 51
return _Change_name_4[_Change_index_4[i]:_Change_index_4[i+1]]
default:
return "Change(" + strconv.FormatInt(int64(i), 10) + ")"
}
}

View file

@ -0,0 +1,449 @@
package change
import (
"reflect"
"testing"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"tailscale.com/tailcfg"
)
func TestChange_FieldSync(t *testing.T) {
r := Change{}
fieldNames := r.boolFieldNames()
typ := reflect.TypeFor[Change]()
boolCount := 0
for i := range typ.NumField() {
if typ.Field(i).Type.Kind() == reflect.Bool {
boolCount++
}
}
if len(fieldNames) != boolCount {
t.Fatalf("boolFieldNames() returns %d fields but struct has %d bool fields; "+
"update boolFieldNames() when adding new bool fields", len(fieldNames), boolCount)
}
}
func TestChange_IsEmpty(t *testing.T) {
tests := []struct {
name string
response Change
want bool
}{
{
name: "zero value is empty",
response: Change{},
want: true,
},
{
name: "only reason is still empty",
response: Change{Reason: "test"},
want: true,
},
{
name: "IncludeSelf not empty",
response: Change{IncludeSelf: true},
want: false,
},
{
name: "IncludeDERPMap not empty",
response: Change{IncludeDERPMap: true},
want: false,
},
{
name: "IncludeDNS not empty",
response: Change{IncludeDNS: true},
want: false,
},
{
name: "IncludeDomain not empty",
response: Change{IncludeDomain: true},
want: false,
},
{
name: "IncludePolicy not empty",
response: Change{IncludePolicy: true},
want: false,
},
{
name: "SendAllPeers not empty",
response: Change{SendAllPeers: true},
want: false,
},
{
name: "PeersChanged not empty",
response: Change{PeersChanged: []types.NodeID{1}},
want: false,
},
{
name: "PeersRemoved not empty",
response: Change{PeersRemoved: []types.NodeID{1}},
want: false,
},
{
name: "PeerPatches not empty",
response: Change{PeerPatches: []*tailcfg.PeerChange{{}}},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.response.IsEmpty()
assert.Equal(t, tt.want, got)
})
}
}
func TestChange_IsSelfOnly(t *testing.T) {
tests := []struct {
name string
response Change
want bool
}{
{
name: "empty is not self only",
response: Change{},
want: false,
},
{
name: "IncludeSelf without TargetNode is not self only",
response: Change{IncludeSelf: true},
want: false,
},
{
name: "TargetNode without IncludeSelf is not self only",
response: Change{TargetNode: 1},
want: false,
},
{
name: "TargetNode with IncludeSelf is self only",
response: Change{TargetNode: 1, IncludeSelf: true},
want: true,
},
{
name: "self only with SendAllPeers is not self only",
response: Change{TargetNode: 1, IncludeSelf: true, SendAllPeers: true},
want: false,
},
{
name: "self only with PeersChanged is not self only",
response: Change{TargetNode: 1, IncludeSelf: true, PeersChanged: []types.NodeID{2}},
want: false,
},
{
name: "self only with PeersRemoved is not self only",
response: Change{TargetNode: 1, IncludeSelf: true, PeersRemoved: []types.NodeID{2}},
want: false,
},
{
name: "self only with PeerPatches is not self only",
response: Change{TargetNode: 1, IncludeSelf: true, PeerPatches: []*tailcfg.PeerChange{{}}},
want: false,
},
{
name: "self only with other include flags is still self only",
response: Change{
TargetNode: 1,
IncludeSelf: true,
IncludePolicy: true,
IncludeDNS: true,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.response.IsSelfOnly()
assert.Equal(t, tt.want, got)
})
}
}
func TestChange_Merge(t *testing.T) {
tests := []struct {
name string
r1 Change
r2 Change
want Change
}{
{
name: "empty merge",
r1: Change{},
r2: Change{},
want: Change{},
},
{
name: "bool fields OR together",
r1: Change{IncludeSelf: true, IncludePolicy: true},
r2: Change{IncludeDERPMap: true, IncludePolicy: true},
want: Change{IncludeSelf: true, IncludeDERPMap: true, IncludePolicy: true},
},
{
name: "all bool fields merge",
r1: Change{IncludeSelf: true, IncludeDNS: true, IncludePolicy: true},
r2: Change{IncludeDERPMap: true, IncludeDomain: true, SendAllPeers: true},
want: Change{
IncludeSelf: true,
IncludeDERPMap: true,
IncludeDNS: true,
IncludeDomain: true,
IncludePolicy: true,
SendAllPeers: true,
},
},
{
name: "peers deduplicated and sorted",
r1: Change{PeersChanged: []types.NodeID{3, 1}},
r2: Change{PeersChanged: []types.NodeID{2, 1}},
want: Change{PeersChanged: []types.NodeID{1, 2, 3}},
},
{
name: "peers removed deduplicated",
r1: Change{PeersRemoved: []types.NodeID{1, 2}},
r2: Change{PeersRemoved: []types.NodeID{2, 3}},
want: Change{PeersRemoved: []types.NodeID{1, 2, 3}},
},
{
name: "peer patches concatenated",
r1: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}}},
r2: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 2}}},
want: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}, {NodeID: 2}}},
},
{
name: "reasons combined when different",
r1: Change{Reason: "route change"},
r2: Change{Reason: "tag change"},
want: Change{Reason: "route change; tag change"},
},
{
name: "same reason not duplicated",
r1: Change{Reason: "policy"},
r2: Change{Reason: "policy"},
want: Change{Reason: "policy"},
},
{
name: "empty reason takes other",
r1: Change{},
r2: Change{Reason: "update"},
want: Change{Reason: "update"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.r1.Merge(tt.r2)
assert.Equal(t, tt.want, got)
})
}
}
func TestChange_Constructors(t *testing.T) {
tests := []struct {
name string
constructor func() Change
wantReason string
want Change
}{
{
name: "FullUpdateResponse",
constructor: FullUpdate,
wantReason: "full update",
want: Change{
Reason: "full update",
IncludeSelf: true,
IncludeDERPMap: true,
IncludeDNS: true,
IncludeDomain: true,
IncludePolicy: true,
SendAllPeers: true,
},
},
{
name: "PolicyOnlyResponse",
constructor: PolicyOnly,
wantReason: "policy update",
want: Change{
Reason: "policy update",
IncludePolicy: true,
},
},
{
name: "DERPMapResponse",
constructor: DERPMap,
wantReason: "DERP map update",
want: Change{
Reason: "DERP map update",
IncludeDERPMap: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := tt.constructor()
assert.Equal(t, tt.wantReason, r.Reason)
assert.Equal(t, tt.want, r)
})
}
}
func TestSelfUpdate(t *testing.T) {
r := SelfUpdate(42)
assert.Equal(t, "self update", r.Reason)
assert.Equal(t, types.NodeID(42), r.TargetNode)
assert.True(t, r.IncludeSelf)
assert.True(t, r.IsSelfOnly())
}
func TestPolicyAndPeers(t *testing.T) {
r := PolicyAndPeers(1, 2, 3)
assert.Equal(t, "policy and peers update", r.Reason)
assert.True(t, r.IncludePolicy)
assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersChanged)
}
func TestVisibilityChange(t *testing.T) {
r := VisibilityChange("tag change", []types.NodeID{1}, []types.NodeID{2, 3})
assert.Equal(t, "tag change", r.Reason)
assert.True(t, r.IncludePolicy)
assert.Equal(t, []types.NodeID{1}, r.PeersChanged)
assert.Equal(t, []types.NodeID{2, 3}, r.PeersRemoved)
}
func TestPeersChanged(t *testing.T) {
r := PeersChanged("routes approved", 1, 2)
assert.Equal(t, "routes approved", r.Reason)
assert.Equal(t, []types.NodeID{1, 2}, r.PeersChanged)
assert.False(t, r.IncludePolicy)
}
func TestPeersRemoved(t *testing.T) {
r := PeersRemoved(1, 2, 3)
assert.Equal(t, "peers removed", r.Reason)
assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersRemoved)
}
func TestPeerPatched(t *testing.T) {
patch := &tailcfg.PeerChange{NodeID: 1}
r := PeerPatched("endpoint change", patch)
assert.Equal(t, "endpoint change", r.Reason)
assert.Equal(t, []*tailcfg.PeerChange{patch}, r.PeerPatches)
}
func TestChange_Type(t *testing.T) {
tests := []struct {
name string
response Change
want string
}{
{
name: "full update",
response: FullUpdate(),
want: "full",
},
{
name: "self only",
response: SelfUpdate(1),
want: "self",
},
{
name: "policy with runtime computation",
response: PolicyChange(),
want: "policy",
},
{
name: "patch only",
response: PeerPatched("test", &tailcfg.PeerChange{NodeID: 1}),
want: "patch",
},
{
name: "peers changed",
response: PeersChanged("test", 1, 2),
want: "peers",
},
{
name: "peers removed",
response: PeersRemoved(1, 2),
want: "peers",
},
{
name: "config - DERP map",
response: DERPMap(),
want: "config",
},
{
name: "config - DNS",
response: DNSConfig(),
want: "config",
},
{
name: "config - policy only (no runtime)",
response: PolicyOnly(),
want: "config",
},
{
name: "empty is unknown",
response: Change{},
want: "unknown",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.response.Type()
assert.Equal(t, tt.want, got)
})
}
}
func TestUniqueNodeIDs(t *testing.T) {
tests := []struct {
name string
input []types.NodeID
want []types.NodeID
}{
{
name: "nil input",
input: nil,
want: nil,
},
{
name: "empty input",
input: []types.NodeID{},
want: nil,
},
{
name: "single element",
input: []types.NodeID{1},
want: []types.NodeID{1},
},
{
name: "no duplicates",
input: []types.NodeID{1, 2, 3},
want: []types.NodeID{1, 2, 3},
},
{
name: "with duplicates",
input: []types.NodeID{3, 1, 2, 1, 3},
want: []types.NodeID{1, 2, 3},
},
{
name: "all same",
input: []types.NodeID{5, 5, 5, 5},
want: []types.NodeID{5},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := uniqueNodeIDs(tt.input)
assert.Equal(t, tt.want, got)
})
}
}

View file

@ -94,6 +94,7 @@ type Config struct {
LogTail LogTailConfig
RandomizeClientPort bool
Taildrop TaildropConfig
CLI CLIConfig
@ -185,6 +186,7 @@ type OIDCConfig struct {
AllowedDomains []string
AllowedUsers []string
AllowedGroups []string
EmailVerifiedRequired bool
Expiry time.Duration
UseExpiryFromToken bool
PKCE PKCEConfig
@ -212,6 +214,10 @@ type LogTailConfig struct {
Enabled bool
}
type TaildropConfig struct {
Enabled bool
}
type CLIConfig struct {
Address string
APIKey string
@ -380,9 +386,11 @@ func LoadConfig(path string, isFile bool) error {
viper.SetDefault("oidc.use_expiry_from_token", false)
viper.SetDefault("oidc.pkce.enabled", false)
viper.SetDefault("oidc.pkce.method", "S256")
viper.SetDefault("oidc.email_verified_required", true)
viper.SetDefault("logtail.enabled", false)
viper.SetDefault("randomize_client_port", false)
viper.SetDefault("taildrop.enabled", true)
viper.SetDefault("ephemeral_node_inactivity_timeout", "120s")
@ -1097,14 +1105,15 @@ func LoadServerConfig() (*Config, error) {
OnlyStartIfOIDCIsAvailable: viper.GetBool(
"oidc.only_start_if_oidc_is_available",
),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
Issuer: viper.GetString("oidc.issuer"),
ClientID: viper.GetString("oidc.client_id"),
ClientSecret: oidcClientSecret,
Scope: viper.GetStringSlice("oidc.scope"),
ExtraParams: viper.GetStringMapString("oidc.extra_params"),
AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"),
AllowedUsers: viper.GetStringSlice("oidc.allowed_users"),
AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"),
EmailVerifiedRequired: viper.GetBool("oidc.email_verified_required"),
Expiry: func() time.Duration {
// if set to 0, we assume no expiry
if value := viper.GetString("oidc.expiry"); value == "0" {
@ -1129,6 +1138,9 @@ func LoadServerConfig() (*Config, error) {
LogTail: logTailConfig,
RandomizeClientPort: randomizeClientPort,
Taildrop: TaildropConfig{
Enabled: viper.GetBool("taildrop.enabled"),
},
Policy: policyConfig(),

View file

@ -28,10 +28,15 @@ var (
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node")
ErrInvalidNodeView = errors.New("cannot convert invalid NodeView to tailcfg.Node")
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
// RouteFunc is a function that takes a node ID and returns a list of
// netip.Prefixes representing the primary routes for that node.
type RouteFunc func(id NodeID) []netip.Prefix
type (
NodeID uint64
NodeIDs []NodeID
@ -714,80 +719,88 @@ func (node Node) DebugString() string {
return sb.String()
}
func (v NodeView) UserView() UserView {
return v.User()
func (nv NodeView) UserView() UserView {
return nv.User()
}
func (v NodeView) IPs() []netip.Addr {
if !v.Valid() {
func (nv NodeView) IPs() []netip.Addr {
if !nv.Valid() {
return nil
}
return v.ж.IPs()
return nv.ж.IPs()
}
func (v NodeView) InIPSet(set *netipx.IPSet) bool {
if !v.Valid() {
return false
}
return v.ж.InIPSet(set)
}
func (v NodeView) CanAccess(matchers []matcher.Match, node2 NodeView) bool {
if !v.Valid() {
func (nv NodeView) InIPSet(set *netipx.IPSet) bool {
if !nv.Valid() {
return false
}
return v.ж.CanAccess(matchers, node2.AsStruct())
return nv.ж.InIPSet(set)
}
func (v NodeView) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool {
if !v.Valid() {
func (nv NodeView) CanAccess(matchers []matcher.Match, node2 NodeView) bool {
if !nv.Valid() {
return false
}
return v.ж.CanAccessRoute(matchers, route)
return nv.ж.CanAccess(matchers, node2.AsStruct())
}
func (v NodeView) AnnouncedRoutes() []netip.Prefix {
if !v.Valid() {
return nil
}
return v.ж.AnnouncedRoutes()
}
func (v NodeView) SubnetRoutes() []netip.Prefix {
if !v.Valid() {
return nil
}
return v.ж.SubnetRoutes()
}
func (v NodeView) IsSubnetRouter() bool {
if !v.Valid() {
func (nv NodeView) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool {
if !nv.Valid() {
return false
}
return v.ж.IsSubnetRouter()
return nv.ж.CanAccessRoute(matchers, route)
}
func (v NodeView) AllApprovedRoutes() []netip.Prefix {
if !v.Valid() {
func (nv NodeView) AnnouncedRoutes() []netip.Prefix {
if !nv.Valid() {
return nil
}
return v.ж.AllApprovedRoutes()
return nv.ж.AnnouncedRoutes()
}
func (v NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
if !v.Valid() {
func (nv NodeView) SubnetRoutes() []netip.Prefix {
if !nv.Valid() {
return nil
}
return nv.ж.SubnetRoutes()
}
func (nv NodeView) IsSubnetRouter() bool {
if !nv.Valid() {
return false
}
return nv.ж.IsSubnetRouter()
}
func (nv NodeView) AllApprovedRoutes() []netip.Prefix {
if !nv.Valid() {
return nil
}
return nv.ж.AllApprovedRoutes()
}
func (nv NodeView) AppendToIPSet(build *netipx.IPSetBuilder) {
if !nv.Valid() {
return
}
v.ж.AppendToIPSet(build)
nv.ж.AppendToIPSet(build)
}
func (v NodeView) RequestTagsSlice() views.Slice[string] {
if !v.Valid() || !v.Hostinfo().Valid() {
func (nv NodeView) RequestTagsSlice() views.Slice[string] {
if !nv.Valid() || !nv.Hostinfo().Valid() {
return views.Slice[string]{}
}
return v.Hostinfo().RequestTags()
return nv.Hostinfo().RequestTags()
}
// IsTagged reports if a device is tagged
@ -795,154 +808,293 @@ func (v NodeView) RequestTagsSlice() views.Slice[string] {
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys).
func (v NodeView) IsTagged() bool {
if !v.Valid() {
func (nv NodeView) IsTagged() bool {
if !nv.Valid() {
return false
}
return v.ж.IsTagged()
return nv.ж.IsTagged()
}
// IsExpired returns whether the node registration has expired.
func (v NodeView) IsExpired() bool {
if !v.Valid() {
func (nv NodeView) IsExpired() bool {
if !nv.Valid() {
return true
}
return v.ж.IsExpired()
return nv.ж.IsExpired()
}
// IsEphemeral returns if the node is registered as an Ephemeral node.
// https://tailscale.com/kb/1111/ephemeral-nodes/
func (v NodeView) IsEphemeral() bool {
if !v.Valid() {
func (nv NodeView) IsEphemeral() bool {
if !nv.Valid() {
return false
}
return v.ж.IsEphemeral()
return nv.ж.IsEphemeral()
}
// PeerChangeFromMapRequest takes a MapRequest and compares it to the node
// to produce a PeerChange struct that can be used to updated the node and
// inform peers about smaller changes to the node.
func (v NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
if !v.Valid() {
func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange {
if !nv.Valid() {
return tailcfg.PeerChange{}
}
return v.ж.PeerChangeFromMapRequest(req)
return nv.ж.PeerChangeFromMapRequest(req)
}
// GetFQDN returns the fully qualified domain name for the node.
func (v NodeView) GetFQDN(baseDomain string) (string, error) {
if !v.Valid() {
func (nv NodeView) GetFQDN(baseDomain string) (string, error) {
if !nv.Valid() {
return "", errors.New("failed to create valid FQDN: node view is invalid")
}
return v.ж.GetFQDN(baseDomain)
return nv.ж.GetFQDN(baseDomain)
}
// ExitRoutes returns a list of both exit routes if the
// node has any exit routes enabled.
// If none are enabled, it will return nil.
func (v NodeView) ExitRoutes() []netip.Prefix {
if !v.Valid() {
func (nv NodeView) ExitRoutes() []netip.Prefix {
if !nv.Valid() {
return nil
}
return v.ж.ExitRoutes()
return nv.ж.ExitRoutes()
}
func (v NodeView) IsExitNode() bool {
if !v.Valid() {
func (nv NodeView) IsExitNode() bool {
if !nv.Valid() {
return false
}
return v.ж.IsExitNode()
return nv.ж.IsExitNode()
}
// RequestTags returns the ACL tags that the node is requesting.
func (v NodeView) RequestTags() []string {
if !v.Valid() || !v.Hostinfo().Valid() {
func (nv NodeView) RequestTags() []string {
if !nv.Valid() || !nv.Hostinfo().Valid() {
return []string{}
}
return v.Hostinfo().RequestTags().AsSlice()
return nv.Hostinfo().RequestTags().AsSlice()
}
// Proto converts the NodeView to a protobuf representation.
func (v NodeView) Proto() *v1.Node {
if !v.Valid() {
func (nv NodeView) Proto() *v1.Node {
if !nv.Valid() {
return nil
}
return v.ж.Proto()
return nv.ж.Proto()
}
// HasIP reports if a node has a given IP address.
func (v NodeView) HasIP(i netip.Addr) bool {
if !v.Valid() {
func (nv NodeView) HasIP(i netip.Addr) bool {
if !nv.Valid() {
return false
}
return v.ж.HasIP(i)
return nv.ж.HasIP(i)
}
// HasTag reports if a node has a given tag.
func (v NodeView) HasTag(tag string) bool {
if !v.Valid() {
func (nv NodeView) HasTag(tag string) bool {
if !nv.Valid() {
return false
}
return v.ж.HasTag(tag)
return nv.ж.HasTag(tag)
}
// TypedUserID returns the UserID as a typed UserID type.
// Returns 0 if UserID is nil or node is invalid.
func (v NodeView) TypedUserID() UserID {
if !v.Valid() {
func (nv NodeView) TypedUserID() UserID {
if !nv.Valid() {
return 0
}
return v.ж.TypedUserID()
return nv.ж.TypedUserID()
}
// TailscaleUserID returns the user ID to use in Tailscale protocol.
// Tagged nodes always return TaggedDevices.ID, user-owned nodes return their actual UserID.
func (v NodeView) TailscaleUserID() tailcfg.UserID {
if !v.Valid() {
func (nv NodeView) TailscaleUserID() tailcfg.UserID {
if !nv.Valid() {
return 0
}
if v.IsTagged() {
if nv.IsTagged() {
//nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64
return tailcfg.UserID(int64(TaggedDevices.ID))
}
//nolint:gosec // G115: UserID values are within int64 range
return tailcfg.UserID(int64(v.UserID().Get()))
return tailcfg.UserID(int64(nv.UserID().Get()))
}
// Prefixes returns the node IPs as netip.Prefix.
func (v NodeView) Prefixes() []netip.Prefix {
if !v.Valid() {
func (nv NodeView) Prefixes() []netip.Prefix {
if !nv.Valid() {
return nil
}
return v.ж.Prefixes()
return nv.ж.Prefixes()
}
// IPsAsString returns the node IPs as strings.
func (v NodeView) IPsAsString() []string {
if !v.Valid() {
func (nv NodeView) IPsAsString() []string {
if !nv.Valid() {
return nil
}
return v.ж.IPsAsString()
return nv.ж.IPsAsString()
}
// HasNetworkChanges checks if the node has network-related changes.
// Returns true if IPs, announced routes, or approved routes changed.
// This is primarily used for policy cache invalidation.
func (v NodeView) HasNetworkChanges(other NodeView) bool {
if !slices.Equal(v.IPs(), other.IPs()) {
func (nv NodeView) HasNetworkChanges(other NodeView) bool {
if !slices.Equal(nv.IPs(), other.IPs()) {
return true
}
if !slices.Equal(v.AnnouncedRoutes(), other.AnnouncedRoutes()) {
if !slices.Equal(nv.AnnouncedRoutes(), other.AnnouncedRoutes()) {
return true
}
if !slices.Equal(v.SubnetRoutes(), other.SubnetRoutes()) {
if !slices.Equal(nv.SubnetRoutes(), other.SubnetRoutes()) {
return true
}
return false
}
// HasPolicyChange reports whether the node has changes that affect policy evaluation.
func (nv NodeView) HasPolicyChange(other NodeView) bool {
if nv.UserID() != other.UserID() {
return true
}
if !views.SliceEqual(nv.Tags(), other.Tags()) {
return true
}
if !slices.Equal(nv.IPs(), other.IPs()) {
return true
}
return false
}
// TailNodes converts a slice of NodeViews into Tailscale tailcfg.Nodes.
func TailNodes(
nodes views.Slice[NodeView],
capVer tailcfg.CapabilityVersion,
primaryRouteFunc RouteFunc,
cfg *Config,
) ([]*tailcfg.Node, error) {
tNodes := make([]*tailcfg.Node, 0, nodes.Len())
for _, node := range nodes.All() {
tNode, err := node.TailNode(capVer, primaryRouteFunc, cfg)
if err != nil {
return nil, err
}
tNodes = append(tNodes, tNode)
}
return tNodes, nil
}
// TailNode converts a NodeView into a Tailscale tailcfg.Node.
func (nv NodeView) TailNode(
capVer tailcfg.CapabilityVersion,
primaryRouteFunc RouteFunc,
cfg *Config,
) (*tailcfg.Node, error) {
if !nv.Valid() {
return nil, ErrInvalidNodeView
}
hostname, err := nv.GetFQDN(cfg.BaseDomain)
if err != nil {
return nil, err
}
var derp int
// TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077
// and should be removed after 111 is the minimum capver.
legacyDERP := "127.3.3.40:0" // Zero means disconnected or unknown.
if nv.Hostinfo().Valid() && nv.Hostinfo().NetInfo().Valid() {
legacyDERP = fmt.Sprintf("127.3.3.40:%d", nv.Hostinfo().NetInfo().PreferredDERP())
derp = nv.Hostinfo().NetInfo().PreferredDERP()
}
var keyExpiry time.Time
if nv.Expiry().Valid() {
keyExpiry = nv.Expiry().Get()
}
primaryRoutes := primaryRouteFunc(nv.ID())
allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes())
tsaddr.SortPrefixes(allowedIPs)
capMap := tailcfg.NodeCapMap{
tailcfg.CapabilityAdmin: []tailcfg.RawMessage{},
tailcfg.CapabilitySSH: []tailcfg.RawMessage{},
}
if cfg.RandomizeClientPort {
capMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
}
if cfg.Taildrop.Enabled {
capMap[tailcfg.CapabilityFileSharing] = []tailcfg.RawMessage{}
}
tNode := tailcfg.Node{
//nolint:gosec // G115: NodeID values are within int64 range
ID: tailcfg.NodeID(nv.ID()),
StableID: nv.ID().StableID(),
Name: hostname,
Cap: capVer,
CapMap: capMap,
User: nv.TailscaleUserID(),
Key: nv.NodeKey(),
KeyExpiry: keyExpiry.UTC(),
Machine: nv.MachineKey(),
DiscoKey: nv.DiscoKey(),
Addresses: nv.Prefixes(),
PrimaryRoutes: primaryRoutes,
AllowedIPs: allowedIPs,
Endpoints: nv.Endpoints().AsSlice(),
HomeDERP: derp,
LegacyDERPString: legacyDERP,
Hostinfo: nv.Hostinfo(),
Created: nv.CreatedAt().UTC(),
Online: nv.IsOnline().Clone(),
Tags: nv.Tags().AsSlice(),
MachineAuthorized: !nv.IsExpired(),
Expired: nv.IsExpired(),
}
// Set LastSeen only for offline nodes to avoid confusing Tailscale clients
// during rapid reconnection cycles. Online nodes should not have LastSeen set
// as this can make clients interpret them as "not online" despite Online=true.
if nv.LastSeen().Valid() && nv.IsOnline().Valid() && !nv.IsOnline().Get() {
lastSeen := nv.LastSeen().Get()
tNode.LastSeen = &lastSeen
}
return &tNode, nil
}

View file

@ -353,7 +353,7 @@ type OIDCUserInfo struct {
// FromClaim overrides a User from OIDC claims.
// All fields will be updated, except for the ID.
func (u *User) FromClaim(claims *OIDCClaims) {
func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) {
err := util.ValidateUsername(claims.Username)
if err == nil {
u.Name = claims.Username
@ -361,7 +361,7 @@ func (u *User) FromClaim(claims *OIDCClaims) {
log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username)
}
if claims.EmailVerified {
if claims.EmailVerified || !FlexibleBoolean(emailVerifiedRequired) {
_, err = mail.ParseAddress(claims.Email)
if err == nil {
u.Email = claims.Email

View file

@ -291,12 +291,14 @@ func TestCleanIdentifier(t *testing.T) {
func TestOIDCClaimsJSONToUser(t *testing.T) {
tests := []struct {
name string
jsonstr string
want User
name string
jsonstr string
emailVerifiedRequired bool
want User
}{
{
name: "normal-bool",
name: "normal-bool",
emailVerifiedRequired: true,
jsonstr: `
{
"sub": "test",
@ -314,7 +316,8 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
},
},
{
name: "string-bool-true",
name: "string-bool-true",
emailVerifiedRequired: true,
jsonstr: `
{
"sub": "test2",
@ -332,7 +335,8 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
},
},
{
name: "string-bool-false",
name: "string-bool-false",
emailVerifiedRequired: true,
jsonstr: `
{
"sub": "test3",
@ -348,9 +352,29 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
},
},
},
{
name: "allow-unverified-email",
emailVerifiedRequired: false,
jsonstr: `
{
"sub": "test4",
"email": "test4@test.no",
"email_verified": "false"
}
`,
want: User{
Provider: util.RegisterMethodOIDC,
Email: "test4@test.no",
ProviderIdentifier: sql.NullString{
String: "/test4",
Valid: true,
},
},
},
{
// From https://github.com/juanfont/headscale/issues/2333
name: "okta-oidc-claim-20250121",
name: "okta-oidc-claim-20250121",
emailVerifiedRequired: true,
jsonstr: `
{
"sub": "00u7dr4qp7XXXXXXXXXX",
@ -375,6 +399,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
want: User{
Provider: util.RegisterMethodOIDC,
DisplayName: "Tim Horton",
Email: "",
Name: "tim.horton@company.com",
ProviderIdentifier: sql.NullString{
String: "https://sso.company.com/oauth2/default/00u7dr4qp7XXXXXXXXXX",
@ -384,7 +409,8 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
},
{
// From https://github.com/juanfont/headscale/issues/2333
name: "okta-oidc-claim-20250121",
name: "okta-oidc-claim-20250121",
emailVerifiedRequired: true,
jsonstr: `
{
"aud": "79xxxxxx-xxxx-xxxx-xxxx-892146xxxxxx",
@ -409,6 +435,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
Provider: util.RegisterMethodOIDC,
DisplayName: "XXXXXX XXXX",
Name: "user@domain.com",
Email: "",
ProviderIdentifier: sql.NullString{
String: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ",
Valid: true,
@ -417,7 +444,8 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
},
{
// From https://github.com/juanfont/headscale/issues/2333
name: "casby-oidc-claim-20250513",
name: "casby-oidc-claim-20250513",
emailVerifiedRequired: true,
jsonstr: `
{
"sub": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx",
@ -458,7 +486,7 @@ func TestOIDCClaimsJSONToUser(t *testing.T) {
var user User
user.FromClaim(&got)
user.FromClaim(&got, tt.emailVerifiedRequired)
if diff := cmp.Diff(user, tt.want); diff != "" {
t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff)
}

View file

@ -22,10 +22,7 @@ const (
LabelHostnameLength = 63
)
var (
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
var ErrInvalidHostName = errors.New("invalid hostname")

File diff suppressed because it is too large Load diff

View file

@ -126,6 +126,7 @@ func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) {
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
// https://github.com/juanfont/headscale/issues/2164
if !https {
//nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS
time.Sleep(5 * time.Minute)
}
@ -427,6 +428,7 @@ func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) {
// https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38
// https://github.com/juanfont/headscale/issues/2164
if !https {
//nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS
time.Sleep(5 * time.Minute)
}
@ -538,7 +540,12 @@ func TestAuthKeyDeleteKey(t *testing.T) {
err = client.Down()
require.NoError(t, err)
time.Sleep(3 * time.Second)
// Wait for client to fully stop before bringing it back up
assert.EventuallyWithT(t, func(c *assert.CollectT) {
status, err := client.Status()
assert.NoError(c, err)
assert.Equal(c, "Stopped", status.BackendState)
}, 10*time.Second, 200*time.Millisecond, "client should be stopped")
err = client.Up()
require.NoError(t, err)

View file

@ -901,7 +901,8 @@ func TestOIDCFollowUpUrl(t *testing.T) {
require.NoError(t, err)
// wait for the registration cache to expire
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION
// a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION (1m30s)
//nolint:forbidigo // Intentional delay: must wait for real-time cache expiration (HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION=1m30s)
time.Sleep(2 * time.Minute)
var newUrl *url.URL

View file

@ -833,602 +833,6 @@ func TestApiKeyCommand(t *testing.T) {
assert.Len(t, listedAPIKeysAfterDelete, 4)
}
func TestNodeTagCommand(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins"))
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Test 1: Verify that tags require authorization via ACL policy
// The tags-as-identity model allows conversion from user-owned to tagged, but only
// if the tag is authorized via tagOwners in the ACL policy.
regID := types.MustRegistrationID().String()
_, err = headscale.Execute(
[]string{
"headscale",
"debug",
"create-node",
"--name",
"user-owned-node",
"--user",
"user1",
"--key",
regID,
"--output",
"json",
},
)
assert.NoError(t, err)
var userOwnedNode v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"--user",
"user1",
"register",
"--key",
regID,
"--output",
"json",
},
&userOwnedNode,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Waiting for user-owned node registration")
// Verify node is user-owned (no tags)
assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should not have tags")
assert.Empty(t, userOwnedNode.GetForcedTags(), "User-owned node should not have forced tags")
// Attempt to set tags on user-owned node should FAIL because there's no ACL policy
// authorizing the tag. The tags-as-identity model allows conversion from user-owned
// to tagged, but only if the tag is authorized via tagOwners in the ACL policy.
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"tag",
"-i", strconv.FormatUint(userOwnedNode.GetId(), 10),
"-t", "tag:test",
"--output", "json",
},
)
require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail")
// Test 2: Verify tag format validation
// Create a PreAuthKey with tags to create a tagged node
// Get the user ID from the node
userID := userOwnedNode.GetUser().GetId()
var preAuthKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(userID, 10),
"create",
"--reusable",
"--tags", "tag:integration-test",
"--output", "json",
},
&preAuthKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating PreAuthKey with tags")
// Verify PreAuthKey has tags
assert.Contains(t, preAuthKey.GetAclTags(), "tag:integration-test", "PreAuthKey should have tags")
// Test 3: Verify invalid tag format is rejected
_, err = headscale.Execute(
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(userID, 10),
"create",
"--tags", "wrong-tag", // Missing "tag:" prefix
"--output", "json",
},
)
assert.ErrorContains(t, err, "tag must start with the string 'tag:'", "Invalid tag format should be rejected")
}
func TestTaggedNodeRegistration(t *testing.T) {
IntegrationSkip(t)
// ACL policy that authorizes the tags used in tagged PreAuthKeys
// user1 and user2 can assign these tags when creating PreAuthKeys
policy := &policyv2.Policy{
TagOwners: policyv2.TagOwners{
"tag:server": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
"tag:prod": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
"tag:forbidden": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{policyv2.Wildcard},
Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}},
},
},
}
spec := ScenarioSpec{
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithACLPolicy(policy),
hsic.WithTestName("tagged-reg"),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Get users (they were already created by ScenarioSpec)
users, err := headscale.ListUsers()
require.NoError(t, err)
require.Len(t, users, 2, "Should have 2 users")
var user1, user2 *v1.User
for _, u := range users {
if u.GetName() == "user1" {
user1 = u
} else if u.GetName() == "user2" {
user2 = u
}
}
require.NotNil(t, user1, "Should find user1")
require.NotNil(t, user2, "Should find user2")
// Test 1: Create a PreAuthKey with tags
var taggedKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(user1.GetId(), 10),
"create",
"--reusable",
"--tags", "tag:server,tag:prod",
"--output", "json",
},
&taggedKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating tagged PreAuthKey")
// Verify PreAuthKey has both tags
assert.Contains(t, taggedKey.GetAclTags(), "tag:server", "PreAuthKey should have tag:server")
assert.Contains(t, taggedKey.GetAclTags(), "tag:prod", "PreAuthKey should have tag:prod")
assert.Len(t, taggedKey.GetAclTags(), 2, "PreAuthKey should have exactly 2 tags")
// Test 2: Register a node using the tagged PreAuthKey
err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0]))
require.NoError(t, err)
err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
// Wait for the node to be registered
var registeredNode *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node")
// Find the tagged node - it will have user "tagged-devices" per tags-as-identity model
for _, node := range nodes {
if node.GetUser().GetName() == "tagged-devices" && len(node.GetValidTags()) > 0 {
registeredNode = node
break
}
}
assert.NotNil(c, registeredNode, "Should find a tagged node")
}, 30*time.Second, 500*time.Millisecond, "Waiting for tagged node registration")
// Test 3: Verify the registered node has the tags from the PreAuthKey
assert.Contains(t, registeredNode.GetValidTags(), "tag:server", "Node should have tag:server")
assert.Contains(t, registeredNode.GetValidTags(), "tag:prod", "Node should have tag:prod")
assert.Len(t, registeredNode.GetValidTags(), 2, "Node should have exactly 2 tags")
// Test 4: Verify the node shows as TaggedDevices user (tags-as-identity model)
// Tagged nodes always show as "tagged-devices" in API responses, even though
// internally UserID may be set for "created by" tracking
assert.Equal(t, "tagged-devices", registeredNode.GetUser().GetName(), "Tagged node should show as tagged-devices user")
// Test 5: Verify the node is identified as tagged
assert.NotEmpty(t, registeredNode.GetValidTags(), "Tagged node should have tags")
// Test 6: Verify tag modification on tagged nodes
// NOTE: Changing tags requires complex ACL authorization where the node's IP
// must be authorized for the new tags via tagOwners. For simplicity, we skip
// this test and instead verify that tags cannot be arbitrarily changed without
// proper ACL authorization.
//
// This is expected behavior - tag changes must be authorized by ACL policy.
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"tag",
"-i", strconv.FormatUint(registeredNode.GetId(), 10),
"-t", "tag:unauthorized",
"--output", "json",
},
)
// This SHOULD fail because tag:unauthorized is not in our ACL policy
require.ErrorContains(t, err, "invalid or unauthorized tags", "Unauthorized tag should be rejected")
// Test 7: Create a user-owned node for comparison
var userOwnedKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(user2.GetId(), 10),
"create",
"--reusable",
"--output", "json",
},
&userOwnedKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating user-owned PreAuthKey")
// Verify this PreAuthKey has NO tags
assert.Empty(t, userOwnedKey.GetAclTags(), "User-owned PreAuthKey should have no tags")
err = scenario.CreateTailscaleNodesInUser("user2", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0]))
require.NoError(t, err)
err = scenario.RunTailscaleUp("user2", headscale.GetEndpoint(), userOwnedKey.GetKey())
require.NoError(t, err)
// Wait for the user-owned node to be registered
var userOwnedNode *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.GreaterOrEqual(c, len(nodes), 2, "Should have at least 2 nodes")
// Find the node registered with user2
for _, node := range nodes {
if node.GetUser().GetName() == "user2" {
userOwnedNode = node
break
}
}
assert.NotNil(c, userOwnedNode, "Should find a node for user2")
}, 30*time.Second, 500*time.Millisecond, "Waiting for user-owned node registration")
// Test 8: Verify user-owned node has NO tags
assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should have no tags")
assert.NotZero(t, userOwnedNode.GetUser().GetId(), "User-owned node should have UserID")
// Test 9: Verify attempting to set UNAUTHORIZED tags on user-owned node fails
// Note: Under tags-as-identity model, user-owned nodes CAN be converted to tagged nodes
// if the tags are authorized. We use an unauthorized tag to test rejection.
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"tag",
"-i", strconv.FormatUint(userOwnedNode.GetId(), 10),
"-t", "tag:not-in-policy",
"--output", "json",
},
)
require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail")
// Test 10: Verify basic connectivity - wait for sync
err = scenario.WaitForTailscaleSync()
require.NoError(t, err, "Clients should be able to sync")
}
// TestTagPersistenceAcrossRestart validates that tags persist across container
// restarts and that re-authentication doesn't re-apply tags from PreAuthKey.
// This is a regression test for issue #2830.
func TestTagPersistenceAcrossRestart(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("tag-persist"))
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Get user
users, err := headscale.ListUsers()
require.NoError(t, err)
require.Len(t, users, 1)
user1 := users[0]
// Create a reusable PreAuthKey with tags
var taggedKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(user1.GetId(), 10),
"create",
"--reusable", // Critical: key must be reusable for container restart
"--tags", "tag:server,tag:prod",
"--output", "json",
},
&taggedKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating reusable tagged PreAuthKey")
require.True(t, taggedKey.GetReusable(), "PreAuthKey must be reusable for restart scenario")
require.Contains(t, taggedKey.GetAclTags(), "tag:server")
require.Contains(t, taggedKey.GetAclTags(), "tag:prod")
// Register initial node with tagged PreAuthKey
err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0]))
require.NoError(t, err)
err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
// Wait for node registration and get initial node state
var initialNode *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node")
for _, node := range nodes {
if node.GetUser().GetId() == user1.GetId() || node.GetUser().GetName() == "tagged-devices" {
initialNode = node
break
}
}
assert.NotNil(c, initialNode, "Should find the registered node")
}, 30*time.Second, 500*time.Millisecond, "Waiting for initial node registration")
// Verify initial tags
require.Contains(t, initialNode.GetValidTags(), "tag:server", "Initial node should have tag:server")
require.Contains(t, initialNode.GetValidTags(), "tag:prod", "Initial node should have tag:prod")
require.Len(t, initialNode.GetValidTags(), 2, "Initial node should have exactly 2 tags")
initialNodeID := initialNode.GetId()
t.Logf("Initial node registered with ID %d and tags %v", initialNodeID, initialNode.GetValidTags())
// Simulate container restart by shutting down and restarting Tailscale client
allClients, err := scenario.ListTailscaleClients()
require.NoError(t, err)
require.Len(t, allClients, 1, "Should have exactly 1 client")
client := allClients[0]
// Stop the client (simulates container stop)
err = client.Down()
require.NoError(t, err)
// Wait a bit to ensure the client is fully stopped
time.Sleep(2 * time.Second)
// Restart the client with the SAME PreAuthKey (container restart scenario)
// This simulates what happens when a Docker container restarts with a reusable PreAuthKey
err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
// Wait for re-authentication
var nodeAfterRestart *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
for _, node := range nodes {
if node.GetId() == initialNodeID {
nodeAfterRestart = node
break
}
}
assert.NotNil(c, nodeAfterRestart, "Should find the same node after restart")
}, 30*time.Second, 500*time.Millisecond, "Waiting for node re-authentication")
// CRITICAL ASSERTION: Tags should NOT be re-applied from PreAuthKey
// Tags are only applied during INITIAL authentication, not re-authentication
// The node should keep its existing tags (which happen to be the same in this case)
assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:server", "Node should still have tag:server after restart")
assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:prod", "Node should still have tag:prod after restart")
assert.Len(t, nodeAfterRestart.GetValidTags(), 2, "Node should still have exactly 2 tags after restart")
// Verify it's the SAME node (same ID), not a new registration
assert.Equal(t, initialNodeID, nodeAfterRestart.GetId(), "Should be the same node, not a new registration")
// Verify node count hasn't increased (no duplicate nodes)
finalNodes, err := headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, finalNodes, 1, "Should still have exactly 1 node (no duplicates from restart)")
t.Logf("Container restart validation complete - node %d maintained tags across restart", initialNodeID)
}
func TestNodeAdvertiseTagCommand(t *testing.T) {
IntegrationSkip(t)
tests := []struct {
name string
policy *policyv2.Policy
wantTag bool
}{
{
name: "no-policy",
wantTag: false,
},
{
name: "with-policy-email",
policy: &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
TagOwners: policyv2.TagOwners{
policyv2.Tag("tag:test"): policyv2.Owners{usernameOwner("user1@test.no")},
},
},
wantTag: true,
},
{
name: "with-policy-username",
policy: &policyv2.Policy{
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
TagOwners: policyv2.TagOwners{
policyv2.Tag("tag:test"): policyv2.Owners{usernameOwner("user1@")},
},
},
wantTag: true,
},
{
name: "with-policy-groups",
policy: &policyv2.Policy{
Groups: policyv2.Groups{
policyv2.Group("group:admins"): []policyv2.Username{policyv2.Username("user1@")},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Protocol: "tcp",
Sources: []policyv2.Alias{wildcard()},
Destinations: []policyv2.AliasWithPorts{
aliasWithPorts(wildcard(), tailcfg.PortRangeAny),
},
},
},
TagOwners: policyv2.TagOwners{
policyv2.Tag("tag:test"): policyv2.Owners{groupOwner("group:admins")},
},
},
wantTag: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
spec := ScenarioSpec{
NodesPerUser: 1,
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{tsic.WithTags([]string{"tag:test"})},
hsic.WithTestName("cliadvtags"),
hsic.WithACLPolicy(tt.policy),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Test list all nodes after added seconds
var resultMachines []*v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
resultMachines = make([]*v1.Node, spec.NodesPerUser)
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--tags",
"--output", "json",
},
&resultMachines,
)
assert.NoError(c, err)
found := false
for _, node := range resultMachines {
if tags := node.GetValidTags(); tags != nil {
found = slices.Contains(tags, "tag:test")
}
}
assert.Equalf(
c,
tt.wantTag,
found,
"'tag:test' found(%t) is the list of nodes, expected %t", found, tt.wantTag,
)
}, 10*time.Second, 200*time.Millisecond, "Waiting for tag propagation to nodes")
})
}
}
func TestNodeCommand(t *testing.T) {
IntegrationSkip(t)

View file

@ -24,6 +24,7 @@ type ControlServer interface {
WaitForRunning() error
CreateUser(user string) (*v1.User, error)
CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error)
CreateAuthKeyWithTags(user uint64, reusable bool, ephemeral bool, tags []string) (*v1.PreAuthKey, error)
DeleteAuthKey(user uint64, key string) error
ListNodes(users ...string) ([]*v1.Node, error)
DeleteNode(nodeID uint64) error
@ -32,6 +33,7 @@ type ControlServer interface {
ListUsers() ([]*v1.User, error)
MapUsers() (map[string]*v1.User, error)
ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error)
SetNodeTags(nodeID uint64, tags []string) error
GetCert() []byte
GetHostname() string
GetIPInNetwork(network *dockertest.Network) string

View file

@ -23,6 +23,7 @@ func TestResolveMagicDNS(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -79,6 +80,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
@ -94,11 +96,7 @@ func TestResolveMagicDNSExtraRecordsPath(t *testing.T) {
b, _ := json.Marshal(extraRecords)
err = scenario.CreateHeadscaleEnv([]tsic.Option{
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add python3 curl bind-tools ; update-ca-certificates ; tailscaled --tun=tsdev",
}),
tsic.WithPackages("python3", "curl", "bind-tools"),
},
hsic.WithTestName("extrarecords"),
hsic.WithConfigEnv(map[string]string{

View file

@ -103,6 +103,38 @@ func WithExtraHosts(hosts []string) Option {
}
}
// buildEntrypoint builds the container entrypoint command based on configuration.
// It constructs proper wait conditions instead of fixed sleeps:
// 1. Wait for network to be ready
// 2. Wait for TLS cert to be written (always written after container start)
// 3. Wait for CA certs if configured
// 4. Update CA certificates
// 5. Run derper with provided arguments.
func (dsic *DERPServerInContainer) buildEntrypoint(derperArgs string) []string {
var commands []string
// Wait for network to be ready
commands = append(commands, "while ! ip route show default >/dev/null 2>&1; do sleep 0.1; done")
// Wait for TLS cert to be written (always written after container start)
commands = append(commands,
fmt.Sprintf("while [ ! -f %s/%s.crt ]; do sleep 0.1; done", DERPerCertRoot, dsic.hostname))
// If CA certs are configured, wait for them to be written
if len(dsic.caCerts) > 0 {
commands = append(commands,
fmt.Sprintf("while [ ! -f %s/user-0.crt ]; do sleep 0.1; done", caCertRoot))
}
// Update CA certificates
commands = append(commands, "update-ca-certificates")
// Run derper
commands = append(commands, "derper "+derperArgs)
return []string{"/bin/sh", "-c", strings.Join(commands, " ; ")}
}
// New returns a new TailscaleInContainer instance.
func New(
pool *dockertest.Pool,
@ -150,8 +182,7 @@ func New(
Name: hostname,
Networks: dsic.networks,
ExtraHosts: dsic.withExtraHosts,
// we currently need to give us some time to inject the certificate further down.
Entrypoint: []string{"/bin/sh", "-c", "/bin/sleep 3 ; update-ca-certificates ; derper " + cmdArgs.String()},
Entrypoint: dsic.buildEntrypoint(cmdArgs.String()),
ExposedPorts: []string{
"80/tcp",
fmt.Sprintf("%d/tcp", dsic.derpPort),

View file

@ -178,7 +178,8 @@ func derpServerScenario(
t.Logf("Run 1: %d successful pings out of %d", success, len(allClients)*len(allHostnames))
// Let the DERP updater run a couple of times to ensure it does not
// break the DERPMap.
// break the DERPMap. The updater runs on a 10s interval by default.
//nolint:forbidigo // Intentional delay: must wait for DERP updater to run multiple times (interval-based)
time.Sleep(30 * time.Second)
success = pingDerpAllHelper(t, allClients, allHostnames)

View file

@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/juanfont/headscale/integration/hsic"
"github.com/juanfont/headscale/integration/integrationutil"
"github.com/juanfont/headscale/integration/tsic"
"github.com/rs/zerolog/log"
"github.com/samber/lo"
@ -366,12 +367,18 @@ func TestPingAllByHostname(t *testing.T) {
// This might mean we approach setup slightly wrong, but for now, ignore
// the linter
// nolint:tparallel
// TestTaildrop tests the Taildrop file sharing functionality across multiple scenarios:
// 1. Same-user transfers: Nodes owned by the same user can send files to each other
// 2. Cross-user transfers: Nodes owned by different users cannot send files to each other
// 3. Tagged device transfers: Tagged devices cannot send nor receive files
//
// Each user gets len(MustTestVersions) nodes to ensure compatibility across all supported versions.
func TestTaildrop(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
NodesPerUser: len(MustTestVersions),
Users: []string{"user1"},
NodesPerUser: 0, // We'll create nodes manually to control tags
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
@ -385,16 +392,99 @@ func TestTaildrop(t *testing.T) {
)
requireNoErrHeadscaleEnv(t, err)
headscale, err := scenario.Headscale()
requireNoErrGetHeadscale(t, err)
userMap, err := headscale.MapUsers()
require.NoError(t, err)
networks := scenario.Networks()
require.NotEmpty(t, networks, "scenario should have at least one network")
network := networks[0]
// Create untagged nodes for user1 using all test versions
user1Key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false)
require.NoError(t, err)
var user1Clients []TailscaleClient
for i, version := range MustTestVersions {
t.Logf("Creating user1 client %d with version %s", i, version)
client, err := scenario.CreateTailscaleNode(
version,
tsic.WithNetwork(network),
)
require.NoError(t, err)
err = client.Login(headscale.GetEndpoint(), user1Key.GetKey())
require.NoError(t, err)
err = client.WaitForRunning(integrationutil.PeerSyncTimeout())
require.NoError(t, err)
user1Clients = append(user1Clients, client)
scenario.GetOrCreateUser("user1").Clients[client.Hostname()] = client
}
// Create untagged nodes for user2 using all test versions
user2Key, err := scenario.CreatePreAuthKey(userMap["user2"].GetId(), true, false)
require.NoError(t, err)
var user2Clients []TailscaleClient
for i, version := range MustTestVersions {
t.Logf("Creating user2 client %d with version %s", i, version)
client, err := scenario.CreateTailscaleNode(
version,
tsic.WithNetwork(network),
)
require.NoError(t, err)
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
require.NoError(t, err)
err = client.WaitForRunning(integrationutil.PeerSyncTimeout())
require.NoError(t, err)
user2Clients = append(user2Clients, client)
scenario.GetOrCreateUser("user2").Clients[client.Hostname()] = client
}
// Create a tagged device (tags-as-identity: tags come from PreAuthKey)
// Use "head" version to test latest behavior
taggedKey, err := scenario.CreatePreAuthKeyWithTags(userMap["user1"].GetId(), true, false, []string{"tag:server"})
require.NoError(t, err)
taggedClient, err := scenario.CreateTailscaleNode(
"head",
tsic.WithNetwork(network),
)
require.NoError(t, err)
err = taggedClient.Login(headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
err = taggedClient.WaitForRunning(integrationutil.PeerSyncTimeout())
require.NoError(t, err)
// Add tagged client to user1 for tracking (though it's tagged, not user-owned)
scenario.GetOrCreateUser("user1").Clients[taggedClient.Hostname()] = taggedClient
allClients, err := scenario.ListTailscaleClients()
requireNoErrListClients(t, err)
// Expected: len(MustTestVersions) for user1 + len(MustTestVersions) for user2 + 1 tagged
expectedClientCount := len(MustTestVersions)*2 + 1
require.Len(t, allClients, expectedClientCount,
"should have %d clients: %d user1 + %d user2 + 1 tagged",
expectedClientCount, len(MustTestVersions), len(MustTestVersions))
err = scenario.WaitForTailscaleSync()
requireNoErrSync(t, err)
// This will essentially fetch and cache all the FQDNs
// Cache FQDNs
_, err = scenario.ListTailscaleClientsFQDNs()
requireNoErrListFQDN(t, err)
// Install curl on all clients
for _, client := range allClients {
if !strings.Contains(client.Hostname(), "head") {
command := []string{"apk", "add", "curl"}
@ -403,110 +493,269 @@ func TestTaildrop(t *testing.T) {
t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err)
}
}
}
// Helper to get FileTargets for a client.
getFileTargets := func(client TailscaleClient) ([]apitype.FileTarget, error) {
curlCommand := []string{
"curl",
"--unix-socket",
"/var/run/tailscale/tailscaled.sock",
"http://local-tailscaled.sock/localapi/v0/file-targets",
}
result, _, err := client.Execute(curlCommand)
if err != nil {
return nil, err
}
var fts []apitype.FileTarget
if err := json.Unmarshal([]byte(result), &fts); err != nil {
return nil, fmt.Errorf("failed to parse file-targets response: %w (response: %s)", err, result)
}
return fts, nil
}
// Helper to check if a client is in the FileTargets list
isInFileTargets := func(fts []apitype.FileTarget, targetHostname string) bool {
for _, ft := range fts {
if strings.Contains(ft.Node.Name, targetHostname) {
return true
}
}
return false
}
// Test 1: Verify user1 nodes can see each other in FileTargets but not user2 nodes or tagged node
t.Run("FileTargets-user1", func(t *testing.T) {
for _, client := range user1Clients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
fts, err := getFileTargets(client)
assert.NoError(ct, err)
// Should see the other user1 clients
for _, peer := range user1Clients {
if peer.Hostname() == client.Hostname() {
continue
}
assert.True(ct, isInFileTargets(fts, peer.Hostname()),
"user1 client %s should see user1 peer %s in FileTargets", client.Hostname(), peer.Hostname())
}
// Should NOT see user2 clients
for _, peer := range user2Clients {
assert.False(ct, isInFileTargets(fts, peer.Hostname()),
"user1 client %s should NOT see user2 peer %s in FileTargets", client.Hostname(), peer.Hostname())
}
// Should NOT see tagged client
assert.False(ct, isInFileTargets(fts, taggedClient.Hostname()),
"user1 client %s should NOT see tagged client %s in FileTargets", client.Hostname(), taggedClient.Hostname())
}, 10*time.Second, 1*time.Second)
}
})
// Test 2: Verify user2 nodes can see each other in FileTargets but not user1 nodes or tagged node
t.Run("FileTargets-user2", func(t *testing.T) {
for _, client := range user2Clients {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
fts, err := getFileTargets(client)
assert.NoError(ct, err)
// Should see the other user2 clients
for _, peer := range user2Clients {
if peer.Hostname() == client.Hostname() {
continue
}
assert.True(ct, isInFileTargets(fts, peer.Hostname()),
"user2 client %s should see user2 peer %s in FileTargets", client.Hostname(), peer.Hostname())
}
// Should NOT see user1 clients
for _, peer := range user1Clients {
assert.False(ct, isInFileTargets(fts, peer.Hostname()),
"user2 client %s should NOT see user1 peer %s in FileTargets", client.Hostname(), peer.Hostname())
}
// Should NOT see tagged client
assert.False(ct, isInFileTargets(fts, taggedClient.Hostname()),
"user2 client %s should NOT see tagged client %s in FileTargets", client.Hostname(), taggedClient.Hostname())
}, 10*time.Second, 1*time.Second)
}
})
// Test 3: Verify tagged device has no FileTargets (empty list)
t.Run("FileTargets-tagged", func(t *testing.T) {
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
result, _, err := client.Execute(curlCommand)
fts, err := getFileTargets(taggedClient)
assert.NoError(ct, err)
var fts []apitype.FileTarget
err = json.Unmarshal([]byte(result), &fts)
assert.NoError(ct, err)
if len(fts) != len(allClients)-1 {
ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname())
for _, ft := range fts {
ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name)
}
assert.Failf(ct, "client %s does not have all its peers as FileTargets",
"got %d, want: %d\n%s",
len(fts),
len(allClients)-1,
ftStr,
)
}
assert.Empty(ct, fts, "tagged client %s should have no FileTargets", taggedClient.Hostname())
}, 10*time.Second, 1*time.Second)
}
})
for _, client := range allClients {
command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())}
// Test 4: Same-user file transfer works (user1 -> user1) for all version combinations
t.Run("SameUserTransfer", func(t *testing.T) {
for _, sender := range user1Clients {
// Create file on sender
filename := fmt.Sprintf("file_from_%s", sender.Hostname())
command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)}
_, _, err := sender.Execute(command)
require.NoError(t, err, "failed to create taildrop file on %s", sender.Hostname())
if _, _, err := client.Execute(command); err != nil {
t.Fatalf("failed to create taildrop file on %s, err: %s", client.Hostname(), err)
}
for _, receiver := range user1Clients {
if sender.Hostname() == receiver.Hostname() {
continue
}
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
receiverFQDN, _ := receiver.FQDN()
t.Run(fmt.Sprintf("%s->%s", sender.Hostname(), receiver.Hostname()), func(t *testing.T) {
sendCommand := []string{
"tailscale", "file", "cp",
fmt.Sprintf("/tmp/%s", filename),
fmt.Sprintf("%s:", receiverFQDN),
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
t.Logf("Sending file from %s to %s", sender.Hostname(), receiver.Hostname())
_, _, err := sender.Execute(sendCommand)
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
})
}
}
// It is safe to ignore this error as we handled it when caching it
peerFQDN, _ := peer.FQDN()
// Receive files on all user1 clients
for _, client := range user1Clients {
getCommand := []string{"tailscale", "file", "get", "/tmp/"}
_, _, err := client.Execute(getCommand)
require.NoError(t, err, "failed to get taildrop file on %s", client.Hostname())
t.Run(fmt.Sprintf("%s-%s", client.Hostname(), peer.Hostname()), func(t *testing.T) {
command := []string{
"tailscale", "file", "cp",
fmt.Sprintf("/tmp/file_from_%s", client.Hostname()),
fmt.Sprintf("%s:", peerFQDN),
// Verify files from all other user1 clients exist
for _, peer := range user1Clients {
if client.Hostname() == peer.Hostname() {
continue
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
t.Logf(
"Sending file from %s to %s\n",
client.Hostname(),
peer.Hostname(),
)
_, _, err := client.Execute(command)
assert.NoError(ct, err)
}, 10*time.Second, 1*time.Second)
})
}
}
for _, client := range allClients {
command := []string{
"tailscale", "file",
"get",
"/tmp/",
}
if _, _, err := client.Execute(command); err != nil {
t.Fatalf("failed to get taildrop file on %s, err: %s", client.Hostname(), err)
}
for _, peer := range allClients {
if client.Hostname() == peer.Hostname() {
continue
t.Run(fmt.Sprintf("verify-%s-received-from-%s", client.Hostname(), peer.Hostname()), func(t *testing.T) {
lsCommand := []string{"ls", fmt.Sprintf("/tmp/file_from_%s", peer.Hostname())}
result, _, err := client.Execute(lsCommand)
require.NoErrorf(t, err, "failed to ls taildrop file from %s", peer.Hostname())
assert.Equal(t, fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()), result)
})
}
t.Run(fmt.Sprintf("%s-%s", client.Hostname(), peer.Hostname()), func(t *testing.T) {
command := []string{
"ls",
fmt.Sprintf("/tmp/file_from_%s", peer.Hostname()),
}
log.Printf(
"Checking file in %s from %s\n",
client.Hostname(),
peer.Hostname(),
)
result, _, err := client.Execute(command)
require.NoErrorf(t, err, "failed to execute command to ls taildrop")
log.Printf("Result for %s: %s\n", peer.Hostname(), result)
if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result {
t.Fatalf(
"taildrop result is not correct %s, wanted %s",
result,
fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()),
)
}
})
}
}
})
// Test 5: Cross-user file transfer fails (user1 -> user2)
t.Run("CrossUserTransferBlocked", func(t *testing.T) {
sender := user1Clients[0]
receiver := user2Clients[0]
// Create file on sender
filename := fmt.Sprintf("cross_user_file_from_%s", sender.Hostname())
command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)}
_, _, err := sender.Execute(command)
require.NoError(t, err, "failed to create taildrop file on %s", sender.Hostname())
// Attempt to send file - this should fail
receiverFQDN, _ := receiver.FQDN()
sendCommand := []string{
"tailscale", "file", "cp",
fmt.Sprintf("/tmp/%s", filename),
fmt.Sprintf("%s:", receiverFQDN),
}
t.Logf("Attempting cross-user file send from %s to %s (should fail)", sender.Hostname(), receiver.Hostname())
_, stderr, err := sender.Execute(sendCommand)
// The file transfer should fail because user2 is not in user1's FileTargets
// Either the command errors, or it silently fails (check stderr for error message)
if err != nil {
t.Logf("Cross-user transfer correctly failed with error: %v", err)
} else if strings.Contains(stderr, "not a valid peer") || strings.Contains(stderr, "unknown target") {
t.Logf("Cross-user transfer correctly rejected: %s", stderr)
} else {
// Even if command succeeded, verify the file was NOT received
getCommand := []string{"tailscale", "file", "get", "/tmp/"}
receiver.Execute(getCommand)
lsCommand := []string{"ls", fmt.Sprintf("/tmp/%s", filename)}
_, _, lsErr := receiver.Execute(lsCommand)
assert.Error(t, lsErr, "Cross-user file should NOT have been received")
}
})
// Test 6: Tagged device cannot send files
t.Run("TaggedCannotSend", func(t *testing.T) {
// Create file on tagged client
filename := fmt.Sprintf("file_from_tagged_%s", taggedClient.Hostname())
command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)}
_, _, err := taggedClient.Execute(command)
require.NoError(t, err, "failed to create taildrop file on tagged client")
// Attempt to send to user1 client - should fail because tagged client has no FileTargets
receiver := user1Clients[0]
receiverFQDN, _ := receiver.FQDN()
sendCommand := []string{
"tailscale", "file", "cp",
fmt.Sprintf("/tmp/%s", filename),
fmt.Sprintf("%s:", receiverFQDN),
}
t.Logf("Attempting tagged->user file send from %s to %s (should fail)", taggedClient.Hostname(), receiver.Hostname())
_, stderr, err := taggedClient.Execute(sendCommand)
if err != nil {
t.Logf("Tagged client send correctly failed with error: %v", err)
} else if strings.Contains(stderr, "not a valid peer") || strings.Contains(stderr, "unknown target") || strings.Contains(stderr, "no matches for") {
t.Logf("Tagged client send correctly rejected: %s", stderr)
} else {
// Verify file was NOT received
getCommand := []string{"tailscale", "file", "get", "/tmp/"}
receiver.Execute(getCommand)
lsCommand := []string{"ls", fmt.Sprintf("/tmp/%s", filename)}
_, _, lsErr := receiver.Execute(lsCommand)
assert.Error(t, lsErr, "Tagged client's file should NOT have been received")
}
})
// Test 7: Tagged device cannot receive files (user1 tries to send to tagged)
t.Run("TaggedCannotReceive", func(t *testing.T) {
sender := user1Clients[0]
// Create file on sender
filename := fmt.Sprintf("file_to_tagged_from_%s", sender.Hostname())
command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)}
_, _, err := sender.Execute(command)
require.NoError(t, err, "failed to create taildrop file on %s", sender.Hostname())
// Attempt to send to tagged client - should fail because tagged is not in user1's FileTargets
taggedFQDN, _ := taggedClient.FQDN()
sendCommand := []string{
"tailscale", "file", "cp",
fmt.Sprintf("/tmp/%s", filename),
fmt.Sprintf("%s:", taggedFQDN),
}
t.Logf("Attempting user->tagged file send from %s to %s (should fail)", sender.Hostname(), taggedClient.Hostname())
_, stderr, err := sender.Execute(sendCommand)
if err != nil {
t.Logf("Send to tagged client correctly failed with error: %v", err)
} else if strings.Contains(stderr, "not a valid peer") || strings.Contains(stderr, "unknown target") || strings.Contains(stderr, "no matches for") {
t.Logf("Send to tagged client correctly rejected: %s", stderr)
} else {
// Verify file was NOT received by tagged client
getCommand := []string{"tailscale", "file", "get", "/tmp/"}
taggedClient.Execute(getCommand)
lsCommand := []string{"ls", fmt.Sprintf("/tmp/%s", filename)}
_, _, lsErr := taggedClient.Execute(lsCommand)
assert.Error(t, lsErr, "File to tagged client should NOT have been received")
}
})
}
func TestUpdateHostnameFromClient(t *testing.T) {

View file

@ -33,7 +33,6 @@ import (
"github.com/ory/dockertest/v3"
"github.com/ory/dockertest/v3/docker"
"gopkg.in/yaml.v3"
"tailscale.com/envknob"
"tailscale.com/tailcfg"
"tailscale.com/util/mak"
)
@ -49,7 +48,12 @@ const (
IntegrationTestDockerFileName = "Dockerfile.integration"
)
var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
var (
errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok")
errInvalidHeadscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_HEADSCALE_IMAGE format, expected repository:tag")
errHeadscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_HEADSCALE_IMAGE must be set in CI")
errInvalidPostgresImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_POSTGRES_IMAGE format, expected repository:tag")
)
type fileInContainer struct {
path string
@ -70,7 +74,6 @@ type HeadscaleInContainer struct {
// optional config
port int
extraPorts []string
debugPort int
caCerts [][]byte
hostPortBindings map[string][]string
aclPolicy *policyv2.Policy
@ -281,26 +284,39 @@ func WithDERPAsIP() Option {
}
}
// WithDebugPort sets the debug port for delve debugging.
func WithDebugPort(port int) Option {
return func(hsic *HeadscaleInContainer) {
hsic.debugPort = port
}
}
// buildEntrypoint builds the container entrypoint command based on configuration.
// It constructs proper wait conditions instead of fixed sleeps:
// 1. Wait for network to be ready
// 2. Wait for config.yaml (always written after container start)
// 3. Wait for CA certs if configured
// 4. Update CA certificates
// 5. Run headscale serve
// 6. Sleep at end to keep container alive for log collection on shutdown.
func (hsic *HeadscaleInContainer) buildEntrypoint() []string {
debugCmd := fmt.Sprintf(
"/go/bin/dlv --listen=0.0.0.0:%d --headless=true --api-version=2 --accept-multiclient --allow-non-terminal-interactive=true exec /go/bin/headscale --continue -- serve",
hsic.debugPort,
)
var commands []string
entrypoint := fmt.Sprintf(
"/bin/sleep 3 ; update-ca-certificates ; %s ; /bin/sleep 30",
debugCmd,
)
// Wait for network to be ready
commands = append(commands, "while ! ip route show default >/dev/null 2>&1; do sleep 0.1; done")
return []string{"/bin/bash", "-c", entrypoint}
// Wait for config.yaml to be written (always written after container start)
commands = append(commands, "while [ ! -f /etc/headscale/config.yaml ]; do sleep 0.1; done")
// If CA certs are configured, wait for them to be written
if len(hsic.caCerts) > 0 {
commands = append(commands,
fmt.Sprintf("while [ ! -f %s/user-0.crt ]; do sleep 0.1; done", caCertRoot))
}
// Update CA certificates
commands = append(commands, "update-ca-certificates")
// Run headscale serve
commands = append(commands, "/usr/local/bin/headscale serve")
// Keep container alive after headscale exits for log collection
commands = append(commands, "/bin/sleep 30")
return []string{"/bin/bash", "-c", strings.Join(commands, " ; ")}
}
// New returns a new HeadscaleInContainer instance.
@ -316,18 +332,9 @@ func New(
hostname := "hs-" + hash
// Get debug port from environment or use default
debugPort := 40000
if envDebugPort := envknob.String("HEADSCALE_DEBUG_PORT"); envDebugPort != "" {
if port, err := strconv.Atoi(envDebugPort); err == nil {
debugPort = port
}
}
hsic := &HeadscaleInContainer{
hostname: hostname,
port: headscaleDefaultPort,
debugPort: debugPort,
hostname: hostname,
port: headscaleDefaultPort,
pool: pool,
networks: networks,
@ -344,7 +351,6 @@ func New(
log.Println("NAME: ", hsic.hostname)
portProto := fmt.Sprintf("%d/tcp", hsic.port)
debugPortProto := fmt.Sprintf("%d/tcp", hsic.debugPort)
headscaleBuildOptions := &dockertest.BuildOptions{
Dockerfile: IntegrationTestDockerFileName,
@ -359,10 +365,24 @@ func New(
hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale"
delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH")
// Determine postgres image - use prebuilt if available, otherwise pull from registry
pgRepo := "postgres"
pgTag := "latest"
if prebuiltImage := os.Getenv("HEADSCALE_INTEGRATION_POSTGRES_IMAGE"); prebuiltImage != "" {
repo, tag, found := strings.Cut(prebuiltImage, ":")
if !found {
return nil, errInvalidPostgresImageFormat
}
pgRepo = repo
pgTag = tag
}
pgRunOptions := &dockertest.RunOptions{
Name: "postgres-" + hash,
Repository: "postgres",
Tag: "latest",
Repository: pgRepo,
Tag: pgTag,
Networks: networks,
Env: []string{
"POSTGRES_USER=headscale",
@ -409,7 +429,7 @@ func New(
runOptions := &dockertest.RunOptions{
Name: hsic.hostname,
ExposedPorts: append([]string{portProto, debugPortProto, "9090/tcp"}, hsic.extraPorts...),
ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...),
Networks: networks,
// Cmd: []string{"headscale", "serve"},
// TODO(kradalby): Get rid of this hack, we currently need to give us some
@ -418,13 +438,11 @@ func New(
Env: env,
}
// Always bind debug port and metrics port to predictable host ports
// Bind metrics port to predictable host port
if runOptions.PortBindings == nil {
runOptions.PortBindings = map[docker.Port][]docker.PortBinding{}
}
runOptions.PortBindings[docker.Port(debugPortProto)] = []docker.PortBinding{
{HostPort: strconv.Itoa(hsic.debugPort)},
}
runOptions.PortBindings["9090/tcp"] = []docker.PortBinding{
{HostPort: "49090"},
}
@ -451,52 +469,80 @@ func New(
// Add integration test labels if running under hi tool
dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale")
container, err := pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
runOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
)
if err != nil {
// Try to get more detailed build output
log.Printf("Docker build failed, attempting to get detailed output...")
var container *dockertest.Resource
buildOutput, buildErr := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName)
// Check if a pre-built image is available via environment variable
prebuiltImage := os.Getenv("HEADSCALE_INTEGRATION_HEADSCALE_IMAGE")
// Show the last 100 lines of build output to avoid overwhelming the logs
lines := strings.Split(buildOutput, "\n")
const maxLines = 100
startLine := 0
if len(lines) > maxLines {
startLine = len(lines) - maxLines
if prebuiltImage != "" {
log.Printf("Using pre-built headscale image: %s", prebuiltImage)
// Parse image into repository and tag
repo, tag, ok := strings.Cut(prebuiltImage, ":")
if !ok {
return nil, errInvalidHeadscaleImageFormat
}
relevantOutput := strings.Join(lines[startLine:], "\n")
runOptions.Repository = repo
runOptions.Tag = tag
if buildErr != nil {
// The diagnostic build also failed - this is the real error
return nil, fmt.Errorf("could not start headscale container: %w\n\nDocker build failed. Last %d lines of output:\n%s", err, maxLines, relevantOutput)
container, err = pool.RunWithOptions(
runOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
)
if err != nil {
return nil, fmt.Errorf("could not run pre-built headscale container %q: %w", prebuiltImage, err)
}
} else if util.IsCI() {
return nil, errHeadscaleImageRequiredInCI
} else {
container, err = pool.BuildAndRunWithBuildOptions(
headscaleBuildOptions,
runOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
)
if err != nil {
// Try to get more detailed build output
log.Printf("Docker build/run failed, attempting to get detailed output...")
if buildOutput != "" {
// Build succeeded on retry but container creation still failed
return nil, fmt.Errorf("could not start headscale container: %w\n\nDocker build succeeded on retry, but container creation failed. Last %d lines of build output:\n%s", err, maxLines, relevantOutput)
buildOutput, buildErr := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName)
// Show the last 100 lines of build output to avoid overwhelming the logs
lines := strings.Split(buildOutput, "\n")
const maxLines = 100
startLine := 0
if len(lines) > maxLines {
startLine = len(lines) - maxLines
}
relevantOutput := strings.Join(lines[startLine:], "\n")
if buildErr != nil {
// The diagnostic build also failed - this is the real error
return nil, fmt.Errorf("could not start headscale container: %w\n\nDocker build failed. Last %d lines of output:\n%s", err, maxLines, relevantOutput)
}
if buildOutput != "" {
// Build succeeded on retry but container creation still failed
return nil, fmt.Errorf("could not start headscale container: %w\n\nDocker build succeeded on retry, but container creation failed. Last %d lines of build output:\n%s", err, maxLines, relevantOutput)
}
// No output at all - diagnostic build command may have failed
return nil, fmt.Errorf("could not start headscale container: %w\n\nUnable to get diagnostic build output (command may have failed silently)", err)
}
// No output at all - diagnostic build command may have failed
return nil, fmt.Errorf("could not start headscale container: %w\n\nUnable to get diagnostic build output (command may have failed silently)", err)
}
log.Printf("Created %s container\n", hsic.hostname)
hsic.container = container
log.Printf(
"Debug ports for %s: delve=%s, metrics/pprof=49090\n",
"Ports for %s: metrics/pprof=49090\n",
hsic.hostname,
hsic.GetHostDebugPort(),
)
// Write the CA certificates to the container
@ -886,16 +932,6 @@ func (t *HeadscaleInContainer) GetPort() string {
return strconv.Itoa(t.port)
}
// GetDebugPort returns the debug port as a string.
func (t *HeadscaleInContainer) GetDebugPort() string {
return strconv.Itoa(t.debugPort)
}
// GetHostDebugPort returns the host port mapped to the debug port.
func (t *HeadscaleInContainer) GetHostDebugPort() string {
return strconv.Itoa(t.debugPort)
}
// GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer
// instance.
func (t *HeadscaleInContainer) GetHealthEndpoint() string {
@ -1052,6 +1088,57 @@ func (t *HeadscaleInContainer) CreateAuthKey(
return &preAuthKey, nil
}
// CreateAuthKeyWithTags creates a new "authorisation key" for a User with the specified tags.
// This is used to create tagged PreAuthKeys for testing the tags-as-identity model.
func (t *HeadscaleInContainer) CreateAuthKeyWithTags(
user uint64,
reusable bool,
ephemeral bool,
tags []string,
) (*v1.PreAuthKey, error) {
command := []string{
"headscale",
"--user",
strconv.FormatUint(user, 10),
"preauthkeys",
"create",
"--expiration",
"24h",
"--output",
"json",
}
if reusable {
command = append(command, "--reusable")
}
if ephemeral {
command = append(command, "--ephemeral")
}
if len(tags) > 0 {
command = append(command, "--tags", strings.Join(tags, ","))
}
result, _, err := dockertestutil.ExecuteCommand(
t.container,
command,
[]string{},
)
if err != nil {
return nil, fmt.Errorf("failed to execute create auth key with tags command: %w", err)
}
var preAuthKey v1.PreAuthKey
err = json.Unmarshal([]byte(result), &preAuthKey)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal auth key: %w", err)
}
return &preAuthKey, nil
}
// DeleteAuthKey deletes an "authorisation key" for a User.
func (t *HeadscaleInContainer) DeleteAuthKey(
user uint64,
@ -1369,6 +1456,36 @@ func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (
return node, nil
}
// SetNodeTags sets tags on a node via the headscale CLI.
// This simulates what the Tailscale admin console UI does - it calls the headscale
// SetTags API which is exposed via the CLI command: headscale nodes tag -i <id> -t <tags>.
func (t *HeadscaleInContainer) SetNodeTags(nodeID uint64, tags []string) error {
command := []string{
"headscale", "nodes", "tag",
"--identifier", strconv.FormatUint(nodeID, 10),
"--output", "json",
}
// Add tags - the CLI expects -t flag for each tag or comma-separated
if len(tags) > 0 {
command = append(command, "--tags", strings.Join(tags, ","))
} else {
// Empty tags to clear all tags
command = append(command, "--tags", "")
}
_, _, err := dockertestutil.ExecuteCommand(
t.container,
command,
[]string{},
)
if err != nil {
return fmt.Errorf("failed to execute set tags command (node %d, tags %v): %w", nodeID, tags, err)
}
return nil
}
// WriteFile save file inside the Headscale container.
func (t *HeadscaleInContainer) WriteFile(path string, data []byte) error {
return integrationutil.WriteFileToContainer(t.pool, t.container, path, data)

View file

@ -4,6 +4,7 @@ import (
"cmp"
"encoding/json"
"fmt"
"maps"
"net/netip"
"slices"
"sort"
@ -25,7 +26,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
xmaps "golang.org/x/exp/maps"
"tailscale.com/envknob"
"tailscale.com/ipn/ipnstate"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
@ -1979,6 +1979,11 @@ func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node {
// - Verify that routes can now be seen by peers.
func TestAutoApproveMultiNetwork(t *testing.T) {
IntegrationSkip(t)
// Timeout for EventuallyWithT assertions.
// Set generously to account for CI infrastructure variability.
assertTimeout := 60 * time.Second
bigRoute := netip.MustParsePrefix("10.42.0.0/16")
subRoute := netip.MustParsePrefix("10.42.7.0/24")
notApprovedRoute := netip.MustParsePrefix("192.168.0.0/24")
@ -2217,31 +2222,24 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
},
}
// Check if we should run the full matrix of tests
// By default, we only run a minimal subset to avoid overwhelming Docker/disk
// Set HEADSCALE_INTEGRATION_FULL_MATRIX=1 to run all combinations
fullMatrix := envknob.Bool("HEADSCALE_INTEGRATION_FULL_MATRIX")
// Minimal test set: 3 tests covering all key dimensions
// - Both auth methods (authkey, webauth)
// - All 3 approver types (tag, user, group)
// - Both policy modes (database, file)
// - Both advertiseDuringUp values (true, false)
minimalTestSet := map[string]bool{
"authkey-tag-advertiseduringup-false-pol-database": true, // authkey + database + tag + false
"webauth-user-advertiseduringup-true-pol-file": true, // webauth + file + user + true
"authkey-group-advertiseduringup-false-pol-file": true, // authkey + file + group + false
}
for _, tt := range tests {
for _, polMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} {
for _, advertiseDuringUp := range []bool{false, true} {
name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, polMode)
t.Run(name, func(t *testing.T) {
// Skip tests not in minimal set unless full matrix is enabled
if !fullMatrix && !minimalTestSet[name] {
t.Skip("Skipping to reduce test matrix size. Set HEADSCALE_INTEGRATION_FULL_MATRIX=1 to run all tests.")
// Create a deep copy of the policy to avoid mutating the shared test case.
// Each subtest modifies AutoApprovers.Routes (add then delete), so we need
// an isolated copy to prevent state leakage between sequential test runs.
pol := &policyv2.Policy{
ACLs: slices.Clone(tt.pol.ACLs),
Groups: maps.Clone(tt.pol.Groups),
TagOwners: maps.Clone(tt.pol.TagOwners),
AutoApprovers: policyv2.AutoApproverPolicy{
ExitNode: slices.Clone(tt.pol.AutoApprovers.ExitNode),
Routes: maps.Clone(tt.pol.AutoApprovers.Routes),
},
}
scenario, err := NewScenario(tt.spec)
require.NoErrorf(t, err, "failed to create scenario: %s", err)
defer scenario.ShutdownAssertNoPanics(t)
@ -2251,7 +2249,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
hsic.WithTestName("autoapprovemulti"),
hsic.WithEmbeddedDERPServerOnly(),
hsic.WithTLS(),
hsic.WithACLPolicy(tt.pol),
hsic.WithACLPolicy(pol),
hsic.WithPolicyMode(polMode),
}
@ -2259,16 +2257,25 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
tsic.WithAcceptRoutes(),
}
if tt.approver == "tag:approve" {
tsOpts = append(tsOpts,
tsic.WithTags([]string{"tag:approve"}),
)
}
route, err := scenario.SubnetOfNetwork("usernet1")
require.NoError(t, err)
err = scenario.createHeadscaleEnv(tt.withURL, tsOpts,
// For tag-based approvers, nodes must be tagged with that tag
// (tags-as-identity model: tagged nodes are identified by their tags)
var (
preAuthKeyTags []string
webauthTagUser string
)
if strings.HasPrefix(tt.approver, "tag:") {
preAuthKeyTags = []string{tt.approver}
if tt.withURL {
// For webauth, only user1 can request tags (per tagOwners policy)
webauthTagUser = "user1"
}
}
err = scenario.createHeadscaleEnvWithTags(tt.withURL, tsOpts, preAuthKeyTags, webauthTagUser,
opts...,
)
requireNoErrHeadscaleEnv(t, err)
@ -2301,12 +2308,10 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
default:
approvers = append(approvers, usernameApprover(tt.approver))
}
if tt.pol.AutoApprovers.Routes == nil {
tt.pol.AutoApprovers.Routes = make(map[netip.Prefix]policyv2.AutoApprovers)
}
// pol.AutoApprovers.Routes is already initialized in the deep copy above
prefix := *route
tt.pol.AutoApprovers.Routes[prefix] = approvers
err = headscale.SetPolicy(tt.pol)
pol.AutoApprovers.Routes[prefix] = approvers
err = headscale.SetPolicy(pol)
require.NoError(t, err)
if advertiseDuringUp {
@ -2315,6 +2320,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
)
}
// For webauth with tag approver, the node needs to advertise the tag during registration
// (tags-as-identity model: webauth nodes can use --advertise-tags if authorized by tagOwners)
if tt.withURL && strings.HasPrefix(tt.approver, "tag:") {
tsOpts = append(tsOpts, tsic.WithTags([]string{tt.approver}))
}
tsOpts = append(tsOpts, tsic.WithNetwork(usernet1))
// This whole dance is to add a node _after_ all the other nodes
@ -2349,7 +2360,14 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
userMap, err := headscale.MapUsers()
require.NoError(t, err)
pak, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
// If the approver is a tag, create a tagged PreAuthKey
// (tags-as-identity model: tags come from PreAuthKey, not --advertise-tags)
var pak *v1.PreAuthKey
if strings.HasPrefix(tt.approver, "tag:") {
pak, err = scenario.CreatePreAuthKeyWithTags(userMap["user1"].GetId(), false, false, []string{tt.approver})
} else {
pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false)
}
require.NoError(t, err)
err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey())
@ -2362,6 +2380,21 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
err = routerUsernet1.WaitForRunning(30 * time.Second)
require.NoError(t, err)
// Wait for bidirectional peer synchronization.
// Both the router and all existing clients must see each other.
// This is critical for connectivity - without this, the WireGuard
// tunnels may not be established despite peers appearing in netmaps.
// Router waits for all existing clients
err = routerUsernet1.WaitForPeers(len(allClients), 60*time.Second, 1*time.Second)
require.NoError(t, err, "router failed to see all peers")
// All clients wait for the router (they should see 6 peers including the router)
for _, existingClient := range allClients {
err = existingClient.WaitForPeers(len(allClients), 60*time.Second, 1*time.Second)
require.NoErrorf(t, err, "client %s failed to see all peers including router", existingClient.Hostname())
}
routerUsernet1ID := routerUsernet1.MustID()
web := services[0]
@ -2396,7 +2429,11 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
require.NoErrorf(t, err, "failed to advertise route: %s", err)
}
// Wait for route state changes to propagate
// Wait for route state changes to propagate.
// Use a longer timeout (30s) to account for CI infrastructure variability -
// when advertiseDuringUp=true, routes are sent during registration and may
// take longer to propagate through the server's auto-approval logic in slow
// environments.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
// These route should auto approve, so the node is expected to have a route
// for all counts.
@ -2411,7 +2448,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
routerNode.GetSubnetRoutes())
requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "Initial route auto-approval: Route should be approved via policy")
}, assertTimeout, 500*time.Millisecond, "Initial route auto-approval: Route should be approved via policy")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2444,7 +2481,22 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
}
assert.True(c, routerPeerFound, "Client should see the router peer")
}, 5*time.Second, 200*time.Millisecond, "Verifying routes sent to client after auto-approval")
}, assertTimeout, 200*time.Millisecond, "Verifying routes sent to client after auto-approval")
// Verify WireGuard tunnel connectivity to the router before testing route.
// The client may have the route in its netmap but the actual tunnel may not
// be established yet, especially in CI environments with higher latency.
routerIPv4, err := routerUsernet1.IPv4()
require.NoError(t, err, "failed to get router IPv4")
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err := client.Ping(
routerIPv4.String(),
tsic.WithPingUntilDirect(false), // DERP relay is fine
tsic.WithPingCount(1),
tsic.WithPingTimeout(5*time.Second),
)
assert.NoError(c, err, "ping to router should succeed")
}, assertTimeout, 200*time.Millisecond, "Verifying WireGuard tunnel to router is established")
url := fmt.Sprintf("http://%s/etc/hostname", webip)
t.Logf("url from %s to %s", client.Hostname(), url)
@ -2453,7 +2505,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 20*time.Second, 200*time.Millisecond, "Verifying client can reach webservice through auto-approved route")
}, assertTimeout, 200*time.Millisecond, "Verifying client can reach webservice through auto-approved route")
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
@ -2463,12 +2515,12 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, 20*time.Second, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router")
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router")
// Remove the auto approval from the policy, any routes already enabled should be allowed.
prefix = *route
delete(tt.pol.AutoApprovers.Routes, prefix)
err = headscale.SetPolicy(tt.pol)
delete(pol.AutoApprovers.Routes, prefix)
err = headscale.SetPolicy(pol)
require.NoError(t, err)
t.Logf("Policy updated: removed auto-approver for route %s", prefix)
@ -2486,7 +2538,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
routerNode.GetSubnetRoutes())
requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "Routes should remain approved after auto-approver removal")
}, assertTimeout, 500*time.Millisecond, "Routes should remain approved after auto-approver removal")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2506,7 +2558,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
}
}
}, 5*time.Second, 200*time.Millisecond, "Verifying routes remain after policy change")
}, assertTimeout, 200*time.Millisecond, "Verifying routes remain after policy change")
url = fmt.Sprintf("http://%s/etc/hostname", webip)
t.Logf("url from %s to %s", client.Hostname(), url)
@ -2515,7 +2567,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 20*time.Second, 200*time.Millisecond, "Verifying client can still reach webservice after policy change")
}, assertTimeout, 200*time.Millisecond, "Verifying client can still reach webservice after policy change")
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
@ -2525,7 +2577,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, 20*time.Second, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change")
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change")
// Disable the route, making it unavailable since it is no longer auto-approved
_, err = headscale.ApproveRoutes(
@ -2541,7 +2593,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
}, assertTimeout, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2552,7 +2604,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
peerStatus := status.Peer[peerKey]
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
}
}, 5*time.Second, 200*time.Millisecond, "Verifying routes disabled after route removal")
}, assertTimeout, 200*time.Millisecond, "Verifying routes disabled after route removal")
// Add the route back to the auto approver in the policy, the route should
// now become available again.
@ -2565,12 +2617,10 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
default:
newApprovers = append(newApprovers, usernameApprover(tt.approver))
}
if tt.pol.AutoApprovers.Routes == nil {
tt.pol.AutoApprovers.Routes = make(map[netip.Prefix]policyv2.AutoApprovers)
}
// pol.AutoApprovers.Routes is already initialized in the deep copy above
prefix = *route
tt.pol.AutoApprovers.Routes[prefix] = newApprovers
err = headscale.SetPolicy(tt.pol)
pol.AutoApprovers.Routes[prefix] = newApprovers
err = headscale.SetPolicy(pol)
require.NoError(t, err)
// Wait for route state changes to propagate
@ -2580,7 +2630,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
nodes, err = headscale.ListNodes()
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
}, assertTimeout, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2600,7 +2650,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
}
}
}, 5*time.Second, 200*time.Millisecond, "Verifying routes re-enabled after policy re-approval")
}, assertTimeout, 200*time.Millisecond, "Verifying routes re-enabled after policy re-approval")
url = fmt.Sprintf("http://%s/etc/hostname", webip)
t.Logf("url from %s to %s", client.Hostname(), url)
@ -2609,7 +2659,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
result, err := client.Curl(url)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 20*time.Second, 200*time.Millisecond, "Verifying client can reach webservice after route re-approval")
}, assertTimeout, 200*time.Millisecond, "Verifying client can reach webservice after route re-approval")
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := client.Traceroute(webip)
@ -2619,7 +2669,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, 20*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval")
}, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval")
// Advertise and validate a subnet of an auto approved route, /24 inside the
// auto approved /16.
@ -2639,7 +2689,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
assert.NoError(c, err)
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 1)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
}, assertTimeout, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2663,7 +2713,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
}
}
}, 5*time.Second, 200*time.Millisecond, "Verifying sub-route propagated to client")
}, assertTimeout, 200*time.Millisecond, "Verifying sub-route propagated to client")
// Advertise a not approved route will not end up anywhere
command = []string{
@ -2683,7 +2733,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 0, 0, 0)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
}, assertTimeout, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2703,7 +2753,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
}
}
}, 5*time.Second, 200*time.Millisecond, "Verifying unapproved route not propagated")
}, assertTimeout, 200*time.Millisecond, "Verifying unapproved route not propagated")
// Exit routes are also automatically approved
command = []string{
@ -2721,7 +2771,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1)
requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0)
requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
}, assertTimeout, 500*time.Millisecond, "route state changes should propagate")
// Verify that the routes have been sent to the client.
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -2742,7 +2792,7 @@ func TestAutoApproveMultiNetwork(t *testing.T) {
requirePeerSubnetRoutesWithCollect(c, peerStatus, nil)
}
}
}, 5*time.Second, 200*time.Millisecond, "Verifying exit node routes propagated to client")
}, assertTimeout, 200*time.Millisecond, "Verifying exit node routes propagated to client")
})
}
}
@ -2985,7 +3035,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
// Check that the router has 3 routes now approved and available
requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3)
}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate")
}, 15*time.Second, 500*time.Millisecond, "route state changes should propagate")
// Now check the client node status
assert.EventuallyWithT(t, func(c *assert.CollectT) {
@ -3006,7 +3056,7 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
result, err := nodeClient.Curl(weburl)
assert.NoError(c, err)
assert.Len(c, result, 13)
}, 20*time.Second, 200*time.Millisecond, "Verifying node can reach webservice through allowed route")
}, 60*time.Second, 200*time.Millisecond, "Verifying node can reach webservice through allowed route")
assert.EventuallyWithT(t, func(c *assert.CollectT) {
tr, err := nodeClient.Traceroute(webip)
@ -3016,5 +3066,5 @@ func TestSubnetRouteACLFiltering(t *testing.T) {
return
}
assertTracerouteViaIPWithCollect(c, tr, ip)
}, 20*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router")
}, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router")
}

View file

@ -473,6 +473,27 @@ func (s *Scenario) CreatePreAuthKey(
return nil, fmt.Errorf("failed to create user: %w", errNoHeadscaleAvailable)
}
// CreatePreAuthKeyWithTags creates a "pre authorised key" with the specified tags
// to be created in the Headscale instance on behalf of the Scenario.
func (s *Scenario) CreatePreAuthKeyWithTags(
user uint64,
reusable bool,
ephemeral bool,
tags []string,
) (*v1.PreAuthKey, error) {
headscale, err := s.Headscale()
if err != nil {
return nil, fmt.Errorf("failed to create preauth key with tags: %w", errNoHeadscaleAvailable)
}
key, err := headscale.CreateAuthKeyWithTags(user, reusable, ephemeral, tags)
if err != nil {
return nil, fmt.Errorf("failed to create preauth key with tags: %w", err)
}
return key, nil
}
// CreateUser creates a User to be created in the
// Headscale instance on behalf of the Scenario.
func (s *Scenario) CreateUser(user string) (*v1.User, error) {
@ -767,6 +788,25 @@ func (s *Scenario) createHeadscaleEnv(
withURL bool,
tsOpts []tsic.Option,
opts ...hsic.Option,
) error {
return s.createHeadscaleEnvWithTags(withURL, tsOpts, nil, "", opts...)
}
// createHeadscaleEnvWithTags starts the headscale environment and the clients
// according to the ScenarioSpec passed to the Scenario. If preAuthKeyTags is
// non-empty and withURL is false, the tags will be applied to the PreAuthKey
// (tags-as-identity model).
//
// For webauth (withURL=true), if webauthTagUser is non-empty and preAuthKeyTags
// is non-empty, only nodes belonging to that user will request tags via
// --advertise-tags. This is necessary because tagOwners ACL controls which
// users can request specific tags.
func (s *Scenario) createHeadscaleEnvWithTags(
withURL bool,
tsOpts []tsic.Option,
preAuthKeyTags []string,
webauthTagUser string,
opts ...hsic.Option,
) error {
headscale, err := s.Headscale(opts...)
if err != nil {
@ -779,14 +819,20 @@ func (s *Scenario) createHeadscaleEnv(
return err
}
var opts []tsic.Option
var userOpts []tsic.Option
if s.userToNetwork != nil {
opts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user]))
userOpts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user]))
} else {
opts = append(tsOpts, tsic.WithNetwork(s.networks[s.testDefaultNetwork]))
userOpts = append(tsOpts, tsic.WithNetwork(s.networks[s.testDefaultNetwork]))
}
err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, opts...)
// For webauth with tags, only apply tags to the specified webauthTagUser
// (other users may not be authorized via tagOwners)
if withURL && webauthTagUser != "" && len(preAuthKeyTags) > 0 && user == webauthTagUser {
userOpts = append(userOpts, tsic.WithTags(preAuthKeyTags))
}
err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, userOpts...)
if err != nil {
return err
}
@ -797,7 +843,13 @@ func (s *Scenario) createHeadscaleEnv(
return err
}
} else {
key, err := s.CreatePreAuthKey(u.GetId(), true, false)
// Use tagged PreAuthKey if tags are provided (tags-as-identity model)
var key *v1.PreAuthKey
if len(preAuthKeyTags) > 0 {
key, err = s.CreatePreAuthKeyWithTags(u.GetId(), true, false, preAuthKeyTags)
} else {
key, err = s.CreatePreAuthKey(u.GetId(), true, false)
}
if err != nil {
return err
}

View file

@ -42,11 +42,8 @@ func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Sce
// tailscaled to stop configuring the wgengine, causing it
// to not configure DNS.
tsic.WithNetfilter("off"),
tsic.WithDockerEntrypoint([]string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; apk add openssh ; adduser ssh-it-user ; update-ca-certificates ; tailscaled --tun=tsdev",
}),
tsic.WithPackages("openssh"),
tsic.WithExtraCommands("adduser ssh-it-user"),
tsic.WithDockerWorkdir("/"),
},
hsic.WithACLPolicy(policy),
@ -395,8 +392,10 @@ func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient,
log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname())
log.Printf("Command: %s", strings.Join(command, " "))
var result, stderr string
var err error
var (
result, stderr string
err error
)
if retry {
// Use assert.EventuallyWithT to retry SSH connections for success cases
@ -455,6 +454,7 @@ func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient
func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) {
t.Helper()
assert.Error(t, err)
if !isSSHNoAccessStdError(stderr) {
t.Errorf("expected stderr output suggesting access denied, got: %s", stderr)
}
@ -462,7 +462,7 @@ func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) {
// TestSSHAutogroupSelf tests that SSH with autogroup:self works correctly:
// - Users can SSH to their own devices
// - Users cannot SSH to other users' devices
// - Users cannot SSH to other users' devices.
func TestSSHAutogroupSelf(t *testing.T) {
IntegrationSkip(t)

2465
integration/tags_test.go Normal file

File diff suppressed because it is too large Load diff

View file

@ -14,6 +14,7 @@ import (
"os"
"reflect"
"runtime/debug"
"slices"
"strconv"
"strings"
"time"
@ -54,6 +55,10 @@ var (
errTailscaleNotConnected = errors.New("tailscale not connected")
errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login")
errInvalidClientConfig = errors.New("verifiably invalid client config requested")
errInvalidTailscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_TAILSCALE_IMAGE format, expected repository:tag")
errTailscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE must be set in CI for HEAD version")
errContainerNotInitialized = errors.New("container not initialized")
errFQDNNotYetAvailable = errors.New("FQDN not yet available")
)
const (
@ -90,6 +95,9 @@ type TailscaleInContainer struct {
netfilter string
extraLoginArgs []string
withAcceptRoutes bool
withPackages []string // Alpine packages to install at container start
withWebserverPort int // Port for built-in HTTP server (0 = disabled)
withExtraCommands []string // Extra shell commands to run before tailscaled
// build options, solely for HEAD
buildConfig TailscaleInContainerBuildConfig
@ -212,6 +220,82 @@ func WithAcceptRoutes() Option {
}
}
// WithPackages specifies Alpine packages to install when the container starts.
// This requires internet access and uses `apk add`. Common packages:
// - "python3" for HTTP server
// - "curl" for HTTP client
// - "bind-tools" for dig command
// - "iptables", "ip6tables" for firewall rules
// Note: Tests using this option require internet access and cannot use
// the built-in DERP server in offline mode.
func WithPackages(packages ...string) Option {
return func(tsic *TailscaleInContainer) {
tsic.withPackages = append(tsic.withPackages, packages...)
}
}
// WithWebserver starts a Python HTTP server on the specified port
// alongside tailscaled. This is useful for testing subnet routing
// and ACL connectivity. Automatically adds "python3" to packages if needed.
// The server serves files from the root directory (/).
func WithWebserver(port int) Option {
return func(tsic *TailscaleInContainer) {
tsic.withWebserverPort = port
}
}
// WithExtraCommands adds extra shell commands to run before tailscaled starts.
// Commands are run after package installation and CA certificate updates.
func WithExtraCommands(commands ...string) Option {
return func(tsic *TailscaleInContainer) {
tsic.withExtraCommands = append(tsic.withExtraCommands, commands...)
}
}
// buildEntrypoint constructs the container entrypoint command based on
// configured options (packages, webserver, etc.).
func (t *TailscaleInContainer) buildEntrypoint() []string {
var commands []string
// Wait for network to be ready
commands = append(commands, "while ! ip route show default >/dev/null 2>&1; do sleep 0.1; done")
// If CA certs are configured, wait for them to be written by the Go code
// (certs are written after container start via tsic.WriteFile)
if len(t.caCerts) > 0 {
commands = append(commands,
fmt.Sprintf("while [ ! -f %s/user-0.crt ]; do sleep 0.1; done", caCertRoot))
}
// Install packages if requested (requires internet access)
packages := t.withPackages
if t.withWebserverPort > 0 && !slices.Contains(packages, "python3") {
packages = append(packages, "python3")
}
if len(packages) > 0 {
commands = append(commands, "apk add --no-cache "+strings.Join(packages, " "))
}
// Update CA certificates
commands = append(commands, "update-ca-certificates")
// Run extra commands if any
commands = append(commands, t.withExtraCommands...)
// Start webserver in background if requested
// Use subshell to avoid & interfering with command joining
if t.withWebserverPort > 0 {
commands = append(commands,
fmt.Sprintf("(python3 -m http.server --bind :: %d &)", t.withWebserverPort))
}
// Start tailscaled (must be last as it's the foreground process)
commands = append(commands, "tailscaled --tun=tsdev --verbose=10")
return []string{"/bin/sh", "-c", strings.Join(commands, " ; ")}
}
// New returns a new TailscaleInContainer instance.
func New(
pool *dockertest.Pool,
@ -230,18 +314,18 @@ func New(
hostname: hostname,
pool: pool,
withEntrypoint: []string{
"/bin/sh",
"-c",
"/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev --verbose=10",
},
}
for _, opt := range opts {
opt(tsic)
}
// Build the entrypoint command dynamically based on options.
// Only build if no custom entrypoint was provided via WithDockerEntrypoint.
if len(tsic.withEntrypoint) == 0 {
tsic.withEntrypoint = tsic.buildEntrypoint()
}
if tsic.network == nil {
return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack()))
}
@ -291,6 +375,7 @@ func New(
// build options are not meaningful with pre-existing images,
// let's not lead anyone astray by pretending otherwise.
defaultBuildConfig := TailscaleInContainerBuildConfig{}
hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig)
if hasBuildConfig {
return tsic, errInvalidClientConfig
@ -299,80 +384,119 @@ func New(
switch version {
case VersionHead:
buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale-HEAD",
ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{},
// Check if a pre-built image is available via environment variable
prebuiltImage := os.Getenv("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE")
// If custom build tags are required (e.g., for websocket DERP), we cannot use
// the pre-built image as it won't have the necessary code compiled in.
hasBuildTags := len(tsic.buildConfig.tags) > 0
if hasBuildTags && prebuiltImage != "" {
log.Printf("Ignoring pre-built image %s because custom build tags are required: %v",
prebuiltImage, tsic.buildConfig.tags)
prebuiltImage = ""
}
buildTags := strings.Join(tsic.buildConfig.tags, ",")
if len(buildTags) > 0 {
buildOptions.BuildArgs = append(
buildOptions.BuildArgs,
docker.BuildArg{
Name: "BUILD_TAGS",
Value: buildTags,
},
)
}
if prebuiltImage != "" {
log.Printf("Using pre-built tailscale image: %s", prebuiltImage)
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
tailscaleOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
if err != nil {
// Try to get more detailed build output
log.Printf("Docker build failed for %s, attempting to get detailed output...", hostname)
buildOutput, buildErr := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, "Dockerfile.tailscale-HEAD")
// Show the last 100 lines of build output to avoid overwhelming the logs
lines := strings.Split(buildOutput, "\n")
const maxLines = 100
startLine := 0
if len(lines) > maxLines {
startLine = len(lines) - maxLines
// Parse image into repository and tag
repo, tag, ok := strings.Cut(prebuiltImage, ":")
if !ok {
return nil, errInvalidTailscaleImageFormat
}
relevantOutput := strings.Join(lines[startLine:], "\n")
tailscaleOptions.Repository = repo
tailscaleOptions.Tag = tag
if buildErr != nil {
// The diagnostic build also failed - this is the real error
return nil, fmt.Errorf(
"%s could not start tailscale container (version: %s): %w\n\nDocker build failed. Last %d lines of output:\n%s",
hostname,
version,
err,
maxLines,
relevantOutput,
container, err = pool.RunWithOptions(
tailscaleOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
if err != nil {
return nil, fmt.Errorf("could not run pre-built tailscale container %q: %w", prebuiltImage, err)
}
} else if util.IsCI() && !hasBuildTags {
// In CI, we require a pre-built image unless custom build tags are needed
return nil, errTailscaleImageRequiredInCI
} else {
buildOptions := &dockertest.BuildOptions{
Dockerfile: "Dockerfile.tailscale-HEAD",
ContextDir: dockerContextPath,
BuildArgs: []docker.BuildArg{},
}
buildTags := strings.Join(tsic.buildConfig.tags, ",")
if len(buildTags) > 0 {
buildOptions.BuildArgs = append(
buildOptions.BuildArgs,
docker.BuildArg{
Name: "BUILD_TAGS",
Value: buildTags,
},
)
}
if buildOutput != "" {
// Build succeeded on retry but container creation still failed
container, err = pool.BuildAndRunWithBuildOptions(
buildOptions,
tailscaleOptions,
dockertestutil.DockerRestartPolicy,
dockertestutil.DockerAllowLocalIPv6,
dockertestutil.DockerAllowNetworkAdministration,
dockertestutil.DockerMemoryLimit,
)
if err != nil {
// Try to get more detailed build output
log.Printf("Docker build failed for %s, attempting to get detailed output...", hostname)
buildOutput, buildErr := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, "Dockerfile.tailscale-HEAD")
// Show the last 100 lines of build output to avoid overwhelming the logs
lines := strings.Split(buildOutput, "\n")
const maxLines = 100
startLine := 0
if len(lines) > maxLines {
startLine = len(lines) - maxLines
}
relevantOutput := strings.Join(lines[startLine:], "\n")
if buildErr != nil {
// The diagnostic build also failed - this is the real error
return nil, fmt.Errorf(
"%s could not start tailscale container (version: %s): %w\n\nDocker build failed. Last %d lines of output:\n%s",
hostname,
version,
err,
maxLines,
relevantOutput,
)
}
if buildOutput != "" {
// Build succeeded on retry but container creation still failed
return nil, fmt.Errorf(
"%s could not start tailscale container (version: %s): %w\n\nDocker build succeeded on retry, but container creation failed. Last %d lines of build output:\n%s",
hostname,
version,
err,
maxLines,
relevantOutput,
)
}
// No output at all - diagnostic build command may have failed
return nil, fmt.Errorf(
"%s could not start tailscale container (version: %s): %w\n\nDocker build succeeded on retry, but container creation failed. Last %d lines of build output:\n%s",
"%s could not start tailscale container (version: %s): %w\n\nUnable to get diagnostic build output (command may have failed silently)",
hostname,
version,
err,
maxLines,
relevantOutput,
)
}
// No output at all - diagnostic build command may have failed
return nil, fmt.Errorf(
"%s could not start tailscale container (version: %s): %w\n\nUnable to get diagnostic build output (command may have failed silently)",
hostname,
version,
err,
)
}
case "unstable":
tailscaleOptions.Repository = "tailscale/tailscale"
@ -412,6 +536,7 @@ func New(
err,
)
}
log.Printf("Created %s container\n", hostname)
tsic.container = container
@ -471,7 +596,6 @@ func (t *TailscaleInContainer) Execute(
if err != nil {
// log.Printf("command issued: %s", strings.Join(command, " "))
// log.Printf("command stderr: %s\n", stderr)
if stdout != "" {
log.Printf("command stdout: %s\n", stdout)
}
@ -597,7 +721,7 @@ func (t *TailscaleInContainer) Logout() error {
// "tailscale up" with any auth keys stored in environment variables.
func (t *TailscaleInContainer) Restart() error {
if t.container == nil {
return fmt.Errorf("container not initialized")
return errContainerNotInitialized
}
// Use Docker API to restart the container
@ -614,6 +738,7 @@ func (t *TailscaleInContainer) Restart() error {
if err != nil {
return struct{}{}, fmt.Errorf("container not ready: %w", err)
}
return struct{}{}, nil
}, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(30*time.Second))
if err != nil {
@ -680,15 +805,18 @@ func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) {
}
ips := make([]netip.Addr, 0)
for address := range strings.SplitSeq(result, "\n") {
address = strings.TrimSuffix(address, "\n")
if len(address) < 1 {
continue
}
ip, err := netip.ParseAddr(address)
if err != nil {
return nil, fmt.Errorf("failed to parse IP %s: %w", address, err)
}
ips = append(ips, ip)
}
@ -710,6 +838,7 @@ func (t *TailscaleInContainer) MustIPs() []netip.Addr {
if err != nil {
panic(err)
}
return ips
}
@ -734,6 +863,7 @@ func (t *TailscaleInContainer) MustIPv4() netip.Addr {
if err != nil {
panic(err)
}
return ip
}
@ -743,6 +873,7 @@ func (t *TailscaleInContainer) MustIPv6() netip.Addr {
return ip
}
}
panic("no ipv6 found")
}
@ -760,6 +891,7 @@ func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) {
}
var status ipnstate.Status
err = json.Unmarshal([]byte(result), &status)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err)
@ -819,6 +951,7 @@ func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) {
}
var nm netmap.NetworkMap
err = json.Unmarshal([]byte(result), &nm)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err)
@ -864,6 +997,7 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error
notify *ipn.Notify
err error
}
resultChan := make(chan result, 1)
// There is no good way to kill the goroutine with watch-ipn,
@ -895,7 +1029,9 @@ func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error
decoder := json.NewDecoder(pr)
for decoder.More() {
var notify ipn.Notify
if err := decoder.Decode(&notify); err != nil {
err := decoder.Decode(&notify)
if err != nil {
resultChan <- result{nil, fmt.Errorf("parse notify: %w", err)}
}
@ -942,6 +1078,7 @@ func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDE
}
var report ipnstate.DebugDERPRegionReport
err = json.Unmarshal([]byte(result), &report)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err)
@ -965,6 +1102,7 @@ func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) {
}
var nm netcheck.Report
err = json.Unmarshal([]byte(result), &nm)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err)
@ -987,7 +1125,7 @@ func (t *TailscaleInContainer) FQDN() (string, error) {
}
if status.Self.DNSName == "" {
return "", fmt.Errorf("FQDN not yet available")
return "", errFQDNNotYetAvailable
}
return status.Self.DNSName, nil
@ -1005,6 +1143,7 @@ func (t *TailscaleInContainer) MustFQDN() string {
if err != nil {
panic(err)
}
return fqdn
}
@ -1098,12 +1237,14 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval
defer cancel()
var lastErrs []error
for {
select {
case <-ctx.Done():
if len(lastErrs) > 0 {
return fmt.Errorf("timeout waiting for %d peers on %s after %v, errors: %w", expected, t.hostname, timeout, multierr.New(lastErrs...))
}
return fmt.Errorf("timeout waiting for %d peers on %s after %v", expected, t.hostname, timeout)
case <-ticker.C:
status, err := t.Status()
@ -1127,6 +1268,7 @@ func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval
// Verify that the peers of a given node is Online
// has a hostname and a DERP relay.
var peerErrors []error
for _, peerKey := range status.Peers() {
peer := status.Peer[peerKey]
@ -1320,6 +1462,7 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err
}
var result string
result, _, err := t.Execute(command)
if err != nil {
log.Printf(
@ -1353,6 +1496,7 @@ func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error
}
var result util.Traceroute
stdout, stderr, err := t.Execute(command)
if err != nil {
return result, err
@ -1398,12 +1542,14 @@ func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) {
}
var out bytes.Buffer
tr := tar.NewReader(bytes.NewReader(tarBytes))
for {
hdr, err := tr.Next()
if err == io.EOF {
break // End of archive
}
if err != nil {
return nil, fmt.Errorf("reading tar header: %w", err)
}
@ -1432,6 +1578,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
if err != nil {
return nil, fmt.Errorf("failed to read state file: %w", err)
}
store := &mem.Store{}
if err = store.LoadFromJSON(state); err != nil {
return nil, fmt.Errorf("failed to unmarshal state file: %w", err)
@ -1441,6 +1588,7 @@ func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) {
if err != nil {
return nil, fmt.Errorf("failed to read current profile state key: %w", err)
}
currentProfile, err := store.ReadState(ipn.StateKey(currentProfileKey))
if err != nil {
return nil, fmt.Errorf("failed to read current profile state: %w", err)

View file

@ -3,7 +3,9 @@ package main
//go:generate go run main.go
import (
"context"
"encoding/json"
"errors"
"fmt"
"go/format"
"io"
@ -21,64 +23,211 @@ import (
)
const (
releasesURL = "https://api.github.com/repos/tailscale/tailscale/releases"
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
outputFile = "../../hscontrol/capver/capver_generated.go"
testFile = "../../hscontrol/capver/capver_test_data.go"
minVersionParts = 2
fallbackCapVer = 90
maxTestCases = 4
// TODO(https://github.com/tailscale/tailscale/issues/12849): Restore to 10 when v1.92 is released.
supportedMajorMinorVersions = 9
ghcrTokenURL = "https://ghcr.io/token?service=ghcr.io&scope=repository:tailscale/tailscale:pull" //nolint:gosec
ghcrTagsURL = "https://ghcr.io/v2/tailscale/tailscale/tags/list?n=10000"
rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go"
outputFile = "../../hscontrol/capver/capver_generated.go"
testFile = "../../hscontrol/capver/capver_test_data.go"
fallbackCapVer = 90
maxTestCases = 4
supportedMajorMinorVersions = 10
filePermissions = 0o600
semverMatchGroups = 4
latest3Count = 3
latest2Count = 2
)
type Release struct {
Name string `json:"name"`
var errUnexpectedStatusCode = errors.New("unexpected status code")
// GHCRTokenResponse represents the response from GHCR token endpoint.
type GHCRTokenResponse struct {
Token string `json:"token"`
}
func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
// Fetch the releases
resp, err := http.Get(releasesURL)
// GHCRTagsResponse represents the response from GHCR tags list endpoint.
type GHCRTagsResponse struct {
Name string `json:"name"`
Tags []string `json:"tags"`
}
// getGHCRToken fetches an anonymous token from GHCR for accessing public container images.
func getGHCRToken(ctx context.Context) (string, error) {
client := &http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ghcrTokenURL, nil)
if err != nil {
return nil, fmt.Errorf("error fetching releases: %w", err)
return "", fmt.Errorf("error creating token request: %w", err)
}
resp, err := client.Do(req)
if err != nil {
return "", fmt.Errorf("error fetching GHCR token: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("%w: %d", errUnexpectedStatusCode, resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading response body: %w", err)
return "", fmt.Errorf("error reading token response: %w", err)
}
var releases []Release
var tokenResp GHCRTokenResponse
err = json.Unmarshal(body, &releases)
err = json.Unmarshal(body, &tokenResp)
if err != nil {
return nil, fmt.Errorf("error unmarshalling JSON: %w", err)
return "", fmt.Errorf("error parsing token response: %w", err)
}
return tokenResp.Token, nil
}
// getGHCRTags fetches all available tags from GHCR for tailscale/tailscale.
func getGHCRTags(ctx context.Context) ([]string, error) {
token, err := getGHCRToken(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get GHCR token: %w", err)
}
client := &http.Client{}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, ghcrTagsURL, nil)
if err != nil {
return nil, fmt.Errorf("error creating tags request: %w", err)
}
req.Header.Set("Authorization", "Bearer "+token)
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("error fetching tags: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("%w: %d", errUnexpectedStatusCode, resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("error reading tags response: %w", err)
}
var tagsResp GHCRTagsResponse
err = json.Unmarshal(body, &tagsResp)
if err != nil {
return nil, fmt.Errorf("error parsing tags response: %w", err)
}
return tagsResp.Tags, nil
}
// semverRegex matches semantic version tags like v1.90.0 or v1.90.1.
var semverRegex = regexp.MustCompile(`^v(\d+)\.(\d+)\.(\d+)$`)
// parseSemver extracts major, minor, patch from a semver tag.
// Returns -1 for all values if not a valid semver.
func parseSemver(tag string) (int, int, int) {
matches := semverRegex.FindStringSubmatch(tag)
if len(matches) != semverMatchGroups {
return -1, -1, -1
}
major, _ := strconv.Atoi(matches[1])
minor, _ := strconv.Atoi(matches[2])
patch, _ := strconv.Atoi(matches[3])
return major, minor, patch
}
// getMinorVersionsFromTags processes container tags and returns a map of minor versions
// to the first available patch version for each minor.
// For example: {"v1.90": "v1.90.0", "v1.92": "v1.92.0"}.
func getMinorVersionsFromTags(tags []string) map[string]string {
// Map minor version (e.g., "v1.90") to lowest patch version available
minorToLowestPatch := make(map[string]struct {
patch int
fullVer string
})
for _, tag := range tags {
major, minor, patch := parseSemver(tag)
if major < 0 {
continue // Not a semver tag
}
minorKey := fmt.Sprintf("v%d.%d", major, minor)
existing, exists := minorToLowestPatch[minorKey]
if !exists || patch < existing.patch {
minorToLowestPatch[minorKey] = struct {
patch int
fullVer string
}{
patch: patch,
fullVer: tag,
}
}
}
// Convert to simple map
result := make(map[string]string)
for minorVer, info := range minorToLowestPatch {
result[minorVer] = info.fullVer
}
return result
}
// getCapabilityVersions fetches container tags from GHCR, identifies minor versions,
// and fetches the capability version for each from the Tailscale source.
func getCapabilityVersions(ctx context.Context) (map[string]tailcfg.CapabilityVersion, error) {
// Fetch container tags from GHCR
tags, err := getGHCRTags(ctx)
if err != nil {
return nil, fmt.Errorf("failed to get container tags: %w", err)
}
log.Printf("Found %d container tags", len(tags))
// Get minor versions with their representative patch versions
minorVersions := getMinorVersionsFromTags(tags)
log.Printf("Found %d minor versions", len(minorVersions))
// Regular expression to find the CurrentCapabilityVersion line
re := regexp.MustCompile(`const CurrentCapabilityVersion CapabilityVersion = (\d+)`)
versions := make(map[string]tailcfg.CapabilityVersion)
client := &http.Client{}
for _, release := range releases {
version := strings.TrimSpace(release.Name)
if !strings.HasPrefix(version, "v") {
version = "v" + version
for minorVer, patchVer := range minorVersions {
// Fetch the raw Go file for the patch version
rawURL := fmt.Sprintf(rawFileURL, patchVer)
req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) //nolint:gosec
if err != nil {
log.Printf("Warning: failed to create request for %s: %v", patchVer, err)
continue
}
// Fetch the raw Go file
rawURL := fmt.Sprintf(rawFileURL, version)
resp, err := http.Get(rawURL)
resp, err := client.Do(req)
if err != nil {
log.Printf("Warning: failed to fetch %s: %v", patchVer, err)
continue
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
log.Printf("Warning: got status %d for %s", resp.StatusCode, patchVer)
continue
}
body, err := io.ReadAll(resp.Body)
if err != nil {
log.Printf("Warning: failed to read response for %s: %v", patchVer, err)
continue
}
@ -87,7 +236,8 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
if len(matches) > 1 {
capabilityVersionStr := matches[1]
capabilityVersion, _ := strconv.Atoi(capabilityVersionStr)
versions[version] = tailcfg.CapabilityVersion(capabilityVersion)
versions[minorVer] = tailcfg.CapabilityVersion(capabilityVersion)
log.Printf(" %s (from %s): capVer %d", minorVer, patchVer, capabilityVersion)
}
}
@ -95,38 +245,20 @@ func getCapabilityVersions() (map[string]tailcfg.CapabilityVersion, error) {
}
func calculateMinSupportedCapabilityVersion(versions map[string]tailcfg.CapabilityVersion) tailcfg.CapabilityVersion {
// Get unique major.minor versions
majorMinorToCapVer := make(map[string]tailcfg.CapabilityVersion)
// Since we now store minor versions directly, just sort and take the oldest of the latest N
minorVersions := xmaps.Keys(versions)
sort.Strings(minorVersions)
for version, capVer := range versions {
// Remove 'v' prefix and split by '.'
cleanVersion := strings.TrimPrefix(version, "v")
parts := strings.Split(cleanVersion, ".")
if len(parts) >= minVersionParts {
majorMinor := parts[0] + "." + parts[1]
// Keep the earliest (lowest) capver for each major.minor
if existing, exists := majorMinorToCapVer[majorMinor]; !exists || capVer < existing {
majorMinorToCapVer[majorMinor] = capVer
}
}
}
// Sort major.minor versions
majorMinors := xmaps.Keys(majorMinorToCapVer)
sort.Strings(majorMinors)
// Take the latest 10 versions
supportedCount := min(len(majorMinors), supportedMajorMinorVersions)
supportedCount := min(len(minorVersions), supportedMajorMinorVersions)
if supportedCount == 0 {
return fallbackCapVer
}
// The minimum supported version is the oldest of the latest 10
oldestSupportedMajorMinor := majorMinors[len(majorMinors)-supportedCount]
oldestSupportedMinor := minorVersions[len(minorVersions)-supportedCount]
return majorMinorToCapVer[oldestSupportedMajorMinor]
return versions[oldestSupportedMinor]
}
func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion, minSupportedCapVer tailcfg.CapabilityVersion) error {
@ -156,8 +288,8 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
capabilityVersion := versions[v]
// If it is already set, skip and continue,
// we only want the first tailscale vsion per
// capability vsion.
// we only want the first tailscale version per
// capability version.
if _, ok := capVarToTailscaleVer[capabilityVersion]; ok {
continue
}
@ -199,31 +331,16 @@ func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion
}
func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupportedCapVer tailcfg.CapabilityVersion) error {
// Get unique major.minor versions for test generation
majorMinorToCapVer := make(map[string]tailcfg.CapabilityVersion)
// Sort minor versions
minorVersions := xmaps.Keys(versions)
sort.Strings(minorVersions)
for version, capVer := range versions {
cleanVersion := strings.TrimPrefix(version, "v")
// Take latest N
supportedCount := min(len(minorVersions), supportedMajorMinorVersions)
parts := strings.Split(cleanVersion, ".")
if len(parts) >= minVersionParts {
majorMinor := parts[0] + "." + parts[1]
if existing, exists := majorMinorToCapVer[majorMinor]; !exists || capVer < existing {
majorMinorToCapVer[majorMinor] = capVer
}
}
}
// Sort major.minor versions
majorMinors := xmaps.Keys(majorMinorToCapVer)
sort.Strings(majorMinors)
// Take latest 10
supportedCount := min(len(majorMinors), supportedMajorMinorVersions)
latest10 := majorMinors[len(majorMinors)-supportedCount:]
latest3 := majorMinors[len(majorMinors)-3:]
latest2 := majorMinors[len(majorMinors)-2:]
latest10 := minorVersions[len(minorVersions)-supportedCount:]
latest3 := minorVersions[len(minorVersions)-min(latest3Count, len(minorVersions)):]
latest2 := minorVersions[len(minorVersions)-min(latest2Count, len(minorVersions)):]
// Generate test data file content
var content strings.Builder
@ -242,7 +359,7 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport
content.WriteString("\t{3, false, []string{")
for i, version := range latest3 {
content.WriteString(fmt.Sprintf("\"v%s\"", version))
content.WriteString(fmt.Sprintf("\"%s\"", version))
if i < len(latest3)-1 {
content.WriteString(", ")
@ -255,7 +372,9 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport
content.WriteString("\t{2, true, []string{")
for i, version := range latest2 {
content.WriteString(fmt.Sprintf("\"%s\"", version))
// Strip v prefix for this test case
verNoV := strings.TrimPrefix(version, "v")
content.WriteString(fmt.Sprintf("\"%s\"", verNoV))
if i < len(latest2)-1 {
content.WriteString(", ")
@ -268,7 +387,8 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport
content.WriteString(fmt.Sprintf("\t{%d, true, []string{\n", supportedMajorMinorVersions))
for _, version := range latest10 {
content.WriteString(fmt.Sprintf("\t\t\"%s\",\n", version))
verNoV := strings.TrimPrefix(version, "v")
content.WriteString(fmt.Sprintf("\t\t\"%s\",\n", verNoV))
}
content.WriteString("\t}},\n")
@ -338,7 +458,9 @@ func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupport
}
func main() {
versions, err := getCapabilityVersions()
ctx := context.Background()
versions, err := getCapabilityVersions(ctx)
if err != nil {
log.Println("Error:", err)
return