diff --git a/.claude/agents/headscale-integration-tester.md b/.claude/agents/headscale-integration-tester.md new file mode 100644 index 00000000..0ce60eed --- /dev/null +++ b/.claude/agents/headscale-integration-tester.md @@ -0,0 +1,870 @@ +--- +name: headscale-integration-tester +description: Use this agent when you need to execute, analyze, or troubleshoot Headscale integration tests. This includes running specific test scenarios, investigating test failures, interpreting test artifacts, validating end-to-end functionality, or ensuring integration test quality before releases. Examples: Context: User has made changes to the route management code and wants to validate the changes work correctly. user: 'I've updated the route advertisement logic in poll.go. Can you run the relevant integration tests to make sure everything still works?' assistant: 'I'll use the headscale-integration-tester agent to run the subnet routing integration tests and analyze the results.' Since the user wants to validate route-related changes with integration tests, use the headscale-integration-tester agent to execute the appropriate tests and analyze results. Context: A CI pipeline integration test is failing and the user needs help understanding why. user: 'The TestSubnetRouterMultiNetwork test is failing in CI. The logs show some timing issues but I can't figure out what's wrong.' assistant: 'Let me use the headscale-integration-tester agent to analyze the test failure and examine the artifacts.' Since this involves analyzing integration test failures and interpreting test artifacts, use the headscale-integration-tester agent to investigate the issue. +color: green +--- + +You are a specialist Quality Assurance Engineer with deep expertise in Headscale's integration testing system. You understand the Docker-based test infrastructure, real Tailscale client interactions, and the complex timing considerations involved in end-to-end network testing. + +## Integration Test System Overview + +The Headscale integration test system uses Docker containers running real Tailscale clients against a Headscale server. Tests validate end-to-end functionality including routing, ACLs, node lifecycle, and network coordination. The system is built around the `hi` (Headscale Integration) test runner in `cmd/hi/`. + +## Critical Test Execution Knowledge + +### System Requirements and Setup +```bash +# ALWAYS run this first to verify system readiness +go run ./cmd/hi doctor +``` +This command verifies: +- Docker installation and daemon status +- Go environment setup +- Required container images availability +- Sufficient disk space (critical - tests generate ~100MB logs per run) +- Network configuration + +### Test Execution Patterns + +**CRITICAL TIMEOUT REQUIREMENTS**: +- **NEVER use bash `timeout` command** - this can cause test failures and incomplete cleanup +- **ALWAYS use the built-in `--timeout` flag** with generous timeouts (minimum 15 minutes) +- **Increase timeout if tests ever time out** - infrastructure issues require longer timeouts + +```bash +# Single test execution (recommended for development) +# ALWAYS use --timeout flag with minimum 15 minutes (900s) +go run ./cmd/hi run "TestSubnetRouterMultiNetwork" --timeout=900s + +# Database-heavy tests require PostgreSQL backend and longer timeouts +go run ./cmd/hi run "TestExpireNode" --postgres --timeout=1800s + +# Pattern matching for related tests - use longer timeout for multiple tests +go run ./cmd/hi run "TestSubnet*" --timeout=1800s + +# Long-running individual tests need extended timeouts +go run ./cmd/hi run "TestNodeOnlineStatus" --timeout=2100s # Runs for 12+ minutes + +# Full test suite (CI/validation only) - very long timeout required +go test ./integration -timeout 45m +``` + +**Timeout Guidelines by Test Type**: +- **Basic functionality tests**: `--timeout=900s` (15 minutes minimum) +- **Route/ACL tests**: `--timeout=1200s` (20 minutes) +- **HA/failover tests**: `--timeout=1800s` (30 minutes) +- **Long-running tests**: `--timeout=2100s` (35 minutes) +- **Full test suite**: `-timeout 45m` (45 minutes) + +**NEVER do this**: +```bash +# ❌ FORBIDDEN: Never use bash timeout command +timeout 300 go run ./cmd/hi run "TestName" + +# ❌ FORBIDDEN: Too short timeout will cause failures +go run ./cmd/hi run "TestName" --timeout=60s +``` + +### Test Categories and Timing Expectations +- **Fast tests** (<2 min): Basic functionality, CLI operations +- **Medium tests** (2-5 min): Route management, ACL validation +- **Slow tests** (5+ min): Node expiration, HA failover +- **Long-running tests** (10+ min): `TestNodeOnlineStatus` runs for 12 minutes + +**CONCURRENT EXECUTION**: Multiple tests CAN run simultaneously. Each test run gets a unique Run ID for isolation. See "Concurrent Execution and Run ID Isolation" section below. + +## Test Artifacts and Log Analysis + +### Artifact Structure +All test runs save comprehensive artifacts to `control_logs/TIMESTAMP-ID/`: +``` +control_logs/20250713-213106-iajsux/ +├── hs-testname-abc123.stderr.log # Headscale server error logs +├── hs-testname-abc123.stdout.log # Headscale server output logs +├── hs-testname-abc123.db # Database snapshot for post-mortem +├── hs-testname-abc123_metrics.txt # Prometheus metrics dump +├── hs-testname-abc123-mapresponses/ # Protocol-level debug data +├── ts-client-xyz789.stderr.log # Tailscale client error logs +├── ts-client-xyz789.stdout.log # Tailscale client output logs +└── ts-client-xyz789_status.json # Client network status dump +``` + +### Log Analysis Priority Order +When tests fail, examine artifacts in this specific order: + +1. **Headscale server stderr logs** (`hs-*.stderr.log`): Look for errors, panics, database issues, policy evaluation failures +2. **Tailscale client stderr logs** (`ts-*.stderr.log`): Check for authentication failures, network connectivity issues +3. **MapResponse JSON files**: Protocol-level debugging for network map generation issues +4. **Client status dumps** (`*_status.json`): Network state and peer connectivity information +5. **Database snapshots** (`.db` files): For data consistency and state persistence issues + +## Concurrent Execution and Run ID Isolation + +### Overview + +The integration test system supports running multiple tests concurrently on the same Docker daemon. Each test run is isolated through a unique Run ID that ensures containers, networks, and cleanup operations don't interfere with each other. + +### Run ID Format and Usage + +Each test run generates a unique Run ID in the format: `YYYYMMDD-HHMMSS-{6-char-hash}` +- Example: `20260109-104215-mdjtzx` + +The Run ID is used for: +- **Container naming**: `ts-{runIDShort}-{version}-{hash}` (e.g., `ts-mdjtzx-1-74-fgdyls`) +- **Docker labels**: All containers get `hi.run-id={runID}` label +- **Log directories**: `control_logs/{runID}/` +- **Cleanup isolation**: Only containers with matching run ID are cleaned up + +### Container Isolation Mechanisms + +1. **Unique Container Names**: Each container includes the run ID for identification +2. **Docker Labels**: `hi.run-id` and `hi.test-type` labels on all containers +3. **Dynamic Port Allocation**: All ports use `{HostPort: "0"}` to let kernel assign free ports +4. **Per-Run Networks**: Network names include scenario hash for isolation +5. **Isolated Cleanup**: `killTestContainersByRunID()` only removes containers matching the run ID + +### ⚠️ CRITICAL: Never Interfere with Other Test Runs + +**FORBIDDEN OPERATIONS** when other tests may be running: + +```bash +# ❌ NEVER do global container cleanup while tests are running +docker rm -f $(docker ps -q --filter "name=hs-") +docker rm -f $(docker ps -q --filter "name=ts-") + +# ❌ NEVER kill all test containers +# This will destroy other agents' test sessions! + +# ❌ NEVER prune all Docker resources during active tests +docker system prune -f # Only safe when NO tests are running +``` + +**SAFE OPERATIONS**: + +```bash +# ✅ Clean up only YOUR test run's containers (by run ID) +# The test runner does this automatically via cleanup functions + +# ✅ Clean stale (stopped/exited) containers only +# Pre-test cleanup only removes stopped containers, not running ones + +# ✅ Check what's running before cleanup +docker ps --filter "name=headscale-test-suite" --format "{{.Names}}" +``` + +### Running Concurrent Tests + +```bash +# Start multiple tests in parallel - each gets unique run ID +go run ./cmd/hi run "TestPingAllByIP" & +go run ./cmd/hi run "TestACLAllowUserDst" & +go run ./cmd/hi run "TestOIDCAuthenticationPingAll" & + +# Monitor running test suites +docker ps --filter "name=headscale-test-suite" --format "table {{.Names}}\t{{.Status}}" +``` + +### Agent Session Isolation Rules + +When working as an agent: + +1. **Your run ID is unique**: Each test you start gets its own run ID +2. **Never clean up globally**: Only use run ID-specific cleanup +3. **Check before cleanup**: Verify no other tests are running if you need to prune resources +4. **Respect other sessions**: Other agents may have tests running concurrently +5. **Log directories are isolated**: Your artifacts are in `control_logs/{your-run-id}/` + +### Identifying Your Containers + +Your test containers can be identified by: +- The run ID in the container name +- The `hi.run-id` Docker label +- The test suite container: `headscale-test-suite-{your-run-id}` + +```bash +# List containers for a specific run ID +docker ps --filter "label=hi.run-id=20260109-104215-mdjtzx" + +# Get your run ID from the test output +# Look for: "Run ID: 20260109-104215-mdjtzx" +``` + +## Common Failure Patterns and Root Cause Analysis + +### CRITICAL MINDSET: Code Issues vs Infrastructure Issues + +**⚠️ IMPORTANT**: When tests fail, it is ALMOST ALWAYS a code issue with Headscale, NOT infrastructure problems. Do not immediately blame disk space, Docker issues, or timing unless you have thoroughly investigated the actual error logs first. + +### Systematic Debugging Process + +1. **Read the actual error message**: Don't assume - read the stderr logs completely +2. **Check Headscale server logs first**: Most issues originate from server-side logic +3. **Verify client connectivity**: Only after ruling out server issues +4. **Check timing patterns**: Use proper `EventuallyWithT` patterns +5. **Infrastructure as last resort**: Only blame infrastructure after code analysis + +### Real Failure Patterns + +#### 1. Timing Issues (Common but fixable) +```go +// ❌ Wrong: Immediate assertions after async operations +client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"}) +nodes, _ := headscale.ListNodes() +require.Len(t, nodes[0].GetAvailableRoutes(), 1) // WILL FAIL + +// ✅ Correct: Wait for async operations +client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"}) +require.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes[0].GetAvailableRoutes(), 1) +}, 10*time.Second, 100*time.Millisecond, "route should be advertised") +``` + +**Timeout Guidelines**: +- Route operations: 3-5 seconds +- Node state changes: 5-10 seconds +- Complex scenarios: 10-15 seconds +- Policy recalculation: 5-10 seconds + +#### 2. NodeStore Synchronization Issues +Route advertisements must propagate through poll requests (`poll.go:420`). NodeStore updates happen at specific synchronization points after Hostinfo changes. + +#### 3. Test Data Management Issues +```go +// ❌ Wrong: Assuming array ordering +require.Len(t, nodes[0].GetAvailableRoutes(), 1) + +// ✅ Correct: Identify nodes by properties +expectedRoutes := map[string]string{"1": "10.33.0.0/16"} +for _, node := range nodes { + nodeIDStr := fmt.Sprintf("%d", node.GetId()) + if route, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { + // Test the specific node that should have the route + } +} +``` + +#### 4. Database Backend Differences +SQLite vs PostgreSQL have different timing characteristics: +- Use `--postgres` flag for database-intensive tests +- PostgreSQL generally has more consistent timing +- Some race conditions only appear with specific backends + +## Resource Management and Cleanup + +### Disk Space Management +Tests consume significant disk space (~100MB per run): +```bash +# Check available space before running tests +df -h + +# Clean up test artifacts periodically +rm -rf control_logs/older-timestamp-dirs/ + +# Clean Docker resources +docker system prune -f +docker volume prune -f +``` + +### Container Cleanup +- Successful tests clean up automatically +- Failed tests may leave containers running +- Manually clean if needed: `docker ps -a` and `docker rm -f ` + +## Advanced Debugging Techniques + +### Protocol-Level Debugging +MapResponse JSON files in `control_logs/*/hs-*-mapresponses/` contain: +- Network topology as sent to clients +- Peer relationships and visibility +- Route distribution and primary route selection +- Policy evaluation results + +### Database State Analysis +Use the database snapshots for post-mortem analysis: +```bash +# SQLite examination +sqlite3 control_logs/TIMESTAMP/hs-*.db +.tables +.schema nodes +SELECT * FROM nodes WHERE name LIKE '%problematic%'; +``` + +### Performance Analysis +Prometheus metrics dumps show: +- Request latencies and error rates +- NodeStore operation timing +- Database query performance +- Memory usage patterns + +## Test Development and Quality Guidelines + +### Proper Test Patterns +```go +// Always use EventuallyWithT for async operations +require.EventuallyWithT(t, func(c *assert.CollectT) { + // Test condition that may take time to become true +}, timeout, interval, "descriptive failure message") + +// Handle node identification correctly +var targetNode *v1.Node +for _, node := range nodes { + if node.GetName() == expectedNodeName { + targetNode = node + break + } +} +require.NotNil(t, targetNode, "should find expected node") +``` + +### Quality Validation Checklist +- ✅ Tests use `EventuallyWithT` for asynchronous operations +- ✅ Tests don't rely on array ordering for node identification +- ✅ Proper cleanup and resource management +- ✅ Tests handle both success and failure scenarios +- ✅ Timing assumptions are realistic for operations being tested +- ✅ Error messages are descriptive and actionable + +## Real-World Test Failure Patterns from HA Debugging + +### Infrastructure vs Code Issues - Detailed Examples + +**INFRASTRUCTURE FAILURES (Rare but Real)**: +1. **DNS Resolution in Auth Tests**: `failed to resolve "hs-pingallbyip-jax97k": no DNS fallback candidates remain` + - **Pattern**: Client containers can't resolve headscale server hostname during logout + - **Detection**: Error messages specifically mention DNS/hostname resolution + - **Solution**: Docker networking reset, not code changes + +2. **Container Creation Timeouts**: Test gets stuck during client container setup + - **Pattern**: Tests hang indefinitely at container startup phase + - **Detection**: No progress in logs for >2 minutes during initialization + - **Solution**: `docker system prune -f` and retry + +3. **Docker Resource Exhaustion**: Too many concurrent tests overwhelming system + - **Pattern**: Container creation timeouts, OOM kills, slow test execution + - **Detection**: System load high, Docker daemon slow to respond + - **Solution**: Reduce number of concurrent tests, wait for completion before starting more + +**CODE ISSUES (99% of failures)**: +1. **Route Approval Process Failures**: Routes not getting approved when they should be + - **Pattern**: Tests expecting approved routes but finding none + - **Detection**: `SubnetRoutes()` returns empty when `AnnouncedRoutes()` shows routes + - **Root Cause**: Auto-approval logic bugs, policy evaluation issues + +2. **NodeStore Synchronization Issues**: State updates not propagating correctly + - **Pattern**: Route changes not reflected in NodeStore or Primary Routes + - **Detection**: Logs show route announcements but no tracking updates + - **Root Cause**: Missing synchronization points in `poll.go:420` area + +3. **HA Failover Architecture Issues**: Routes removed when nodes go offline + - **Pattern**: `TestHASubnetRouterFailover` fails because approved routes disappear + - **Detection**: Routes available on online nodes but lost when nodes disconnect + - **Root Cause**: Conflating route approval with node connectivity + +### Critical Test Environment Setup + +**Pre-Test Cleanup**: + +The test runner automatically handles cleanup: +- **Before test**: Removes only stale (stopped/exited) containers - does NOT affect running tests +- **After test**: Removes only containers belonging to the specific run ID + +```bash +# Only clean old log directories if disk space is low +rm -rf control_logs/202507* +df -h # Verify sufficient disk space + +# SAFE: Clean only stale/stopped containers (does not affect running tests) +# The test runner does this automatically via cleanupStaleTestContainers() + +# ⚠️ DANGEROUS: Only use when NO tests are running +docker system prune -f +``` + +**Environment Verification**: +```bash +# Verify system readiness +go run ./cmd/hi doctor + +# Check what tests are currently running (ALWAYS check before global cleanup) +docker ps --filter "name=headscale-test-suite" --format "{{.Names}}" +``` + +### Specific Test Categories and Known Issues + +#### Route-Related Tests (Primary Focus) +```bash +# Core route functionality - these should work first +# Note: Generous timeouts are required for reliable execution +go run ./cmd/hi run "TestSubnetRouteACL" --timeout=1200s +go run ./cmd/hi run "TestAutoApproveMultiNetwork" --timeout=1800s +go run ./cmd/hi run "TestHASubnetRouterFailover" --timeout=1800s +``` + +**Common Route Test Patterns**: +- Tests validate route announcement, approval, and distribution workflows +- Route state changes are asynchronous - may need `EventuallyWithT` wrappers +- Route approval must respect ACL policies - test expectations encode security requirements +- HA tests verify route persistence during node connectivity changes + +#### Authentication Tests (Infrastructure-Prone) +```bash +# These tests are more prone to infrastructure issues +# Require longer timeouts due to auth flow complexity +go run ./cmd/hi run "TestAuthKeyLogoutAndReloginSameUser" --timeout=1200s +go run ./cmd/hi run "TestAuthWebFlowLogoutAndRelogin" --timeout=1200s +go run ./cmd/hi run "TestOIDCExpireNodesBasedOnTokenExpiry" --timeout=1800s +``` + +**Common Auth Test Infrastructure Failures**: +- DNS resolution during logout operations +- Container creation timeouts +- HTTP/2 stream errors (often symptoms, not root cause) + +### Security-Critical Debugging Rules + +**❌ FORBIDDEN CHANGES (Security & Test Integrity)**: +1. **Never change expected test outputs** - Tests define correct behavior contracts + - Changing `require.Len(t, routes, 3)` to `require.Len(t, routes, 2)` because test fails + - Modifying expected status codes, node counts, or route counts + - Removing assertions that are "inconvenient" + - **Why forbidden**: Test expectations encode business requirements and security policies + +2. **Never bypass security mechanisms** - Security must never be compromised for convenience + - Using `AnnouncedRoutes()` instead of `SubnetRoutes()` in production code + - Skipping authentication or authorization checks + - **Why forbidden**: Security bypasses create vulnerabilities in production + +3. **Never reduce test coverage** - Tests prevent regressions + - Removing test cases or assertions + - Commenting out "problematic" test sections + - **Why forbidden**: Reduced coverage allows bugs to slip through + +**✅ ALLOWED CHANGES (Timing & Observability)**: +1. **Fix timing issues with proper async patterns** + ```go + // ✅ GOOD: Add EventuallyWithT for async operations + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, expectedCount) // Keep original expectation + }, 10*time.Second, 100*time.Millisecond, "nodes should reach expected count") + ``` + - **Why allowed**: Fixes race conditions without changing business logic + +2. **Add MORE observability and debugging** + - Additional logging statements + - More detailed error messages + - Extra assertions that verify intermediate states + - **Why allowed**: Better observability helps debug without changing behavior + +3. **Improve test documentation** + - Add godoc comments explaining test purpose and business logic + - Document timing requirements and async behavior + - **Why encouraged**: Helps future maintainers understand intent + +### Advanced Debugging Workflows + +#### Route Tracking Debug Flow +```bash +# Run test with detailed logging and proper timeout +go run ./cmd/hi run "TestSubnetRouteACL" --timeout=1200s > test_output.log 2>&1 + +# Check route approval process +grep -E "(auto-approval|ApproveRoutesWithPolicy|PolicyManager)" test_output.log + +# Check route tracking +tail -50 control_logs/*/hs-*.stderr.log | grep -E "(announced|tracking|SetNodeRoutes)" + +# Check for security violations +grep -E "(AnnouncedRoutes.*SetNodeRoutes|bypass.*approval)" test_output.log +``` + +#### HA Failover Debug Flow +```bash +# Test HA failover specifically with adequate timeout +go run ./cmd/hi run "TestHASubnetRouterFailover" --timeout=1800s + +# Check route persistence during disconnect +grep -E "(Disconnect|NodeWentOffline|PrimaryRoutes)" control_logs/*/hs-*.stderr.log + +# Verify routes don't disappear inappropriately +grep -E "(removing.*routes|SetNodeRoutes.*empty)" control_logs/*/hs-*.stderr.log +``` + +### Test Result Interpretation Guidelines + +#### Success Patterns to Look For +- `"updating node routes for tracking"` in logs +- Routes appearing in `announcedRoutes` logs +- Proper `ApproveRoutesWithPolicy` calls for auto-approval +- Routes persisting through node connectivity changes (HA tests) + +#### Failure Patterns to Investigate +- `SubnetRoutes()` returning empty when `AnnouncedRoutes()` has routes +- Routes disappearing when nodes go offline (HA architectural issue) +- Missing `EventuallyWithT` causing timing race conditions +- Security bypass attempts using wrong route methods + +### Critical Testing Methodology + +**Phase-Based Testing Approach**: +1. **Phase 1**: Core route tests (ACL, auto-approval, basic functionality) +2. **Phase 2**: HA and complex route scenarios +3. **Phase 3**: Auth tests (infrastructure-sensitive, test last) + +**Per-Test Process**: +1. Clean environment before each test +2. Monitor logs for route tracking and approval messages +3. Check artifacts in `control_logs/` if test fails +4. Focus on actual error messages, not assumptions +5. Document results and patterns discovered + +## Test Documentation and Code Quality Standards + +### Adding Missing Test Documentation +When you understand a test's purpose through debugging, always add comprehensive godoc: + +```go +// TestSubnetRoutes validates the complete subnet route lifecycle including +// advertisement from clients, policy-based approval, and distribution to peers. +// This test ensures that route security policies are properly enforced and that +// only approved routes are distributed to the network. +// +// The test verifies: +// - Route announcements are received and tracked +// - ACL policies control route approval correctly +// - Only approved routes appear in peer network maps +// - Route state persists correctly in the database +func TestSubnetRoutes(t *testing.T) { + // Test implementation... +} +``` + +**Why add documentation**: Future maintainers need to understand business logic and security requirements encoded in tests. + +### Comment Guidelines - Focus on WHY, Not WHAT + +```go +// ✅ GOOD: Explains reasoning and business logic +// Wait for route propagation because NodeStore updates are asynchronous +// and happen after poll requests complete processing +require.EventuallyWithT(t, func(c *assert.CollectT) { + // Check that security policies are enforced... +}, timeout, interval, "route approval must respect ACL policies") + +// ❌ BAD: Just describes what the code does +// Wait for routes +require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get routes and check length +}, timeout, interval, "checking routes") +``` + +**Why focus on WHY**: Helps maintainers understand architectural decisions and security requirements. + +## EventuallyWithT Pattern for External Calls + +### Overview +EventuallyWithT is a testing pattern used to handle eventual consistency in distributed systems. In Headscale integration tests, many operations are asynchronous - clients advertise routes, the server processes them, updates propagate through the network. EventuallyWithT allows tests to wait for these operations to complete while making assertions. + +### External Calls That Must Be Wrapped +The following operations are **external calls** that interact with the headscale server or tailscale clients and MUST be wrapped in EventuallyWithT: +- `headscale.ListNodes()` - Queries server state +- `client.Status()` - Gets client network status +- `client.Curl()` - Makes HTTP requests through the network +- `client.Traceroute()` - Performs network diagnostics +- `client.Execute()` when running commands that query state +- Any operation that reads from the headscale server or tailscale client + +### Five Key Rules for EventuallyWithT + +1. **One External Call Per EventuallyWithT Block** + - Each EventuallyWithT should make ONE external call (e.g., ListNodes OR Status) + - Related assertions based on that single call can be grouped together + - Unrelated external calls must be in separate EventuallyWithT blocks + +2. **Variable Scoping** + - Declare variables that need to be shared across EventuallyWithT blocks at function scope + - Use `=` for assignment inside EventuallyWithT, not `:=` (unless the variable is only used within that block) + - Variables declared with `:=` inside EventuallyWithT are not accessible outside + +3. **No Nested EventuallyWithT** + - NEVER put an EventuallyWithT inside another EventuallyWithT + - This is a critical anti-pattern that must be avoided + +4. **Use CollectT for Assertions** + - Inside EventuallyWithT, use `assert` methods with the CollectT parameter + - Helper functions called within EventuallyWithT must accept `*assert.CollectT` + +5. **Descriptive Messages** + - Always provide a descriptive message as the last parameter + - Message should explain what condition is being waited for + +### Correct Pattern Examples + +```go +// CORRECT: Single external call with related assertions +var nodes []*v1.Node +var err error + +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + // These assertions are all based on the ListNodes() call + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 1) +}, 10*time.Second, 500*time.Millisecond, "nodes should have expected route counts") + +// CORRECT: Separate EventuallyWithT for different external call +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + // All these assertions are based on the single Status() call + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + requirePeerSubnetRoutesWithCollect(c, peerStatus, expectedPrefixes) + } +}, 10*time.Second, 500*time.Millisecond, "client should see expected routes") + +// CORRECT: Variable scoping for sharing between blocks +var routeNode *v1.Node +var nodeKey key.NodePublic + +// First EventuallyWithT to get the node +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + + for _, node := range nodes { + if node.GetName() == "router" { + routeNode = node + nodeKey, _ = key.ParseNodePublicUntyped(mem.S(node.GetNodeKey())) + break + } + } + assert.NotNil(c, routeNode, "should find router node") +}, 10*time.Second, 100*time.Millisecond, "router node should exist") + +// Second EventuallyWithT using the nodeKey from first block +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + peerStatus, ok := status.Peer[nodeKey] + assert.True(c, ok, "peer should exist in status") + requirePeerSubnetRoutesWithCollect(c, peerStatus, expectedPrefixes) +}, 10*time.Second, 100*time.Millisecond, "routes should be visible to client") +``` + +### Incorrect Patterns to Avoid + +```go +// INCORRECT: Multiple unrelated external calls in same EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + // First external call + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + // Second unrelated external call - WRONG! + status, err := client.Status() + assert.NoError(c, err) + assert.NotNil(c, status) +}, 10*time.Second, 500*time.Millisecond, "mixed operations") + +// INCORRECT: Nested EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + + // NEVER do this! + assert.EventuallyWithT(t, func(c2 *assert.CollectT) { + status, _ := client.Status() + assert.NotNil(c2, status) + }, 5*time.Second, 100*time.Millisecond, "nested") +}, 10*time.Second, 500*time.Millisecond, "outer") + +// INCORRECT: Variable scoping error +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() // This shadows outer 'nodes' variable + assert.NoError(c, err) +}, 10*time.Second, 500*time.Millisecond, "get nodes") + +// This will fail - nodes is nil because := created a new variable inside the block +require.Len(t, nodes, 2) // COMPILATION ERROR or nil pointer + +// INCORRECT: Not wrapping external calls +nodes, err := headscale.ListNodes() // External call not wrapped! +require.NoError(t, err) +``` + +### Helper Functions for EventuallyWithT + +When creating helper functions for use within EventuallyWithT: + +```go +// Helper function that accepts CollectT +func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, available, approved, primary int) { + assert.Len(c, node.GetAvailableRoutes(), available, "available routes for node %s", node.GetName()) + assert.Len(c, node.GetApprovedRoutes(), approved, "approved routes for node %s", node.GetName()) + assert.Len(c, node.GetPrimaryRoutes(), primary, "primary routes for node %s", node.GetName()) +} + +// Usage within EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) +}, 10*time.Second, 500*time.Millisecond, "route counts should match expected") +``` + +### Operations That Must NOT Be Wrapped + +**CRITICAL**: The following operations are **blocking/mutating operations** that change state and MUST NOT be wrapped in EventuallyWithT: +- `tailscale set` commands (e.g., `--advertise-routes`, `--accept-routes`) +- `headscale.ApproveRoute()` - Approves routes on server +- `headscale.CreateUser()` - Creates users +- `headscale.CreatePreAuthKey()` - Creates authentication keys +- `headscale.RegisterNode()` - Registers new nodes +- Any `client.Execute()` that modifies configuration +- Any operation that creates, updates, or deletes resources + +These operations: +1. Complete synchronously or fail immediately +2. Should not be retried automatically +3. Need explicit error handling with `require.NoError()` + +### Correct Pattern for Blocking Operations + +```go +// CORRECT: Blocking operation NOT wrapped +status := client.MustStatus() +command := []string{"tailscale", "set", "--advertise-routes=" + expectedRoutes[string(status.Self.ID)]} +_, _, err = client.Execute(command) +require.NoErrorf(t, err, "failed to advertise route: %s", err) + +// Then wait for the result with EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Contains(c, nodes[0].GetAvailableRoutes(), expectedRoutes[string(status.Self.ID)]) +}, 10*time.Second, 100*time.Millisecond, "route should be advertised") + +// INCORRECT: Blocking operation wrapped (DON'T DO THIS) +assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, _, err = client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"}) + assert.NoError(c, err) // This might retry the command multiple times! +}, 10*time.Second, 100*time.Millisecond, "advertise routes") +``` + +### Assert vs Require Pattern + +When working within EventuallyWithT blocks where you need to prevent panics: + +```go +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + + // For array bounds - use require with t to prevent panic + assert.Len(c, nodes, 6) // Test expectation + require.GreaterOrEqual(t, len(nodes), 3, "need at least 3 nodes to avoid panic") + + // For nil pointer access - use require with t before dereferencing + assert.NotNil(c, srs1PeerStatus.PrimaryRoutes) // Test expectation + require.NotNil(t, srs1PeerStatus.PrimaryRoutes, "primary routes must be set to avoid panic") + assert.Contains(c, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) +}, 5*time.Second, 200*time.Millisecond, "checking route state") +``` + +**Key Principle**: +- Use `assert` with `c` (*assert.CollectT) for test expectations that can be retried +- Use `require` with `t` (*testing.T) for MUST conditions that prevent panics +- Within EventuallyWithT, both are available - choose based on whether failure would cause a panic + +### Common Scenarios + +1. **Waiting for route advertisement**: +```go +client.Execute([]string{"tailscale", "set", "--advertise-routes=10.0.0.0/24"}) + +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Contains(c, nodes[0].GetAvailableRoutes(), "10.0.0.0/24") +}, 10*time.Second, 100*time.Millisecond, "route should be advertised") +``` + +2. **Checking client sees routes**: +```go +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + // Check all peers have expected routes + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + assert.Contains(c, peerStatus.AllowedIPs, expectedPrefix) + } +}, 10*time.Second, 100*time.Millisecond, "all peers should see route") +``` + +3. **Sequential operations**: +```go +// First wait for node to appear +var nodeID uint64 +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1) + nodeID = nodes[0].GetId() +}, 10*time.Second, 100*time.Millisecond, "node should register") + +// Then perform operation +_, err := headscale.ApproveRoute(nodeID, "10.0.0.0/24") +require.NoError(t, err) + +// Then wait for result +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Contains(c, nodes[0].GetApprovedRoutes(), "10.0.0.0/24") +}, 10*time.Second, 100*time.Millisecond, "route should be approved") +``` + +## Your Core Responsibilities + +1. **Test Execution Strategy**: Execute integration tests with appropriate configurations, understanding when to use `--postgres` and timing requirements for different test categories. Follow phase-based testing approach prioritizing route tests. + - **Why this priority**: Route tests are less infrastructure-sensitive and validate core security logic + +2. **Systematic Test Analysis**: When tests fail, systematically examine artifacts starting with Headscale server logs, then client logs, then protocol data. Focus on CODE ISSUES first (99% of cases), not infrastructure. Use real-world failure patterns to guide investigation. + - **Why this approach**: Most failures are logic bugs, not environment issues - efficient debugging saves time + +3. **Timing & Synchronization Expertise**: Understand asynchronous Headscale operations, particularly route advertisements, NodeStore synchronization at `poll.go:420`, and policy propagation. Fix timing with `EventuallyWithT` while preserving original test expectations. + - **Why preserve expectations**: Test assertions encode business requirements and security policies + - **Key Pattern**: Apply the EventuallyWithT pattern correctly for all external calls as documented above + +4. **Root Cause Analysis**: Distinguish between actual code regressions (route approval logic, HA failover architecture), timing issues requiring `EventuallyWithT` patterns, and genuine infrastructure problems (DNS, Docker, container issues). + - **Why this distinction matters**: Different problem types require completely different solution approaches + - **EventuallyWithT Issues**: Often manifest as flaky tests or immediate assertion failures after async operations + +5. **Security-Aware Quality Validation**: Ensure tests properly validate end-to-end functionality with realistic timing expectations and proper error handling. Never suggest security bypasses or test expectation changes. Add comprehensive godoc when you understand test business logic. + - **Why security focus**: Integration tests are the last line of defense against security regressions + - **EventuallyWithT Usage**: Proper use prevents race conditions without weakening security assertions + +6. **Concurrent Execution Awareness**: Respect run ID isolation and never interfere with other agents' test sessions. Each test run has a unique run ID - only clean up YOUR containers (by run ID label), never perform global cleanup while tests may be running. + - **Why this matters**: Multiple agents/users may run tests concurrently on the same Docker daemon + - **Key Rule**: NEVER use global container cleanup commands - the test runner handles cleanup automatically per run ID + +**CRITICAL PRINCIPLE**: Test expectations are sacred contracts that define correct system behavior. When tests fail, fix the code to match the test, never change the test to match broken code. Only timing and observability improvements are allowed - business logic expectations are immutable. + +**ISOLATION PRINCIPLE**: Each test run is isolated by its unique Run ID. Never interfere with other test sessions. The system handles cleanup automatically - manual global cleanup commands are forbidden when other tests may be running. + +**EventuallyWithT PRINCIPLE**: Every external call to headscale server or tailscale client must be wrapped in EventuallyWithT. Follow the five key rules strictly: one external call per block, proper variable scoping, no nesting, use CollectT for assertions, and provide descriptive messages. + +**Remember**: Test failures are usually code issues in Headscale that need to be fixed, not infrastructure problems to be ignored. Use the specific debugging workflows and failure patterns documented above to efficiently identify root causes. Infrastructure issues have very specific signatures - everything else is code-related. diff --git a/.dockerignore b/.dockerignore index e3acf996..9ea3e4a4 100644 --- a/.dockerignore +++ b/.dockerignore @@ -17,3 +17,7 @@ LICENSE .vscode *.sock + +node_modules/ +package-lock.json +package.json diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 00000000..d91a81d8 --- /dev/null +++ b/.editorconfig @@ -0,0 +1,16 @@ +root = true + +[*] +charset = utf-8 +end_of_line = lf +indent_size = 2 +indent_style = space +insert_final_newline = true +trim_trailing_whitespace = true +max_line_length = 120 + +[*.go] +indent_style = tab + +[Makefile] +indent_style = tab diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index fa1c06da..4eb9c2d2 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -1,10 +1,10 @@ * @juanfont @kradalby -*.md @ohdearaugustin -*.yml @ohdearaugustin -*.yaml @ohdearaugustin -Dockerfile* @ohdearaugustin -.goreleaser.yaml @ohdearaugustin -/docs/ @ohdearaugustin -/.github/workflows/ @ohdearaugustin -/.github/renovate.json @ohdearaugustin +*.md @ohdearaugustin @nblock +*.yml @ohdearaugustin @nblock +*.yaml @ohdearaugustin @nblock +Dockerfile* @ohdearaugustin @nblock +.goreleaser.yaml @ohdearaugustin @nblock +/docs/ @ohdearaugustin @nblock +/.github/workflows/ @ohdearaugustin @nblock +/.github/renovate.json @ohdearaugustin @nblock diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md deleted file mode 100644 index 02e47425..00000000 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ /dev/null @@ -1,52 +0,0 @@ ---- -name: "Bug report" -about: "Create a bug report to help us improve" -title: "" -labels: ["bug"] -assignees: "" ---- - - - -## Bug description - - - -## Environment - - - -- OS: -- Headscale version: -- Tailscale version: - - - -- [ ] Headscale is behind a (reverse) proxy -- [ ] Headscale runs in a container - -## To Reproduce - - diff --git a/.github/ISSUE_TEMPLATE/bug_report.yaml b/.github/ISSUE_TEMPLATE/bug_report.yaml new file mode 100644 index 00000000..4b05f11f --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yaml @@ -0,0 +1,108 @@ +name: 🐞 Bug +description: File a bug/issue +title: "[Bug] " +labels: ["bug", "needs triage"] +body: + - type: checkboxes + attributes: + label: Is this a support request? + description: + This issue tracker is for bugs and feature requests only. If you need + help, please use ask in our Discord community + options: + - label: This is not a support request + required: true + - type: checkboxes + attributes: + label: Is there an existing issue for this? + description: + Please search to see if an issue already exists for the bug you + encountered. + options: + - label: I have searched the existing issues + required: true + - type: textarea + attributes: + label: Current Behavior + description: A concise description of what you're experiencing. + validations: + required: true + - type: textarea + attributes: + label: Expected Behavior + description: A concise description of what you expected to happen. + validations: + required: true + - type: textarea + attributes: + label: Steps To Reproduce + description: Steps to reproduce the behavior. + placeholder: | + 1. In this environment... + 1. With this config... + 1. Run '...' + 1. See error... + validations: + required: true + - type: textarea + attributes: + label: Environment + description: | + Please provide information about your environment. + If you are using a container, always provide the headscale version and not only the Docker image version. + Please do not put "latest". + + Describe your "headscale network". Is there a lot of nodes, are the nodes all interconnected, are some subnet routers? + + If you are experiencing a problem during an upgrade, please provide the versions of the old and new versions of Headscale and Tailscale. + + examples: + - **OS**: Ubuntu 24.04 + - **Headscale version**: 0.24.3 + - **Tailscale version**: 1.80.0 + - **Number of nodes**: 20 + value: | + - OS: + - Headscale version: + - Tailscale version: + render: markdown + validations: + required: true + - type: checkboxes + attributes: + label: Runtime environment + options: + - label: Headscale is behind a (reverse) proxy + required: false + - label: Headscale runs in a container + required: false + - type: textarea + attributes: + label: Debug information + description: | + Please have a look at our [Debugging and troubleshooting + guide](https://headscale.net/development/ref/debug/) to learn about + common debugging techniques. + + Links? References? Anything that will give us more context about the issue you are encountering. + If **any** of these are omitted we will likely close your issue, do **not** ignore them. + + - Client netmap dump (see below) + - Policy configuration + - Headscale configuration + - Headscale log (with `trace` enabled) + + Dump the netmap of tailscale clients: + `tailscale debug netmap > DESCRIPTIVE_NAME.json` + + Dump the status of tailscale clients: + `tailscale status --json > DESCRIPTIVE_NAME.json` + + Get the logs of a Tailscale client that is not working as expected. + `tailscale debug daemon-logs` + + Tip: You can attach images or log files by clicking this area to highlight it and then dragging files in. + **Ensure** you use formatting for files you attach. + Do **not** paste in long files. + validations: + required: true diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md deleted file mode 100644 index 92c51b8f..00000000 --- a/.github/ISSUE_TEMPLATE/feature_request.md +++ /dev/null @@ -1,26 +0,0 @@ ---- -name: "Feature request" -about: "Suggest an idea for headscale" -title: "" -labels: ["enhancement"] -assignees: "" ---- - -<!-- -We typically have a clear roadmap for what we want to improve and reserve the right -to close feature requests that does not fit in the roadmap, or fit with the scope -of the project, or we actually want to implement ourselves. - -Headscale is a multinational community across the globe. Our language is English. -All bug reports needs to be in English. ---> - -## Why - -<!-- Include the reason, why you would need the feature. E.g. what problem - does it solve? Or which workflow is currently frustrating and will be improved by - this? --> - -## Description - -<!-- A clear and precise description of what new or changed feature you want. --> diff --git a/.github/ISSUE_TEMPLATE/feature_request.yaml b/.github/ISSUE_TEMPLATE/feature_request.yaml new file mode 100644 index 00000000..70f1a146 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yaml @@ -0,0 +1,36 @@ +name: 🚀 Feature Request +description: Suggest an idea for Headscale +title: "[Feature] <title>" +labels: [enhancement] +body: + - type: textarea + attributes: + label: Use case + description: Please describe the use case for this feature. + placeholder: | + <!-- Include the reason, why you would need the feature. E.g. what problem + does it solve? Or which workflow is currently frustrating and will be improved by + this? --> + validations: + required: true + - type: textarea + attributes: + label: Description + description: A clear and precise description of what new or changed feature you want. + validations: + required: true + - type: checkboxes + attributes: + label: Contribution + description: Are you willing to contribute to the implementation of this feature? + options: + - label: I can write the design doc for this feature + required: false + - label: I can contribute this feature + required: false + - type: textarea + attributes: + label: How can it be implemented? + description: Free text for your ideas on how this feature could be implemented. + validations: + required: false diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index d4e4f4f9..9d8e731d 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,7 +12,7 @@ If you find mistakes in the documentation, please submit a fix to the documentat <!-- Please tick if the following things apply. You… --> -- [ ] read the [CONTRIBUTING guidelines](README.md#contributing) +- [ ] have read the [CONTRIBUTING.md](./CONTRIBUTING.md) file - [ ] raised a GitHub issue or discussed it on the projects chat beforehand - [ ] added unit tests - [ ] added integration tests diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 9d4b9925..594829f9 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -5,42 +5,42 @@ on: branches: - main pull_request: - branches: - - main concurrency: group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} cancel-in-progress: true jobs: - build: + build-nix: runs-on: ubuntu-latest permissions: write-all - steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 2 - - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v34 + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - 'integration_test/' + - 'config-example.yaml' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} - - uses: DeterminateSystems/nix-installer-action@main - if: steps.changed-files.outputs.any_changed == 'true' - - uses: DeterminateSystems/magic-nix-cache-action@main - if: steps.changed-files.outputs.any_changed == 'true' - - - name: Run build + - name: Run nix build id: build - if: steps.changed-files.outputs.any_changed == 'true' + if: steps.changed-files.outputs.files == 'true' run: | nix build |& tee build-result BUILD_STATUS="${PIPESTATUS[0]}" @@ -54,7 +54,7 @@ jobs: exit $BUILD_STATUS - name: Nix gosum diverging - uses: actions/github-script@v6 + uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0 if: failure() && steps.build.outcome == 'failure' with: github-token: ${{secrets.GITHUB_TOKEN}} @@ -66,8 +66,35 @@ jobs: body: 'Nix build failed with wrong gosum, please update "vendorSha256" (${{ steps.build.outputs.OLD_HASH }}) for the "headscale" package in flake.nix with the new SHA: ${{ steps.build.outputs.NEW_HASH }}' }) - - uses: actions/upload-artifact@v3 - if: steps.changed-files.outputs.any_changed == 'true' + - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + if: steps.changed-files.outputs.files == 'true' with: name: headscale-linux path: result/bin/headscale + build-cross: + runs-on: ubuntu-latest + strategy: + matrix: + env: + - "GOARCH=arm64 GOOS=linux" + - "GOARCH=amd64 GOOS=linux" + - "GOARCH=arm64 GOOS=darwin" + - "GOARCH=amd64 GOOS=darwin" + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + + - name: Run go cross compile + env: + CGO_ENABLED: 0 + run: env ${{ matrix.env }} nix develop --command -- go build -o "headscale" + ./cmd/headscale + - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: "headscale-${{ matrix.env }}" + path: "headscale" diff --git a/.github/workflows/check-generated.yml b/.github/workflows/check-generated.yml new file mode 100644 index 00000000..43f1d62d --- /dev/null +++ b/.github/workflows/check-generated.yml @@ -0,0 +1,55 @@ +name: Check Generated Files + +on: + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + check-generated: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 2 + - name: Get changed files + id: changed-files + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + with: + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - '**/*.proto' + - 'buf.gen.yaml' + - 'tools/**' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + + - name: Run make generate + if: steps.changed-files.outputs.files == 'true' + run: nix develop --command -- make generate + + - name: Check for uncommitted changes + if: steps.changed-files.outputs.files == 'true' + run: | + if ! git diff --exit-code; then + echo "❌ Generated files are not up to date!" + echo "Please run 'make generate' and commit the changes." + exit 1 + else + echo "✅ All generated files are up to date." + fi diff --git a/.github/workflows/check-tests.yaml b/.github/workflows/check-tests.yaml new file mode 100644 index 00000000..63a18141 --- /dev/null +++ b/.github/workflows/check-tests.yaml @@ -0,0 +1,45 @@ +name: Check integration tests workflow + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + check-tests: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 2 + - name: Get changed files + id: changed-files + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + with: + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - 'integration_test/' + - 'config-example.yaml' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + + - name: Generate and check integration tests + if: steps.changed-files.outputs.files == 'true' + run: | + nix develop --command bash -c "cd .github/workflows && go generate" + git diff --exit-code .github/workflows/test-integration.yaml + + - name: Show missing tests + if: failure() + run: | + git diff .github/workflows/test-integration.yaml diff --git a/.github/workflows/contributors.yml b/.github/workflows/contributors.yml deleted file mode 100644 index 4b05ffd2..00000000 --- a/.github/workflows/contributors.yml +++ /dev/null @@ -1,35 +0,0 @@ -name: Contributors - -on: - push: - branches: - - main - workflow_dispatch: -jobs: - add-contributors: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@v3 - - name: Delete upstream contributor branch - # Allow continue on failure to account for when the - # upstream branch is deleted or does not exist. - continue-on-error: true - run: git push origin --delete update-contributors - - name: Create up-to-date contributors branch - run: git checkout -B update-contributors - - name: Push empty contributors branch - run: git push origin update-contributors - - name: Switch back to main - run: git checkout main - - uses: BobAnkh/add-contributors@v0.2.2 - with: - CONTRIBUTOR: "## Contributors" - COLUMN_PER_ROW: "6" - ACCESS_TOKEN: ${{secrets.GITHUB_TOKEN}} - IMG_WIDTH: "100" - FONT_SIZE: "14" - PATH: "/README.md" - COMMIT_MESSAGE: "docs(README): update contributors" - AVATAR_SHAPE: "round" - BRANCH: "update-contributors" - PULL_REQUEST: "main" diff --git a/.github/workflows/docs-deploy.yml b/.github/workflows/docs-deploy.yml new file mode 100644 index 00000000..0a8be5c1 --- /dev/null +++ b/.github/workflows/docs-deploy.yml @@ -0,0 +1,51 @@ +name: Deploy docs + +on: + push: + branches: + # Main branch for development docs + - main + + # Doc maintenance branches + - doc/[0-9]+.[0-9]+.[0-9]+ + tags: + # Stable release tags + - v[0-9]+.[0-9]+.[0-9]+ + paths: + - "docs/**" + - "mkdocs.yml" + workflow_dispatch: + +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 0 + - name: Install python + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + with: + python-version: 3.x + - name: Setup cache + uses: actions/cache@a7833574556fa59680c1b7cb190c1735db73ebf0 # v5.0.0 + with: + key: ${{ github.ref }} + path: .cache + - name: Setup dependencies + run: pip install -r docs/requirements.txt + - name: Configure git + run: | + git config user.name github-actions + git config user.email github-actions@github.com + - name: Deploy development docs + if: github.ref == 'refs/heads/main' + run: mike deploy --push development unstable + - name: Deploy stable docs from doc branches + if: startsWith(github.ref, 'refs/heads/doc/') + run: mike deploy --push ${GITHUB_REF_NAME##*/} + - name: Deploy stable docs from tag + if: startsWith(github.ref, 'refs/tags/v') + # This assumes that only newer tags are pushed + run: mike deploy --push --update-aliases ${GITHUB_REF_NAME#v} stable latest diff --git a/.github/workflows/docs-test.yml b/.github/workflows/docs-test.yml new file mode 100644 index 00000000..cab8f95c --- /dev/null +++ b/.github/workflows/docs-test.yml @@ -0,0 +1,27 @@ +name: Test documentation build + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - name: Install python + uses: actions/setup-python@83679a892e2d95755f2dac6acb0bfd1e9ac5d548 # v6.1.0 + with: + python-version: 3.x + - name: Setup cache + uses: actions/cache@a7833574556fa59680c1b7cb190c1735db73ebf0 # v5.0.0 + with: + key: ${{ github.ref }} + path: .cache + - name: Setup dependencies + run: pip install -r docs/requirements.txt + - name: Build docs + run: mkdocs build --strict diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml deleted file mode 100644 index 1d19ed3d..00000000 --- a/.github/workflows/docs.yml +++ /dev/null @@ -1,45 +0,0 @@ -name: Build documentation -on: - push: - branches: - - main - workflow_dispatch: - -permissions: - contents: read - pages: write - id-token: write - -jobs: - build: - runs-on: ubuntu-latest - steps: - - name: Checkout repository - uses: actions/checkout@v3 - - name: Install python - uses: actions/setup-python@v4 - with: - python-version: 3.x - - name: Setup cache - uses: actions/cache@v2 - with: - key: ${{ github.ref }} - path: .cache - - name: Setup dependencies - run: pip install -r docs/requirements.txt - - name: Build docs - run: mkdocs build --strict - - name: Upload artifact - uses: actions/upload-pages-artifact@v1 - with: - path: ./site - deploy: - environment: - name: github-pages - url: ${{ steps.deployment.outputs.page_url }} - runs-on: ubuntu-latest - needs: build - steps: - - name: Deploy to GitHub Pages - id: deployment - uses: actions/deploy-pages@v1 diff --git a/.github/workflows/gh-action-integration-generator.go b/.github/workflows/gh-action-integration-generator.go new file mode 100644 index 00000000..c0a3d6aa --- /dev/null +++ b/.github/workflows/gh-action-integration-generator.go @@ -0,0 +1,143 @@ +package main + +//go:generate go run ./gh-action-integration-generator.go + +import ( + "bytes" + "fmt" + "log" + "os/exec" + "strings" +) + +// testsToSplit defines tests that should be split into multiple CI jobs. +// Key is the test function name, value is a list of subtest prefixes. +// Each prefix becomes a separate CI job as "TestName/prefix". +// +// Example: TestAutoApproveMultiNetwork has subtests like: +// - TestAutoApproveMultiNetwork/authkey-tag-advertiseduringup-false-pol-database +// - TestAutoApproveMultiNetwork/webauth-user-advertiseduringup-true-pol-file +// +// Splitting by approver type (tag, user, group) creates 6 CI jobs with 4 tests each: +// - TestAutoApproveMultiNetwork/authkey-tag.* (4 tests) +// - TestAutoApproveMultiNetwork/authkey-user.* (4 tests) +// - TestAutoApproveMultiNetwork/authkey-group.* (4 tests) +// - TestAutoApproveMultiNetwork/webauth-tag.* (4 tests) +// - TestAutoApproveMultiNetwork/webauth-user.* (4 tests) +// - TestAutoApproveMultiNetwork/webauth-group.* (4 tests) +// +// This reduces load per CI job (4 tests instead of 12) to avoid infrastructure +// flakiness when running many sequential Docker-based integration tests. +var testsToSplit = map[string][]string{ + "TestAutoApproveMultiNetwork": { + "authkey-tag", + "authkey-user", + "authkey-group", + "webauth-tag", + "webauth-user", + "webauth-group", + }, +} + +// expandTests takes a list of test names and expands any that need splitting +// into multiple subtest patterns. +func expandTests(tests []string) []string { + var expanded []string + for _, test := range tests { + if prefixes, ok := testsToSplit[test]; ok { + // This test should be split into multiple jobs. + // We append ".*" to each prefix because the CI runner wraps patterns + // with ^...$ anchors. Without ".*", a pattern like "authkey$" wouldn't + // match "authkey-tag-advertiseduringup-false-pol-database". + for _, prefix := range prefixes { + expanded = append(expanded, fmt.Sprintf("%s/%s.*", test, prefix)) + } + } else { + expanded = append(expanded, test) + } + } + return expanded +} + +func findTests() []string { + rgBin, err := exec.LookPath("rg") + if err != nil { + log.Fatalf("failed to find rg (ripgrep) binary") + } + + args := []string{ + "--regexp", "func (Test.+)\\(.*", + "../../integration/", + "--replace", "$1", + "--sort", "path", + "--no-line-number", + "--no-filename", + "--no-heading", + } + + cmd := exec.Command(rgBin, args...) + var out bytes.Buffer + cmd.Stdout = &out + err = cmd.Run() + if err != nil { + log.Fatalf("failed to run command: %s", err) + } + + tests := strings.Split(strings.TrimSpace(out.String()), "\n") + return tests +} + +func updateYAML(tests []string, jobName string, testPath string) { + testsForYq := fmt.Sprintf("[%s]", strings.Join(tests, ", ")) + + yqCommand := fmt.Sprintf( + "yq eval '.jobs.%s.strategy.matrix.test = %s' %s -i", + jobName, + testsForYq, + testPath, + ) + cmd := exec.Command("bash", "-c", yqCommand) + + var stdout bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + err := cmd.Run() + if err != nil { + log.Printf("stdout: %s", stdout.String()) + log.Printf("stderr: %s", stderr.String()) + log.Fatalf("failed to run yq command: %s", err) + } + + fmt.Printf("YAML file (%s) job %s updated successfully\n", testPath, jobName) +} + +func main() { + tests := findTests() + + // Expand tests that should be split into multiple jobs + expandedTests := expandTests(tests) + + quotedTests := make([]string, len(expandedTests)) + for i, test := range expandedTests { + quotedTests[i] = fmt.Sprintf("\"%s\"", test) + } + + // Define selected tests for PostgreSQL + postgresTestNames := []string{ + "TestACLAllowUserDst", + "TestPingAllByIP", + "TestEphemeral2006DeletedTooQuickly", + "TestPingAllByIPManyUpDown", + "TestSubnetRouterMultiNetwork", + } + + quotedPostgresTests := make([]string, len(postgresTestNames)) + for i, test := range postgresTestNames { + quotedPostgresTests[i] = fmt.Sprintf("\"%s\"", test) + } + + // Update both SQLite and PostgreSQL job matrices + updateYAML(quotedTests, "sqlite", "./test-integration.yaml") + updateYAML(quotedPostgresTests, "postgres", "./test-integration.yaml") +} diff --git a/.github/workflows/gh-actions-updater.yaml b/.github/workflows/gh-actions-updater.yaml index 6b44051a..647e27dc 100644 --- a/.github/workflows/gh-actions-updater.yaml +++ b/.github/workflows/gh-actions-updater.yaml @@ -1,6 +1,5 @@ name: GitHub Actions Version Updater -# Controls when the action will run. on: schedule: # Automatically run on every Sunday @@ -8,16 +7,17 @@ on: jobs: build: + if: github.repository == 'juanfont/headscale' runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: # [Required] Access token with `workflow` scope. token: ${{ secrets.WORKFLOW_SECRET }} - name: Run GitHub Actions Version Updater - uses: saadmk11/github-actions-version-updater@v0.7.1 + uses: saadmk11/github-actions-version-updater@d8781caf11d11168579c8e5e94f62b068038f442 # v0.9.0 with: # [Required] Access token with `workflow` scope. token: ${{ secrets.WORKFLOW_SECRET }} diff --git a/.github/workflows/integration-test-template.yml b/.github/workflows/integration-test-template.yml new file mode 100644 index 00000000..0a884814 --- /dev/null +++ b/.github/workflows/integration-test-template.yml @@ -0,0 +1,112 @@ +name: Integration Test Template + +on: + workflow_call: + inputs: + test: + required: true + type: string + postgres_flag: + required: false + type: string + default: "" + database_name: + required: true + type: string + +jobs: + test: + runs-on: ubuntu-latest + env: + # Github does not allow us to access secrets in pull requests, + # so this env var is used to check if we have the secret or not. + # If we have the secrets, meaning we are running on push in a fork, + # there might be secrets available for more debugging. + # If TS_OAUTH_CLIENT_ID and TS_OAUTH_SECRET is set, then the job + # will join a debug tailscale network, set up SSH and a tmux session. + # The SSH will be configured to use the SSH key of the Github user + # that triggered the build. + HAS_TAILSCALE_SECRET: ${{ secrets.TS_OAUTH_CLIENT_ID }} + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 2 + - name: Tailscale + if: ${{ env.HAS_TAILSCALE_SECRET }} + uses: tailscale/github-action@a392da0a182bba0e9613b6243ebd69529b1878aa # v4.1.0 + with: + oauth-client-id: ${{ secrets.TS_OAUTH_CLIENT_ID }} + oauth-secret: ${{ secrets.TS_OAUTH_SECRET }} + tags: tag:gh + - name: Setup SSH server for Actor + if: ${{ env.HAS_TAILSCALE_SECRET }} + uses: alexellis/setup-sshd-actor@master + - name: Download headscale image + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 + with: + name: headscale-image + path: /tmp/artifacts + - name: Download tailscale HEAD image + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 + with: + name: tailscale-head-image + path: /tmp/artifacts + - name: Download hi binary + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 + with: + name: hi-binary + path: /tmp/artifacts + - name: Download Go cache + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 + with: + name: go-cache + path: /tmp/artifacts + - name: Download postgres image + if: ${{ inputs.postgres_flag == '--postgres=1' }} + uses: actions/download-artifact@018cc2cf5baa6db3ef3c5f8a56943fffe632ef53 # v6.0.0 + with: + name: postgres-image + path: /tmp/artifacts + - name: Load Docker images, Go cache, and prepare binary + run: | + gunzip -c /tmp/artifacts/headscale-image.tar.gz | docker load + gunzip -c /tmp/artifacts/tailscale-head-image.tar.gz | docker load + if [ -f /tmp/artifacts/postgres-image.tar.gz ]; then + gunzip -c /tmp/artifacts/postgres-image.tar.gz | docker load + fi + chmod +x /tmp/artifacts/hi + docker images + # Extract Go cache to host directories for bind mounting + mkdir -p /tmp/go-cache + tar -xzf /tmp/artifacts/go-cache.tar.gz -C /tmp/go-cache + ls -la /tmp/go-cache/ /tmp/go-cache/.cache/ + - name: Run Integration Test + env: + HEADSCALE_INTEGRATION_HEADSCALE_IMAGE: headscale:${{ github.sha }} + HEADSCALE_INTEGRATION_TAILSCALE_IMAGE: tailscale-head:${{ github.sha }} + HEADSCALE_INTEGRATION_POSTGRES_IMAGE: ${{ inputs.postgres_flag == '--postgres=1' && format('postgres:{0}', github.sha) || '' }} + HEADSCALE_INTEGRATION_GO_CACHE: /tmp/go-cache/go + HEADSCALE_INTEGRATION_GO_BUILD_CACHE: /tmp/go-cache/.cache/go-build + run: /tmp/artifacts/hi run --stats --ts-memory-limit=300 --hs-memory-limit=1500 "^${{ inputs.test }}$" \ + --timeout=120m \ + ${{ inputs.postgres_flag }} + # Sanitize test name for artifact upload (replace invalid characters: " : < > | * ? \ / with -) + - name: Sanitize test name for artifacts + if: always() + id: sanitize + run: echo "name=${TEST_NAME//[\":<>|*?\\\/]/-}" >> $GITHUB_OUTPUT + env: + TEST_NAME: ${{ inputs.test }} + - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + if: always() + with: + name: ${{ inputs.database_name }}-${{ steps.sanitize.outputs.name }}-logs + path: "control_logs/*/*.log" + - uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + if: always() + with: + name: ${{ inputs.database_name }}-${{ steps.sanitize.outputs.name }}-artifacts + path: control_logs/ + - name: Setup a blocking tmux session + if: ${{ env.HAS_TAILSCALE_SECRET }} + uses: alexellis/block-with-tmux-action@master diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 662a4cf4..75088b38 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,7 +1,6 @@ ---- name: Lint -on: [push, pull_request] +on: [pull_request] concurrency: group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} @@ -11,70 +10,84 @@ jobs: golangci-lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 2 - - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v34 + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - 'integration_test/' + - 'config-example.yaml' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} - name: golangci-lint - if: steps.changed-files.outputs.any_changed == 'true' - uses: golangci/golangci-lint-action@v2 - with: - version: v1.51.2 - - # Only block PRs on new problems. - # If this is not enabled, we will end up having PRs - # blocked because new linters has appared and other - # parts of the code is affected. - only-new-issues: true + if: steps.changed-files.outputs.files == 'true' + run: nix develop --command -- golangci-lint run + --new-from-rev=${{github.event.pull_request.base.sha}} + --output.text.path=stdout + --output.text.print-linter-name + --output.text.print-issued-lines + --output.text.colors prettier-lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 2 - - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v14.1 + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 with: - files: | - *.nix - **/*.md - **/*.yml - **/*.yaml - **/*.ts - **/*.js - **/*.sass - **/*.css - **/*.scss - **/*.html + filters: | + files: + - '*.nix' + - '**/*.md' + - '**/*.yml' + - '**/*.yaml' + - '**/*.ts' + - '**/*.js' + - '**/*.sass' + - '**/*.css' + - '**/*.scss' + - '**/*.html' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} - name: Prettify code - if: steps.changed-files.outputs.any_changed == 'true' - uses: creyD/prettier_action@v4.3 - with: - prettier_options: >- - --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html} - only_changed: false - dry: true + if: steps.changed-files.outputs.files == 'true' + run: nix develop --command -- prettier --no-error-on-unmatched-pattern + --ignore-unknown --check **/*.{ts,js,md,yaml,yml,sass,css,scss,html} proto-lint: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v2 - - uses: bufbuild/buf-setup-action@v1.7.0 - - uses: bufbuild/buf-lint-action@v1 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 with: - input: "proto" + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + + - name: Buf lint + run: nix develop --command -- buf lint proto diff --git a/.github/workflows/nix-module-test.yml b/.github/workflows/nix-module-test.yml new file mode 100644 index 00000000..68ad9545 --- /dev/null +++ b/.github/workflows/nix-module-test.yml @@ -0,0 +1,55 @@ +name: NixOS Module Tests + +on: + push: + branches: + - main + pull_request: + branches: + - main + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + nix-module-check: + runs-on: ubuntu-latest + permissions: + contents: read + + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 2 + + - name: Get changed files + id: changed-files + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + with: + filters: | + nix: + - 'nix/**' + - 'flake.nix' + - 'flake.lock' + go: + - 'go.*' + - '**/*.go' + - 'cmd/**' + - 'hscontrol/**' + + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.nix == 'true' || steps.changed-files.outputs.go == 'true' + + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.nix == 'true' || steps.changed-files.outputs.go == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + + - name: Run NixOS module tests + if: steps.changed-files.outputs.nix == 'true' || steps.changed-files.outputs.go == 'true' + run: | + echo "Running NixOS module integration test..." + nix build .#checks.x86_64-linux.headscale -L diff --git a/.github/workflows/release-docker.yml b/.github/workflows/release-docker.yml deleted file mode 100644 index d82f3268..00000000 --- a/.github/workflows/release-docker.yml +++ /dev/null @@ -1,138 +0,0 @@ ---- -name: Release Docker - -on: - push: - tags: - - "*" # triggers only if push new tag version - workflow_dispatch: - -jobs: - docker-release: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 - - name: Set up QEMU for multiple platforms - uses: docker/setup-qemu-action@master - with: - platforms: arm64,amd64 - - name: Cache Docker layers - uses: actions/cache@v2 - with: - path: /tmp/.buildx-cache - key: ${{ runner.os }}-buildx-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-buildx- - - name: Docker meta - id: meta - uses: docker/metadata-action@v3 - with: - # list of Docker images to use as base name for tags - images: | - ${{ secrets.DOCKERHUB_USERNAME }}/headscale - ghcr.io/${{ github.repository_owner }}/headscale - tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=semver,pattern={{major}} - type=sha - type=raw,value=develop - - name: Login to DockerHub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Login to GHCR - uses: docker/login-action@v1 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push - id: docker_build - uses: docker/build-push-action@v2 - with: - push: true - context: . - tags: ${{ steps.meta.outputs.tags }} - labels: ${{ steps.meta.outputs.labels }} - platforms: linux/amd64,linux/arm64 - cache-from: type=local,src=/tmp/.buildx-cache - cache-to: type=local,dest=/tmp/.buildx-cache-new - build-args: | - VERSION=${{ steps.meta.outputs.version }} - - name: Prepare cache for next build - run: | - rm -rf /tmp/.buildx-cache - mv /tmp/.buildx-cache-new /tmp/.buildx-cache - - docker-debug-release: - runs-on: ubuntu-latest - steps: - - name: Checkout - uses: actions/checkout@v3 - with: - fetch-depth: 0 - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@v1 - - name: Set up QEMU for multiple platforms - uses: docker/setup-qemu-action@master - with: - platforms: arm64,amd64 - - name: Cache Docker layers - uses: actions/cache@v2 - with: - path: /tmp/.buildx-cache-debug - key: ${{ runner.os }}-buildx-debug-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-buildx-debug- - - name: Docker meta - id: meta-debug - uses: docker/metadata-action@v3 - with: - # list of Docker images to use as base name for tags - images: | - ${{ secrets.DOCKERHUB_USERNAME }}/headscale - ghcr.io/${{ github.repository_owner }}/headscale - flavor: | - suffix=-debug,onlatest=true - tags: | - type=semver,pattern={{version}} - type=semver,pattern={{major}}.{{minor}} - type=semver,pattern={{major}} - type=sha - type=raw,value=develop - - name: Login to DockerHub - uses: docker/login-action@v1 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - name: Login to GHCR - uses: docker/login-action@v1 - with: - registry: ghcr.io - username: ${{ github.repository_owner }} - password: ${{ secrets.GITHUB_TOKEN }} - - name: Build and push - id: docker_build - uses: docker/build-push-action@v2 - with: - push: true - context: . - file: Dockerfile.debug - tags: ${{ steps.meta-debug.outputs.tags }} - labels: ${{ steps.meta-debug.outputs.labels }} - platforms: linux/amd64,linux/arm64 - cache-from: type=local,src=/tmp/.buildx-cache-debug - cache-to: type=local,dest=/tmp/.buildx-cache-debug-new - build-args: | - VERSION=${{ steps.meta-debug.outputs.version }} - - name: Prepare cache for next build - run: | - rm -rf /tmp/.buildx-cache-debug - mv /tmp/.buildx-cache-debug-new /tmp/.buildx-cache-debug diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 72eddbcb..4835e255 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -9,15 +9,33 @@ on: jobs: goreleaser: + if: github.repository == 'juanfont/headscale' runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v3 + uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 0 - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main + - name: Login to DockerHub + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 + with: + username: ${{ secrets.DOCKERHUB_USERNAME }} + password: ${{ secrets.DOCKERHUB_TOKEN }} + + - name: Login to GHCR + uses: docker/login-action@5e57cd118135c172c3672efd75eb46360885c0ef # v3.6.0 + with: + registry: ghcr.io + username: ${{ github.repository_owner }} + password: ${{ secrets.GITHUB_TOKEN }} + + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} - name: Run goreleaser run: nix develop --command -- goreleaser release --clean diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index c30571c4..0915ec2c 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -1,22 +1,27 @@ name: Close inactive issues + on: schedule: - cron: "30 1 * * *" jobs: close-issues: + if: github.repository == 'juanfont/headscale' runs-on: ubuntu-latest permissions: issues: write pull-requests: write steps: - - uses: actions/stale@v5 + - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # v10.1.1 with: days-before-issue-stale: 90 days-before-issue-close: 7 stale-issue-label: "stale" - stale-issue-message: "This issue is stale because it has been open for 90 days with no activity." - close-issue-message: "This issue was closed because it has been inactive for 14 days since being marked as stale." + stale-issue-message: "This issue is stale because it has been open for 90 days with no + activity." + close-issue-message: "This issue was closed because it has been inactive for 14 days + since being marked as stale." days-before-pr-stale: -1 days-before-pr-close: -1 + exempt-issue-labels: "no-stale-bot" repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/test-integration-v2-TestACLAllowStarDst.yaml b/.github/workflows/test-integration-v2-TestACLAllowStarDst.yaml deleted file mode 100644 index 63017ac6..00000000 --- a/.github/workflows/test-integration-v2-TestACLAllowStarDst.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLAllowStarDst - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLAllowStarDst: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLAllowStarDst - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLAllowStarDst$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLAllowUser80Dst.yaml b/.github/workflows/test-integration-v2-TestACLAllowUser80Dst.yaml deleted file mode 100644 index e3d5d293..00000000 --- a/.github/workflows/test-integration-v2-TestACLAllowUser80Dst.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLAllowUser80Dst - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLAllowUser80Dst: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLAllowUser80Dst - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLAllowUser80Dst$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLAllowUserDst.yaml b/.github/workflows/test-integration-v2-TestACLAllowUserDst.yaml deleted file mode 100644 index dc328ede..00000000 --- a/.github/workflows/test-integration-v2-TestACLAllowUserDst.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLAllowUserDst - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLAllowUserDst: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLAllowUserDst - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLAllowUserDst$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLDenyAllPort80.yaml b/.github/workflows/test-integration-v2-TestACLDenyAllPort80.yaml deleted file mode 100644 index 396994a6..00000000 --- a/.github/workflows/test-integration-v2-TestACLDenyAllPort80.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLDenyAllPort80 - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLDenyAllPort80: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLDenyAllPort80 - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLDenyAllPort80$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLDevice1CanAccessDevice2.yaml b/.github/workflows/test-integration-v2-TestACLDevice1CanAccessDevice2.yaml deleted file mode 100644 index 9af861f7..00000000 --- a/.github/workflows/test-integration-v2-TestACLDevice1CanAccessDevice2.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLDevice1CanAccessDevice2 - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLDevice1CanAccessDevice2: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLDevice1CanAccessDevice2 - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLDevice1CanAccessDevice2$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLHostsInNetMapTable.yaml b/.github/workflows/test-integration-v2-TestACLHostsInNetMapTable.yaml deleted file mode 100644 index cac45ba1..00000000 --- a/.github/workflows/test-integration-v2-TestACLHostsInNetMapTable.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLHostsInNetMapTable - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLHostsInNetMapTable: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLHostsInNetMapTable - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLHostsInNetMapTable$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLNamedHostsCanReach.yaml b/.github/workflows/test-integration-v2-TestACLNamedHostsCanReach.yaml deleted file mode 100644 index f0985228..00000000 --- a/.github/workflows/test-integration-v2-TestACLNamedHostsCanReach.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLNamedHostsCanReach - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLNamedHostsCanReach: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLNamedHostsCanReach - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLNamedHostsCanReach$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestACLNamedHostsCanReachBySubnet.yaml b/.github/workflows/test-integration-v2-TestACLNamedHostsCanReachBySubnet.yaml deleted file mode 100644 index cee0e35c..00000000 --- a/.github/workflows/test-integration-v2-TestACLNamedHostsCanReachBySubnet.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestACLNamedHostsCanReachBySubnet - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestACLNamedHostsCanReachBySubnet: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestACLNamedHostsCanReachBySubnet - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestACLNamedHostsCanReachBySubnet$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestApiKeyCommand.yaml b/.github/workflows/test-integration-v2-TestApiKeyCommand.yaml deleted file mode 100644 index b495b9b3..00000000 --- a/.github/workflows/test-integration-v2-TestApiKeyCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestApiKeyCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestApiKeyCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestApiKeyCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestApiKeyCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestAuthKeyLogoutAndRelogin.yaml b/.github/workflows/test-integration-v2-TestAuthKeyLogoutAndRelogin.yaml deleted file mode 100644 index fcdceeb0..00000000 --- a/.github/workflows/test-integration-v2-TestAuthKeyLogoutAndRelogin.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestAuthKeyLogoutAndRelogin - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestAuthKeyLogoutAndRelogin: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestAuthKeyLogoutAndRelogin - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestAuthKeyLogoutAndRelogin$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestAuthWebFlowAuthenticationPingAll.yaml b/.github/workflows/test-integration-v2-TestAuthWebFlowAuthenticationPingAll.yaml deleted file mode 100644 index 9e24a7d1..00000000 --- a/.github/workflows/test-integration-v2-TestAuthWebFlowAuthenticationPingAll.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestAuthWebFlowAuthenticationPingAll - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestAuthWebFlowAuthenticationPingAll: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestAuthWebFlowAuthenticationPingAll - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestAuthWebFlowAuthenticationPingAll$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestAuthWebFlowLogoutAndRelogin.yaml b/.github/workflows/test-integration-v2-TestAuthWebFlowLogoutAndRelogin.yaml deleted file mode 100644 index e1ff6c3c..00000000 --- a/.github/workflows/test-integration-v2-TestAuthWebFlowLogoutAndRelogin.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestAuthWebFlowLogoutAndRelogin - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestAuthWebFlowLogoutAndRelogin: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestAuthWebFlowLogoutAndRelogin - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestAuthWebFlowLogoutAndRelogin$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestCreateTailscale.yaml b/.github/workflows/test-integration-v2-TestCreateTailscale.yaml deleted file mode 100644 index eaf829c5..00000000 --- a/.github/workflows/test-integration-v2-TestCreateTailscale.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestCreateTailscale - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestCreateTailscale: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestCreateTailscale - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestCreateTailscale$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestDERPServerScenario.yaml b/.github/workflows/test-integration-v2-TestDERPServerScenario.yaml deleted file mode 100644 index 41c7db50..00000000 --- a/.github/workflows/test-integration-v2-TestDERPServerScenario.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestDERPServerScenario - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestDERPServerScenario: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestDERPServerScenario - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestDERPServerScenario$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestEnablingRoutes.yaml b/.github/workflows/test-integration-v2-TestEnablingRoutes.yaml deleted file mode 100644 index 750ea9ff..00000000 --- a/.github/workflows/test-integration-v2-TestEnablingRoutes.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestEnablingRoutes - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestEnablingRoutes: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestEnablingRoutes - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestEnablingRoutes$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestEphemeral.yaml b/.github/workflows/test-integration-v2-TestEphemeral.yaml deleted file mode 100644 index df037ee6..00000000 --- a/.github/workflows/test-integration-v2-TestEphemeral.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestEphemeral - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestEphemeral: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestEphemeral - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestEphemeral$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestExpireNode.yaml b/.github/workflows/test-integration-v2-TestExpireNode.yaml deleted file mode 100644 index 48e5e368..00000000 --- a/.github/workflows/test-integration-v2-TestExpireNode.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestExpireNode - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestExpireNode: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestExpireNode - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestExpireNode$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml b/.github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml deleted file mode 100644 index 4ffe4640..00000000 --- a/.github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestHASubnetRouterFailover - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestHASubnetRouterFailover: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestHASubnetRouterFailover - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestHASubnetRouterFailover$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestHeadscale.yaml b/.github/workflows/test-integration-v2-TestHeadscale.yaml deleted file mode 100644 index ff7dbb16..00000000 --- a/.github/workflows/test-integration-v2-TestHeadscale.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestHeadscale - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestHeadscale: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestHeadscale - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestHeadscale$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeCommand.yaml b/.github/workflows/test-integration-v2-TestNodeCommand.yaml deleted file mode 100644 index 4398672f..00000000 --- a/.github/workflows/test-integration-v2-TestNodeCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestNodeCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestNodeCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestNodeCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestNodeCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeExpireCommand.yaml b/.github/workflows/test-integration-v2-TestNodeExpireCommand.yaml deleted file mode 100644 index f953a1c4..00000000 --- a/.github/workflows/test-integration-v2-TestNodeExpireCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestNodeExpireCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestNodeExpireCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestNodeExpireCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestNodeExpireCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeMoveCommand.yaml b/.github/workflows/test-integration-v2-TestNodeMoveCommand.yaml deleted file mode 100644 index ce5f5b90..00000000 --- a/.github/workflows/test-integration-v2-TestNodeMoveCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestNodeMoveCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestNodeMoveCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestNodeMoveCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestNodeMoveCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml b/.github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml deleted file mode 100644 index e3a30f83..00000000 --- a/.github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestNodeOnlineLastSeenStatus - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestNodeOnlineLastSeenStatus: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestNodeOnlineLastSeenStatus - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestNodeOnlineLastSeenStatus$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeRenameCommand.yaml b/.github/workflows/test-integration-v2-TestNodeRenameCommand.yaml deleted file mode 100644 index e3ac56a1..00000000 --- a/.github/workflows/test-integration-v2-TestNodeRenameCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestNodeRenameCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestNodeRenameCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestNodeRenameCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestNodeRenameCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeTagCommand.yaml b/.github/workflows/test-integration-v2-TestNodeTagCommand.yaml deleted file mode 100644 index 5e1e5782..00000000 --- a/.github/workflows/test-integration-v2-TestNodeTagCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestNodeTagCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestNodeTagCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestNodeTagCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestNodeTagCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestOIDCAuthenticationPingAll.yaml b/.github/workflows/test-integration-v2-TestOIDCAuthenticationPingAll.yaml deleted file mode 100644 index e333be2e..00000000 --- a/.github/workflows/test-integration-v2-TestOIDCAuthenticationPingAll.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestOIDCAuthenticationPingAll - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestOIDCAuthenticationPingAll: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestOIDCAuthenticationPingAll - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestOIDCAuthenticationPingAll$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestOIDCExpireNodesBasedOnTokenExpiry.yaml b/.github/workflows/test-integration-v2-TestOIDCExpireNodesBasedOnTokenExpiry.yaml deleted file mode 100644 index 1f148c79..00000000 --- a/.github/workflows/test-integration-v2-TestOIDCExpireNodesBasedOnTokenExpiry.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestOIDCExpireNodesBasedOnTokenExpiry - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestOIDCExpireNodesBasedOnTokenExpiry: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestOIDCExpireNodesBasedOnTokenExpiry - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestOIDCExpireNodesBasedOnTokenExpiry$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestPingAllByHostname.yaml b/.github/workflows/test-integration-v2-TestPingAllByHostname.yaml deleted file mode 100644 index fe9ad76c..00000000 --- a/.github/workflows/test-integration-v2-TestPingAllByHostname.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestPingAllByHostname - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestPingAllByHostname: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestPingAllByHostname - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestPingAllByHostname$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestPingAllByIP.yaml b/.github/workflows/test-integration-v2-TestPingAllByIP.yaml deleted file mode 100644 index 156ef734..00000000 --- a/.github/workflows/test-integration-v2-TestPingAllByIP.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestPingAllByIP - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestPingAllByIP: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestPingAllByIP - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestPingAllByIP$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestPreAuthKeyCommand.yaml b/.github/workflows/test-integration-v2-TestPreAuthKeyCommand.yaml deleted file mode 100644 index 11f10b08..00000000 --- a/.github/workflows/test-integration-v2-TestPreAuthKeyCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestPreAuthKeyCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestPreAuthKeyCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestPreAuthKeyCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestPreAuthKeyCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestPreAuthKeyCommandReusableEphemeral.yaml b/.github/workflows/test-integration-v2-TestPreAuthKeyCommandReusableEphemeral.yaml deleted file mode 100644 index 1be71ac7..00000000 --- a/.github/workflows/test-integration-v2-TestPreAuthKeyCommandReusableEphemeral.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestPreAuthKeyCommandReusableEphemeral - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestPreAuthKeyCommandReusableEphemeral: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestPreAuthKeyCommandReusableEphemeral - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestPreAuthKeyCommandReusableEphemeral$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestPreAuthKeyCommandWithoutExpiry.yaml b/.github/workflows/test-integration-v2-TestPreAuthKeyCommandWithoutExpiry.yaml deleted file mode 100644 index 7d290cd4..00000000 --- a/.github/workflows/test-integration-v2-TestPreAuthKeyCommandWithoutExpiry.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestPreAuthKeyCommandWithoutExpiry - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestPreAuthKeyCommandWithoutExpiry: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestPreAuthKeyCommandWithoutExpiry - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestPreAuthKeyCommandWithoutExpiry$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestResolveMagicDNS.yaml b/.github/workflows/test-integration-v2-TestResolveMagicDNS.yaml deleted file mode 100644 index fbcf8081..00000000 --- a/.github/workflows/test-integration-v2-TestResolveMagicDNS.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestResolveMagicDNS - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestResolveMagicDNS: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestResolveMagicDNS - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestResolveMagicDNS$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestSSHIsBlockedInACL.yaml b/.github/workflows/test-integration-v2-TestSSHIsBlockedInACL.yaml deleted file mode 100644 index bd19c8d5..00000000 --- a/.github/workflows/test-integration-v2-TestSSHIsBlockedInACL.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestSSHIsBlockedInACL - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestSSHIsBlockedInACL: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestSSHIsBlockedInACL - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestSSHIsBlockedInACL$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestSSHMultipleUsersAllToAll.yaml b/.github/workflows/test-integration-v2-TestSSHMultipleUsersAllToAll.yaml deleted file mode 100644 index 00748aa2..00000000 --- a/.github/workflows/test-integration-v2-TestSSHMultipleUsersAllToAll.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestSSHMultipleUsersAllToAll - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestSSHMultipleUsersAllToAll: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestSSHMultipleUsersAllToAll - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestSSHMultipleUsersAllToAll$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestSSHNoSSHConfigured.yaml b/.github/workflows/test-integration-v2-TestSSHNoSSHConfigured.yaml deleted file mode 100644 index be8f38a3..00000000 --- a/.github/workflows/test-integration-v2-TestSSHNoSSHConfigured.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestSSHNoSSHConfigured - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestSSHNoSSHConfigured: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestSSHNoSSHConfigured - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestSSHNoSSHConfigured$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestSSHOneUserToAll.yaml b/.github/workflows/test-integration-v2-TestSSHOneUserToAll.yaml deleted file mode 100644 index 62ab49be..00000000 --- a/.github/workflows/test-integration-v2-TestSSHOneUserToAll.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestSSHOneUserToAll - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestSSHOneUserToAll: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestSSHOneUserToAll - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestSSHOneUserToAll$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestSSHUserOnlyIsolation.yaml b/.github/workflows/test-integration-v2-TestSSHUserOnlyIsolation.yaml deleted file mode 100644 index 86264536..00000000 --- a/.github/workflows/test-integration-v2-TestSSHUserOnlyIsolation.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestSSHUserOnlyIsolation - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestSSHUserOnlyIsolation: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestSSHUserOnlyIsolation - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestSSHUserOnlyIsolation$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestTaildrop.yaml b/.github/workflows/test-integration-v2-TestTaildrop.yaml deleted file mode 100644 index e64eedec..00000000 --- a/.github/workflows/test-integration-v2-TestTaildrop.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestTaildrop - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestTaildrop: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestTaildrop - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestTaildrop$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestTailscaleNodesJoiningHeadcale.yaml b/.github/workflows/test-integration-v2-TestTailscaleNodesJoiningHeadcale.yaml deleted file mode 100644 index c406b2b2..00000000 --- a/.github/workflows/test-integration-v2-TestTailscaleNodesJoiningHeadcale.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestTailscaleNodesJoiningHeadcale - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestTailscaleNodesJoiningHeadcale: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestTailscaleNodesJoiningHeadcale - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestTailscaleNodesJoiningHeadcale$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestUserCommand.yaml b/.github/workflows/test-integration-v2-TestUserCommand.yaml deleted file mode 100644 index 667ad43e..00000000 --- a/.github/workflows/test-integration-v2-TestUserCommand.yaml +++ /dev/null @@ -1,67 +0,0 @@ -# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - TestUserCommand - -on: [pull_request] - -concurrency: - group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} - cancel-in-progress: true - -jobs: - TestUserCommand: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run TestUserCommand - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^TestUserCommand$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration.yaml b/.github/workflows/test-integration.yaml new file mode 100644 index 00000000..82b40044 --- /dev/null +++ b/.github/workflows/test-integration.yaml @@ -0,0 +1,273 @@ +name: integration +# To debug locally on a branch, and when needing secrets +# change this to include `push` so the build is ran on +# the main repository. +on: [pull_request] +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true +jobs: + # build: Builds binaries and Docker images once, uploads as artifacts for reuse. + # build-postgres: Pulls postgres image separately to avoid Docker Hub rate limits. + # sqlite: Runs all integration tests with SQLite backend. + # postgres: Runs a subset of tests with PostgreSQL to verify database compatibility. + build: + runs-on: ubuntu-latest + outputs: + files-changed: ${{ steps.changed-files.outputs.files }} + steps: + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 + with: + fetch-depth: 2 + - name: Get changed files + id: changed-files + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 + with: + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - 'integration/**' + - 'config-example.yaml' + - '.github/workflows/test-integration.yaml' + - '.github/workflows/integration-test-template.yml' + - 'Dockerfile.*' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} + - name: Build binaries and warm Go cache + if: steps.changed-files.outputs.files == 'true' + run: | + # Build all Go binaries in one nix shell to maximize cache reuse + nix develop --command -- bash -c ' + go build -o hi ./cmd/hi + CGO_ENABLED=0 GOOS=linux go build -o headscale ./cmd/headscale + # Build integration test binary to warm the cache with all dependencies + go test -c ./integration -o /dev/null 2>/dev/null || true + ' + - name: Upload hi binary + if: steps.changed-files.outputs.files == 'true' + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: hi-binary + path: hi + retention-days: 10 + - name: Package Go cache + if: steps.changed-files.outputs.files == 'true' + run: | + # Package Go module cache and build cache + tar -czf go-cache.tar.gz -C ~ go .cache/go-build + - name: Upload Go cache + if: steps.changed-files.outputs.files == 'true' + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: go-cache + path: go-cache.tar.gz + retention-days: 10 + - name: Build headscale image + if: steps.changed-files.outputs.files == 'true' + run: | + docker build \ + --file Dockerfile.integration-ci \ + --tag headscale:${{ github.sha }} \ + . + docker save headscale:${{ github.sha }} | gzip > headscale-image.tar.gz + - name: Build tailscale HEAD image + if: steps.changed-files.outputs.files == 'true' + run: | + docker build \ + --file Dockerfile.tailscale-HEAD \ + --tag tailscale-head:${{ github.sha }} \ + . + docker save tailscale-head:${{ github.sha }} | gzip > tailscale-head-image.tar.gz + - name: Upload headscale image + if: steps.changed-files.outputs.files == 'true' + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: headscale-image + path: headscale-image.tar.gz + retention-days: 10 + - name: Upload tailscale HEAD image + if: steps.changed-files.outputs.files == 'true' + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: tailscale-head-image + path: tailscale-head-image.tar.gz + retention-days: 10 + build-postgres: + runs-on: ubuntu-latest + needs: build + if: needs.build.outputs.files-changed == 'true' + steps: + - name: Pull and save postgres image + run: | + docker pull postgres:latest + docker tag postgres:latest postgres:${{ github.sha }} + docker save postgres:${{ github.sha }} | gzip > postgres-image.tar.gz + - name: Upload postgres image + uses: actions/upload-artifact@330a01c490aca151604b8cf639adc76d48f6c5d4 # v5.0.0 + with: + name: postgres-image + path: postgres-image.tar.gz + retention-days: 10 + sqlite: + needs: build + if: needs.build.outputs.files-changed == 'true' + strategy: + fail-fast: false + matrix: + test: + - TestACLHostsInNetMapTable + - TestACLAllowUser80Dst + - TestACLDenyAllPort80 + - TestACLAllowUserDst + - TestACLAllowStarDst + - TestACLNamedHostsCanReachBySubnet + - TestACLNamedHostsCanReach + - TestACLDevice1CanAccessDevice2 + - TestPolicyUpdateWhileRunningWithCLIInDatabase + - TestACLAutogroupMember + - TestACLAutogroupTagged + - TestACLAutogroupSelf + - TestACLPolicyPropagationOverTime + - TestACLTagPropagation + - TestACLTagPropagationPortSpecific + - TestACLGroupWithUnknownUser + - TestACLGroupAfterUserDeletion + - TestACLGroupDeletionExactReproduction + - TestACLDynamicUnknownUserAddition + - TestACLDynamicUnknownUserRemoval + - TestAPIAuthenticationBypass + - TestAPIAuthenticationBypassCurl + - TestGRPCAuthenticationBypass + - TestCLIWithConfigAuthenticationBypass + - TestAuthKeyLogoutAndReloginSameUser + - TestAuthKeyLogoutAndReloginNewUser + - TestAuthKeyLogoutAndReloginSameUserExpiredKey + - TestAuthKeyDeleteKey + - TestAuthKeyLogoutAndReloginRoutesPreserved + - TestOIDCAuthenticationPingAll + - TestOIDCExpireNodesBasedOnTokenExpiry + - TestOIDC024UserCreation + - TestOIDCAuthenticationWithPKCE + - TestOIDCReloginSameNodeNewUser + - TestOIDCFollowUpUrl + - TestOIDCMultipleOpenedLoginUrls + - TestOIDCReloginSameNodeSameUser + - TestOIDCExpiryAfterRestart + - TestOIDCACLPolicyOnJoin + - TestOIDCReloginSameUserRoutesPreserved + - TestAuthWebFlowAuthenticationPingAll + - TestAuthWebFlowLogoutAndReloginSameUser + - TestAuthWebFlowLogoutAndReloginNewUser + - TestUserCommand + - TestPreAuthKeyCommand + - TestPreAuthKeyCommandWithoutExpiry + - TestPreAuthKeyCommandReusableEphemeral + - TestPreAuthKeyCorrectUserLoggedInCommand + - TestTaggedNodesCLIOutput + - TestApiKeyCommand + - TestNodeCommand + - TestNodeExpireCommand + - TestNodeRenameCommand + - TestPolicyCommand + - TestPolicyBrokenConfigCommand + - TestDERPVerifyEndpoint + - TestResolveMagicDNS + - TestResolveMagicDNSExtraRecordsPath + - TestDERPServerScenario + - TestDERPServerWebsocketScenario + - TestPingAllByIP + - TestPingAllByIPPublicDERP + - TestEphemeral + - TestEphemeralInAlternateTimezone + - TestEphemeral2006DeletedTooQuickly + - TestPingAllByHostname + - TestTaildrop + - TestUpdateHostnameFromClient + - TestExpireNode + - TestSetNodeExpiryInFuture + - TestNodeOnlineStatus + - TestPingAllByIPManyUpDown + - Test2118DeletingOnlineNodePanics + - TestEnablingRoutes + - TestHASubnetRouterFailover + - TestSubnetRouteACL + - TestEnablingExitRoutes + - TestSubnetRouterMultiNetwork + - TestSubnetRouterMultiNetworkExitNode + - TestAutoApproveMultiNetwork/authkey-tag.* + - TestAutoApproveMultiNetwork/authkey-user.* + - TestAutoApproveMultiNetwork/authkey-group.* + - TestAutoApproveMultiNetwork/webauth-tag.* + - TestAutoApproveMultiNetwork/webauth-user.* + - TestAutoApproveMultiNetwork/webauth-group.* + - TestSubnetRouteACLFiltering + - TestHeadscale + - TestTailscaleNodesJoiningHeadcale + - TestSSHOneUserToAll + - TestSSHMultipleUsersAllToAll + - TestSSHNoSSHConfigured + - TestSSHIsBlockedInACL + - TestSSHUserOnlyIsolation + - TestSSHAutogroupSelf + - TestTagsAuthKeyWithTagRequestDifferentTag + - TestTagsAuthKeyWithTagNoAdvertiseFlag + - TestTagsAuthKeyWithTagCannotAddViaCLI + - TestTagsAuthKeyWithTagCannotChangeViaCLI + - TestTagsAuthKeyWithTagAdminOverrideReauthPreserves + - TestTagsAuthKeyWithTagCLICannotModifyAdminTags + - TestTagsAuthKeyWithoutTagCannotRequestTags + - TestTagsAuthKeyWithoutTagRegisterNoTags + - TestTagsAuthKeyWithoutTagCannotAddViaCLI + - TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithReset + - TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithEmptyAdvertise + - TestTagsAuthKeyWithoutTagCLICannotReduceAdminMultiTag + - TestTagsUserLoginOwnedTagAtRegistration + - TestTagsUserLoginNonExistentTagAtRegistration + - TestTagsUserLoginUnownedTagAtRegistration + - TestTagsUserLoginAddTagViaCLIReauth + - TestTagsUserLoginRemoveTagViaCLIReauth + - TestTagsUserLoginCLINoOpAfterAdminAssignment + - TestTagsUserLoginCLICannotRemoveAdminTags + - TestTagsAuthKeyWithTagRequestNonExistentTag + - TestTagsAuthKeyWithTagRequestUnownedTag + - TestTagsAuthKeyWithoutTagRequestNonExistentTag + - TestTagsAuthKeyWithoutTagRequestUnownedTag + - TestTagsAdminAPICannotSetNonExistentTag + - TestTagsAdminAPICanSetUnownedTag + - TestTagsAdminAPICannotRemoveAllTags + - TestTagsIssue2978ReproTagReplacement + - TestTagsAdminAPICannotSetInvalidFormat + - TestTagsUserLoginReauthWithEmptyTagsRemovesAllTags + - TestTagsAuthKeyWithoutUserInheritsTags + - TestTagsAuthKeyWithoutUserRejectsAdvertisedTags + uses: ./.github/workflows/integration-test-template.yml + secrets: inherit + with: + test: ${{ matrix.test }} + postgres_flag: "--postgres=0" + database_name: "sqlite" + postgres: + needs: [build, build-postgres] + if: needs.build.outputs.files-changed == 'true' + strategy: + fail-fast: false + matrix: + test: + - TestACLAllowUserDst + - TestPingAllByIP + - TestEphemeral2006DeletedTooQuickly + - TestPingAllByIPManyUpDown + - TestSubnetRouterMultiNetwork + uses: ./.github/workflows/integration-test-template.yml + secrets: inherit + with: + test: ${{ matrix.test }} + postgres_flag: "--postgres=1" + database_name: "postgres" diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index c2700d17..31eb431b 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -11,26 +11,37 @@ jobs: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@8e8c483db84b4bee98b60c0593521ed34d9990e8 # v6.0.1 with: fetch-depth: 2 - name: Get changed files id: changed-files - uses: tj-actions/changed-files@v34 + uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2 with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml + filters: | + files: + - '*.nix' + - 'go.*' + - '**/*.go' + - 'integration_test/' + - 'config-example.yaml' - - uses: DeterminateSystems/nix-installer-action@main - if: steps.changed-files.outputs.any_changed == 'true' - - uses: DeterminateSystems/magic-nix-cache-action@main - if: steps.changed-files.outputs.any_changed == 'true' + - uses: nixbuild/nix-quick-install-action@2c9db80fb984ceb1bcaa77cdda3fdf8cfba92035 # v34 + if: steps.changed-files.outputs.files == 'true' + - uses: nix-community/cache-nix-action@135667ec418502fa5a3598af6fb9eb733888ce6a # v6.1.3 + if: steps.changed-files.outputs.files == 'true' + with: + primary-key: nix-${{ runner.os }}-${{ runner.arch }}-${{ hashFiles('**/*.nix', + '**/flake.lock') }} + restore-prefixes-first-match: nix-${{ runner.os }}-${{ runner.arch }} - name: Run tests - if: steps.changed-files.outputs.any_changed == 'true' - run: nix develop --check + if: steps.changed-files.outputs.files == 'true' + env: + # As of 2025-01-06, these env vars was not automatically + # set anymore which breaks the initdb for postgres on + # some of the database migration tests. + LC_ALL: "en_US.UTF-8" + LC_CTYPE: "en_US.UTF-8" + run: nix develop --command -- gotestsum diff --git a/.github/workflows/update-flake.yml b/.github/workflows/update-flake.yml index 6fcea23e..1c8b262e 100644 --- a/.github/workflows/update-flake.yml +++ b/.github/workflows/update-flake.yml @@ -6,13 +6,14 @@ on: jobs: lockfile: + if: github.repository == 'juanfont/headscale' runs-on: ubuntu-latest steps: - name: Checkout repository - uses: actions/checkout@v3 + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install Nix - uses: DeterminateSystems/nix-installer-action@main + uses: DeterminateSystems/nix-installer-action@21a544727d0c62386e78b4befe52d19ad12692e3 # v17 - name: Update flake.lock - uses: DeterminateSystems/update-flake-lock@main + uses: DeterminateSystems/update-flake-lock@428c2b58a4b7414dabd372acb6a03dba1084d3ab # v25 with: pr-title: "Update flake.lock" diff --git a/.gitignore b/.gitignore index f6e506bc..4fec4f53 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,10 @@ ignored/ tailscale/ .vscode/ +.claude/ +logs/ + +*.prof # Binaries for programs and plugins *.exe @@ -20,8 +24,9 @@ vendor/ dist/ /headscale -config.json config.yaml +config*.yaml +!config-example.yaml derp.yaml *.hujson *.key @@ -45,3 +50,7 @@ integration_test/etc/config.dump.yaml /site __debug_bin + +node_modules/ +package-lock.json +package.json diff --git a/.golangci.yaml b/.golangci.yaml index 65a88511..eda3bed4 100644 --- a/.golangci.yaml +++ b/.golangci.yaml @@ -1,77 +1,90 @@ --- -run: - timeout: 10m - build-tags: - - ts2019 - -issues: - skip-dirs: - - gen +version: "2" linters: - enable-all: true + default: all disable: + - cyclop - depguard - - - exhaustivestruct - - revive - - lll - - interfacer - - scopelint - - maligned - - golint - - gofmt + - dupl + - exhaustruct + - funcorder + - funlen - gochecknoglobals - gochecknoinits - gocognit - - funlen - - exhaustivestruct - - tagliatelle - godox - - ireturn - - execinquery - - exhaustruct - - nolintlint - - musttag # causes issues with imported libs - - depguard - - # deprecated - - structcheck # replaced by unused - - ifshort # deprecated by the owner - - varcheck # replaced by unused - - nosnakecase # replaced by revive - - deadcode # replaced by unused - - # We should strive to enable these: - - wrapcheck - - dupl - - makezero - - maintidx - - # Limits the methods of an interface to 10. We have more in integration tests - interfacebloat - - # We might want to enable this, but it might be a lot of work - - cyclop + - ireturn + - lll + - maintidx + - makezero + - musttag - nestif - - wsl # might be incompatible with gofumpt - - testpackage + - nolintlint - paralleltest + - revive + - tagliatelle + - testpackage + - varnamelen + - wrapcheck + - wsl + settings: + forbidigo: + forbid: + # Forbid time.Sleep everywhere with context-appropriate alternatives + - pattern: 'time\.Sleep' + msg: >- + time.Sleep is forbidden. + In tests: use assert.EventuallyWithT for polling/waiting patterns. + In production code: use a backoff strategy (e.g., cenkalti/backoff) or proper synchronization primitives. + analyze-types: true + gocritic: + disabled-checks: + - appendAssign + - ifElseChain + nlreturn: + block-size: 4 + varnamelen: + ignore-names: + - err + - db + - id + - ip + - ok + - c + - tt + - tx + - rx + - sb + - wg + - pr + - p + - p2 + ignore-type-assert-ok: true + ignore-map-index-ok: true + exclusions: + generated: lax + presets: + - comments + - common-false-positives + - legacy + - std-error-handling + paths: + - third_party$ + - builtin$ + - examples$ + - gen -linters-settings: - varnamelen: - ignore-type-assert-ok: true - ignore-map-index-ok: true - ignore-names: - - err - - db - - id - - ip - - ok - - c - - tt - - gocritic: - disabled-checks: - - appendAssign - # TODO(kradalby): Remove this - - ifElseChain +formatters: + enable: + - gci + - gofmt + - gofumpt + - goimports + exclusions: + generated: lax + paths: + - third_party$ + - builtin$ + - examples$ + - gen diff --git a/.goreleaser.yml b/.goreleaser.yml index 07efe6f7..f77dfe38 100644 --- a/.goreleaser.yml +++ b/.goreleaser.yml @@ -1,15 +1,44 @@ --- +version: 2 before: hooks: - - go mod tidy -compat=1.20 + - go mod tidy -compat=1.25 - go mod vendor release: prerelease: auto + draft: true + header: | + ## Upgrade + + Please follow the steps outlined in the [upgrade guide](https://headscale.net/stable/setup/upgrade/) to update your existing Headscale installation. + + **It's best to update from one stable version to the next** (e.g., 0.24.0 → 0.25.1 → 0.26.1) in case you are multiple releases behind. You should always pick the latest available patch release. + + Be sure to check the changelog above for version-specific upgrade instructions and breaking changes. + + ### Backup Your Database + + **Always backup your database before upgrading.** Here's how to backup a SQLite database: + + ```bash + # Stop headscale + systemctl stop headscale + + # Backup sqlite database + cp /var/lib/headscale/db.sqlite /var/lib/headscale/db.sqlite.backup + + # Backup sqlite WAL/SHM files (if they exist) + cp /var/lib/headscale/db.sqlite-wal /var/lib/headscale/db.sqlite-wal.backup + cp /var/lib/headscale/db.sqlite-shm /var/lib/headscale/db.sqlite-shm.backup + + # Start headscale (migration will run automatically) + systemctl start headscale + ``` builds: - id: headscale - main: ./cmd/headscale/headscale.go + main: ./cmd/headscale mod_timestamp: "{{ .CommitTimestamp }}" env: - CGO_ENABLED=0 @@ -17,23 +46,18 @@ builds: - darwin_amd64 - darwin_arm64 - freebsd_amd64 - - linux_386 - linux_amd64 - linux_arm64 - - linux_arm_5 - - linux_arm_6 - - linux_arm_7 flags: - -mod=readonly - ldflags: - - -s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=v{{.Version}} tags: - ts2019 archives: - id: golang-cross name_template: '{{ .ProjectName }}_{{ .Version }}_{{ .Os }}_{{ .Arch }}{{ with .Arm }}v{{ . }}{{ end }}{{ with .Mips }}_{{ . }}{{ end }}{{ if not (eq .Amd64 "v1") }}{{ .Amd64 }}{{ end }}' - format: binary + formats: + - binary source: enabled: true @@ -52,38 +76,109 @@ nfpms: # List file contents: dpkg -c dist/headscale...deb # Package metadata: dpkg --info dist/headscale....deb # - - builds: + - ids: - headscale package_name: headscale priority: optional vendor: headscale maintainer: Kristoffer Dalby <kristoffer@dalby.cc> homepage: https://github.com/juanfont/headscale - license: BSD + description: |- + Open source implementation of the Tailscale control server. + Headscale aims to implement a self-hosted, open source alternative to the + Tailscale control server. Headscale's goal is to provide self-hosters and + hobbyists with an open-source server they can use for their projects and + labs. It implements a narrow scope, a single Tailscale network (tailnet), + suitable for a personal use, or a small open-source organisation. bindir: /usr/bin + section: net formats: - deb - # - rpm contents: - src: ./config-example.yaml dst: /etc/headscale/config.yaml type: config|noreplace file_info: mode: 0644 - - src: ./docs/packaging/headscale.systemd.service + - src: ./packaging/systemd/headscale.service dst: /usr/lib/systemd/system/headscale.service - dst: /var/lib/headscale type: dir - - dst: /var/run/headscale - type: dir + - src: LICENSE + dst: /usr/share/doc/headscale/copyright scripts: - postinstall: ./docs/packaging/postinstall.sh - postremove: ./docs/packaging/postremove.sh + postinstall: ./packaging/deb/postinst + postremove: ./packaging/deb/postrm + preremove: ./packaging/deb/prerm + deb: + lintian_overrides: + - no-changelog # Our CHANGELOG.md uses a different formatting + - no-manual-page + - statically-linked-binary + +kos: + - id: ghcr + repositories: + - ghcr.io/juanfont/headscale + - headscale/headscale + + # bare tells KO to only use the repository + # for tagging and naming the container. + bare: true + base_image: gcr.io/distroless/base-debian13 + build: headscale + main: ./cmd/headscale + env: + - CGO_ENABLED=0 + platforms: + - linux/amd64 + - linux/arm64 + tags: + - "{{ if not .Prerelease }}latest{{ end }}" + - "{{ if not .Prerelease }}{{ .Major }}.{{ .Minor }}.{{ .Patch }}{{ end }}" + - "{{ if not .Prerelease }}{{ .Major }}.{{ .Minor }}{{ end }}" + - "{{ if not .Prerelease }}{{ .Major }}{{ end }}" + - "{{ if not .Prerelease }}v{{ .Major }}.{{ .Minor }}.{{ .Patch }}{{ end }}" + - "{{ if not .Prerelease }}v{{ .Major }}.{{ .Minor }}{{ end }}" + - "{{ if not .Prerelease }}v{{ .Major }}{{ end }}" + - "{{ if not .Prerelease }}stable{{ else }}unstable{{ end }}" + - "{{ .Tag }}" + - '{{ trimprefix .Tag "v" }}' + - "sha-{{ .ShortCommit }}" + creation_time: "{{.CommitTimestamp}}" + ko_data_creation_time: "{{.CommitTimestamp}}" + + - id: ghcr-debug + repositories: + - ghcr.io/juanfont/headscale + - headscale/headscale + + bare: true + base_image: gcr.io/distroless/base-debian13:debug + build: headscale + main: ./cmd/headscale + env: + - CGO_ENABLED=0 + platforms: + - linux/amd64 + - linux/arm64 + tags: + - "{{ if not .Prerelease }}latest-debug{{ end }}" + - "{{ if not .Prerelease }}{{ .Major }}.{{ .Minor }}.{{ .Patch }}-debug{{ end }}" + - "{{ if not .Prerelease }}{{ .Major }}.{{ .Minor }}-debug{{ end }}" + - "{{ if not .Prerelease }}{{ .Major }}-debug{{ end }}" + - "{{ if not .Prerelease }}v{{ .Major }}.{{ .Minor }}.{{ .Patch }}-debug{{ end }}" + - "{{ if not .Prerelease }}v{{ .Major }}.{{ .Minor }}-debug{{ end }}" + - "{{ if not .Prerelease }}v{{ .Major }}-debug{{ end }}" + - "{{ if not .Prerelease }}stable-debug{{ else }}unstable-debug{{ end }}" + - "{{ .Tag }}-debug" + - '{{ trimprefix .Tag "v" }}-debug' + - "sha-{{ .ShortCommit }}-debug" checksum: name_template: "checksums.txt" snapshot: - name_template: "{{ .Tag }}-next" + version_template: "{{ .Tag }}-next" changelog: sort: asc filters: diff --git a/.mcp.json b/.mcp.json new file mode 100644 index 00000000..71554002 --- /dev/null +++ b/.mcp.json @@ -0,0 +1,34 @@ +{ + "mcpServers": { + "claude-code-mcp": { + "type": "stdio", + "command": "npx", + "args": ["-y", "@steipete/claude-code-mcp@latest"], + "env": {} + }, + "sequential-thinking": { + "type": "stdio", + "command": "npx", + "args": ["-y", "@modelcontextprotocol/server-sequential-thinking"], + "env": {} + }, + "nixos": { + "type": "stdio", + "command": "uvx", + "args": ["mcp-nixos"], + "env": {} + }, + "context7": { + "type": "stdio", + "command": "npx", + "args": ["-y", "@upstash/context7-mcp"], + "env": {} + }, + "git": { + "type": "stdio", + "command": "npx", + "args": ["-y", "@cyanheads/git-mcp-server"], + "env": {} + } + } +} diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 00000000..ed869775 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,68 @@ +# prek/pre-commit configuration for headscale +# See: https://prek.j178.dev/quickstart/ +# See: https://prek.j178.dev/builtin/ + +# Global exclusions - ignore generated code +exclude: ^gen/ + +repos: + # Built-in hooks from pre-commit/pre-commit-hooks + # prek will use fast-path optimized versions automatically + # See: https://prek.j178.dev/builtin/ + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 + hooks: + - id: check-added-large-files + - id: check-case-conflict + - id: check-executables-have-shebangs + - id: check-json + - id: check-merge-conflict + - id: check-symlinks + - id: check-toml + - id: check-xml + - id: check-yaml + - id: detect-private-key + - id: end-of-file-fixer + - id: fix-byte-order-marker + - id: mixed-line-ending + - id: trailing-whitespace + + # Local hooks for project-specific tooling + - repo: local + hooks: + # nixpkgs-fmt for Nix files + - id: nixpkgs-fmt + name: nixpkgs-fmt + entry: nixpkgs-fmt + language: system + files: \.nix$ + + # Prettier for formatting + - id: prettier + name: prettier + entry: prettier --write --list-different + language: system + exclude: ^docs/ + types_or: + [ + javascript, + jsx, + ts, + tsx, + yaml, + json, + toml, + html, + css, + scss, + sass, + markdown, + ] + + # golangci-lint for Go code quality + - id: golangci-lint + name: golangci-lint + entry: nix develop --command golangci-lint run --new-from-rev=HEAD~1 --timeout=5m --fix + language: system + types: [go] + pass_filenames: false diff --git a/.prettierignore b/.prettierignore index 146ae4dd..ebb727cc 100644 --- a/.prettierignore +++ b/.prettierignore @@ -1 +1,5 @@ .github/workflows/test-integration-v2* +docs/about/features.md +docs/ref/api.md +docs/ref/configuration.md +docs/ref/oidc.md diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..2432ea28 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,1051 @@ +# AGENTS.md + +This file provides guidance to AI agents when working with code in this repository. + +## Overview + +Headscale is an open-source implementation of the Tailscale control server written in Go. It provides self-hosted coordination for Tailscale networks (tailnets), managing node registration, IP allocation, policy enforcement, and DERP routing. + +## Development Commands + +### Quick Setup + +```bash +# Recommended: Use Nix for dependency management +nix develop + +# Full development workflow +make dev # runs fmt + lint + test + build +``` + +### Essential Commands + +```bash +# Build headscale binary +make build + +# Run tests +make test +go test ./... # All unit tests +go test -race ./... # With race detection + +# Run specific integration test +go run ./cmd/hi run "TestName" --postgres + +# Code formatting and linting +make fmt # Format all code (Go, docs, proto) +make lint # Lint all code (Go, proto) +make fmt-go # Format Go code only +make lint-go # Lint Go code only + +# Protocol buffer generation (after modifying proto/) +make generate + +# Clean build artifacts +make clean +``` + +### Integration Testing + +```bash +# Use the hi (Headscale Integration) test runner +go run ./cmd/hi doctor # Check system requirements +go run ./cmd/hi run "TestPattern" # Run specific test +go run ./cmd/hi run "TestPattern" --postgres # With PostgreSQL backend + +# Test artifacts are saved to control_logs/ with logs and debug data +``` + +## Pre-Commit Quality Checks + +### **MANDATORY: Automated Pre-Commit Hooks with prek** + +**CRITICAL REQUIREMENT**: This repository uses [prek](https://prek.j178.dev/) for automated pre-commit hooks. All commits are automatically validated for code quality, formatting, and common issues. + +### Initial Setup + +When you first clone the repository or enter the nix shell, install the git hooks: + +```bash +# Enter nix development environment +nix develop + +# Install prek git hooks (one-time setup) +prek install +``` + +This installs the pre-commit hook at `.git/hooks/pre-commit` which automatically runs all configured checks before each commit. + +### Configured Hooks + +The repository uses `.pre-commit-config.yaml` with the following hooks: + +**Built-in Checks** (optimized fast-path execution): + +- `check-added-large-files` - Prevents accidentally committing large files +- `check-case-conflict` - Checks for files that would conflict in case-insensitive filesystems +- `check-executables-have-shebangs` - Ensures executables have proper shebangs +- `check-json` - Validates JSON syntax +- `check-merge-conflict` - Prevents committing files with merge conflict markers +- `check-symlinks` - Checks for broken symlinks +- `check-toml` - Validates TOML syntax +- `check-xml` - Validates XML syntax +- `check-yaml` - Validates YAML syntax +- `detect-private-key` - Detects accidentally committed private keys +- `end-of-file-fixer` - Ensures files end with a newline +- `fix-byte-order-marker` - Removes UTF-8 byte order markers +- `mixed-line-ending` - Prevents mixed line endings +- `trailing-whitespace` - Removes trailing whitespace + +**Project-Specific Hooks**: + +- `nixpkgs-fmt` - Formats Nix files +- `prettier` - Formats markdown, YAML, JSON, and TOML files +- `golangci-lint` - Runs Go linter with auto-fix on changed files only + +### Manual Hook Execution + +Run hooks manually without making a commit: + +```bash +# Run hooks on staged files only +prek run + +# Run hooks on all files in the repository +prek run --all-files + +# Run a specific hook +prek run golangci-lint + +# Run hooks on specific files +prek run --files path/to/file1.go path/to/file2.go +``` + +### Workflow Pattern + +With prek installed, your normal workflow becomes: + +```bash +# 1. Make your code changes +vim hscontrol/state/state.go + +# 2. Stage your changes +git add . + +# 3. Commit - hooks run automatically +git commit -m "feat: add new feature" + +# If hooks fail, they will show which checks failed +# Fix the issues and try committing again +``` + +### Manual golangci-lint + +While golangci-lint runs automatically via prek, you can also run it manually: + +```bash +# If you have upstream remote configured (recommended) +golangci-lint run --new-from-rev=upstream/main --timeout=5m --fix + +# If you only have origin remote +golangci-lint run --new-from-rev=main --timeout=5m --fix +``` + +**Important**: Always use `--new-from-rev` to only lint changed files. This prevents formatting the entire repository and keeps changes focused on your actual modifications. + +### Skipping Hooks (Not Recommended) + +In rare cases where you need to skip hooks (e.g., work-in-progress commits), use: + +```bash +git commit --no-verify -m "WIP: work in progress" +``` + +**WARNING**: Only use `--no-verify` for temporary WIP commits on feature branches. All commits to main must pass all hooks. + +### Troubleshooting + +**Hook installation issues**: + +```bash +# Check if hooks are installed +ls -la .git/hooks/pre-commit + +# Reinstall hooks +prek install +``` + +**Hooks running slow**: + +```bash +# prek uses optimized fast-path for built-in hooks +# If running slow, check which hook is taking time with verbose output +prek run -v +``` + +**Update hook configuration**: + +```bash +# After modifying .pre-commit-config.yaml, hooks will automatically use new config +# No reinstallation needed +``` + +## Project Structure & Architecture + +### Top-Level Organization + +``` +headscale/ +├── cmd/ # Command-line applications +│ ├── headscale/ # Main headscale server binary +│ └── hi/ # Headscale Integration test runner +├── hscontrol/ # Core control plane logic +├── integration/ # End-to-end Docker-based tests +├── proto/ # Protocol buffer definitions +├── gen/ # Generated code (protobuf) +├── docs/ # Documentation +└── packaging/ # Distribution packaging +``` + +### Core Packages (`hscontrol/`) + +**Main Server (`hscontrol/`)** + +- `app.go`: Application setup, dependency injection, server lifecycle +- `handlers.go`: HTTP/gRPC API endpoints for management operations +- `grpcv1.go`: gRPC service implementation for headscale API +- `poll.go`: **Critical** - Handles Tailscale MapRequest/MapResponse protocol +- `noise.go`: Noise protocol implementation for secure client communication +- `auth.go`: Authentication flows (web, OIDC, command-line) +- `oidc.go`: OpenID Connect integration for user authentication + +**State Management (`hscontrol/state/`)** + +- `state.go`: Central coordinator for all subsystems (database, policy, IP allocation, DERP) +- `node_store.go`: **Performance-critical** - In-memory cache with copy-on-write semantics +- Thread-safe operations with deadlock detection +- Coordinates between database persistence and real-time operations + +**Database Layer (`hscontrol/db/`)** + +- `db.go`: Database abstraction, GORM setup, migration management +- `node.go`: Node lifecycle, registration, expiration, IP assignment +- `users.go`: User management, namespace isolation +- `api_key.go`: API authentication tokens +- `preauth_keys.go`: Pre-authentication keys for automated node registration +- `ip.go`: IP address allocation and management +- `policy.go`: Policy storage and retrieval +- Schema migrations in `schema.sql` with extensive test data coverage + +**CRITICAL DATABASE MIGRATION RULES**: + +1. **NEVER reorder existing migrations** - Migration order is immutable once committed +2. **ONLY add new migrations to the END** of the migrations array +3. **NEVER disable foreign keys** in new migrations - no new migrations should be added to `migrationsRequiringFKDisabled` +4. **Migration ID format**: `YYYYMMDDHHSS-short-description` (timestamp + descriptive suffix) + - Example: `202511131500-add-user-roles` + - The timestamp must be chronologically ordered +5. **New migrations go after the comment** "As of 2025-07-02, no new IDs should be added here" +6. If you need to rename a column that other migrations depend on: + - Accept that the old column name will exist in intermediate migration states + - Update code to work with the new column name + - Let AutoMigrate create the new column if needed + - Do NOT try to rename columns that later migrations reference + +**Policy Engine (`hscontrol/policy/`)** + +- `policy.go`: Core ACL evaluation logic, HuJSON parsing +- `v2/`: Next-generation policy system with improved filtering +- `matcher/`: ACL rule matching and evaluation engine +- Determines peer visibility, route approval, and network access rules +- Supports both file-based and database-stored policies + +**Network Management (`hscontrol/`)** + +- `derp/`: DERP (Designated Encrypted Relay for Packets) server implementation + - NAT traversal when direct connections fail + - Fallback relay for firewall-restricted environments +- `mapper/`: Converts internal Headscale state to Tailscale's wire protocol format + - `tail.go`: Tailscale-specific data structure generation +- `routes/`: Subnet route management and primary route selection +- `dns/`: DNS record management and MagicDNS implementation + +**Utilities & Support (`hscontrol/`)** + +- `types/`: Core data structures, configuration, validation +- `util/`: Helper functions for networking, DNS, key management +- `templates/`: Client configuration templates (Apple, Windows, etc.) +- `notifier/`: Event notification system for real-time updates +- `metrics.go`: Prometheus metrics collection +- `capver/`: Tailscale capability version management + +### Key Subsystem Interactions + +**Node Registration Flow** + +1. **Client Connection**: `noise.go` handles secure protocol handshake +2. **Authentication**: `auth.go` validates credentials (web/OIDC/preauth) +3. **State Creation**: `state.go` coordinates IP allocation via `db/ip.go` +4. **Storage**: `db/node.go` persists node, `NodeStore` caches in memory +5. **Network Setup**: `mapper/` generates initial Tailscale network map + +**Ongoing Operations** + +1. **Poll Requests**: `poll.go` receives periodic client updates +2. **State Updates**: `NodeStore` maintains real-time node information +3. **Policy Application**: `policy/` evaluates ACL rules for peer relationships +4. **Map Distribution**: `mapper/` sends network topology to all affected clients + +**Route Management** + +1. **Advertisement**: Clients announce routes via `poll.go` Hostinfo updates +2. **Storage**: `db/` persists routes, `NodeStore` caches for performance +3. **Approval**: `policy/` auto-approves routes based on ACL rules +4. **Distribution**: `routes/` selects primary routes, `mapper/` distributes to peers + +### Command-Line Tools (`cmd/`) + +**Main Server (`cmd/headscale/`)** + +- `headscale.go`: CLI parsing, configuration loading, server startup +- Supports daemon mode, CLI operations (user/node management), database operations + +**Integration Test Runner (`cmd/hi/`)** + +- `main.go`: Test execution framework with Docker orchestration +- `run.go`: Individual test execution with artifact collection +- `doctor.go`: System requirements validation +- `docker.go`: Container lifecycle management +- Essential for validating changes against real Tailscale clients + +### Generated & External Code + +**Protocol Buffers (`proto/` → `gen/`)** + +- Defines gRPC API for headscale management operations +- Client libraries can generate from these definitions +- Run `make generate` after modifying `.proto` files + +**Integration Testing (`integration/`)** + +- `scenario.go`: Docker test environment setup +- `tailscale.go`: Tailscale client container management +- Individual test files for specific functionality areas +- Real end-to-end validation with network isolation + +### Critical Performance Paths + +**High-Frequency Operations** + +1. **MapRequest Processing** (`poll.go`): Every 15-60 seconds per client +2. **NodeStore Reads** (`node_store.go`): Every operation requiring node data +3. **Policy Evaluation** (`policy/`): On every peer relationship calculation +4. **Route Lookups** (`routes/`): During network map generation + +**Database Write Patterns** + +- **Frequent**: Node heartbeats, endpoint updates, route changes +- **Moderate**: User operations, policy updates, API key management +- **Rare**: Schema migrations, bulk operations + +### Configuration & Deployment + +**Configuration** (`hscontrol/types/config.go`)\*\* + +- Database connection settings (SQLite/PostgreSQL) +- Network configuration (IP ranges, DNS settings) +- Policy mode (file vs database) +- DERP relay configuration +- OIDC provider settings + +**Key Dependencies** + +- **GORM**: Database ORM with migration support +- **Tailscale Libraries**: Core networking and protocol code +- **Zerolog**: Structured logging throughout the application +- **Buf**: Protocol buffer toolchain for code generation + +### Development Workflow Integration + +The architecture supports incremental development: + +- **Unit Tests**: Focus on individual packages (`*_test.go` files) +- **Integration Tests**: Validate cross-component interactions +- **Database Tests**: Extensive migration and data integrity validation +- **Policy Tests**: ACL rule evaluation and edge cases +- **Performance Tests**: NodeStore and high-frequency operation validation + +## Integration Testing System + +### Overview + +Headscale uses Docker-based integration tests with real Tailscale clients to validate end-to-end functionality. The integration test system is complex and requires specialized knowledge for effective execution and debugging. + +### **MANDATORY: Use the headscale-integration-tester Agent** + +**CRITICAL REQUIREMENT**: For ANY integration test execution, analysis, troubleshooting, or validation, you MUST use the `headscale-integration-tester` agent. This agent contains specialized knowledge about: + +- Test execution strategies and timing requirements +- Infrastructure vs code issue distinction (99% vs 1% failure patterns) +- Security-critical debugging rules and forbidden practices +- Comprehensive artifact analysis workflows +- Real-world failure patterns from HA debugging experiences + +### Quick Reference Commands + +```bash +# Check system requirements (always run first) +go run ./cmd/hi doctor + +# Run single test (recommended for development) +go run ./cmd/hi run "TestName" + +# Use PostgreSQL for database-heavy tests +go run ./cmd/hi run "TestName" --postgres + +# Pattern matching for related tests +go run ./cmd/hi run "TestPattern*" + +# Run multiple tests concurrently (each gets isolated run ID) +go run ./cmd/hi run "TestPingAllByIP" & +go run ./cmd/hi run "TestACLAllowUserDst" & +go run ./cmd/hi run "TestOIDCAuthenticationPingAll" & +``` + +**Concurrent Execution Support**: + +The test runner supports running multiple tests concurrently on the same Docker daemon: + +- Each test run gets a **unique Run ID** (format: `YYYYMMDD-HHMMSS-{6-char-hash}`) +- All containers are labeled with `hi.run-id` for isolation +- Container names include the run ID for easy identification (e.g., `ts-{runID}-1-74-{hash}`) +- Dynamic port allocation prevents port conflicts between concurrent runs +- Cleanup only affects containers belonging to the specific run ID +- Log directories are isolated per run: `control_logs/{runID}/` + +**Critical Notes**: + +- Tests generate ~100MB of logs per run in `control_logs/` +- Running many tests concurrently may cause resource contention (CPU/memory) +- Clean stale containers periodically: `docker system prune -f` + +### Test Artifacts Location + +All test runs save comprehensive debugging artifacts to `control_logs/TIMESTAMP-ID/` including server logs, client logs, database dumps, MapResponse protocol data, and Prometheus metrics. + +**For all integration test work, use the headscale-integration-tester agent - it contains the complete knowledge needed for effective testing and debugging.** + +## NodeStore Implementation Details + +**Key Insight from Recent Work**: The NodeStore is a critical performance optimization that caches node data in memory while ensuring consistency with the database. When working with route advertisements or node state changes: + +1. **Timing Considerations**: Route advertisements need time to propagate from clients to server. Use `require.EventuallyWithT()` patterns in tests instead of immediate assertions. + +2. **Synchronization Points**: NodeStore updates happen at specific points like `poll.go:420` after Hostinfo changes. Ensure these are maintained when modifying the polling logic. + +3. **Peer Visibility**: The NodeStore's `peersFunc` determines which nodes are visible to each other. Policy-based filtering is separate from monitoring visibility - expired nodes should remain visible for debugging but marked as expired. + +## Testing Guidelines + +### Integration Test Patterns + +#### **CRITICAL: EventuallyWithT Pattern for External Calls** + +**All external calls in integration tests MUST be wrapped in EventuallyWithT blocks** to handle eventual consistency in distributed systems. External calls include: + +- `client.Status()` - Getting Tailscale client status +- `client.Curl()` - Making HTTP requests through clients +- `client.Traceroute()` - Running network diagnostics +- `headscale.ListNodes()` - Querying headscale server state +- Any other calls that interact with external systems or network operations + +**Key Rules**: + +1. **Never use bare `require.NoError(t, err)` with external calls** - Always wrap in EventuallyWithT +2. **Keep related assertions together** - If multiple assertions depend on the same external call, keep them in the same EventuallyWithT block +3. **Split unrelated external calls** - Different external calls should be in separate EventuallyWithT blocks +4. **Never nest EventuallyWithT calls** - Each EventuallyWithT should be at the same level +5. **Declare shared variables at function scope** - Variables used across multiple EventuallyWithT blocks must be declared before first use + +**Examples**: + +```go +// CORRECT: External call wrapped in EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + // Related assertions using the same status call + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + assert.NotNil(c, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutesWithCollect(c, peerStatus, expectedRoutes) + } +}, 5*time.Second, 200*time.Millisecond, "Verifying client status and routes") + +// INCORRECT: Bare external call without EventuallyWithT +status, err := client.Status() // ❌ Will fail intermittently +require.NoError(t, err) + +// CORRECT: Separate EventuallyWithT for different external calls +// First external call - headscale.ListNodes() +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) +}, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") + +// Second external call - client.Status() +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } +}, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") + +// INCORRECT: Multiple unrelated external calls in same EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() // ❌ First external call + assert.NoError(c, err) + + status, err := client.Status() // ❌ Different external call - should be separate + assert.NoError(c, err) +}, 10*time.Second, 500*time.Millisecond, "mixed calls") + +// CORRECT: Variable scoping for shared data +var ( + srs1, srs2, srs3 *ipnstate.Status + clientStatus *ipnstate.Status + srs1PeerStatus *ipnstate.PeerStatus +) + +assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs1 = subRouter1.MustStatus() // = not := + srs2 = subRouter2.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + // assertions... +}, 5*time.Second, 200*time.Millisecond, "checking router status") + +// CORRECT: Wrapping client operations +assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) +}, 5*time.Second, 200*time.Millisecond, "Verifying HTTP connectivity") + +assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + assertTracerouteViaIPWithCollect(c, tr, expectedRouter.MustIPv4()) +}, 5*time.Second, 200*time.Millisecond, "Verifying network path") +``` + +**Helper Functions**: + +- Use `requirePeerSubnetRoutesWithCollect` instead of `requirePeerSubnetRoutes` inside EventuallyWithT +- Use `requireNodeRouteCountWithCollect` instead of `requireNodeRouteCount` inside EventuallyWithT +- Use `assertTracerouteViaIPWithCollect` instead of `assertTracerouteViaIP` inside EventuallyWithT + +```go +// Node route checking by actual node properties, not array position +var routeNode *v1.Node +for _, node := range nodes { + if nodeIDStr := fmt.Sprintf("%d", node.GetId()); expectedRoutes[nodeIDStr] != "" { + routeNode = node + break + } +} +``` + +### Running Problematic Tests + +- Some tests require significant time (e.g., `TestNodeOnlineStatus` runs for 12 minutes) +- Infrastructure issues like disk space can cause test failures unrelated to code changes +- Use `--postgres` flag when testing database-heavy scenarios + +## Quality Assurance and Testing Requirements + +### **MANDATORY: Always Use Specialized Testing Agents** + +**CRITICAL REQUIREMENT**: For ANY task involving testing, quality assurance, review, or validation, you MUST use the appropriate specialized agent at the END of your task list. This ensures comprehensive quality validation and prevents regressions. + +**Required Agents for Different Task Types**: + +1. **Integration Testing**: Use `headscale-integration-tester` agent for: + - Running integration tests with `cmd/hi` + - Analyzing test failures and artifacts + - Troubleshooting Docker-based test infrastructure + - Validating end-to-end functionality changes + +2. **Quality Control**: Use `quality-control-enforcer` agent for: + - Code review and validation + - Ensuring best practices compliance + - Preventing common pitfalls and anti-patterns + - Validating architectural decisions + +**Agent Usage Pattern**: Always add the appropriate agent as the FINAL step in any task list to ensure quality validation occurs after all work is complete. + +### Integration Test Debugging Reference + +Test artifacts are preserved in `control_logs/TIMESTAMP-ID/` including: + +- Headscale server logs (stderr/stdout) +- Tailscale client logs and status +- Database dumps and network captures +- MapResponse JSON files for protocol debugging + +**For integration test issues, ALWAYS use the headscale-integration-tester agent - do not attempt manual debugging.** + +## EventuallyWithT Pattern for Integration Tests + +### Overview + +EventuallyWithT is a testing pattern used to handle eventual consistency in distributed systems. In Headscale integration tests, many operations are asynchronous - clients advertise routes, the server processes them, updates propagate through the network. EventuallyWithT allows tests to wait for these operations to complete while making assertions. + +### External Calls That Must Be Wrapped + +The following operations are **external calls** that interact with the headscale server or tailscale clients and MUST be wrapped in EventuallyWithT: + +- `headscale.ListNodes()` - Queries server state +- `client.Status()` - Gets client network status +- `client.Curl()` - Makes HTTP requests through the network +- `client.Traceroute()` - Performs network diagnostics +- `client.Execute()` when running commands that query state +- Any operation that reads from the headscale server or tailscale client + +### Operations That Must NOT Be Wrapped + +The following are **blocking operations** that modify state and should NOT be wrapped in EventuallyWithT: + +- `tailscale set` commands (e.g., `--advertise-routes`, `--exit-node`) +- Any command that changes configuration or state +- Use `client.MustStatus()` instead of `client.Status()` when you just need the ID for a blocking operation + +### Five Key Rules for EventuallyWithT + +1. **One External Call Per EventuallyWithT Block** + - Each EventuallyWithT should make ONE external call (e.g., ListNodes OR Status) + - Related assertions based on that single call can be grouped together + - Unrelated external calls must be in separate EventuallyWithT blocks + +2. **Variable Scoping** + - Declare variables that need to be shared across EventuallyWithT blocks at function scope + - Use `=` for assignment inside EventuallyWithT, not `:=` (unless the variable is only used within that block) + - Variables declared with `:=` inside EventuallyWithT are not accessible outside + +3. **No Nested EventuallyWithT** + - NEVER put an EventuallyWithT inside another EventuallyWithT + - This is a critical anti-pattern that must be avoided + +4. **Use CollectT for Assertions** + - Inside EventuallyWithT, use `assert` methods with the CollectT parameter + - Helper functions called within EventuallyWithT must accept `*assert.CollectT` + +5. **Descriptive Messages** + - Always provide a descriptive message as the last parameter + - Message should explain what condition is being waited for + +### Correct Pattern Examples + +```go +// CORRECT: Blocking operation NOT wrapped +for _, client := range allClients { + status := client.MustStatus() + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes[string(status.Self.ID)], + } + _, _, err = client.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) +} + +// CORRECT: Single external call with related assertions +var nodes []*v1.Node +assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) +}, 10*time.Second, 500*time.Millisecond, "nodes should have expected route counts") + +// CORRECT: Separate EventuallyWithT for different external call +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + requirePeerSubnetRoutesWithCollect(c, peerStatus, expectedPrefixes) + } +}, 10*time.Second, 500*time.Millisecond, "client should see expected routes") +``` + +### Incorrect Patterns to Avoid + +```go +// INCORRECT: Blocking operation wrapped in EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + // This is a blocking operation - should NOT be in EventuallyWithT! + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes[string(status.Self.ID)], + } + _, _, err = client.Execute(command) + assert.NoError(c, err) +}, 5*time.Second, 200*time.Millisecond, "wrong pattern") + +// INCORRECT: Multiple unrelated external calls in same EventuallyWithT +assert.EventuallyWithT(t, func(c *assert.CollectT) { + // First external call + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + // Second unrelated external call - WRONG! + status, err := client.Status() + assert.NoError(c, err) + assert.NotNil(c, status) +}, 10*time.Second, 500*time.Millisecond, "mixed operations") +``` + +## Tags-as-Identity Architecture + +### Overview + +Headscale implements a **tags-as-identity** model where tags and user ownership are mutually exclusive ways to identify nodes. This is a fundamental architectural principle that affects node registration, ownership, ACL evaluation, and API behavior. + +### Core Principle: Tags XOR User Ownership + +Every node in Headscale is **either** tagged **or** user-owned, never both: + +- **Tagged Nodes**: Ownership is defined by tags (e.g., `tag:server`, `tag:database`) + - Tags are set during registration via tagged PreAuthKey + - Tags are immutable after registration (cannot be changed via API) + - May have `UserID` set for "created by" tracking, but ownership is via tags + - Identified by: `node.IsTagged()` returns `true` + +- **User-Owned Nodes**: Ownership is defined by user assignment + - Registered via OIDC, web auth, or untagged PreAuthKey + - Node belongs to a specific user's namespace + - No tags (empty tags array) + - Identified by: `node.UserID().Valid() && !node.IsTagged()` + +### Critical Implementation Details + +#### Node Identification Methods + +```go +// Primary methods for determining node ownership +node.IsTagged() // Returns true if node has tags OR AuthKey.Tags +node.HasTag(tag) // Returns true if node has specific tag +node.IsUserOwned() // Returns true if UserID set AND not tagged + +// IMPORTANT: UserID can be set on tagged nodes for tracking! +// Always use IsTagged() to determine actual ownership, not just UserID.Valid() +``` + +#### UserID Field Semantics + +**Critical distinction**: `UserID` has different meanings depending on node type: + +- **Tagged nodes**: `UserID` is optional "created by" tracking + - Indicates which user created the tagged PreAuthKey + - Does NOT define ownership (tags define ownership) + - Example: User "alice" creates tagged PreAuthKey with `tag:server`, node gets `UserID=alice.ID` + `Tags=["tag:server"]` + +- **User-owned nodes**: `UserID` defines ownership + - Required field for non-tagged nodes + - Defines which user namespace the node belongs to + - Example: User "bob" registers via OIDC, node gets `UserID=bob.ID` + `Tags=[]` + +#### Mapper Behavior (mapper/tail.go) + +The mapper converts internal nodes to Tailscale protocol format, handling the TaggedDevices special user: + +```go +// From mapper/tail.go:102-116 +User: func() tailcfg.UserID { + // IMPORTANT: Tags-as-identity model + // Tagged nodes ALWAYS use TaggedDevices user, even if UserID is set + if node.IsTagged() { + return tailcfg.UserID(int64(types.TaggedDevices.ID)) + } + // User-owned nodes: use the actual user ID + return tailcfg.UserID(int64(node.UserID().Get())) +}() +``` + +**TaggedDevices constant** (`types.TaggedDevices.ID = 2147455555`): Special user ID for all tagged nodes in MapResponse protocol. + +#### Registration Flow + +**Tagged Node Registration** (via tagged PreAuthKey): + +1. User creates PreAuthKey with tags: `pak.Tags = ["tag:server"]` +2. Node registers with PreAuthKey +3. Node gets: `Tags = ["tag:server"]`, `UserID = user.ID` (optional tracking), `AuthKeyID = pak.ID` +4. `IsTagged()` returns `true` (ownership via tags) +5. MapResponse sends `User = TaggedDevices.ID` + +**User-Owned Node Registration** (via OIDC/web/untagged PreAuthKey): + +1. User authenticates or uses untagged PreAuthKey +2. Node registers +3. Node gets: `Tags = []`, `UserID = user.ID` (required) +4. `IsTagged()` returns `false` (ownership via user) +5. MapResponse sends `User = user.ID` + +#### API Validation (SetTags) + +The SetTags gRPC API enforces tags-as-identity rules: + +```go +// From grpcv1.go:340-347 +// User-owned nodes are nodes with UserID that are NOT tagged +isUserOwned := nodeView.UserID().Valid() && !nodeView.IsTagged() +if isUserOwned && len(request.GetTags()) > 0 { + return error("cannot set tags on user-owned nodes") +} +``` + +**Key validation rules**: + +- ✅ Can call SetTags on tagged nodes (tags already define ownership) +- ❌ Cannot set tags on user-owned nodes (would violate XOR rule) +- ❌ Cannot remove all tags from tagged nodes (would orphan the node) + +#### Database Layer (db/node.go) + +**Tag storage**: Tags are stored in PostgreSQL ARRAY column and SQLite JSON column: + +```sql +-- From schema.sql +tags TEXT[] DEFAULT '{}' NOT NULL, -- PostgreSQL +tags TEXT DEFAULT '[]' NOT NULL, -- SQLite (JSON array) +``` + +**Validation** (`state/tags.go`): + +- `validateNodeOwnership()`: Enforces tags XOR user rule +- `validateAndNormalizeTags()`: Validates tag format (`tag:name`) and uniqueness + +#### Policy Layer + +**Tag Ownership** (policy/v2/policy.go): + +```go +func NodeCanHaveTag(node types.NodeView, tag string) bool { + // Checks if node's IP is in the tagOwnerMap IP set + // This is IP-based authorization, not UserID-based + if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok { + if slices.ContainsFunc(node.IPs(), ips.Contains) { + return true + } + } + return false +} +``` + +**Important**: Tag authorization is based on IP ranges in ACL, not UserID. Tags define identity, ACL authorizes that identity. + +### Testing Tags-as-Identity + +**Unit Tests** (`hscontrol/types/node_tags_test.go`): + +- `TestNodeIsTagged`: Validates IsTagged() for various scenarios +- `TestNodeOwnershipModel`: Tests tags XOR user ownership +- `TestUserTypedID`: Helper method validation + +**API Tests** (`hscontrol/grpcv1_test.go`): + +- `TestSetTags_UserXORTags`: Validates rejection of setting tags on user-owned nodes +- `TestSetTags_TaggedNode`: Validates that tagged nodes (even with UserID) are not rejected + +**Auth Tests** (`hscontrol/auth_test.go:890-928`): + +- Tests node registration with tagged PreAuthKey +- Validates tags are applied during registration + +### Common Pitfalls + +1. **Don't check only `UserID.Valid()` to determine user ownership** + - ❌ Wrong: `if node.UserID().Valid() { /* user-owned */ }` + - ✅ Correct: `if node.UserID().Valid() && !node.IsTagged() { /* user-owned */ }` + +2. **Don't assume tagged nodes never have UserID set** + - Tagged nodes MAY have UserID for "created by" tracking + - Always use `IsTagged()` to determine ownership type + +3. **Don't allow setting tags on user-owned nodes** + - This violates the tags XOR user principle + - Use API validation to prevent this + +4. **Don't forget TaggedDevices in mapper** + - All tagged nodes MUST use `TaggedDevices.ID` in MapResponse + - User ID is only for actual user-owned nodes + +### Migration Considerations + +When nodes transition between ownership models: + +- **No automatic migration**: Tags-as-identity is set at registration and immutable +- **Re-registration required**: To change from user-owned to tagged (or vice versa), node must be deleted and re-registered +- **UserID persistence**: UserID on tagged nodes is informational and not cleared + +### Architecture Benefits + +The tags-as-identity model provides: + +1. **Clear ownership semantics**: No ambiguity about who/what owns a node +2. **ACL simplicity**: Tag-based access control without user conflicts +3. **API safety**: Validation prevents invalid ownership states +4. **Protocol compatibility**: TaggedDevices special user aligns with Tailscale's model + +## Logging Patterns + +### Incremental Log Event Building + +When building log statements with multiple fields, especially with conditional fields, use the **incremental log event pattern** instead of long single-line chains. This improves readability and allows conditional field addition. + +**Pattern:** + +```go +// GOOD: Incremental building with conditional fields +logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) + +if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) +} else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) +} else { + logEvent = logEvent.Str("user", "none") +} + +logEvent.Msg("Registering node") +``` + +**Key rules:** + +1. **Assign chained calls back to the variable**: `logEvent = logEvent.Str(...)` - zerolog methods return a new event, so you must capture the return value +2. **Use for conditional fields**: When fields depend on runtime conditions, build incrementally +3. **Use for long log lines**: When a log line exceeds ~100 characters, split it for readability +4. **Call `.Msg()` at the end**: The final `.Msg()` or `.Msgf()` sends the log event + +**Anti-pattern to avoid:** + +```go +// BAD: Long single-line chains are hard to read and can't have conditional fields +log.Debug().Caller().Str("node", node.Hostname).Str("machine_key", node.MachineKey.ShortString()).Str("node_key", node.NodeKey.ShortString()).Str("user", node.User.Username()).Msg("Registering node") + +// BAD: Forgetting to assign the return value (field is lost!) +logEvent := log.Debug().Str("node", node.Hostname) +logEvent.Str("user", username) // This field is LOST - not assigned back +logEvent.Msg("message") // Only has "node" field +``` + +**When to use this pattern:** + +- Log statements with 4+ fields +- Any log with conditional fields +- Complex logging in loops or error handling +- When you need to add context incrementally + +**Example from codebase** (`hscontrol/db/node.go`): + +```go +logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) + +if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) +} else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) +} else { + logEvent = logEvent.Str("user", "none") +} + +logEvent.Msg("Registering test node") +``` + +### Avoiding Log Helper Functions + +Prefer the incremental log event pattern over creating helper functions that return multiple logging closures. Helper functions like `logPollFunc` create unnecessary indirection and allocate closures. + +**Instead of:** + +```go +// AVOID: Helper function returning closures +func logPollFunc(req tailcfg.MapRequest, node *types.Node) ( + func(string, ...any), // warnf + func(string, ...any), // infof + func(string, ...any), // tracef + func(error, string, ...any), // errf +) { + return func(msg string, a ...any) { + log.Warn(). + Caller(). + Bool("omitPeers", req.OmitPeers). + Bool("stream", req.Stream). + Uint64("node.id", node.ID.Uint64()). + Str("node.name", node.Hostname). + Msgf(msg, a...) + }, + // ... more closures +} +``` + +**Prefer:** + +```go +// BETTER: Build log events inline with shared context +func (m *mapSession) logTrace(msg string) { + log.Trace(). + Caller(). + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname). + Msg(msg) +} + +// Or use incremental building for complex cases +logEvent := log.Trace(). + Caller(). + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname) + +if additionalContext { + logEvent = logEvent.Str("extra", value) +} + +logEvent.Msg("Operation completed") +``` + +## Important Notes + +- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting) +- **Protocol Buffers**: Changes to `proto/` require `make generate` and should be committed separately +- **Code Style**: Enforced via golangci-lint with golines (width 88) and gofumpt formatting +- **Linting**: ALL code must pass `golangci-lint run --new-from-rev=upstream/main --timeout=5m --fix` before commit +- **Database**: Supports both SQLite (development) and PostgreSQL (production/testing) +- **Integration Tests**: Require Docker and can consume significant disk space - use headscale-integration-tester agent +- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management +- **Quality Assurance**: Always use appropriate specialized agents for testing and validation tasks +- **Tags-as-Identity**: Tags and user ownership are mutually exclusive - always use `IsTagged()` to determine ownership diff --git a/CHANGELOG.md b/CHANGELOG.md index 1a7b0569..13a4e321 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,48 +1,815 @@ # CHANGELOG -## 0.23.0 (2023-XX-XX) +## 0.28.0 (202x-xx-xx) -This release is mainly a code reorganisation and refactoring, significantly improving the maintainability of the codebase. This should allow us to improve further and make it easier for the maintainers to keep on top of the project. +**Minimum supported Tailscale client version: v1.74.0** + +### Tags as identity + +Tags are now implemented following the Tailscale model where tags and user ownership are mutually exclusive. Devices can be either +user-owned (authenticated via web/OIDC) or tagged (authenticated via tagged PreAuthKeys). Tagged devices receive their identity from +tags rather than users, making them suitable for servers and infrastructure. Applying a tag to a device removes user-based +ownership. See the [Tailscale tags documentation](https://tailscale.com/kb/1068/tags) for details on how tags work. + +User-owned nodes can now request tags during registration using `--advertise-tags`. Tags are validated against the `tagOwners` policy +and applied at registration time. Tags can be managed via the CLI or API after registration. Tagged nodes can return to user-owned +by re-authenticating with `tailscale up --advertise-tags= --force-reauth`. + +A one-time migration will validate and migrate any `RequestTags` (stored in hostinfo) to the tags column. Tags are validated against +your policy's `tagOwners` rules during migration. [#3011](https://github.com/juanfont/headscale/pull/3011) + +### Smarter map updates + +The map update system has been rewritten to send smaller, partial updates instead of full network maps whenever possible. This reduces bandwidth usage and improves performance, especially for large networks. The system now properly tracks peer +changes and can send removal notifications when nodes are removed due to policy changes. +[#2856](https://github.com/juanfont/headscale/pull/2856) [#2961](https://github.com/juanfont/headscale/pull/2961) + +### Pre-authentication key security improvements + +Pre-authentication keys now use bcrypt hashing for improved security [#2853](https://github.com/juanfont/headscale/pull/2853). Keys +are stored as a prefix and bcrypt hash instead of plaintext. The full key is only displayed once at creation time. When listing keys, +only the prefix is shown (e.g., `hskey-auth-{prefix}-***`). All new keys use the format `hskey-auth-{prefix}-{secret}`. Legacy plaintext keys in the format `{secret}` will continue to work for backwards compatibility. + +### Web registration templates redesign + +The OIDC callback and device registration web pages have been updated to use the Material for MkDocs design system from the official +documentation. The templates now use consistent typography, spacing, and colours across all registration flows. + +### Database migration support removed for pre-0.25.0 databases + +Headscale no longer supports direct upgrades from databases created before version 0.25.0. Users on older versions must upgrade +sequentially through each stable release, selecting the latest patch version available for each minor release. + +### BREAKING + +- **API**: The Node message in the gRPC/REST API has been simplified - the `ForcedTags`, `InvalidTags`, and `ValidTags` fields have been removed and replaced with a single `Tags` field that contains the node's applied tags [#2993](https://github.com/juanfont/headscale/pull/2993) + - API clients should use the `Tags` field instead of `ValidTags` + - The `headscale nodes list` CLI command now always shows a Tags column and the `--tags` flag has been removed +- **PreAuthKey CLI**: Commands now use ID-based operations instead of user+key combinations [#2992](https://github.com/juanfont/headscale/pull/2992) + - `headscale preauthkeys create` no longer requires `--user` flag (optional for tracking creation) + - `headscale preauthkeys list` lists all keys (no longer filtered by user) + - `headscale preauthkeys expire --id <ID>` replaces `--user <USER> <KEY>` + - `headscale preauthkeys delete --id <ID>` replaces `--user <USER> <KEY>` + + **Before:** + + ```bash + headscale preauthkeys create --user 1 --reusable --tags tag:server + headscale preauthkeys list --user 1 + headscale preauthkeys expire --user 1 <KEY> + headscale preauthkeys delete --user 1 <KEY> + ``` + + **After:** + + ```bash + headscale preauthkeys create --reusable --tags tag:server + headscale preauthkeys list + headscale preauthkeys expire --id 123 + headscale preauthkeys delete --id 123 + ``` + +- **Tags**: The gRPC `SetTags` endpoint now allows converting user-owned nodes to tagged nodes by setting tags. [#2885](https://github.com/juanfont/headscale/pull/2885) +- **Tags**: Tags are now resolved from the node's stored Tags field only [#2931](https://github.com/juanfont/headscale/pull/2931) + - `--advertise-tags` is processed during registration, not on every policy evaluation + - PreAuthKey tagged devices ignore `--advertise-tags` from clients + - User-owned nodes can use `--advertise-tags` if authorized by `tagOwners` policy + - Tags can be managed via CLI (`headscale nodes tag`) or the SetTags API after registration +- Database migration support removed for pre-0.25.0 databases [#2883](https://github.com/juanfont/headscale/pull/2883) + - If you are running a version older than 0.25.0, you must upgrade to 0.25.1 first, then upgrade to this release + - See the [upgrade path documentation](https://headscale.net/stable/about/faq/#what-is-the-recommended-update-path-can-i-skip-multiple-versions-while-updating) for detailed guidance + - In version 0.29, all migrations before 0.28.0 will also be removed +- Remove ability to move nodes between users [#2922](https://github.com/juanfont/headscale/pull/2922) + - The `headscale nodes move` CLI command has been removed + - The `MoveNode` API endpoint has been removed + - Nodes are permanently associated with their user or tag at registration time +- Add `oidc.email_verified_required` config option to control email verification requirement [#2860](https://github.com/juanfont/headscale/pull/2860) + - When `true` (default), only verified emails can authenticate via OIDC in conjunction with `oidc.allowed_domains` or + `oidc.allowed_users`. Previous versions allowed to authenticate with an unverified email but did not store the email + address in the user profile. This is now rejected during authentication with an `unverified email` error. + - When `false`, unverified emails are allowed for OIDC authentication and the email address is stored in the user + profile regardless of its verification state. +- **SSH Policy**: Wildcard (`*`) is no longer supported as an SSH destination [#3009](https://github.com/juanfont/headscale/issues/3009) + - Use `autogroup:member` for user-owned devices + - Use `autogroup:tagged` for tagged devices + - Use specific tags (e.g., `tag:server`) for targeted access + + **Before:** + + ```json + { "action": "accept", "src": ["group:admins"], "dst": ["*"], "users": ["root"] } + ``` + + **After:** + + ```json + { "action": "accept", "src": ["group:admins"], "dst": ["autogroup:member", "autogroup:tagged"], "users": ["root"] } + ``` + +- **SSH Policy**: SSH source/destination validation now enforces Tailscale's security model [#3010](https://github.com/juanfont/headscale/issues/3010) + + Per [Tailscale SSH documentation](https://tailscale.com/kb/1193/tailscale-ssh), the following rules are now enforced: + 1. **Tags cannot SSH to user-owned devices**: SSH rules with `tag:*` or `autogroup:tagged` as source cannot have username destinations (e.g., `alice@`) or `autogroup:member`/`autogroup:self` as destination + 2. **Username destinations require same-user source**: If destination is a specific username (e.g., `alice@`), the source must be that exact same user only. Use `autogroup:self` for same-user SSH access instead + + **Invalid policies now rejected at load time:** + + ```json + // INVALID: tag source to user destination + {"src": ["tag:server"], "dst": ["alice@"], ...} + + // INVALID: autogroup:tagged to autogroup:member + {"src": ["autogroup:tagged"], "dst": ["autogroup:member"], ...} + + // INVALID: group to specific user (use autogroup:self instead) + {"src": ["group:admins"], "dst": ["alice@"], ...} + ``` + + **Valid patterns:** + + ```json + // Users/groups can SSH to their own devices via autogroup:self + {"src": ["group:admins"], "dst": ["autogroup:self"], ...} + + // Users/groups can SSH to tagged devices + {"src": ["group:admins"], "dst": ["autogroup:tagged"], ...} + + // Tagged devices can SSH to other tagged devices + {"src": ["autogroup:tagged"], "dst": ["autogroup:tagged"], ...} + + // Same user can SSH to their own devices + {"src": ["alice@"], "dst": ["alice@"], ...} + ``` + +### Changes + +- Smarter change notifications send partial map updates and node removals instead of full maps [#2961](https://github.com/juanfont/headscale/pull/2961) + - Send lightweight endpoint and DERP region updates instead of full maps [#2856](https://github.com/juanfont/headscale/pull/2856) +- Add NixOS module in repository for faster iteration [#2857](https://github.com/juanfont/headscale/pull/2857) +- Add favicon to webpages [#2858](https://github.com/juanfont/headscale/pull/2858) +- Redesign OIDC callback and registration web templates [#2832](https://github.com/juanfont/headscale/pull/2832) +- Reclaim IPs from the IP allocator when nodes are deleted [#2831](https://github.com/juanfont/headscale/pull/2831) +- Add bcrypt hashing for pre-authentication keys [#2853](https://github.com/juanfont/headscale/pull/2853) +- Add prefix to API keys (`hskey-api-{prefix}-{secret}`) [#2853](https://github.com/juanfont/headscale/pull/2853) +- Add prefix to registration keys for web authentication tracking (`hskey-reg-{random}`) [#2853](https://github.com/juanfont/headscale/pull/2853) +- Tags can now be tagOwner of other tags [#2930](https://github.com/juanfont/headscale/pull/2930) +- Add `taildrop.enabled` configuration option to enable/disable Taildrop file sharing [#2955](https://github.com/juanfont/headscale/pull/2955) +- Allow disabling the metrics server by setting empty `metrics_listen_addr` [#2914](https://github.com/juanfont/headscale/pull/2914) +- Log ACME/autocert errors for easier debugging [#2933](https://github.com/juanfont/headscale/pull/2933) +- Improve CLI list output formatting [#2951](https://github.com/juanfont/headscale/pull/2951) +- Use Debian 13 distroless base images for containers [#2944](https://github.com/juanfont/headscale/pull/2944) +- Fix ACL policy not applied to new OIDC nodes until client restart [#2890](https://github.com/juanfont/headscale/pull/2890) +- Fix autogroup:self preventing visibility of nodes matched by other ACL rules [#2882](https://github.com/juanfont/headscale/pull/2882) +- Fix nodes being rejected after pre-authentication key expiration [#2917](https://github.com/juanfont/headscale/pull/2917) +- Fix list-routes command respecting identifier filter with JSON output [#2927](https://github.com/juanfont/headscale/pull/2927) +- **API Key CLI**: Add `--id` flag to expire/delete commands as alternative to `--prefix` [#3016](https://github.com/juanfont/headscale/pull/3016) + - `headscale apikeys expire --id <ID>` or `--prefix <PREFIX>` + - `headscale apikeys delete --id <ID>` or `--prefix <PREFIX>` + +## 0.27.1 (2025-11-11) + +**Minimum supported Tailscale client version: v1.64.0** + +### Changes + +- Expire nodes with a custom timestamp [#2828](https://github.com/juanfont/headscale/pull/2828) +- Fix issue where node expiry was reset when tailscaled restarts [#2875](https://github.com/juanfont/headscale/pull/2875) +- Fix OIDC authentication when multiple login URLs are opened [#2861](https://github.com/juanfont/headscale/pull/2861) +- Fix node re-registration failing with expired auth keys [#2859](https://github.com/juanfont/headscale/pull/2859) +- Remove old unused database tables and indices [#2844](https://github.com/juanfont/headscale/pull/2844) [#2872](https://github.com/juanfont/headscale/pull/2872) +- Ignore litestream tables during database validation [#2843](https://github.com/juanfont/headscale/pull/2843) +- Fix exit node visibility to respect ACL rules [#2855](https://github.com/juanfont/headscale/pull/2855) +- Fix SSH policy becoming empty when unknown user is referenced [#2874](https://github.com/juanfont/headscale/pull/2874) +- Fix policy validation when using bypass-grpc mode [#2854](https://github.com/juanfont/headscale/pull/2854) +- Fix autogroup:self interaction with other ACL rules [#2842](https://github.com/juanfont/headscale/pull/2842) +- Fix flaky DERP map shuffle test [#2848](https://github.com/juanfont/headscale/pull/2848) +- Use current stable base images for Debian and Alpine containers [#2827](https://github.com/juanfont/headscale/pull/2827) + +## 0.27.0 (2025-10-27) + +**Minimum supported Tailscale client version: v1.64.0** + +### Database integrity improvements + +This release includes a significant database migration that addresses +longstanding issues with the database schema and data integrity that has +accumulated over the years. The migration introduces a `schema.sql` file as the +source of truth for the expected database schema to ensure new migrations that +will cause divergence does not occur again. + +These issues arose from a combination of factors discovered over time: SQLite +foreign keys not being enforced for many early versions, all migrations being +run in one large function until version 0.23.0, and inconsistent use of GORM's +AutoMigrate feature. Moving forward, all new migrations will be explicit SQL +operations rather than relying on GORM AutoMigrate, and foreign keys will be +enforced throughout the migration process. + +We are only improving SQLite databases with this change - PostgreSQL databases +are not affected. + +Please read the +[PR description](https://github.com/juanfont/headscale/pull/2617) for more +technical details about the issues and solutions. + +**SQLite Database Backup Example:** + +```bash +# Stop headscale +systemctl stop headscale + +# Backup sqlite database +cp /var/lib/headscale/db.sqlite /var/lib/headscale/db.sqlite.backup + +# Backup sqlite WAL/SHM files (if they exist) +cp /var/lib/headscale/db.sqlite-wal /var/lib/headscale/db.sqlite-wal.backup +cp /var/lib/headscale/db.sqlite-shm /var/lib/headscale/db.sqlite-shm.backup + +# Start headscale (migration will run automatically) +systemctl start headscale +``` + +### DERPMap update frequency + +The default DERPMap update frequency has been changed from 24 hours to 3 hours. +If you set the `derp.update_frequency` configuration option, it is recommended +to change it to `3h` to ensure that the headscale instance gets the latest +DERPMap updates when upstream is changed. + +### Autogroups + +This release adds support for the three missing autogroups: `self` +(experimental), `member`, and `tagged`. Please refer to the +[documentation](https://tailscale.com/kb/1018/autogroups/) for a detailed +explanation. + +`autogroup:self` is marked as experimental and should be used with caution, but +we need help testing it. Experimental here means two things; first, generating +the packet filter from policies that use `autogroup:self` is very expensive, and +it might perform, or straight up not work on Headscale installations with a +large number of nodes. Second, the implementation might have bugs or edge cases +we are not aware of, meaning that nodes or users might gain _more_ access than +expected. Please report bugs. + +### Node store (in memory database) + +Under the hood, we have added a new datastructure to store nodes in memory. This +datastructure is called `NodeStore` and aims to reduce the reading and writing +of nodes to the database layer. We have not benchmarked it, but expect it to +improve performance for read heavy workloads. We think of it as, "worst case" we +have moved the bottle neck somewhere else, and "best case" we should see a good +improvement in compute resource usage at the expense of memory usage. We are +quite excited for this change and think it will make it easier for us to improve +the code base over time and make it more correct and efficient. + +### BREAKING + +- Remove support for 32-bit binaries [#2692](https://github.com/juanfont/headscale/pull/2692) +- Policy: Zero or empty destination port is no longer allowed [#2606](https://github.com/juanfont/headscale/pull/2606) +- Stricter hostname validation [#2383](https://github.com/juanfont/headscale/pull/2383) + - Hostnames must be valid DNS labels (2-63 characters, alphanumeric and + hyphens only, cannot start/end with hyphen) + - **Client Registration (New Nodes)**: Invalid hostnames are automatically + renamed to `invalid-XXXXXX` format + - `my-laptop` → accepted as-is + - `My-Laptop` → `my-laptop` (lowercased) + - `my_laptop` → `invalid-a1b2c3` (underscore not allowed) + - `test@host` → `invalid-d4e5f6` (@ not allowed) + - `laptop-🚀` → `invalid-j1k2l3` (emoji not allowed) + - **Hostinfo Updates / CLI**: Invalid hostnames are rejected with an error + - Valid names are accepted or lowercased + - Names with invalid characters, too short (<2), too long (>63), or + starting/ending with hyphen are rejected + +### Changes + +- **Database schema migration improvements for SQLite** [#2617](https://github.com/juanfont/headscale/pull/2617) + - **IMPORTANT: Backup your SQLite database before upgrading** + - Introduces safer table renaming migration strategy + - Addresses longstanding database integrity issues +- Add flag to directly manipulate the policy in the database [#2765](https://github.com/juanfont/headscale/pull/2765) +- DERPmap update frequency default changed from 24h to 3h [#2741](https://github.com/juanfont/headscale/pull/2741) +- DERPmap update mechanism has been improved with retry, and is now failing + conservatively, preserving the old map upon failure. + [#2741](https://github.com/juanfont/headscale/pull/2741) +- Add support for `autogroup:member`, `autogroup:tagged` [#2572](https://github.com/juanfont/headscale/pull/2572) +- Fix bug where return routes were being removed by policy [#2767](https://github.com/juanfont/headscale/pull/2767) +- Remove policy v1 code [#2600](https://github.com/juanfont/headscale/pull/2600) +- Refactor Debian/Ubuntu packaging and drop support for Ubuntu 20.04. [#2614](https://github.com/juanfont/headscale/pull/2614) +- Remove redundant check regarding `noise` config [#2658](https://github.com/juanfont/headscale/pull/2658) +- Refactor OpenID Connect documentation [#2625](https://github.com/juanfont/headscale/pull/2625) +- Don't crash if config file is missing [#2656](https://github.com/juanfont/headscale/pull/2656) +- Adds `/robots.txt` endpoint to avoid crawlers [#2643](https://github.com/juanfont/headscale/pull/2643) +- OIDC: Use group claim from UserInfo [#2663](https://github.com/juanfont/headscale/pull/2663) +- OIDC: Update user with claims from UserInfo _before_ comparing with allowed + groups, email and domain + [#2663](https://github.com/juanfont/headscale/pull/2663) +- Policy will now reject invalid fields, making it easier to spot spelling + errors [#2764](https://github.com/juanfont/headscale/pull/2764) +- Add FAQ entry on how to recover from an invalid policy in the database [#2776](https://github.com/juanfont/headscale/pull/2776) +- EXPERIMENTAL: Add support for `autogroup:self` [#2789](https://github.com/juanfont/headscale/pull/2789) +- Add healthcheck command [#2659](https://github.com/juanfont/headscale/pull/2659) + +## 0.26.1 (2025-06-06) + +### Changes + +- Ensure nodes are matching both node key and machine key when connecting. [#2642](https://github.com/juanfont/headscale/pull/2642) + +## 0.26.0 (2025-05-14) + +### BREAKING + +#### Routes + +Route internals have been rewritten, removing the dedicated route table in the +database. This was done to simplify the codebase, which had grown unnecessarily +complex after the routes were split into separate tables. The overhead of having +to go via the database and keeping the state in sync made the code very hard to +reason about and prone to errors. The majority of the route state is only +relevant when headscale is running, and is now only kept in memory. As part of +this, the CLI and API has been simplified to reflect the changes; + +```console +$ headscale nodes list-routes +ID | Hostname | Approved | Available | Serving (Primary) +1 | ts-head-ruqsg8 | | 0.0.0.0/0, ::/0 | +2 | ts-unstable-fq7ob4 | | 0.0.0.0/0, ::/0 | + +$ headscale nodes approve-routes --identifier 1 --routes 0.0.0.0/0,::/0 +Node updated + +$ headscale nodes list-routes +ID | Hostname | Approved | Available | Serving (Primary) +1 | ts-head-ruqsg8 | 0.0.0.0/0, ::/0 | 0.0.0.0/0, ::/0 | 0.0.0.0/0, ::/0 +2 | ts-unstable-fq7ob4 | | 0.0.0.0/0, ::/0 | +``` + +Note that if an exit route is approved (0.0.0.0/0 or ::/0), both IPv4 and IPv6 +will be approved. + +- Route API and CLI has been removed [#2422](https://github.com/juanfont/headscale/pull/2422) +- Routes are now managed via the Node API [#2422](https://github.com/juanfont/headscale/pull/2422) +- Only routes accessible to the node will be sent to the node [#2561](https://github.com/juanfont/headscale/pull/2561) + +#### Policy v2 + +This release introduces a new policy implementation. The new policy is a +complete rewrite, and it introduces some significant quality and consistency +improvements. In principle, there are not really any new features, but some long +standing bugs should have been resolved, or be easier to fix in the future. The +new policy code passes all of our tests. + +**Changes** + +- The policy is validated and "resolved" when loading, providing errors for + invalid rules and conditions. + - Previously this was done as a mix between load and runtime (when it was + applied to a node). + - This means that when you convert the first time, what was previously a + policy that loaded, but failed at runtime, will now fail at load time. +- Error messages should be more descriptive and informative. + - There is still work to be here, but it is already improved with "typing" + (e.g. only Users can be put in Groups) +- All users in the policy must contain an `@` character. + - If your user naturally contains and `@`, like an email, this will just work. + - If its based on usernames, or other identifiers not containing an `@`, an + `@` should be appended at the end. For example, if your user is `john`, it + must be written as `john@` in the policy. + +<details> + +<summary>Migration notes when the policy is stored in the database.</summary> + +This section **only** applies if the policy is stored in the database and +Headscale 0.26 doesn't start due to a policy error +(`failed to load ACL policy`). + +- Start Headscale 0.26 with the environment variable `HEADSCALE_POLICY_V1=1` + set. You can check that Headscale picked up the environment variable by + observing this message during startup: `Using policy manager version: 1` +- Dump the policy to a file: `headscale policy get > policy.json` +- Edit `policy.json` and migrate to policy V2. Use the command + `headscale policy check --file policy.json` to check for policy errors. +- Load the modified policy: `headscale policy set --file policy.json` +- Restart Headscale **without** the environment variable `HEADSCALE_POLICY_V1`. + Headscale should now print the message `Using policy manager version: 2` and + startup successfully. + +</details> + +**SSH** + +The SSH policy has been reworked to be more consistent with the rest of the +policy. In addition, several inconsistencies between our implementation and +Tailscale's upstream has been closed and this might be a breaking change for +some users. Please refer to the +[upstream documentation](https://tailscale.com/kb/1337/acl-syntax#tailscale-ssh) +for more information on which types are allowed in `src`, `dst` and `users`. + +There is one large inconsistency left, we allow `*` as a destination as we +currently do not support `autogroup:self`, `autogroup:member` and +`autogroup:tagged`. The support for `*` will be removed when we have support for +the autogroups. + +**Current state** + +The new policy is passing all tests, both integration and unit tests. This does +not mean it is perfect, but it is a good start. Corner cases that is currently +working in v1 and not tested might be broken in v2 (and vice versa). + +**We do need help testing this code** + +#### Other breaking changes + +- Disallow `server_url` and `base_domain` to be equal [#2544](https://github.com/juanfont/headscale/pull/2544) +- Return full user in API for pre auth keys instead of string [#2542](https://github.com/juanfont/headscale/pull/2542) +- Pre auth key API/CLI now uses ID over username [#2542](https://github.com/juanfont/headscale/pull/2542) +- A non-empty list of global nameservers needs to be specified via + `dns.nameservers.global` if the configuration option `dns.override_local_dns` + is enabled or is not specified in the configuration file. This aligns with + behaviour of tailscale.com. + [#2438](https://github.com/juanfont/headscale/pull/2438) + +### Changes + +- Use Go 1.24 [#2427](https://github.com/juanfont/headscale/pull/2427) +- Add `headscale policy check` command to check policy [#2553](https://github.com/juanfont/headscale/pull/2553) +- `oidc.map_legacy_users` and `oidc.strip_email_domain` has been removed [#2411](https://github.com/juanfont/headscale/pull/2411) +- Add more information to `/debug` endpoint [#2420](https://github.com/juanfont/headscale/pull/2420) + - It is now possible to inspect running goroutines and take profiles + - View of config, policy, filter, ssh policy per node, connected nodes and + DERPmap +- OIDC: Fetch UserInfo to get EmailVerified if necessary [#2493](https://github.com/juanfont/headscale/pull/2493) + - If a OIDC provider doesn't include the `email_verified` claim in its ID + tokens, Headscale will attempt to get it from the UserInfo endpoint. +- OIDC: Try to populate name, email and username from UserInfo [#2545](https://github.com/juanfont/headscale/pull/2545) +- Improve performance by only querying relevant nodes from the database for node + updates [#2509](https://github.com/juanfont/headscale/pull/2509) +- node FQDNs in the netmap will now contain a dot (".") at the end. This aligns + with behaviour of tailscale.com + [#2503](https://github.com/juanfont/headscale/pull/2503) +- Restore support for "Override local DNS" [#2438](https://github.com/juanfont/headscale/pull/2438) +- Add documentation for routes [#2496](https://github.com/juanfont/headscale/pull/2496) + +## 0.25.1 (2025-02-25) + +### Changes + +- Fix issue where registration errors are sent correctly [#2435](https://github.com/juanfont/headscale/pull/2435) +- Fix issue where routes passed on registration were not saved [#2444](https://github.com/juanfont/headscale/pull/2444) +- Fix issue where registration page was displayed twice [#2445](https://github.com/juanfont/headscale/pull/2445) + +## 0.25.0 (2025-02-11) + +### BREAKING + +- Authentication flow has been rewritten [#2374](https://github.com/juanfont/headscale/pull/2374) This change should be + transparent to users with the exception of some buxfixes that has been + discovered and was fixed as part of the rewrite. + - When a node is registered with _a new user_, it will be registered as a new + node ([#2327](https://github.com/juanfont/headscale/issues/2327) and + [#1310](https://github.com/juanfont/headscale/issues/1310)). + - A logged out node logging in with the same user will replace the existing + node. +- Remove support for Tailscale clients older than 1.62 (Capability version 87) [#2405](https://github.com/juanfont/headscale/pull/2405) + +### Changes + +- `oidc.map_legacy_users` is now `false` by default [#2350](https://github.com/juanfont/headscale/pull/2350) +- Print Tailscale version instead of capability versions for outdated nodes [#2391](https://github.com/juanfont/headscale/pull/2391) +- Do not allow renaming of users from OIDC [#2393](https://github.com/juanfont/headscale/pull/2393) +- Change minimum hostname length to 2 [#2393](https://github.com/juanfont/headscale/pull/2393) +- Fix migration error caused by nodes having invalid auth keys [#2412](https://github.com/juanfont/headscale/pull/2412) +- Pre auth keys belonging to a user are no longer deleted with the user [#2396](https://github.com/juanfont/headscale/pull/2396) +- Pre auth keys that are used by a node can no longer be deleted [#2396](https://github.com/juanfont/headscale/pull/2396) +- Rehaul HTTP errors, return better status code and errors to users [#2398](https://github.com/juanfont/headscale/pull/2398) +- Print headscale version and commit on server startup [#2415](https://github.com/juanfont/headscale/pull/2415) + +## 0.24.3 (2025-02-07) + +### Changes + +- Fix migration error caused by nodes having invalid auth keys [#2412](https://github.com/juanfont/headscale/pull/2412) +- Pre auth keys belonging to a user are no longer deleted with the user [#2396](https://github.com/juanfont/headscale/pull/2396) +- Pre auth keys that are used by a node can no longer be deleted [#2396](https://github.com/juanfont/headscale/pull/2396) + +## 0.24.2 (2025-01-30) + +### Changes + +- Fix issue where email and username being equal fails to match in Policy [#2388](https://github.com/juanfont/headscale/pull/2388) +- Delete invalid routes before adding a NOT NULL constraint on node_id [#2386](https://github.com/juanfont/headscale/pull/2386) + +## 0.24.1 (2025-01-23) + +### Changes + +- Fix migration issue with user table for PostgreSQL [#2367](https://github.com/juanfont/headscale/pull/2367) +- Relax username validation to allow emails [#2364](https://github.com/juanfont/headscale/pull/2364) +- Remove invalid routes and add stronger constraints for routes to avoid API + panic [#2371](https://github.com/juanfont/headscale/pull/2371) +- Fix panic when `derp.update_frequency` is 0 [#2368](https://github.com/juanfont/headscale/pull/2368) + +## 0.24.0 (2025-01-17) + +### Security fix: OIDC changes in Headscale 0.24.0 + +The following issue _only_ affects Headscale installations which authenticate +with OIDC. + +_Headscale v0.23.0 and earlier_ identified OIDC users by the "username" part of +their email address (when `strip_email_domain: true`, the default) or whole +email address (when `strip_email_domain: false`). + +Depending on how Headscale and your Identity Provider (IdP) were configured, +only using the `email` claim could allow a malicious user with an IdP account to +take over another Headscale user's account, even when +`strip_email_domain: false`. + +This would also cause a user to lose access to their Headscale account if they +changed their email address. + +_Headscale v0.24.0_ now identifies OIDC users by the `iss` and `sub` claims. +[These are guaranteed by the OIDC specification to be stable and unique](https://openid.net/specs/openid-connect-core-1_0.html#ClaimStability), +even if a user changes email address. A well-designed IdP will typically set +`sub` to an opaque identifier like a UUID or numeric ID, which has no relation +to the user's name or email address. + +Headscale v0.24.0 and later will also automatically update profile fields with +OIDC data on login. This means that users can change those details in your IdP, +and have it populate to Headscale automatically the next time they log in. +However, this may affect the way you reference users in policies. + +Headscale v0.23.0 and earlier never recorded the `iss` and `sub` fields, so all +legacy (existing) OIDC accounts _need to be migrated_ to be properly secured. + +#### What do I need to do to migrate? + +Headscale v0.24.0 has an automatic migration feature, which is enabled by +default (`map_legacy_users: true`). **This will be disabled by default in a +future version of Headscale – any unmigrated users will get new accounts.** + +The migration will mostly be done automatically, with one exception. If your +OIDC does not provide an `email_verified` claim, Headscale will ignore the +`email`. This means that either the administrator will have to mark the user +emails as verified, or ensure the users verify their emails. Any unverified +emails will be ignored, meaning that the users will get new accounts instead of +being migrated. + +After this exception is ensured, make all users log into Headscale with their +account, and Headscale will automatically update the account record. This will +be transparent to the users. + +When all users have logged in, you can disable the automatic migration by +setting `map_legacy_users: false` in your configuration file. + +Please note that `map_legacy_users` will be set to `false` by default in v0.25.0 +and the migration mechanism will be removed in v0.26.0. + +<details> + +<summary>What does automatic migration do?</summary> + +##### What does automatic migration do? + +When automatic migration is enabled (`map_legacy_users: true`), Headscale will +first match an OIDC account to a Headscale account by `iss` and `sub`, and then +fall back to matching OIDC users similarly to how Headscale v0.23.0 did: + +- If `strip_email_domain: true` (the default): the Headscale username matches + the "username" part of their email address. +- If `strip_email_domain: false`: the Headscale username matches the _whole_ + email address. + +On migration, Headscale will change the account's username to their +`preferred_username`. **This could break any ACLs or policies which are +configured to match by username.** + +Like with Headscale v0.23.0 and earlier, this migration only works for users who +haven't changed their email address since their last Headscale login. + +A _successful_ automated migration should otherwise be transparent to users. + +Once a Headscale account has been migrated, it will be _unavailable_ to be +matched by the legacy process. An OIDC login with a matching username, but +_non-matching_ `iss` and `sub` will instead get a _new_ Headscale account. + +Because of the way OIDC works, Headscale's automated migration process can +_only_ work when a user tries to log in after the update. + +Legacy account migration should have no effect on new installations where all +users have a recorded `sub` and `iss`. + +</details> + +<details> + +<summary>What happens when automatic migration is disabled?</summary> + +##### What happens when automatic migration is disabled? + +When automatic migration is disabled (`map_legacy_users: false`), Headscale will +only try to match an OIDC account to a Headscale account by `iss` and `sub`. + +If there is no match, it will get a _new_ Headscale account – even if there was +a legacy account which _could_ have matched and migrated. + +We recommend new Headscale users explicitly disable automatic migration – but it +should otherwise have no effect if every account has a recorded `iss` and `sub`. + +When automatic migration is disabled, the `strip_email_domain` setting will have +no effect. + +</details> + +Special thanks to @micolous for reviewing, proposing and working with us on +these changes. + +#### Other OIDC changes + +Headscale now uses +[the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) +to populate and update user information every time they log in: + +| Headscale profile field | OIDC claim | Notes / examples | +| ----------------------- | -------------------- | --------------------------------------------------------------------------------------------------------- | +| email address | `email` | Only used when `"email_verified": true` | +| display name | `name` | eg: `Sam Smith` | +| username | `preferred_username` | Varies depending on IdP and configuration, eg: `ssmith`, `ssmith@idp.example.com`, `\\example.com\ssmith` | +| profile picture | `picture` | URL to a profile picture or avatar | + +These should show up nicely in the Tailscale client. + +This will also affect the way you +[reference users in policies](https://github.com/juanfont/headscale/pull/2205). + +### BREAKING + +- Remove `dns.use_username_in_magic_dns` configuration option [#2020](https://github.com/juanfont/headscale/pull/2020), + [#2279](https://github.com/juanfont/headscale/pull/2279) + - Having usernames in magic DNS is no longer possible. +- Remove versions older than 1.56 [#2149](https://github.com/juanfont/headscale/pull/2149) + - Clean up old code required by old versions +- User gRPC/API [#2261](https://github.com/juanfont/headscale/pull/2261): + - If you depend on a Headscale Web UI, you should wait with this update until + the UI have been updated to match the new API. + - `GET /api/v1/user/{name}` and `GetUser` have been removed in favour of + `ListUsers` with an ID parameter + - `RenameUser` and `DeleteUser` now require an ID instead of a name. + +### Changes + +- Improved compatibility of built-in DERP server with clients connecting over + WebSocket [#2132](https://github.com/juanfont/headscale/pull/2132) +- Allow nodes to use SSH agent forwarding [#2145](https://github.com/juanfont/headscale/pull/2145) +- Fixed processing of fields in post request in MoveNode rpc [#2179](https://github.com/juanfont/headscale/pull/2179) +- Added conversion of 'Hostname' to 'givenName' in a node with FQDN rules + applied [#2198](https://github.com/juanfont/headscale/pull/2198) +- Fixed updating of hostname and givenName when it is updated in HostInfo [#2199](https://github.com/juanfont/headscale/pull/2199) +- Fixed missing `stable-debug` container tag [#2232](https://github.com/juanfont/headscale/pull/2232) +- Loosened up `server_url` and `base_domain` check. It was overly strict in some + cases. [#2248](https://github.com/juanfont/headscale/pull/2248) +- CLI for managing users now accepts `--identifier` in addition to `--name`, + usage of `--identifier` is recommended + [#2261](https://github.com/juanfont/headscale/pull/2261) +- Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262) +- Support client verify for DERP [#2046](https://github.com/juanfont/headscale/pull/2046) +- Add PKCE Verifier for OIDC [#2314](https://github.com/juanfont/headscale/pull/2314) + +## 0.23.0 (2024-09-18) + +This release was intended to be mainly a code reorganisation and refactoring, +significantly improving the maintainability of the codebase. This should allow +us to improve further and make it easier for the maintainers to keep on top of +the project. However, as you all have noticed, it turned out to become a much +larger, much longer release cycle than anticipated. It has ended up to be a +release with a lot of rewrites and changes to the code base and functionality of +Headscale, cleaning up a lot of technical debt and introducing a lot of +improvements. This does come with some breaking changes, **Please remember to always back up your database between versions** #### Here is a short summary of the broad topics of changes: -Code has been organised into modules, reducing use of global variables/objects, isolating concerns and “putting the right things in the logical place”. +Code has been organised into modules, reducing use of global variables/objects, +isolating concerns and “putting the right things in the logical place”. -The new [policy](https://github.com/juanfont/headscale/tree/main/hscontrol/policy) and [mapper](https://github.com/juanfont/headscale/tree/main/hscontrol/mapper) package, containing the ACL/Policy logic and the logic for creating the data served to clients (the network “map”) has been rewritten and improved. This change has allowed us to finish SSH support and add additional tests throughout the code to ensure correctness. +The new +[policy](https://github.com/juanfont/headscale/tree/main/hscontrol/policy) and +[mapper](https://github.com/juanfont/headscale/tree/main/hscontrol/mapper) +package, containing the ACL/Policy logic and the logic for creating the data +served to clients (the network “map”) has been rewritten and improved. This +change has allowed us to finish SSH support and add additional tests throughout +the code to ensure correctness. -The [“poller”, or streaming logic](https://github.com/juanfont/headscale/blob/main/hscontrol/poll.go) has been rewritten and instead of keeping track of the latest updates, checking at a fixed interval, it now uses go channels, implemented in our new [notifier](https://github.com/juanfont/headscale/tree/main/hscontrol/notifier) package and it allows us to send updates to connected clients immediately. This should both improve performance and potential latency before a client picks up an update. +The +[“poller”, or streaming logic](https://github.com/juanfont/headscale/blob/main/hscontrol/poll.go) +has been rewritten and instead of keeping track of the latest updates, checking +at a fixed interval, it now uses go channels, implemented in our new +[notifier](https://github.com/juanfont/headscale/tree/main/hscontrol/notifier) +package and it allows us to send updates to connected clients immediately. This +should both improve performance and potential latency before a client picks up +an update. -Headscale now supports sending “delta” updates, thanks to the new mapper and poller logic, allowing us to only inform nodes about new nodes, changed nodes and removed nodes. Previously we sent the entire state of the network every time an update was due. +Headscale now supports sending “delta” updates, thanks to the new mapper and +poller logic, allowing us to only inform nodes about new nodes, changed nodes +and removed nodes. Previously we sent the entire state of the network every time +an update was due. -While we have a pretty good [test harness](https://github.com/search?q=repo%3Ajuanfont%2Fheadscale+path%3A_test.go&type=code) for validating our changes, we have rewritten over [10000 lines of code](https://github.com/juanfont/headscale/compare/b01f1f1867136d9b2d7b1392776eb363b482c525...main) and bugs are expected. We need help testing this release. In addition, while we think the performance should in general be better, there might be regressions in parts of the platform, particularly where we prioritised correctness over speed. +While we have a pretty good +[test harness](https://github.com/search?q=repo%3Ajuanfont%2Fheadscale+path%3A_test.go&type=code) +for validating our changes, the changes came down to +[284 changed files with 32,316 additions and 24,245 deletions](https://github.com/juanfont/headscale/compare/b01f1f1867136d9b2d7b1392776eb363b482c525...ed78ecd) +and bugs are expected. We need help testing this release. In addition, while we +think the performance should in general be better, there might be regressions in +parts of the platform, particularly where we prioritised correctness over speed. -There are also several bugfixes that has been encountered and fixed as part of implementing these changes, particularly -after improving the test harness as part of adopting [#1460](https://github.com/juanfont/headscale/pull/1460). +There are also several bugfixes that has been encountered and fixed as part of +implementing these changes, particularly after improving the test harness as +part of adopting [#1460](https://github.com/juanfont/headscale/pull/1460). ### BREAKING -- Code reorganisation, a lot of code has moved, please review the following PRs accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) +- Code reorganisation, a lot of code has moved, please review the following PRs + accordingly [#1473](https://github.com/juanfont/headscale/pull/1473) +- Change the structure of database configuration, see + [config-example.yaml](./config-example.yaml) for the new structure. + [#1700](https://github.com/juanfont/headscale/pull/1700) + - Old structure has been remove and the configuration _must_ be converted. + - Adds additional configuration for PostgreSQL for setting max open, idle + connection and idle connection lifetime. - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - - The latest supported client is 1.32 + - The oldest supported client is 1.42 - Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564) - - If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url. + - If no DERP is configured, the server will fail to start, this can be because + it cannot load the DERPMap from file or url. - Embedded DERP server requires a private key [#1611](https://github.com/juanfont/headscale/pull/1611) - - Add a filepath entry to [`derp.server.private_key_path`](https://github.com/juanfont/headscale/blob/b35993981297e18393706b2c963d6db882bba6aa/config-example.yaml#L95) + - Add a filepath entry to + [`derp.server.private_key_path`](https://github.com/juanfont/headscale/blob/b35993981297e18393706b2c963d6db882bba6aa/config-example.yaml#L95) +- Docker images are now built with goreleaser (ko) [#1716](https://github.com/juanfont/headscale/pull/1716) + [#1763](https://github.com/juanfont/headscale/pull/1763) + - Entrypoint of container image has changed from shell to headscale, require + change from `headscale serve` to `serve` + - `/var/lib/headscale` and `/var/run/headscale` is no longer created + automatically, see [container docs](./docs/setup/install/container.md) +- Prefixes are now defined per v4 and v6 range. [#1756](https://github.com/juanfont/headscale/pull/1756) + - `ip_prefixes` option is now `prefixes.v4` and `prefixes.v6` + - `prefixes.allocation` can be set to assign IPs at `sequential` or `random`. + [#1869](https://github.com/juanfont/headscale/pull/1869) +- MagicDNS domains no longer contain usernames []() + - This is in preparation to fix Headscales implementation of tags which + currently does not correctly remove the link between a tagged device and a + user. As tagged devices will not have a user, this will require a change to + the DNS generation, removing the username, see + [#1369](https://github.com/juanfont/headscale/issues/1369) for more + information. + - `use_username_in_magic_dns` can be used to turn this behaviour on again, but + note that this option _will be removed_ when tags are fixed. + - dns.base_domain can no longer be the same as (or part of) server_url. + - This option brings Headscales behaviour in line with Tailscale. +- YAML files are no longer supported for headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792) + - HuJSON is now the only supported format for policy. +- DNS configuration has been restructured [#2034](https://github.com/juanfont/headscale/pull/2034) + - Please review the new [config-example.yaml](./config-example.yaml) for the + new structure. ### Changes -Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) -Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) -SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) -State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) -Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on [#1460](https://github.com/juanfont/headscale/pull/1460) -Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) taken from [#1480](https://github.com/juanfont/headscale/pull/1480) -Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) -Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) -Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) fixes [#1259](https://github.com/juanfont/headscale/issues/1259) +- Use versioned migrations [#1644](https://github.com/juanfont/headscale/pull/1644) +- Make the OIDC callback page better [#1484](https://github.com/juanfont/headscale/pull/1484) +- SSH support [#1487](https://github.com/juanfont/headscale/pull/1487) +- State management has been improved [#1492](https://github.com/juanfont/headscale/pull/1492) +- Use error group handling to ensure tests actually pass [#1535](https://github.com/juanfont/headscale/pull/1535) based on + [#1460](https://github.com/juanfont/headscale/pull/1460) +- Fix hang on SIGTERM [#1492](https://github.com/juanfont/headscale/pull/1492) + taken from [#1480](https://github.com/juanfont/headscale/pull/1480) +- Send logs to stderr by default [#1524](https://github.com/juanfont/headscale/pull/1524) +- Fix [TS-2023-006](https://tailscale.com/security-bulletins/#ts-2023-006) + security UPnP issue [#1563](https://github.com/juanfont/headscale/pull/1563) +- Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) + fixes [#1259](https://github.com/juanfont/headscale/issues/1259) +- Added the possibility to manually create a DERP-map entry which can be + customized, instead of automatically creating it. + [#1565](https://github.com/juanfont/headscale/pull/1565) +- Add support for deleting api keys [#1702](https://github.com/juanfont/headscale/pull/1702) +- Add command to backfill IP addresses for nodes missing IPs from configured + prefixes. [#1869](https://github.com/juanfont/headscale/pull/1869) +- Log available update as warning [#1877](https://github.com/juanfont/headscale/pull/1877) +- Add `autogroup:internet` to Policy [#1917](https://github.com/juanfont/headscale/pull/1917) +- Restore foreign keys and add constraints [#1562](https://github.com/juanfont/headscale/pull/1562) +- Make registration page easier to use on mobile devices +- Make write-ahead-log default on and configurable for SQLite [#1985](https://github.com/juanfont/headscale/pull/1985) +- Add APIs for managing headscale policy. [#1792](https://github.com/juanfont/headscale/pull/1792) +- Fix for registering nodes using preauthkeys when running on a postgres + database in a non-UTC timezone. + [#764](https://github.com/juanfont/headscale/issues/764) +- Make sure integration tests cover postgres for all scenarios +- CLI commands (all except `serve`) only requires minimal configuration, no more + errors or warnings from unset settings + [#2109](https://github.com/juanfont/headscale/pull/2109) +- CLI results are now concistently sent to stdout and errors to stderr [#2109](https://github.com/juanfont/headscale/pull/2109) +- Fix issue where shutting down headscale would hang [#2113](https://github.com/juanfont/headscale/pull/2113) ## 0.22.3 (2023-05-12) @@ -55,12 +822,13 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f ### Changes - Add environment flags to enable pprof (profiling) [#1382](https://github.com/juanfont/headscale/pull/1382) - - Profiles are continously generated in our integration tests. + - Profiles are continuously generated in our integration tests. - Fix systemd service file location in `.deb` packages [#1391](https://github.com/juanfont/headscale/pull/1391) - Improvements on Noise implementation [#1379](https://github.com/juanfont/headscale/pull/1379) -- Replace node filter logic, ensuring nodes with access can see eachother [#1381](https://github.com/juanfont/headscale/pull/1381) +- Replace node filter logic, ensuring nodes with access can see each other [#1381](https://github.com/juanfont/headscale/pull/1381) - Disable (or delete) both exit routes at the same time [#1428](https://github.com/juanfont/headscale/pull/1428) -- Ditch distroless for Docker image, create default socket dir in `/var/run/headscale` [#1450](https://github.com/juanfont/headscale/pull/1450) +- Ditch distroless for Docker image, create default socket dir in + `/var/run/headscale` [#1450](https://github.com/juanfont/headscale/pull/1450) ## 0.22.1 (2023-04-20) @@ -75,8 +843,11 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f - Add `.deb` packages to release process [#1297](https://github.com/juanfont/headscale/pull/1297) - Update and simplify the documentation to use new `.deb` packages [#1349](https://github.com/juanfont/headscale/pull/1349) - Add 32-bit Arm platforms to release process [#1297](https://github.com/juanfont/headscale/pull/1297) -- Fix longstanding bug that would prevent "\*" from working properly in ACLs (issue [#699](https://github.com/juanfont/headscale/issues/699)) [#1279](https://github.com/juanfont/headscale/pull/1279) -- Fix issue where IPv6 could not be used in, or while using ACLs (part of [#809](https://github.com/juanfont/headscale/issues/809)) [#1339](https://github.com/juanfont/headscale/pull/1339) +- Fix longstanding bug that would prevent "\*" from working properly in ACLs + (issue [#699](https://github.com/juanfont/headscale/issues/699)) + [#1279](https://github.com/juanfont/headscale/pull/1279) +- Fix issue where IPv6 could not be used in, or while using ACLs (part of [#809](https://github.com/juanfont/headscale/issues/809)) + [#1339](https://github.com/juanfont/headscale/pull/1339) - Target Go 1.20 and Tailscale 1.38 for Headscale [#1323](https://github.com/juanfont/headscale/pull/1323) ## 0.21.0 (2023-03-20) @@ -96,7 +867,8 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f - Align behaviour of `dns_config.restricted_nameservers` to tailscale [#1162](https://github.com/juanfont/headscale/pull/1162) - Make OpenID Connect authenticated client expiry time configurable [#1191](https://github.com/juanfont/headscale/pull/1191) - defaults to 180 days like Tailscale SaaS - - adds option to use the expiry time from the OpenID token for the node (see config-example.yaml) + - adds option to use the expiry time from the OpenID token for the node (see + config-example.yaml) - Set ControlTime in Map info sent to nodes [#1195](https://github.com/juanfont/headscale/pull/1195) - Populate Tags field on Node updates sent [#1195](https://github.com/juanfont/headscale/pull/1195) @@ -106,7 +878,8 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f - Rename Namespace to User [#1144](https://github.com/juanfont/headscale/pull/1144) - **BACKUP your database before upgrading** -- Command line flags previously taking `--namespace` or `-n` will now require `--user` or `-u` +- Command line flags previously taking `--namespace` or `-n` will now require + `--user` or `-u` ## 0.18.0 (2023-01-14) @@ -134,33 +907,45 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f ### BREAKING -- `noise.private_key_path` has been added and is required for the new noise protocol. -- Log level option `log_level` was moved to a distinct `log` config section and renamed to `level` [#768](https://github.com/juanfont/headscale/pull/768) +- `noise.private_key_path` has been added and is required for the new noise + protocol. +- Log level option `log_level` was moved to a distinct `log` config section and + renamed to `level` [#768](https://github.com/juanfont/headscale/pull/768) - Removed Alpine Linux container image [#962](https://github.com/juanfont/headscale/pull/962) ### Important Changes - Added support for Tailscale TS2021 protocol [#738](https://github.com/juanfont/headscale/pull/738) -- Add experimental support for [SSH ACL](https://tailscale.com/kb/1018/acls/#tailscale-ssh) (see docs for limitations) [#847](https://github.com/juanfont/headscale/pull/847) +- Add experimental support for + [SSH ACL](https://tailscale.com/kb/1018/acls/#tailscale-ssh) (see docs for + limitations) [#847](https://github.com/juanfont/headscale/pull/847) - Please note that this support should be considered _partially_ implemented - SSH ACLs status: - - Support `accept` and `check` (SSH can be enabled and used for connecting and authentication) - - Rejecting connections **are not supported**, meaning that if you enable SSH, then assume that _all_ `ssh` connections **will be allowed**. - - If you decied to try this feature, please carefully managed permissions by blocking port `22` with regular ACLs or do _not_ set `--ssh` on your clients. - - We are currently improving our testing of the SSH ACLs, help us get an overview by testing and giving feedback. - - This feature should be considered dangerous and it is disabled by default. Enable by setting `HEADSCALE_EXPERIMENTAL_FEATURE_SSH=1`. + - Support `accept` and `check` (SSH can be enabled and used for connecting + and authentication) + - Rejecting connections **are not supported**, meaning that if you enable + SSH, then assume that _all_ `ssh` connections **will be allowed**. + - If you decided to try this feature, please carefully managed permissions + by blocking port `22` with regular ACLs or do _not_ set `--ssh` on your + clients. + - We are currently improving our testing of the SSH ACLs, help us get an + overview by testing and giving feedback. + - This feature should be considered dangerous and it is disabled by default. + Enable by setting `HEADSCALE_EXPERIMENTAL_FEATURE_SSH=1`. ### Changes - Add ability to specify config location via env var `HEADSCALE_CONFIG` [#674](https://github.com/juanfont/headscale/issues/674) - Target Go 1.19 for Headscale [#778](https://github.com/juanfont/headscale/pull/778) - Target Tailscale v1.30.0 to build Headscale [#780](https://github.com/juanfont/headscale/pull/780) -- Give a warning when running Headscale with reverse proxy improperly configured for WebSockets [#788](https://github.com/juanfont/headscale/pull/788) +- Give a warning when running Headscale with reverse proxy improperly configured + for WebSockets [#788](https://github.com/juanfont/headscale/pull/788) - Fix subnet routers with Primary Routes [#811](https://github.com/juanfont/headscale/pull/811) - Added support for JSON logs [#653](https://github.com/juanfont/headscale/issues/653) - Sanitise the node key passed to registration url [#823](https://github.com/juanfont/headscale/pull/823) - Add support for generating pre-auth keys with tags [#767](https://github.com/juanfont/headscale/pull/767) -- Add support for evaluating `autoApprovers` ACL entries when a machine is registered [#763](https://github.com/juanfont/headscale/pull/763) +- Add support for evaluating `autoApprovers` ACL entries when a machine is + registered [#763](https://github.com/juanfont/headscale/pull/763) - Add config flag to allow Headscale to start if OIDC provider is down [#829](https://github.com/juanfont/headscale/pull/829) - Fix prefix length comparison bug in AutoApprovers route evaluation [#862](https://github.com/juanfont/headscale/pull/862) - Random node DNS suffix only applied if names collide in namespace. [#766](https://github.com/juanfont/headscale/issues/766) @@ -168,7 +953,8 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f - Add `dns_config.override_local_dns` option [#905](https://github.com/juanfont/headscale/pull/905) - Fix some DNS config issues [#660](https://github.com/juanfont/headscale/issues/660) - Make it possible to disable TS2019 with build flag [#928](https://github.com/juanfont/headscale/pull/928) -- Fix OIDC registration issues [#960](https://github.com/juanfont/headscale/pull/960) and [#971](https://github.com/juanfont/headscale/pull/971) +- Fix OIDC registration issues [#960](https://github.com/juanfont/headscale/pull/960) and + [#971](https://github.com/juanfont/headscale/pull/971) - Add support for specifying NextDNS DNS-over-HTTPS resolver [#940](https://github.com/juanfont/headscale/pull/940) - Make more sslmode available for postgresql connection [#927](https://github.com/juanfont/headscale/pull/927) @@ -196,8 +982,9 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f ### Changes - Updated dependencies (including the library that lacked armhf support) [#722](https://github.com/juanfont/headscale/pull/722) -- Fix missing group expansion in function `excludeCorretlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) -- Improve registration protocol implementation and switch to NodeKey as main identifier [#725](https://github.com/juanfont/headscale/pull/725) +- Fix missing group expansion in function `excludeCorrectlyTaggedNodes` [#563](https://github.com/juanfont/headscale/issues/563) +- Improve registration protocol implementation and switch to NodeKey as main + identifier [#725](https://github.com/juanfont/headscale/pull/725) - Add ability to connect to PostgreSQL via unix socket [#734](https://github.com/juanfont/headscale/pull/734) ## 0.16.0 (2022-07-25) @@ -206,7 +993,8 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f ### BREAKING -- Old ACL syntax is no longer supported ("users" & "ports" -> "src" & "dst"). Please check [the new syntax](https://tailscale.com/kb/1018/acls/). +- Old ACL syntax is no longer supported ("users" & "ports" -> "src" & "dst"). + Please check [the new syntax](https://tailscale.com/kb/1018/acls/). ### Changes @@ -216,28 +1004,40 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f - Fix send on closed channel crash in polling [#542](https://github.com/juanfont/headscale/pull/542) - Fixed spurious calls to setLastStateChangeToNow from ephemeral nodes [#566](https://github.com/juanfont/headscale/pull/566) - Add command for moving nodes between namespaces [#362](https://github.com/juanfont/headscale/issues/362) -- Added more configuration parameters for OpenID Connect (scopes, free-form paramters, domain and user allowlist) +- Added more configuration parameters for OpenID Connect (scopes, free-form + parameters, domain and user allowlist) - Add command to set tags on a node [#525](https://github.com/juanfont/headscale/issues/525) - Add command to view tags of nodes [#356](https://github.com/juanfont/headscale/issues/356) - Add --all (-a) flag to enable routes command [#360](https://github.com/juanfont/headscale/issues/360) - Fix issue where nodes was not updated across namespaces [#560](https://github.com/juanfont/headscale/pull/560) - Add the ability to rename a nodes name [#560](https://github.com/juanfont/headscale/pull/560) - - Node DNS names are now unique, a random suffix will be added when a node joins - - This change contains database changes, remember to **backup** your database before upgrading + - Node DNS names are now unique, a random suffix will be added when a node + joins + - This change contains database changes, remember to **backup** your database + before upgrading - Add option to enable/disable logtail (Tailscale's logging infrastructure) [#596](https://github.com/juanfont/headscale/pull/596) - This change disables the logs by default -- Use [Prometheus]'s duration parser, supporting days (`d`), weeks (`w`) and years (`y`) [#598](https://github.com/juanfont/headscale/pull/598) +- Use [Prometheus]'s duration parser, supporting days (`d`), weeks (`w`) and + years (`y`) [#598](https://github.com/juanfont/headscale/pull/598) - Add support for reloading ACLs with SIGHUP [#601](https://github.com/juanfont/headscale/pull/601) - Use new ACL syntax [#618](https://github.com/juanfont/headscale/pull/618) -- Add -c option to specify config file from command line [#285](https://github.com/juanfont/headscale/issues/285) [#612](https://github.com/juanfont/headscale/pull/601) -- Add configuration option to allow Tailscale clients to use a random WireGuard port. [kb/1181/firewalls](https://tailscale.com/kb/1181/firewalls) [#624](https://github.com/juanfont/headscale/pull/624) -- Improve obtuse UX regarding missing configuration (`ephemeral_node_inactivity_timeout` not set) [#639](https://github.com/juanfont/headscale/pull/639) +- Add -c option to specify config file from command line [#285](https://github.com/juanfont/headscale/issues/285) + [#612](https://github.com/juanfont/headscale/pull/601) +- Add configuration option to allow Tailscale clients to use a random WireGuard + port. [kb/1181/firewalls](https://tailscale.com/kb/1181/firewalls) + [#624](https://github.com/juanfont/headscale/pull/624) +- Improve obtuse UX regarding missing configuration + (`ephemeral_node_inactivity_timeout` not set) + [#639](https://github.com/juanfont/headscale/pull/639) - Fix nodes being shown as 'offline' in `tailscale status` [#648](https://github.com/juanfont/headscale/pull/648) - Improve shutdown behaviour [#651](https://github.com/juanfont/headscale/pull/651) -- Drop Gin as web framework in Headscale [648](https://github.com/juanfont/headscale/pull/648) [677](https://github.com/juanfont/headscale/pull/677) +- Drop Gin as web framework in Headscale + [648](https://github.com/juanfont/headscale/pull/648) + [677](https://github.com/juanfont/headscale/pull/677) - Make tailnet node updates check interval configurable [#675](https://github.com/juanfont/headscale/pull/675) - Fix regression with HTTP API [#684](https://github.com/juanfont/headscale/pull/684) -- nodes ls now print both Hostname and Name(Issue [#647](https://github.com/juanfont/headscale/issues/647) PR [#687](https://github.com/juanfont/headscale/pull/687)) +- nodes ls now print both Hostname and Name(Issue [#647](https://github.com/juanfont/headscale/issues/647) PR + [#687](https://github.com/juanfont/headscale/pull/687)) ## 0.15.0 (2022-03-20) @@ -245,9 +1045,11 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f ### BREAKING -- Boundaries between Namespaces has been removed and all nodes can communicate by default [#357](https://github.com/juanfont/headscale/pull/357) - - To limit access between nodes, use [ACLs](./docs/acls.md). -- `/metrics` is now a configurable host:port endpoint: [#344](https://github.com/juanfont/headscale/pull/344). You must update your `config.yaml` file to include: +- Boundaries between Namespaces has been removed and all nodes can communicate + by default [#357](https://github.com/juanfont/headscale/pull/357) + - To limit access between nodes, use [ACLs](./docs/ref/acls.md). +- `/metrics` is now a configurable host:port endpoint: [#344](https://github.com/juanfont/headscale/pull/344). You must update your + `config.yaml` file to include: ```yaml metrics_listen_addr: 127.0.0.1:9090 ``` @@ -257,41 +1059,51 @@ Turn off gRPC logging [#1640](https://github.com/juanfont/headscale/pull/1640) f - Add support for writing ACL files with YAML [#359](https://github.com/juanfont/headscale/pull/359) - Users can now use emails in ACL's groups [#372](https://github.com/juanfont/headscale/issues/372) - Add shorthand aliases for commands and subcommands [#376](https://github.com/juanfont/headscale/pull/376) -- Add `/windows` endpoint for Windows configuration instructions + registry file download [#392](https://github.com/juanfont/headscale/pull/392) +- Add `/windows` endpoint for Windows configuration instructions + registry file + download [#392](https://github.com/juanfont/headscale/pull/392) - Added embedded DERP (and STUN) server into Headscale [#388](https://github.com/juanfont/headscale/pull/388) ### Changes -- Fix a bug were the same IP could be assigned to multiple hosts if joined in quick succession [#346](https://github.com/juanfont/headscale/pull/346) +- Fix a bug were the same IP could be assigned to multiple hosts if joined in + quick succession [#346](https://github.com/juanfont/headscale/pull/346) - Simplify the code behind registration of machines [#366](https://github.com/juanfont/headscale/pull/366) - - Nodes are now only written to database if they are registrated successfully -- Fix a limitation in the ACLs that prevented users to write rules with `*` as source [#374](https://github.com/juanfont/headscale/issues/374) -- Reduce the overhead of marshal/unmarshal for Hostinfo, routes and endpoints by using specific types in Machine [#371](https://github.com/juanfont/headscale/pull/371) -- Apply normalization function to FQDN on hostnames when hosts registers and retrieve informations [#363](https://github.com/juanfont/headscale/issues/363) + - Nodes are now only written to database if they are registered successfully +- Fix a limitation in the ACLs that prevented users to write rules with `*` as + source [#374](https://github.com/juanfont/headscale/issues/374) +- Reduce the overhead of marshal/unmarshal for Hostinfo, routes and endpoints by + using specific types in Machine + [#371](https://github.com/juanfont/headscale/pull/371) +- Apply normalization function to FQDN on hostnames when hosts registers and + retrieve information [#363](https://github.com/juanfont/headscale/issues/363) - Fix a bug that prevented the use of `tailscale logout` with OIDC [#508](https://github.com/juanfont/headscale/issues/508) -- Added Tailscale repo HEAD and unstable releases channel to the integration tests targets [#513](https://github.com/juanfont/headscale/pull/513) +- Added Tailscale repo HEAD and unstable releases channel to the integration + tests targets [#513](https://github.com/juanfont/headscale/pull/513) ## 0.14.0 (2022-02-24) -**UPCOMING ### BREAKING -From the **next\*\* version (`0.15.0`), all machines will be able to communicate regardless of -if they are in the same namespace. This means that the behaviour currently limited to ACLs -will become default. From version `0.15.0`, all limitation of communications must be done -with ACLs. +**UPCOMING ### BREAKING From the **next\*\* version (`0.15.0`), all machines +will be able to communicate regardless of if they are in the same namespace. +This means that the behaviour currently limited to ACLs will become default. +From version `0.15.0`, all limitation of communications must be done with ACLs. -This is a part of aligning `headscale`'s behaviour with Tailscale's upstream behaviour. +This is a part of aligning `headscale`'s behaviour with Tailscale's upstream +behaviour. ### BREAKING -- ACLs have been rewritten to align with the bevaviour Tailscale Control Panel provides. **NOTE:** This is only active if you use ACLs +- ACLs have been rewritten to align with the bevaviour Tailscale Control Panel + provides. **NOTE:** This is only active if you use ACLs - Namespaces are now treated as Users - All machines can communicate with all machines by default - - Tags should now work correctly and adding a host to Headscale should now reload the rules. - - The documentation have a [fictional example](docs/acls.md) that should cover some use cases of the ACLs features + - Tags should now work correctly and adding a host to Headscale should now + reload the rules. + - The documentation have a [fictional example](./docs/ref/acls.md) that should + cover some use cases of the ACLs features ### Features -- Add support for configurable mTLS [docs](docs/tls.md#configuring-mutual-tls-authentication-mtls) [#297](https://github.com/juanfont/headscale/pull/297) +- Add support for configurable mTLS [docs](./docs/ref/tls.md) [#297](https://github.com/juanfont/headscale/pull/297) ### Changes @@ -303,12 +1115,14 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh - Add IPv6 support to the prefix assigned to namespaces - Add API Key support - - Enable remote control of `headscale` via CLI [docs](docs/remote-cli.md) + - Enable remote control of `headscale` via CLI + [docs](./docs/ref/api.md#grpc) - Enable HTTP API (beta, subject to change) - OpenID Connect users will be mapped per namespaces - Each user will get its own namespace, created if it does not exist - `oidc.domain_map` option has been removed - - `strip_email_domain` option has been added (see [config-example.yaml](./config-example.yaml)) + - `strip_email_domain` option has been added (see + [config-example.yaml](./config-example.yaml)) ### Changes @@ -324,7 +1138,9 @@ This is a part of aligning `headscale`'s behaviour with Tailscale's upstream beh - Make gRPC Unix Socket permissions configurable [#292](https://github.com/juanfont/headscale/pull/292) - Trim whitespace before reading Private Key from file [#289](https://github.com/juanfont/headscale/pull/289) - Add new command to generate a private key for `headscale` [#290](https://github.com/juanfont/headscale/pull/290) -- Fixed issue where hosts deleted from control server may be written back to the database, as long as they are connected to the control server [#278](https://github.com/juanfont/headscale/pull/278) +- Fixed issue where hosts deleted from control server may be written back to the + database, as long as they are connected to the control server + [#278](https://github.com/juanfont/headscale/pull/278) ## 0.12.3 (2022-01-13) @@ -345,12 +1161,14 @@ Happy New Year! ## 0.12.1 (2021-12-24) -(We are skipping 0.12.0 to correct a mishap done weeks ago with the version tagging) +(We are skipping 0.12.0 to correct a mishap done weeks ago with the version +tagging) ### BREAKING - Upgrade to Tailscale 1.18 [#229](https://github.com/juanfont/headscale/pull/229) - - This change requires a new format for private key, private keys are now generated automatically: + - This change requires a new format for private key, private keys are now + generated automatically: 1. Delete your current key 2. Restart `headscale`, a new key will be generated. 3. Restart all Tailscale clients to fetch the new key @@ -363,8 +1181,10 @@ Happy New Year! ### Features - Add gRPC and HTTP API (HTTP API is currently disabled) [#204](https://github.com/juanfont/headscale/pull/204) -- Use gRPC between the CLI and the server [#206](https://github.com/juanfont/headscale/pull/206), [#212](https://github.com/juanfont/headscale/pull/212) -- Beta OpenID Connect support [#126](https://github.com/juanfont/headscale/pull/126), [#227](https://github.com/juanfont/headscale/pull/227) +- Use gRPC between the CLI and the server [#206](https://github.com/juanfont/headscale/pull/206), + [#212](https://github.com/juanfont/headscale/pull/212) +- Beta OpenID Connect support [#126](https://github.com/juanfont/headscale/pull/126), + [#227](https://github.com/juanfont/headscale/pull/227) ## 0.11.0 (2021-10-25) diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 00000000..43c994c2 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1 @@ +@AGENTS.md diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md index 14844982..722a543e 100644 --- a/CODE_OF_CONDUCT.md +++ b/CODE_OF_CONDUCT.md @@ -62,7 +62,7 @@ event. Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement -at our Discord channel. All complaints +on our [Discord server](https://discord.gg/c84AZQhmpx). All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 00000000..4c3ca130 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,34 @@ +# Contributing + +Headscale is "Open Source, acknowledged contribution", this means that any contribution will have to be discussed with the maintainers before being added to the project. +This model has been chosen to reduce the risk of burnout by limiting the maintenance overhead of reviewing and validating third-party code. + +## Why do we have this model? + +Headscale has a small maintainer team that tries to balance working on the project, fixing bugs and reviewing contributions. + +When we work on issues ourselves, we develop first hand knowledge of the code and it makes it possible for us to maintain and own the code as the project develops. + +Code contributions are seen as a positive thing. People enjoy and engage with our project, but it also comes with some challenges; we have to understand the code, we have to understand the feature, we might have to become familiar with external libraries or services and we think about security implications. All those steps are required during the reviewing process. After the code has been merged, the feature has to be maintained. Any changes reliant on external services must be updated and expanded accordingly. + +The review and day-1 maintenance adds a significant burden on the maintainers. Often we hope that the contributor will help out, but we found that most of the time, they disappear after their new feature was added. + +This means that when someone contributes, we are mostly happy about it, but we do have to run it through a series of checks to establish if we actually can maintain this feature. + +## What do we require? + +A general description is provided here and an explicit list is provided in our pull request template. + +All new features have to start out with a design document, which should be discussed on the issue tracker (not discord). It should include a use case for the feature, how it can be implemented, who will implement it and a plan for maintaining it. + +All features have to be end-to-end tested (integration tests) and have good unit test coverage to ensure that they work as expected. This will also ensure that the feature continues to work as expected over time. If a change cannot be tested, a strong case for why this is not possible needs to be presented. + +The contributor should help to maintain the feature over time. In case the feature is not maintained probably, the maintainers reserve themselves the right to remove features they redeem as unmaintainable. This should help to improve the quality of the software and keep it in a maintainable state. + +## Bug fixes + +Headscale is open to code contributions for bug fixes without discussion. + +## Documentation + +If you find mistakes in the documentation, please submit a fix to the documentation. diff --git a/Dockerfile b/Dockerfile deleted file mode 100644 index 367afe94..00000000 --- a/Dockerfile +++ /dev/null @@ -1,30 +0,0 @@ -# Builder image -FROM docker.io/golang:1.21-bookworm AS build -ARG VERSION=dev -ENV GOPATH /go -WORKDIR /go/src/headscale - -COPY go.mod go.sum /go/src/headscale/ -RUN go mod download - -COPY . . - -RUN CGO_ENABLED=0 GOOS=linux go install -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale -RUN strip /go/bin/headscale -RUN test -e /go/bin/headscale - -# Production image -FROM docker.io/debian:bookworm-slim - -RUN apt-get update \ - && apt-get install -y ca-certificates \ - && rm -rf /var/lib/apt/lists/* \ - && apt-get clean - -COPY --from=build /go/bin/headscale /bin/headscale -ENV TZ UTC - -RUN mkdir -p /var/run/headscale - -EXPOSE 8080/tcp -CMD ["headscale"] diff --git a/Dockerfile.debug b/Dockerfile.debug deleted file mode 100644 index 8f49d2bc..00000000 --- a/Dockerfile.debug +++ /dev/null @@ -1,30 +0,0 @@ -# Builder image -FROM docker.io/golang:1.21-bookworm AS build -ARG VERSION=dev -ENV GOPATH /go -WORKDIR /go/src/headscale - -COPY go.mod go.sum /go/src/headscale/ -RUN go mod download - -COPY . . - -RUN CGO_ENABLED=0 GOOS=linux go install -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$VERSION" -a ./cmd/headscale -RUN test -e /go/bin/headscale - -# Debug image -FROM docker.io/golang:1.21-bookworm - -COPY --from=build /go/bin/headscale /bin/headscale -ENV TZ UTC - -RUN apt-get update \ - && apt-get install --no-install-recommends --yes less jq \ - && rm -rf /var/lib/apt/lists/* \ - && apt-get clean -RUN mkdir -p /var/run/headscale - -# Need to reset the entrypoint or everything will run as a busybox script -ENTRYPOINT [] -EXPOSE 8080/tcp -CMD ["headscale"] diff --git a/Dockerfile.derper b/Dockerfile.derper new file mode 100644 index 00000000..395d9586 --- /dev/null +++ b/Dockerfile.derper @@ -0,0 +1,19 @@ +# For testing purposes only + +FROM golang:alpine AS build-env + +WORKDIR /go/src + +RUN apk add --no-cache git +ARG VERSION_BRANCH=main +RUN git clone https://github.com/tailscale/tailscale.git --branch=$VERSION_BRANCH --depth=1 +WORKDIR /go/src/tailscale + +ARG TARGETARCH +RUN GOARCH=$TARGETARCH go install -v ./cmd/derper + +FROM alpine:3.22 +RUN apk add --no-cache ca-certificates iptables iproute2 ip6tables curl + +COPY --from=build-env /go/bin/* /usr/local/bin/ +ENTRYPOINT [ "/usr/local/bin/derper" ] diff --git a/Dockerfile.integration b/Dockerfile.integration new file mode 100644 index 00000000..341067e5 --- /dev/null +++ b/Dockerfile.integration @@ -0,0 +1,44 @@ +# This Dockerfile and the images produced are for testing headscale, +# and are in no way endorsed by Headscale's maintainers as an +# official nor supported release or distribution. + +FROM docker.io/golang:1.25-trixie AS builder +ARG VERSION=dev +ENV GOPATH /go +WORKDIR /go/src/headscale + +# Install delve debugger first - rarely changes, good cache candidate +RUN go install github.com/go-delve/delve/cmd/dlv@latest + +# Download dependencies - only invalidated when go.mod/go.sum change +COPY go.mod go.sum /go/src/headscale/ +RUN go mod download + +# Copy source and build - invalidated on any source change +COPY . . + +# Build debug binary with debug symbols for delve +RUN CGO_ENABLED=0 GOOS=linux go build -gcflags="all=-N -l" -o /go/bin/headscale ./cmd/headscale + +# Runtime stage +FROM debian:trixie-slim + +RUN apt-get --update install --no-install-recommends --yes \ + bash ca-certificates curl dnsutils findutils iproute2 jq less procps python3 sqlite3 \ + && apt-get dist-clean + +RUN mkdir -p /var/run/headscale + +# Copy binaries from builder +COPY --from=builder /go/bin/headscale /usr/local/bin/headscale +COPY --from=builder /go/bin/dlv /usr/local/bin/dlv + +# Copy source code for delve source-level debugging +COPY --from=builder /go/src/headscale /go/src/headscale + +WORKDIR /go/src/headscale + +# Need to reset the entrypoint or everything will run as a busybox script +ENTRYPOINT [] +EXPOSE 8080/tcp 40000/tcp +CMD ["dlv", "--listen=0.0.0.0:40000", "--headless=true", "--api-version=2", "--accept-multiclient", "exec", "/usr/local/bin/headscale", "--"] diff --git a/Dockerfile.integration-ci b/Dockerfile.integration-ci new file mode 100644 index 00000000..e55ab7b9 --- /dev/null +++ b/Dockerfile.integration-ci @@ -0,0 +1,17 @@ +# Minimal CI image - expects pre-built headscale binary in build context +# For local development with delve debugging, use Dockerfile.integration instead + +FROM debian:trixie-slim + +RUN apt-get --update install --no-install-recommends --yes \ + bash ca-certificates curl dnsutils findutils iproute2 jq less procps python3 sqlite3 \ + && apt-get dist-clean + +RUN mkdir -p /var/run/headscale + +# Copy pre-built headscale binary from build context +COPY headscale /usr/local/bin/headscale + +ENTRYPOINT [] +EXPOSE 8080/tcp +CMD ["/usr/local/bin/headscale"] diff --git a/Dockerfile.tailscale-HEAD b/Dockerfile.tailscale-HEAD index 2a3aac76..96edf72c 100644 --- a/Dockerfile.tailscale-HEAD +++ b/Dockerfile.tailscale-HEAD @@ -1,17 +1,47 @@ -FROM golang:latest +# Copyright (c) Tailscale Inc & AUTHORS +# SPDX-License-Identifier: BSD-3-Clause -RUN apt-get update \ - && apt-get install -y dnsutils git iptables ssh ca-certificates \ - && rm -rf /var/lib/apt/lists/* +# This Dockerfile is more or less lifted from tailscale/tailscale +# to ensure a similar build process when testing the HEAD of tailscale. -RUN useradd --shell=/bin/bash --create-home ssh-it-user +FROM golang:1.25-alpine AS build-env +WORKDIR /go/src + +RUN apk add --no-cache git + +# Replace `RUN git...` with `COPY` and a local checked out version of Tailscale in `./tailscale` +# to test specific commits of the Tailscale client. This is useful when trying to find out why +# something specific broke between two versions of Tailscale with for example `git bisect`. +# COPY ./tailscale . RUN git clone https://github.com/tailscale/tailscale.git -WORKDIR /go/tailscale +WORKDIR /go/src/tailscale -RUN git checkout main \ - && sh build_dist.sh tailscale.com/cmd/tailscale \ - && sh build_dist.sh tailscale.com/cmd/tailscaled \ - && cp tailscale /usr/local/bin/ \ - && cp tailscaled /usr/local/bin/ + +# see build_docker.sh +ARG VERSION_LONG="" +ENV VERSION_LONG=$VERSION_LONG +ARG VERSION_SHORT="" +ENV VERSION_SHORT=$VERSION_SHORT +ARG VERSION_GIT_HASH="" +ENV VERSION_GIT_HASH=$VERSION_GIT_HASH +ARG TARGETARCH + +ARG BUILD_TAGS="" + +RUN GOARCH=$TARGETARCH go install -tags="${BUILD_TAGS}" -ldflags="\ + -X tailscale.com/version.longStamp=$VERSION_LONG \ + -X tailscale.com/version.shortStamp=$VERSION_SHORT \ + -X tailscale.com/version.gitCommitStamp=$VERSION_GIT_HASH" \ + -v ./cmd/tailscale ./cmd/tailscaled ./cmd/containerboot + +FROM alpine:3.22 +# Upstream: ca-certificates ip6tables iptables iproute2 +# Tests: curl python3 (traceroute via BusyBox) +RUN apk add --no-cache ca-certificates curl ip6tables iptables iproute2 python3 + +COPY --from=build-env /go/bin/* /usr/local/bin/ +# For compat with the previous run.sh, although ideally you should be +# using build_docker.sh which sets an entrypoint for the image. +RUN mkdir /tailscale && ln -s /usr/local/bin/containerboot /tailscale/run.sh diff --git a/Makefile b/Makefile index 442690ed..1e08cda9 100644 --- a/Makefile +++ b/Makefile @@ -1,53 +1,128 @@ -# Calculate version -version ?= $(shell git describe --always --tags --dirty) +# Headscale Makefile +# Modern Makefile following best practices -rwildcard=$(foreach d,$(wildcard $1*),$(call rwildcard,$d/,$2) $(filter $(subst *,%,$2),$d)) +# Version calculation +VERSION ?= $(shell git describe --always --tags --dirty) -# Determine if OS supports pie +# Build configuration GOOS ?= $(shell uname | tr '[:upper:]' '[:lower:]') -ifeq ($(filter $(GOOS), openbsd netbsd soloaris plan9), ) - pieflags = -buildmode=pie -else +ifeq ($(filter $(GOOS), openbsd netbsd solaris plan9), ) + PIE_FLAGS = -buildmode=pie endif -# GO_SOURCES = $(wildcard *.go) -# PROTO_SOURCES = $(wildcard **/*.proto) -GO_SOURCES = $(call rwildcard,,*.go) -PROTO_SOURCES = $(call rwildcard,,*.proto) +# Tool availability check with nix warning +define check_tool + @command -v $(1) >/dev/null 2>&1 || { \ + echo "Warning: $(1) not found. Run 'nix develop' to ensure all dependencies are available."; \ + exit 1; \ + } +endef + +# Source file collections using shell find for better performance +GO_SOURCES := $(shell find . -name '*.go' -not -path './gen/*' -not -path './vendor/*') +PROTO_SOURCES := $(shell find . -name '*.proto' -not -path './gen/*' -not -path './vendor/*') +DOC_SOURCES := $(shell find . \( -name '*.md' -o -name '*.yaml' -o -name '*.yml' -o -name '*.ts' -o -name '*.js' -o -name '*.html' -o -name '*.css' -o -name '*.scss' -o -name '*.sass' \) -not -path './gen/*' -not -path './vendor/*' -not -path './node_modules/*') + +# Default target +.PHONY: all +all: lint test build + +# Dependency checking +.PHONY: check-deps +check-deps: + $(call check_tool,go) + $(call check_tool,golangci-lint) + $(call check_tool,gofumpt) + $(call check_tool,prettier) + $(call check_tool,clang-format) + $(call check_tool,buf) + +# Build targets +.PHONY: build +build: check-deps $(GO_SOURCES) go.mod go.sum + @echo "Building headscale..." + go build $(PIE_FLAGS) -ldflags "-X main.version=$(VERSION)" -o headscale ./cmd/headscale + +# Test targets +.PHONY: test +test: check-deps $(GO_SOURCES) go.mod go.sum + @echo "Running Go tests..." + go test -race ./... -build: - nix build +# Formatting targets +.PHONY: fmt +fmt: fmt-go fmt-prettier fmt-proto -dev: lint test build +.PHONY: fmt-go +fmt-go: check-deps $(GO_SOURCES) + @echo "Formatting Go code..." + gofumpt -l -w . + golangci-lint run --fix -test: - gotestsum -- -short -coverprofile=coverage.out ./... +.PHONY: fmt-prettier +fmt-prettier: check-deps $(DOC_SOURCES) + @echo "Formatting documentation and config files..." + prettier --write '**/*.{ts,js,md,yaml,yml,sass,css,scss,html}' -test_integration: - docker run \ - -t --rm \ - -v ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - -v $$PWD:$$PWD -w $$PWD/integration \ - -v /var/run/docker.sock:/var/run/docker.sock \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- -failfast ./... -timeout 120m -parallel 8 +.PHONY: fmt-proto +fmt-proto: check-deps $(PROTO_SOURCES) + @echo "Formatting Protocol Buffer files..." + clang-format -i $(PROTO_SOURCES) -lint: - golangci-lint run --fix --timeout 10m +# Linting targets +.PHONY: lint +lint: lint-go lint-proto -fmt: - prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' - golines --max-len=88 --base-formatter=gofumpt -w $(GO_SOURCES) - clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i $(PROTO_SOURCES) +.PHONY: lint-go +lint-go: check-deps $(GO_SOURCES) go.mod go.sum + @echo "Linting Go code..." + golangci-lint run --timeout 10m -proto-lint: - cd proto/ && go run github.com/bufbuild/buf/cmd/buf lint +.PHONY: lint-proto +lint-proto: check-deps $(PROTO_SOURCES) + @echo "Linting Protocol Buffer files..." + cd proto/ && buf lint -compress: build - upx --brute headscale +# Code generation +.PHONY: generate +generate: check-deps + @echo "Generating code..." + go generate ./... -generate: - rm -rf gen - buf generate proto +# Clean targets +.PHONY: clean +clean: + rm -rf headscale gen + +# Development workflow +.PHONY: dev +dev: fmt lint test build + +# Help target +.PHONY: help +help: + @echo "Headscale Development Makefile" + @echo "" + @echo "Main targets:" + @echo " all - Run lint, test, and build (default)" + @echo " build - Build headscale binary" + @echo " test - Run Go tests" + @echo " fmt - Format all code (Go, docs, proto)" + @echo " lint - Lint all code (Go, proto)" + @echo " generate - Generate code from Protocol Buffers" + @echo " dev - Full development workflow (fmt + lint + test + build)" + @echo " clean - Clean build artifacts" + @echo "" + @echo "Specific targets:" + @echo " fmt-go - Format Go code only" + @echo " fmt-prettier - Format documentation only" + @echo " fmt-proto - Format Protocol Buffer files only" + @echo " lint-go - Lint Go code only" + @echo " lint-proto - Lint Protocol Buffer files only" + @echo "" + @echo "Dependencies:" + @echo " check-deps - Verify required tools are available" + @echo "" + @echo "Note: If not running in a nix shell, ensure dependencies are available:" + @echo " nix develop" diff --git a/README.md b/README.md index 457e56ff..61eb68c5 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,18 @@ -![headscale logo](./docs/logo/headscale3_header_stacked_left.png) +![headscale logo](./docs/assets/logo/headscale3_header_stacked_left.png) ![ci](https://github.com/juanfont/headscale/actions/workflows/test.yml/badge.svg) An open source, self-hosted implementation of the Tailscale control server. -Join our [Discord](https://discord.gg/c84AZQhmpx) server for a chat. +Join our [Discord server](https://discord.gg/c84AZQhmpx) for a chat. **Note:** Always select the same GitHub tag as the released version you use -to ensure you have the correct example configuration and documentation. -The `main` branch might contain unreleased changes. +to ensure you have the correct example configuration. The `main` branch might +contain unreleased changes. The documentation is available for stable and +development versions: + +- [Documentation for the stable version](https://headscale.net/stable/) +- [Documentation for the development version](https://headscale.net/development/) ## What is Tailscale @@ -32,12 +36,12 @@ organisation. ## Design goal -Headscale aims to implement a self-hosted, open source alternative to the Tailscale -control server. -Headscale's goal is to provide self-hosters and hobbyists with an open-source -server they can use for their projects and labs. -It implements a narrow scope, a single Tailnet, suitable for a personal use, or a small -open-source organisation. +Headscale aims to implement a self-hosted, open source alternative to the +[Tailscale](https://tailscale.com/) control server. Headscale's goal is to +provide self-hosters and hobbyists with an open-source server they can use for +their projects and labs. It implements a narrow scope, a _single_ Tailscale +network (tailnet), suitable for a personal use, or a small open-source +organisation. ## Supporting Headscale @@ -46,39 +50,20 @@ buttons available in the repo. ## Features -- Full "base" support of Tailscale's features -- Configurable DNS - - [Split DNS](https://tailscale.com/kb/1054/dns/#using-dns-settings-in-the-admin-console) -- Node registration - - Single-Sign-On (via Open ID Connect) - - Pre authenticated key -- Taildrop (File Sharing) -- [Access control lists](https://tailscale.com/kb/1018/acls/) -- [MagicDNS](https://tailscale.com/kb/1081/magicdns) -- Support for multiple IP ranges in the tailnet -- Dual stack (IPv4 and IPv6) -- Routing advertising (including exit nodes) -- Ephemeral nodes -- Embedded [DERP server](https://tailscale.com/blog/how-tailscale-works/#encrypted-tcp-relays-derp) +Please see ["Features" in the documentation](https://headscale.net/stable/about/features/). ## Client OS support -| OS | Supports headscale | -| ------- | --------------------------------------------------------- | -| Linux | Yes | -| OpenBSD | Yes | -| FreeBSD | Yes | -| macOS | Yes (see `/apple` on your headscale for more information) | -| Windows | Yes [docs](./docs/windows-client.md) | -| Android | Yes [docs](./docs/android-client.md) | -| iOS | Yes [docs](./docs/iOS-client.md) | +Please see ["Client and operating system support" in the documentation](https://headscale.net/stable/about/clients/). ## Running headscale **Please note that we do not support nor encourage the use of reverse proxies and container to run Headscale.** -Please have a look at the [`documentation`](https://headscale.net/). +Please have a look at the [`documentation`](https://headscale.net/stable/). + +For NixOS users, a module is available in [`nix/`](./nix/). ## Talks @@ -87,25 +72,20 @@ Please have a look at the [`documentation`](https://headscale.net/). ## Disclaimer -1. This project is not associated with Tailscale Inc. -2. The purpose of Headscale is maintaining a working, self-hosted Tailscale control panel. +This project is not associated with Tailscale Inc. + +However, one of the active maintainers for Headscale [is employed by Tailscale](https://tailscale.com/blog/opensource) and he is allowed to spend work hours contributing to the project. Contributions from this maintainer are reviewed by other maintainers. + +The maintainers work together on setting the direction for the project. The underlying principle is to serve the community of self-hosters, enthusiasts and hobbyists - while having a sustainable project. ## Contributing -Headscale is "Open Source, acknowledged contribution", this means that any -contribution will have to be discussed with the Maintainers before being submitted. - -This model has been chosen to reduce the risk of burnout by limiting the -maintenance overhead of reviewing and validating third-party code. - -Headscale is open to code contributions for bug fixes without discussion. - -If you find mistakes in the documentation, please submit a fix to the documentation. +Please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file. ### Requirements -To contribute to headscale you would need the lastest version of [Go](https://golang.org) -and [Buf](https://buf.build)(Protobuf generator). +To contribute to headscale you would need the latest version of [Go](https://golang.org) +and [Buf](https://buf.build) (Protobuf generator). We recommend using [Nix](https://nixos.org/) to setup a development environment. This can be done with `nix develop`, which will install the tools and give you a shell. @@ -160,950 +140,35 @@ make test To build the program: -```shell -nix build -``` - -or - ```shell make build ``` +### Development workflow + +We recommend using Nix for dependency management to ensure you have all required tools. If you prefer to manage dependencies yourself, you can use Make directly: + +**With Nix (recommended):** + +```shell +nix develop +make test +make build +``` + +**With your own dependencies:** + +```shell +make test +make build +``` + +The Makefile will warn you if any required tools are missing and suggest running `nix develop`. Run `make help` to see all available targets. + ## Contributors -<table> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/kradalby> - <img src=https://avatars.githubusercontent.com/u/98431?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Kristoffer Dalby/> - <br /> - <sub style="font-size:14px"><b>Kristoffer Dalby</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/juanfont> - <img src=https://avatars.githubusercontent.com/u/181059?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Juan Font/> - <br /> - <sub style="font-size:14px"><b>Juan Font</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/restanrm> - <img src=https://avatars.githubusercontent.com/u/4344371?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Adrien Raffin-Caboisse/> - <br /> - <sub style="font-size:14px"><b>Adrien Raffin-Caboisse</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/cure> - <img src=https://avatars.githubusercontent.com/u/149135?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Ward Vandewege/> - <br /> - <sub style="font-size:14px"><b>Ward Vandewege</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/huskyii> - <img src=https://avatars.githubusercontent.com/u/5499746?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Jiang Zhu/> - <br /> - <sub style="font-size:14px"><b>Jiang Zhu</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/tsujamin> - <img src=https://avatars.githubusercontent.com/u/2435619?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Benjamin Roberts/> - <br /> - <sub style="font-size:14px"><b>Benjamin Roberts</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/reynico> - <img src=https://avatars.githubusercontent.com/u/715768?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Nico/> - <br /> - <sub style="font-size:14px"><b>Nico</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/evenh> - <img src=https://avatars.githubusercontent.com/u/2701536?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Even Holthe/> - <br /> - <sub style="font-size:14px"><b>Even Holthe</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/e-zk> - <img src=https://avatars.githubusercontent.com/u/58356365?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=e-zk/> - <br /> - <sub style="font-size:14px"><b>e-zk</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ImpostorKeanu> - <img src=https://avatars.githubusercontent.com/u/11574161?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Justin Angel/> - <br /> - <sub style="font-size:14px"><b>Justin Angel</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ItalyPaleAle> - <img src=https://avatars.githubusercontent.com/u/43508?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Alessandro (Ale) Segala/> - <br /> - <sub style="font-size:14px"><b>Alessandro (Ale) Segala</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ohdearaugustin> - <img src=https://avatars.githubusercontent.com/u/14001491?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=ohdearaugustin/> - <br /> - <sub style="font-size:14px"><b>ohdearaugustin</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/mpldr> - <img src=https://avatars.githubusercontent.com/u/33086936?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Moritz Poldrack/> - <br /> - <sub style="font-size:14px"><b>Moritz Poldrack</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Orhideous> - <img src=https://avatars.githubusercontent.com/u/2265184?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Andriy Kushnir/> - <br /> - <sub style="font-size:14px"><b>Andriy Kushnir</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/GrigoriyMikhalkin> - <img src=https://avatars.githubusercontent.com/u/3637857?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=GrigoriyMikhalkin/> - <br /> - <sub style="font-size:14px"><b>GrigoriyMikhalkin</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/christian-heusel> - <img src=https://avatars.githubusercontent.com/u/26827864?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Christian Heusel/> - <br /> - <sub style="font-size:14px"><b>Christian Heusel</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/mike-lloyd03> - <img src=https://avatars.githubusercontent.com/u/49411532?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Mike Lloyd/> - <br /> - <sub style="font-size:14px"><b>Mike Lloyd</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/iSchluff> - <img src=https://avatars.githubusercontent.com/u/1429641?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Anton Schubert/> - <br /> - <sub style="font-size:14px"><b>Anton Schubert</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Niek> - <img src=https://avatars.githubusercontent.com/u/213140?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Niek van der Maas/> - <br /> - <sub style="font-size:14px"><b>Niek van der Maas</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/negbie> - <img src=https://avatars.githubusercontent.com/u/20154956?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Eugen Biegler/> - <br /> - <sub style="font-size:14px"><b>Eugen Biegler</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/617a7a> - <img src=https://avatars.githubusercontent.com/u/67651251?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Azz/> - <br /> - <sub style="font-size:14px"><b>Azz</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/qbit> - <img src=https://avatars.githubusercontent.com/u/68368?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Aaron Bieber/> - <br /> - <sub style="font-size:14px"><b>Aaron Bieber</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/kazauwa> - <img src=https://avatars.githubusercontent.com/u/12330159?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Igor Perepilitsyn/> - <br /> - <sub style="font-size:14px"><b>Igor Perepilitsyn</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Aluxima> - <img src=https://avatars.githubusercontent.com/u/16262531?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Laurent Marchaud/> - <br /> - <sub style="font-size:14px"><b>Laurent Marchaud</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/majst01> - <img src=https://avatars.githubusercontent.com/u/410110?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Stefan Majer/> - <br /> - <sub style="font-size:14px"><b>Stefan Majer</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/fdelucchijr> - <img src=https://avatars.githubusercontent.com/u/69133647?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Fernando De Lucchi/> - <br /> - <sub style="font-size:14px"><b>Fernando De Lucchi</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/OrvilleQ> - <img src=https://avatars.githubusercontent.com/u/21377465?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Orville Q. Song/> - <br /> - <sub style="font-size:14px"><b>Orville Q. Song</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/hdhoang> - <img src=https://avatars.githubusercontent.com/u/12537?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=hdhoang/> - <br /> - <sub style="font-size:14px"><b>hdhoang</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/bravechamp> - <img src=https://avatars.githubusercontent.com/u/48980452?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=bravechamp/> - <br /> - <sub style="font-size:14px"><b>bravechamp</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/deonthomasgy> - <img src=https://avatars.githubusercontent.com/u/150036?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Deon Thomas/> - <br /> - <sub style="font-size:14px"><b>Deon Thomas</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/madjam002> - <img src=https://avatars.githubusercontent.com/u/679137?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Jamie Greeff/> - <br /> - <sub style="font-size:14px"><b>Jamie Greeff</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/jonathanspw> - <img src=https://avatars.githubusercontent.com/u/8390543?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Jonathan Wright/> - <br /> - <sub style="font-size:14px"><b>Jonathan Wright</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ChibangLW> - <img src=https://avatars.githubusercontent.com/u/22293464?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=ChibangLW/> - <br /> - <sub style="font-size:14px"><b>ChibangLW</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/majabojarska> - <img src=https://avatars.githubusercontent.com/u/33836570?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Maja Bojarska/> - <br /> - <sub style="font-size:14px"><b>Maja Bojarska</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/mevansam> - <img src=https://avatars.githubusercontent.com/u/403630?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Mevan Samaratunga/> - <br /> - <sub style="font-size:14px"><b>Mevan Samaratunga</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/dragetd> - <img src=https://avatars.githubusercontent.com/u/3639577?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Michael G./> - <br /> - <sub style="font-size:14px"><b>Michael G.</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ptman> - <img src=https://avatars.githubusercontent.com/u/24669?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Paul Tötterman/> - <br /> - <sub style="font-size:14px"><b>Paul Tötterman</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/samson4649> - <img src=https://avatars.githubusercontent.com/u/12725953?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Samuel Lock/> - <br /> - <sub style="font-size:14px"><b>Samuel Lock</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/loprima-l> - <img src=https://avatars.githubusercontent.com/u/69201633?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=loprima-l/> - <br /> - <sub style="font-size:14px"><b>loprima-l</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/unreality> - <img src=https://avatars.githubusercontent.com/u/352522?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=unreality/> - <br /> - <sub style="font-size:14px"><b>unreality</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/vsychov> - <img src=https://avatars.githubusercontent.com/u/2186303?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=MichaelKo/> - <br /> - <sub style="font-size:14px"><b>MichaelKo</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/kevin1sMe> - <img src=https://avatars.githubusercontent.com/u/6886076?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=kevinlin/> - <br /> - <sub style="font-size:14px"><b>kevinlin</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/QZAiXH> - <img src=https://avatars.githubusercontent.com/u/23068780?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Snack/> - <br /> - <sub style="font-size:14px"><b>Snack</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/artemklevtsov> - <img src=https://avatars.githubusercontent.com/u/603798?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Artem Klevtsov/> - <br /> - <sub style="font-size:14px"><b>Artem Klevtsov</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/cmars> - <img src=https://avatars.githubusercontent.com/u/23741?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Casey Marshall/> - <br /> - <sub style="font-size:14px"><b>Casey Marshall</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/dbevacqua> - <img src=https://avatars.githubusercontent.com/u/6534306?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=dbevacqua/> - <br /> - <sub style="font-size:14px"><b>dbevacqua</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/joshuataylor> - <img src=https://avatars.githubusercontent.com/u/225131?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Josh Taylor/> - <br /> - <sub style="font-size:14px"><b>Josh Taylor</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/CNLHC> - <img src=https://avatars.githubusercontent.com/u/21005146?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=LIU HANCHENG/> - <br /> - <sub style="font-size:14px"><b>LIU HANCHENG</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/motiejus> - <img src=https://avatars.githubusercontent.com/u/107720?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Motiejus Jakštys/> - <br /> - <sub style="font-size:14px"><b>Motiejus Jakštys</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/pvinis> - <img src=https://avatars.githubusercontent.com/u/100233?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Pavlos Vinieratos/> - <br /> - <sub style="font-size:14px"><b>Pavlos Vinieratos</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/SilverBut> - <img src=https://avatars.githubusercontent.com/u/6560655?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Silver Bullet/> - <br /> - <sub style="font-size:14px"><b>Silver Bullet</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/snh> - <img src=https://avatars.githubusercontent.com/u/2051768?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Steven Honson/> - <br /> - <sub style="font-size:14px"><b>Steven Honson</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ratsclub> - <img src=https://avatars.githubusercontent.com/u/25647735?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Victor Freire/> - <br /> - <sub style="font-size:14px"><b>Victor Freire</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/qzydustin> - <img src=https://avatars.githubusercontent.com/u/44362429?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Zhenyu Qi/> - <br /> - <sub style="font-size:14px"><b>Zhenyu Qi</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/t56k> - <img src=https://avatars.githubusercontent.com/u/12165422?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=thomas/> - <br /> - <sub style="font-size:14px"><b>thomas</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/puzpuzpuz> - <img src=https://avatars.githubusercontent.com/u/37772591?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Andrei Pechkurov/> - <br /> - <sub style="font-size:14px"><b>Andrei Pechkurov</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/linsomniac> - <img src=https://avatars.githubusercontent.com/u/466380?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Sean Reifschneider/> - <br /> - <sub style="font-size:14px"><b>Sean Reifschneider</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/aberoham> - <img src=https://avatars.githubusercontent.com/u/586805?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Abraham Ingersoll/> - <br /> - <sub style="font-size:14px"><b>Abraham Ingersoll</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/iFargle> - <img src=https://avatars.githubusercontent.com/u/124551390?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Albert Copeland/> - <br /> - <sub style="font-size:14px"><b>Albert Copeland</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/theryecatcher> - <img src=https://avatars.githubusercontent.com/u/16442416?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Anoop Sundaresh/> - <br /> - <sub style="font-size:14px"><b>Anoop Sundaresh</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/apognu> - <img src=https://avatars.githubusercontent.com/u/3017182?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Antoine POPINEAU/> - <br /> - <sub style="font-size:14px"><b>Antoine POPINEAU</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/tony1661> - <img src=https://avatars.githubusercontent.com/u/5287266?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Antonio Fernandez/> - <br /> - <sub style="font-size:14px"><b>Antonio Fernandez</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/aofei> - <img src=https://avatars.githubusercontent.com/u/5037285?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Aofei Sheng/> - <br /> - <sub style="font-size:14px"><b>Aofei Sheng</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/arnarg> - <img src=https://avatars.githubusercontent.com/u/1291396?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Arnar/> - <br /> - <sub style="font-size:14px"><b>Arnar</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/awoimbee> - <img src=https://avatars.githubusercontent.com/u/22431493?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Arthur Woimbée/> - <br /> - <sub style="font-size:14px"><b>Arthur Woimbée</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/avirut> - <img src=https://avatars.githubusercontent.com/u/27095602?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Avirut Mehta/> - <br /> - <sub style="font-size:14px"><b>Avirut Mehta</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/winterheart> - <img src=https://avatars.githubusercontent.com/u/81112?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Azamat H. Hackimov/> - <br /> - <sub style="font-size:14px"><b>Azamat H. Hackimov</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/stensonb> - <img src=https://avatars.githubusercontent.com/u/933389?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Bryan Stenson/> - <br /> - <sub style="font-size:14px"><b>Bryan Stenson</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/yangchuansheng> - <img src=https://avatars.githubusercontent.com/u/15308462?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt= Carson Yang/> - <br /> - <sub style="font-size:14px"><b> Carson Yang</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/kundel> - <img src=https://avatars.githubusercontent.com/u/10158899?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Darrell Kundel/> - <br /> - <sub style="font-size:14px"><b>Darrell Kundel</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/fatih-acar> - <img src=https://avatars.githubusercontent.com/u/15028881?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=fatih-acar/> - <br /> - <sub style="font-size:14px"><b>fatih-acar</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/fkr> - <img src=https://avatars.githubusercontent.com/u/51063?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Felix Kronlage-Dammers/> - <br /> - <sub style="font-size:14px"><b>Felix Kronlage-Dammers</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/felixonmars> - <img src=https://avatars.githubusercontent.com/u/1006477?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Felix Yan/> - <br /> - <sub style="font-size:14px"><b>Felix Yan</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/gabe565> - <img src=https://avatars.githubusercontent.com/u/7717888?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Gabe Cook/> - <br /> - <sub style="font-size:14px"><b>Gabe Cook</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/JJGadgets> - <img src=https://avatars.githubusercontent.com/u/5709019?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=JJGadgets/> - <br /> - <sub style="font-size:14px"><b>JJGadgets</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/hrtkpf> - <img src=https://avatars.githubusercontent.com/u/42646788?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=hrtkpf/> - <br /> - <sub style="font-size:14px"><b>hrtkpf</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/jessebot> - <img src=https://avatars.githubusercontent.com/u/2389292?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=JesseBot/> - <br /> - <sub style="font-size:14px"><b>JesseBot</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/jimt> - <img src=https://avatars.githubusercontent.com/u/180326?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Jim Tittsler/> - <br /> - <sub style="font-size:14px"><b>Jim Tittsler</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/jsiebens> - <img src=https://avatars.githubusercontent.com/u/499769?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Johan Siebens/> - <br /> - <sub style="font-size:14px"><b>Johan Siebens</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/johnae> - <img src=https://avatars.githubusercontent.com/u/28332?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=John Axel Eriksson/> - <br /> - <sub style="font-size:14px"><b>John Axel Eriksson</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ShadowJonathan> - <img src=https://avatars.githubusercontent.com/u/22740616?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Jonathan de Jong/> - <br /> - <sub style="font-size:14px"><b>Jonathan de Jong</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/JulienFloris> - <img src=https://avatars.githubusercontent.com/u/20380255?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Julien Zweverink/> - <br /> - <sub style="font-size:14px"><b>Julien Zweverink</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/win-t> - <img src=https://avatars.githubusercontent.com/u/1589120?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Kurnia D Win/> - <br /> - <sub style="font-size:14px"><b>Kurnia D Win</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Lucalux> - <img src=https://avatars.githubusercontent.com/u/70356955?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Lucalux/> - <br /> - <sub style="font-size:14px"><b>Lucalux</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/foxtrot> - <img src=https://avatars.githubusercontent.com/u/4153572?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Marc/> - <br /> - <sub style="font-size:14px"><b>Marc</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/mhameed> - <img src=https://avatars.githubusercontent.com/u/447017?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Mesar Hameed/> - <br /> - <sub style="font-size:14px"><b>Mesar Hameed</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/mikejsavage> - <img src=https://avatars.githubusercontent.com/u/579299?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Michael Savage/> - <br /> - <sub style="font-size:14px"><b>Michael Savage</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/pkrivanec> - <img src=https://avatars.githubusercontent.com/u/25530641?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Philipp Krivanec/> - <br /> - <sub style="font-size:14px"><b>Philipp Krivanec</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/piec> - <img src=https://avatars.githubusercontent.com/u/781471?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Pierre Carru/> - <br /> - <sub style="font-size:14px"><b>Pierre Carru</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/donran> - <img src=https://avatars.githubusercontent.com/u/4838348?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Pontus N/> - <br /> - <sub style="font-size:14px"><b>Pontus N</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/nnsee> - <img src=https://avatars.githubusercontent.com/u/36747857?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Rasmus Moorats/> - <br /> - <sub style="font-size:14px"><b>Rasmus Moorats</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/rcursaru> - <img src=https://avatars.githubusercontent.com/u/16259641?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=rcursaru/> - <br /> - <sub style="font-size:14px"><b>rcursaru</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/renovate-bot> - <img src=https://avatars.githubusercontent.com/u/25180681?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Mend Renovate/> - <br /> - <sub style="font-size:14px"><b>Mend Renovate</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ryanfowler> - <img src=https://avatars.githubusercontent.com/u/2668821?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Ryan Fowler/> - <br /> - <sub style="font-size:14px"><b>Ryan Fowler</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/muzy> - <img src=https://avatars.githubusercontent.com/u/321723?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Sebastian/> - <br /> - <sub style="font-size:14px"><b>Sebastian</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/shaananc> - <img src=https://avatars.githubusercontent.com/u/2287839?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Shaanan Cohney/> - <br /> - <sub style="font-size:14px"><b>Shaanan Cohney</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/6ixfalls> - <img src=https://avatars.githubusercontent.com/u/23470032?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Six/> - <br /> - <sub style="font-size:14px"><b>Six</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/stefanvanburen> - <img src=https://avatars.githubusercontent.com/u/622527?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Stefan VanBuren/> - <br /> - <sub style="font-size:14px"><b>Stefan VanBuren</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/sophware> - <img src=https://avatars.githubusercontent.com/u/41669?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=sophware/> - <br /> - <sub style="font-size:14px"><b>sophware</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/m-tanner-dev0> - <img src=https://avatars.githubusercontent.com/u/97977342?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Tanner/> - <br /> - <sub style="font-size:14px"><b>Tanner</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Teteros> - <img src=https://avatars.githubusercontent.com/u/5067989?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Teteros/> - <br /> - <sub style="font-size:14px"><b>Teteros</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/gitter-badger> - <img src=https://avatars.githubusercontent.com/u/8518239?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=The Gitter Badger/> - <br /> - <sub style="font-size:14px"><b>The Gitter Badger</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/tianon> - <img src=https://avatars.githubusercontent.com/u/161631?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Tianon Gravi/> - <br /> - <sub style="font-size:14px"><b>Tianon Gravi</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/thetillhoff> - <img src=https://avatars.githubusercontent.com/u/25052289?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Till Hoffmann/> - <br /> - <sub style="font-size:14px"><b>Till Hoffmann</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/woudsma> - <img src=https://avatars.githubusercontent.com/u/6162978?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Tjerk Woudsma/> - <br /> - <sub style="font-size:14px"><b>Tjerk Woudsma</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/y0ngb1n> - <img src=https://avatars.githubusercontent.com/u/25719408?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=杨斌 Aben/> - <br /> - <sub style="font-size:14px"><b>杨斌 Aben</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/sleepymole> - <img src=https://avatars.githubusercontent.com/u/17199941?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Yujie Xia/> - <br /> - <sub style="font-size:14px"><b>Yujie Xia</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/newellz2> - <img src=https://avatars.githubusercontent.com/u/52436542?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Zachary Newell/> - <br /> - <sub style="font-size:14px"><b>Zachary Newell</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/zekker6> - <img src=https://avatars.githubusercontent.com/u/1367798?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Zakhar Bessarab/> - <br /> - <sub style="font-size:14px"><b>Zakhar Bessarab</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/zhzy0077> - <img src=https://avatars.githubusercontent.com/u/8717471?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Zhiyuan Zheng/> - <br /> - <sub style="font-size:14px"><b>Zhiyuan Zheng</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Bpazy> - <img src=https://avatars.githubusercontent.com/u/9838749?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Ziyuan Han/> - <br /> - <sub style="font-size:14px"><b>Ziyuan Han</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/caelansar> - <img src=https://avatars.githubusercontent.com/u/31852257?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=caelansar/> - <br /> - <sub style="font-size:14px"><b>caelansar</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/derelm> - <img src=https://avatars.githubusercontent.com/u/465155?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=derelm/> - <br /> - <sub style="font-size:14px"><b>derelm</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/dnaq> - <img src=https://avatars.githubusercontent.com/u/1299717?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=dnaq/> - <br /> - <sub style="font-size:14px"><b>dnaq</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/nning> - <img src=https://avatars.githubusercontent.com/u/557430?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=henning mueller/> - <br /> - <sub style="font-size:14px"><b>henning mueller</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ignoramous> - <img src=https://avatars.githubusercontent.com/u/852289?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=ignoramous/> - <br /> - <sub style="font-size:14px"><b>ignoramous</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/jimyag> - <img src=https://avatars.githubusercontent.com/u/69233189?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=jimyag/> - <br /> - <sub style="font-size:14px"><b>jimyag</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/magichuihui> - <img src=https://avatars.githubusercontent.com/u/10866198?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=suhelen/> - <br /> - <sub style="font-size:14px"><b>suhelen</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/lion24> - <img src=https://avatars.githubusercontent.com/u/1382102?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=sharkonet/> - <br /> - <sub style="font-size:14px"><b>sharkonet</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/ma6174> - <img src=https://avatars.githubusercontent.com/u/1449133?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=ma6174/> - <br /> - <sub style="font-size:14px"><b>ma6174</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/manju-rn> - <img src=https://avatars.githubusercontent.com/u/26291847?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=manju-rn/> - <br /> - <sub style="font-size:14px"><b>manju-rn</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/nicholas-yap> - <img src=https://avatars.githubusercontent.com/u/38109533?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=nicholas-yap/> - <br /> - <sub style="font-size:14px"><b>nicholas-yap</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/pernila> - <img src=https://avatars.githubusercontent.com/u/12460060?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Tommi Pernila/> - <br /> - <sub style="font-size:14px"><b>Tommi Pernila</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/phpmalik> - <img src=https://avatars.githubusercontent.com/u/26834645?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=phpmalik/> - <br /> - <sub style="font-size:14px"><b>phpmalik</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/Wakeful-Cloud> - <img src=https://avatars.githubusercontent.com/u/38930607?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Wakeful Cloud/> - <br /> - <sub style="font-size:14px"><b>Wakeful Cloud</b></sub> - </a> - </td> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/xpzouying> - <img src=https://avatars.githubusercontent.com/u/3946563?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=zy/> - <br /> - <sub style="font-size:14px"><b>zy</b></sub> - </a> - </td> -</tr> -<tr> - <td align="center" style="word-wrap: break-word; width: 150.0; height: 150.0"> - <a href=https://github.com/atorregrosa-smd> - <img src=https://avatars.githubusercontent.com/u/78434679?v=4 width="100;" style="border-radius:50%;align-items:center;justify-content:center;overflow:hidden;padding-top:10px" alt=Àlex Torregrosa/> - <br /> - <sub style="font-size:14px"><b>Àlex Torregrosa</b></sub> - </a> - </td> -</tr> -</table> +<a href="https://github.com/juanfont/headscale/graphs/contributors"> + <img src="https://contrib.rocks/image?repo=juanfont/headscale" /> +</a> + +Made with [contrib.rocks](https://contrib.rocks). diff --git a/cmd/gh-action-integration-generator/main.go b/cmd/gh-action-integration-generator/main.go deleted file mode 100644 index d5798a95..00000000 --- a/cmd/gh-action-integration-generator/main.go +++ /dev/null @@ -1,173 +0,0 @@ -package main - -//go:generate go run ./main.go - -import ( - "bytes" - "fmt" - "log" - "os" - "os/exec" - "path" - "path/filepath" - "strings" - "text/template" -) - -var ( - githubWorkflowPath = "../../.github/workflows/" - jobFileNameTemplate = `test-integration-v2-%s.yaml` - jobTemplate = template.Must( - template.New("jobTemplate"). - Parse(`# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go -# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ - -name: Integration Test v2 - {{.Name}} - -on: [pull_request] - -concurrency: - group: {{ "${{ github.workflow }}-$${{ github.head_ref || github.run_id }}" }} - cancel-in-progress: true - -jobs: - {{.Name}}: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@v3 - with: - fetch-depth: 2 - - - uses: DeterminateSystems/nix-installer-action@main - - uses: DeterminateSystems/magic-nix-cache-action@main - - uses: satackey/action-docker-layer-caching@main - continue-on-error: true - - - name: Get changed files - id: changed-files - uses: tj-actions/changed-files@v34 - with: - files: | - *.nix - go.* - **/*.go - integration_test/ - config-example.yaml - - - name: Run {{.Name}} - uses: Wandalen/wretry.action@master - if: steps.changed-files.outputs.any_changed == 'true' - with: - attempt_limit: 5 - command: | - nix develop --command -- docker run \ - --tty --rm \ - --volume ~/.cache/hs-integration-go:/go \ - --name headscale-test-suite \ - --volume $PWD:$PWD -w $PWD/integration \ - --volume /var/run/docker.sock:/var/run/docker.sock \ - --volume $PWD/control_logs:/tmp/control \ - golang:1 \ - go run gotest.tools/gotestsum@latest -- ./... \ - -failfast \ - -timeout 120m \ - -parallel 1 \ - -run "^{{.Name}}$" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: logs - path: "control_logs/*.log" - - - uses: actions/upload-artifact@v3 - if: always() && steps.changed-files.outputs.any_changed == 'true' - with: - name: pprof - path: "control_logs/*.pprof.tar" -`), - ) -) - -const workflowFilePerm = 0o600 - -func removeTests() { - glob := fmt.Sprintf(jobFileNameTemplate, "*") - - files, err := filepath.Glob(filepath.Join(githubWorkflowPath, glob)) - if err != nil { - log.Fatalf("failed to find test files") - } - - for _, file := range files { - err := os.Remove(file) - if err != nil { - log.Printf("failed to remove: %s", err) - } - } -} - -func findTests() []string { - rgBin, err := exec.LookPath("rg") - if err != nil { - log.Fatalf("failed to find rg (ripgrep) binary") - } - - args := []string{ - "--regexp", "func (Test.+)\\(.*", - "../../integration/", - "--replace", "$1", - "--sort", "path", - "--no-line-number", - "--no-filename", - "--no-heading", - } - - log.Printf("executing: %s %s", rgBin, strings.Join(args, " ")) - - ripgrep := exec.Command( - rgBin, - args..., - ) - - result, err := ripgrep.CombinedOutput() - if err != nil { - log.Printf("out: %s", result) - log.Fatalf("failed to run ripgrep: %s", err) - } - - tests := strings.Split(string(result), "\n") - tests = tests[:len(tests)-1] - - return tests -} - -func main() { - type testConfig struct { - Name string - } - - tests := findTests() - - removeTests() - - for _, test := range tests { - log.Printf("generating workflow for %s", test) - - var content bytes.Buffer - - if err := jobTemplate.Execute(&content, testConfig{ - Name: test, - }); err != nil { - log.Fatalf("failed to render template: %s", err) - } - - testPath := path.Join(githubWorkflowPath, fmt.Sprintf(jobFileNameTemplate, test)) - - err := os.WriteFile(testPath, content.Bytes(), workflowFilePerm) - if err != nil { - log.Fatalf("failed to write github job: %s", err) - } - } -} diff --git a/cmd/headscale/cli/api_key.go b/cmd/headscale/cli/api_key.go index 14293aee..d821b290 100644 --- a/cmd/headscale/cli/api_key.go +++ b/cmd/headscale/cli/api_key.go @@ -9,7 +9,6 @@ import ( "github.com/juanfont/headscale/hscontrol/util" "github.com/prometheus/common/model" "github.com/pterm/pterm" - "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/protobuf/types/known/timestamppb" ) @@ -29,11 +28,12 @@ func init() { apiKeysCmd.AddCommand(createAPIKeyCmd) expireAPIKeyCmd.Flags().StringP("prefix", "p", "", "ApiKey prefix") - err := expireAPIKeyCmd.MarkFlagRequired("prefix") - if err != nil { - log.Fatal().Err(err).Msg("") - } + expireAPIKeyCmd.Flags().Uint64P("id", "i", 0, "ApiKey ID") apiKeysCmd.AddCommand(expireAPIKeyCmd) + + deleteAPIKeyCmd.Flags().StringP("prefix", "p", "", "ApiKey prefix") + deleteAPIKeyCmd.Flags().Uint64P("id", "i", 0, "ApiKey ID") + apiKeysCmd.AddCommand(deleteAPIKeyCmd) } var apiKeysCmd = &cobra.Command{ @@ -49,7 +49,7 @@ var listAPIKeys = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -62,14 +62,10 @@ var listAPIKeys = &cobra.Command{ fmt.Sprintf("Error getting the list of keys: %s", err), output, ) - - return } if output != "" { SuccessOutput(response.GetApiKeys(), "", output) - - return } tableData := pterm.TableData{ @@ -97,8 +93,6 @@ var listAPIKeys = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } @@ -114,9 +108,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - log.Trace(). - Msg("Preparing to create ApiKey") - request := &v1.CreateApiKeyRequest{} durationStr, _ := cmd.Flags().GetString("expiration") @@ -128,19 +119,13 @@ If you loose a key, create a new one and revoke (expire) the old one.`, fmt.Sprintf("Could not parse duration: %s\n", err), output, ) - - return } expiration := time.Now().UTC().Add(time.Duration(duration)) - log.Trace(). - Dur("expiration", time.Duration(duration)). - Msg("expiration has been set") - request.Expiration = timestamppb.New(expiration) - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -151,8 +136,6 @@ If you loose a key, create a new one and revoke (expire) the old one.`, fmt.Sprintf("Cannot create Api Key: %s\n", err), output, ) - - return } SuccessOutput(response.GetApiKey(), response.GetApiKey(), output) @@ -166,23 +149,33 @@ var expireAPIKeyCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - prefix, err := cmd.Flags().GetString("prefix") - if err != nil { + id, _ := cmd.Flags().GetUint64("id") + prefix, _ := cmd.Flags().GetString("prefix") + + switch { + case id == 0 && prefix == "": ErrorOutput( - err, - fmt.Sprintf("Error getting prefix from CLI flag: %s", err), + errMissingParameter, + "Either --id or --prefix must be provided", + output, + ) + case id != 0 && prefix != "": + ErrorOutput( + errMissingParameter, + "Only one of --id or --prefix can be provided", output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - request := &v1.ExpireApiKeyRequest{ - Prefix: prefix, + request := &v1.ExpireApiKeyRequest{} + if id != 0 { + request.Id = id + } else { + request.Prefix = prefix } response, err := client.ExpireApiKey(ctx, request) @@ -192,10 +185,57 @@ var expireAPIKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot expire Api Key: %s\n", err), output, ) - - return } SuccessOutput(response, "Key expired", output) }, } + +var deleteAPIKeyCmd = &cobra.Command{ + Use: "delete", + Short: "Delete an ApiKey", + Aliases: []string{"remove", "del"}, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + + id, _ := cmd.Flags().GetUint64("id") + prefix, _ := cmd.Flags().GetString("prefix") + + switch { + case id == 0 && prefix == "": + ErrorOutput( + errMissingParameter, + "Either --id or --prefix must be provided", + output, + ) + case id != 0 && prefix != "": + ErrorOutput( + errMissingParameter, + "Only one of --id or --prefix can be provided", + output, + ) + } + + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + request := &v1.DeleteApiKeyRequest{} + if id != 0 { + request.Id = id + } else { + request.Prefix = prefix + } + + response, err := client.DeleteApiKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot delete Api Key: %s\n", err), + output, + ) + } + + SuccessOutput(response, "Key deleted", output) + }, +} diff --git a/cmd/headscale/cli/configtest.go b/cmd/headscale/cli/configtest.go index 72744a7b..d469885b 100644 --- a/cmd/headscale/cli/configtest.go +++ b/cmd/headscale/cli/configtest.go @@ -14,7 +14,7 @@ var configTestCmd = &cobra.Command{ Short: "Test the configuration.", Long: "Run a test of the configuration and exit.", Run: func(cmd *cobra.Command, args []string) { - _, err := getHeadscaleApp() + _, err := newHeadscaleServerWithConfig() if err != nil { log.Fatal().Caller().Err(err).Msg("Error initializing") } diff --git a/cmd/headscale/cli/debug.go b/cmd/headscale/cli/debug.go index 054fc07f..75187ddd 100644 --- a/cmd/headscale/cli/debug.go +++ b/cmd/headscale/cli/debug.go @@ -4,14 +4,10 @@ import ( "fmt" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" - "tailscale.com/types/key" -) - -const ( - errPreAuthKeyMalformed = Error("key is malformed. expected 64 hex characters with `nodekey` prefix") ) // Error is used to compare errors as per https://dave.cheney.net/2016/04/07/constant-errors @@ -64,11 +60,9 @@ var createNodeCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -79,31 +73,24 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting node from flag: %s", err), output, ) - - return } - machineKey, err := cmd.Flags().GetString("key") + registrationID, err := cmd.Flags().GetString("key") if err != nil { ErrorOutput( err, fmt.Sprintf("Error getting key from flag: %s", err), output, ) - - return } - var mkey key.MachinePublic - err = mkey.UnmarshalText([]byte(machineKey)) + _, err = types.RegistrationIDFromString(registrationID) if err != nil { ErrorOutput( err, fmt.Sprintf("Failed to parse machine key from flag: %s", err), output, ) - - return } routes, err := cmd.Flags().GetStringSlice("route") @@ -113,12 +100,10 @@ var createNodeCmd = &cobra.Command{ fmt.Sprintf("Error getting routes from flag: %s", err), output, ) - - return } request := &v1.DebugCreateNodeRequest{ - Key: machineKey, + Key: registrationID, Name: name, User: user, Routes: routes, @@ -128,11 +113,9 @@ var createNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot create node: %s", status.Convert(err).Message()), + "Cannot create node: "+status.Convert(err).Message(), output, ) - - return } SuccessOutput(response.GetNode(), "Node created", output) diff --git a/cmd/headscale/cli/health.go b/cmd/headscale/cli/health.go new file mode 100644 index 00000000..864724cc --- /dev/null +++ b/cmd/headscale/cli/health.go @@ -0,0 +1,29 @@ +package cli + +import ( + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/spf13/cobra" +) + +func init() { + rootCmd.AddCommand(healthCmd) +} + +var healthCmd = &cobra.Command{ + Use: "health", + Short: "Check the health of the Headscale server", + Long: "Check the health of the Headscale server. This command will return an exit code of 0 if the server is healthy, or 1 if it is not.", + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + response, err := client.Health(ctx, &v1.HealthRequest{}) + if err != nil { + ErrorOutput(err, "Error checking health", output) + } + + SuccessOutput(response, "", output) + }, +} diff --git a/cmd/headscale/cli/mockoidc.go b/cmd/headscale/cli/mockoidc.go index 568a2a03..9969f7c6 100644 --- a/cmd/headscale/cli/mockoidc.go +++ b/cmd/headscale/cli/mockoidc.go @@ -1,8 +1,11 @@ package cli import ( + "encoding/json" + "errors" "fmt" "net" + "net/http" "os" "strconv" "time" @@ -64,6 +67,19 @@ func mockOIDC() error { accessTTL = newTTL } + userStr := os.Getenv("MOCKOIDC_USERS") + if userStr == "" { + return errors.New("MOCKOIDC_USERS not defined") + } + + var users []mockoidc.MockUser + err := json.Unmarshal([]byte(userStr), &users) + if err != nil { + return fmt.Errorf("unmarshalling users: %w", err) + } + + log.Info().Interface("users", users).Msg("loading users from JSON") + log.Info().Msgf("Access token TTL: %s", accessTTL) port, err := strconv.Atoi(portStr) @@ -71,7 +87,7 @@ func mockOIDC() error { return err } - mock, err := getMockOIDC(clientID, clientSecret) + mock, err := getMockOIDC(clientID, clientSecret, users) if err != nil { return err } @@ -93,12 +109,18 @@ func mockOIDC() error { return nil } -func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, error) { +func getMockOIDC(clientID string, clientSecret string, users []mockoidc.MockUser) (*mockoidc.MockOIDC, error) { keypair, err := mockoidc.NewKeypair(nil) if err != nil { return nil, err } + userQueue := mockoidc.UserQueue{} + + for _, user := range users { + userQueue.Push(&user) + } + mock := mockoidc.MockOIDC{ ClientID: clientID, ClientSecret: clientSecret, @@ -107,9 +129,19 @@ func getMockOIDC(clientID string, clientSecret string) (*mockoidc.MockOIDC, erro CodeChallengeMethodsSupported: []string{"plain", "S256"}, Keypair: keypair, SessionStore: mockoidc.NewSessionStore(), - UserQueue: &mockoidc.UserQueue{}, + UserQueue: &userQueue, ErrorQueue: &mockoidc.ErrorQueue{}, } + mock.AddMiddleware(func(h http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + log.Info().Msgf("Request: %+v", r) + h.ServeHTTP(w, r) + if r.Response != nil { + log.Info().Msgf("Response: %+v", r.Response) + } + }) + }) + return &mock, nil } diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index ac996245..882460dd 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -8,27 +8,29 @@ import ( "strings" "time" - survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" "github.com/pterm/pterm" + "github.com/samber/lo" "github.com/spf13/cobra" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" "tailscale.com/types/key" ) func init() { rootCmd.AddCommand(nodeCmd) listNodesCmd.Flags().StringP("user", "u", "", "Filter by user") - listNodesCmd.Flags().BoolP("tags", "t", false, "Show tags") listNodesCmd.Flags().StringP("namespace", "n", "", "User") listNodesNamespaceFlag := listNodesCmd.Flags().Lookup("namespace") listNodesNamespaceFlag.Deprecated = deprecateNamespaceMessage listNodesNamespaceFlag.Hidden = true - nodeCmd.AddCommand(listNodesCmd) + listNodeRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + nodeCmd.AddCommand(listNodeRoutesCmd) + registerNodeCmd.Flags().StringP("user", "u", "", "User") registerNodeCmd.Flags().StringP("namespace", "n", "", "User") @@ -38,65 +40,48 @@ func init() { err := registerNodeCmd.MarkFlagRequired("user") if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } registerNodeCmd.Flags().StringP("key", "k", "", "Key") err = registerNodeCmd.MarkFlagRequired("key") if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } nodeCmd.AddCommand(registerNodeCmd) expireNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + expireNodeCmd.Flags().StringP("expiry", "e", "", "Set expire to (RFC3339 format, e.g. 2025-08-27T10:00:00Z), or leave empty to expire immediately.") err = expireNodeCmd.MarkFlagRequired("identifier") if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } nodeCmd.AddCommand(expireNodeCmd) renameNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") err = renameNodeCmd.MarkFlagRequired("identifier") if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } nodeCmd.AddCommand(renameNodeCmd) deleteNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") err = deleteNodeCmd.MarkFlagRequired("identifier") if err != nil { - log.Fatalf(err.Error()) + log.Fatal(err.Error()) } nodeCmd.AddCommand(deleteNodeCmd) - moveNodeCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - - err = moveNodeCmd.MarkFlagRequired("identifier") - if err != nil { - log.Fatalf(err.Error()) - } - - moveNodeCmd.Flags().StringP("user", "u", "", "New user") - - moveNodeCmd.Flags().StringP("namespace", "n", "", "User") - moveNodeNamespaceFlag := moveNodeCmd.Flags().Lookup("namespace") - moveNodeNamespaceFlag.Deprecated = deprecateNamespaceMessage - moveNodeNamespaceFlag.Hidden = true - - err = moveNodeCmd.MarkFlagRequired("user") - if err != nil { - log.Fatalf(err.Error()) - } - nodeCmd.AddCommand(moveNodeCmd) - tagCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - - err = tagCmd.MarkFlagRequired("identifier") - if err != nil { - log.Fatalf(err.Error()) - } - tagCmd.Flags(). - StringSliceP("tags", "t", []string{}, "List of tags to add to the node") + tagCmd.MarkFlagRequired("identifier") + tagCmd.Flags().StringSliceP("tags", "t", []string{}, "List of tags to add to the node") nodeCmd.AddCommand(tagCmd) + + approveRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") + approveRoutesCmd.MarkFlagRequired("identifier") + approveRoutesCmd.Flags().StringSliceP("routes", "r", []string{}, `List of routes that will be approved (comma-separated, e.g. "10.0.0.0/8,192.168.0.0/24" or empty string to remove all approved routes)`) + nodeCmd.AddCommand(approveRoutesCmd) + + nodeCmd.AddCommand(backfillNodeIPsCmd) } var nodeCmd = &cobra.Command{ @@ -113,27 +98,23 @@ var registerNodeCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - machineKey, err := cmd.Flags().GetString("key") + registrationID, err := cmd.Flags().GetString("key") if err != nil { ErrorOutput( err, fmt.Sprintf("Error getting node key from flag: %s", err), output, ) - - return } request := &v1.RegisterNodeRequest{ - Key: machineKey, + Key: registrationID, User: user, } @@ -147,8 +128,6 @@ var registerNodeCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput( @@ -166,17 +145,9 @@ var listNodesCmd = &cobra.Command{ user, err := cmd.Flags().GetString("user") if err != nil { ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return - } - showTags, err := cmd.Flags().GetBool("tags") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting tags flag: %s", err), output) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -188,24 +159,18 @@ var listNodesCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), + "Cannot get nodes: "+status.Convert(err).Message(), output, ) - - return } if output != "" { SuccessOutput(response.GetNodes(), "", output) - - return } - tableData, err := nodesToPtables(user, showTags, response.GetNodes()) + tableData, err := nodesToPtables(user, response.GetNodes()) if err != nil { ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - - return } err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() @@ -215,9 +180,72 @@ var listNodesCmd = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) + } + }, +} +var listNodeRoutesCmd = &cobra.Command{ + Use: "list-routes", + Short: "List routes available on nodes", + Aliases: []string{"lsr", "routes"}, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + identifier, err := cmd.Flags().GetUint64("identifier") + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error converting ID to integer: %s", err), + output, + ) + } + + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + request := &v1.ListNodesRequest{} + + response, err := client.ListNodes(ctx, request) + if err != nil { + ErrorOutput( + err, + "Cannot get nodes: "+status.Convert(err).Message(), + output, + ) + } + + nodes := response.GetNodes() + if identifier != 0 { + for _, node := range response.GetNodes() { + if node.GetId() == identifier { + nodes = []*v1.Node{node} + break + } + } + } + + nodes = lo.Filter(nodes, func(n *v1.Node, _ int) bool { + return (n.GetSubnetRoutes() != nil && len(n.GetSubnetRoutes()) > 0) || (n.GetApprovedRoutes() != nil && len(n.GetApprovedRoutes()) > 0) || (n.GetAvailableRoutes() != nil && len(n.GetAvailableRoutes()) > 0) + }) + + if output != "" { + SuccessOutput(nodes, "", output) return } + + tableData, err := nodeRoutesToPtables(nodes) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) + } + + err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Failed to render pterm table: %s", err), + output, + ) + } }, } @@ -236,16 +264,39 @@ var expireNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) + } + + expiry, err := cmd.Flags().GetString("expiry") + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error converting expiry to string: %s", err), + output, + ) return } + expiryTime := time.Now() + if expiry != "" { + expiryTime, err = time.Parse(time.RFC3339, expiry) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error converting expiry to string: %s", err), + output, + ) - ctx, client, conn, cancel := getHeadscaleCLIClient() + return + } + } + + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() request := &v1.ExpireNodeRequest{ NodeId: identifier, + Expiry: timestamppb.New(expiryTime), } response, err := client.ExpireNode(ctx, request) @@ -258,8 +309,6 @@ var expireNodeCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput(response.GetNode(), "Node expired", output) @@ -279,11 +328,9 @@ var renameNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -306,8 +353,6 @@ var renameNodeCmd = &cobra.Command{ ), output, ) - - return } SuccessOutput(response.GetNode(), "Node renamed", output) @@ -328,11 +373,9 @@ var deleteNodeCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - - return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -344,14 +387,9 @@ var deleteNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error getting node node: %s", - status.Convert(err).Message(), - ), + "Error getting node node: "+status.Convert(err).Message(), output, ) - - return } deleteRequest := &v1.DeleteNodeRequest{ @@ -361,16 +399,10 @@ var deleteNodeCmd = &cobra.Command{ confirm := false force, _ := cmd.Flags().GetBool("force") if !force { - prompt := &survey.Confirm{ - Message: fmt.Sprintf( - "Do you want to remove the node %s?", - getResponse.GetNode().GetName(), - ), - } - err = survey.AskOne(prompt, &confirm) - if err != nil { - return - } + confirm = util.YesNo(fmt.Sprintf( + "Do you want to remove the node %s?", + getResponse.GetNode().GetName(), + )) } if confirm || force { @@ -383,14 +415,9 @@ var deleteNodeCmd = &cobra.Command{ if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Error deleting node: %s", - status.Convert(err).Message(), - ), + "Error deleting node: "+status.Convert(err).Message(), output, ) - - return } SuccessOutput( map[string]string{"Result": "Node deleted"}, @@ -403,83 +430,52 @@ var deleteNodeCmd = &cobra.Command{ }, } -var moveNodeCmd = &cobra.Command{ - Use: "move", - Short: "Move node to another user", - Aliases: []string{"mv"}, +var backfillNodeIPsCmd = &cobra.Command{ + Use: "backfillips", + Short: "Backfill IPs missing from nodes", + Long: ` +Backfill IPs can be used to add/remove IPs from nodes +based on the current configuration of Headscale. + +If there are nodes that does not have IPv4 or IPv6 +even if prefixes for both are configured in the config, +this command can be used to assign IPs of the sort to +all nodes that are missing. + +If you remove IPv4 or IPv6 prefixes from the config, +it can be run to remove the IPs that should no longer +be assigned to nodes.`, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - identifier, err := cmd.Flags().GetUint64("identifier") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error converting ID to integer: %s", err), - output, - ) + confirm := false - return + force, _ := cmd.Flags().GetBool("force") + if !force { + confirm = util.YesNo("Are you sure that you want to assign/remove IPs to/from nodes?") } - user, err := cmd.Flags().GetString("user") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting user: %s", err), - output, - ) + if confirm || force { + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() - return + changes, err := client.BackfillNodeIPs(ctx, &v1.BackfillNodeIPsRequest{Confirmed: confirm || force}) + if err != nil { + ErrorOutput( + err, + "Error backfilling IPs: "+status.Convert(err).Message(), + output, + ) + } + + SuccessOutput(changes, "Node IPs backfilled successfully", output) } - - ctx, client, conn, cancel := getHeadscaleCLIClient() - defer cancel() - defer conn.Close() - - getRequest := &v1.GetNodeRequest{ - NodeId: identifier, - } - - _, err = client.GetNode(ctx, getRequest) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf( - "Error getting node: %s", - status.Convert(err).Message(), - ), - output, - ) - - return - } - - moveRequest := &v1.MoveNodeRequest{ - NodeId: identifier, - User: user, - } - - moveResponse, err := client.MoveNode(ctx, moveRequest) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf( - "Error moving node: %s", - status.Convert(err).Message(), - ), - output, - ) - - return - } - - SuccessOutput(moveResponse.GetNode(), "Node moved to another user", output) }, } func nodesToPtables( currentUser string, - showTags bool, nodes []*v1.Node, ) (pterm.TableData, error) { tableHeader := []string{ @@ -489,6 +485,7 @@ func nodesToPtables( "MachineKey", "NodeKey", "User", + "Tags", "IP addresses", "Ephemeral", "Last seen", @@ -496,13 +493,6 @@ func nodesToPtables( "Connected", "Expired", } - if showTags { - tableHeader = append(tableHeader, []string{ - "ForcedTags", - "InvalidTags", - "ValidTags", - }...) - } tableData := pterm.TableData{tableHeader} for _, node := range nodes { @@ -557,25 +547,17 @@ func nodesToPtables( expired = pterm.LightRed("yes") } - var forcedTags string - for _, tag := range node.GetForcedTags() { - forcedTags += "," + tag + // TODO(kradalby): as part of CLI rework, we should add the posibility to show "unusable" tags as mentioned in + // https://github.com/juanfont/headscale/issues/2981 + var tagsBuilder strings.Builder + + for _, tag := range node.GetTags() { + tagsBuilder.WriteString("\n" + tag) } - forcedTags = strings.TrimLeft(forcedTags, ",") - var invalidTags string - for _, tag := range node.GetInvalidTags() { - if !contains(node.GetForcedTags(), tag) { - invalidTags += "," + pterm.LightRed(tag) - } - } - invalidTags = strings.TrimLeft(invalidTags, ",") - var validTags string - for _, tag := range node.GetValidTags() { - if !contains(node.GetForcedTags(), tag) { - validTags += "," + pterm.LightGreen(tag) - } - } - validTags = strings.TrimLeft(validTags, ",") + + tags := tagsBuilder.String() + + tags = strings.TrimLeft(tags, "\n") var user string if currentUser == "" || (currentUser == node.GetUser().GetName()) { @@ -602,6 +584,7 @@ func nodesToPtables( machineKey.ShortString(), nodeKey.ShortString(), user, + tags, strings.Join([]string{IPV4Address, IPV6Address}, ", "), strconv.FormatBool(ephemeral), lastSeenTime, @@ -609,8 +592,34 @@ func nodesToPtables( online, expired, } - if showTags { - nodeData = append(nodeData, []string{forcedTags, invalidTags, validTags}...) + tableData = append( + tableData, + nodeData, + ) + } + + return tableData, nil +} + +func nodeRoutesToPtables( + nodes []*v1.Node, +) (pterm.TableData, error) { + tableHeader := []string{ + "ID", + "Hostname", + "Approved", + "Available", + "Serving (Primary)", + } + tableData := pterm.TableData{tableHeader} + + for _, node := range nodes { + nodeData := []string{ + strconv.FormatUint(node.GetId(), util.Base10), + node.GetGivenName(), + strings.Join(node.GetApprovedRoutes(), "\n"), + strings.Join(node.GetAvailableRoutes(), "\n"), + strings.Join(node.GetSubnetRoutes(), "\n"), } tableData = append( tableData, @@ -627,7 +636,7 @@ var tagCmd = &cobra.Command{ Aliases: []string{"tags", "t"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -639,8 +648,6 @@ var tagCmd = &cobra.Command{ fmt.Sprintf("Error converting ID to integer: %s", err), output, ) - - return } tagsToSet, err := cmd.Flags().GetStringSlice("tags") if err != nil { @@ -649,8 +656,6 @@ var tagCmd = &cobra.Command{ fmt.Sprintf("Error retrieving list of tags to add to node, %v", err), output, ) - - return } // Sending tags to node @@ -665,8 +670,57 @@ var tagCmd = &cobra.Command{ fmt.Sprintf("Error while sending tags to headscale: %s", err), output, ) - - return + } + + if resp != nil { + SuccessOutput( + resp.GetNode(), + "Node updated", + output, + ) + } + }, +} + +var approveRoutesCmd = &cobra.Command{ + Use: "approve-routes", + Short: "Manage the approved routes of a node", + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + // retrieve flags from CLI + identifier, err := cmd.Flags().GetUint64("identifier") + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error converting ID to integer: %s", err), + output, + ) + } + routes, err := cmd.Flags().GetStringSlice("routes") + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error retrieving list of routes to add to node, %v", err), + output, + ) + } + + // Sending routes to node + request := &v1.SetApprovedRoutesRequest{ + NodeId: identifier, + Routes: routes, + } + resp, err := client.SetApprovedRoutes(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Error while sending routes to headscale: %s", err), + output, + ) } if resp != nil { diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go new file mode 100644 index 00000000..2aaebcfa --- /dev/null +++ b/cmd/headscale/cli/policy.go @@ -0,0 +1,210 @@ +package cli + +import ( + "fmt" + "io" + "os" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "github.com/spf13/cobra" + "tailscale.com/types/views" +) + +const ( + bypassFlag = "bypass-grpc-and-access-database-directly" +) + +func init() { + rootCmd.AddCommand(policyCmd) + + getPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") + policyCmd.AddCommand(getPolicy) + + setPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") + if err := setPolicy.MarkFlagRequired("file"); err != nil { + log.Fatal().Err(err).Msg("") + } + setPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") + policyCmd.AddCommand(setPolicy) + + checkPolicy.Flags().StringP("file", "f", "", "Path to a policy file in HuJSON format") + if err := checkPolicy.MarkFlagRequired("file"); err != nil { + log.Fatal().Err(err).Msg("") + } + policyCmd.AddCommand(checkPolicy) +} + +var policyCmd = &cobra.Command{ + Use: "policy", + Short: "Manage the Headscale ACL Policy", +} + +var getPolicy = &cobra.Command{ + Use: "get", + Short: "Print the current ACL Policy", + Aliases: []string{"show", "view", "fetch"}, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + var policy string + if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { + confirm := false + force, _ := cmd.Flags().GetBool("force") + if !force { + confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") + } + + if !confirm && !force { + ErrorOutput(nil, "Aborting command", output) + return + } + + cfg, err := types.LoadServerConfig() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output) + } + + d, err := db.NewHeadscaleDatabase( + cfg, + nil, + ) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output) + } + + pol, err := d.GetPolicy() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed loading Policy from database: %s", err), output) + } + + policy = pol.Data + } else { + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + request := &v1.GetPolicyRequest{} + + response, err := client.GetPolicy(ctx, request) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed loading ACL Policy: %s", err), output) + } + + policy = response.GetPolicy() + } + + // TODO(pallabpain): Maybe print this better? + // This does not pass output as we dont support yaml, json or json-line + // output for this command. It is HuJSON already. + SuccessOutput("", policy, "") + }, +} + +var setPolicy = &cobra.Command{ + Use: "set", + Short: "Updates the ACL Policy", + Long: ` + Updates the existing ACL Policy with the provided policy. The policy must be a valid HuJSON object. + This command only works when the acl.policy_mode is set to "db", and the policy will be stored in the database.`, + Aliases: []string{"put", "update"}, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + policyPath, _ := cmd.Flags().GetString("file") + + f, err := os.Open(policyPath) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output) + } + defer f.Close() + + policyBytes, err := io.ReadAll(f) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) + } + + if bypass, _ := cmd.Flags().GetBool(bypassFlag); bypass { + confirm := false + force, _ := cmd.Flags().GetBool("force") + if !force { + confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") + } + + if !confirm && !force { + ErrorOutput(nil, "Aborting command", output) + return + } + + cfg, err := types.LoadServerConfig() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output) + } + + d, err := db.NewHeadscaleDatabase( + cfg, + nil, + ) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output) + } + + users, err := d.ListUsers() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to load users for policy validation: %s", err), output) + } + + _, err = policy.NewPolicyManager(policyBytes, users, views.Slice[types.NodeView]{}) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) + return + } + + _, err = d.SetPolicy(string(policyBytes)) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) + } + } else { + request := &v1.SetPolicyRequest{Policy: string(policyBytes)} + + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + if _, err := client.SetPolicy(ctx, request); err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to set ACL Policy: %s", err), output) + } + } + + SuccessOutput(nil, "Policy updated.", "") + }, +} + +var checkPolicy = &cobra.Command{ + Use: "check", + Short: "Check the Policy file for errors", + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + policyPath, _ := cmd.Flags().GetString("file") + + f, err := os.Open(policyPath) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error opening the policy file: %s", err), output) + } + defer f.Close() + + policyBytes, err := io.ReadAll(f) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error reading the policy file: %s", err), output) + } + + _, err = policy.NewPolicyManager(policyBytes, nil, views.Slice[types.NodeView]{}) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error parsing the policy file: %s", err), output) + } + + SuccessOutput(nil, "Policy is valid", "") + }, +} diff --git a/cmd/headscale/cli/preauthkeys.go b/cmd/headscale/cli/preauthkeys.go index c8dd2adc..51133200 100644 --- a/cmd/headscale/cli/preauthkeys.go +++ b/cmd/headscale/cli/preauthkeys.go @@ -20,20 +20,10 @@ const ( func init() { rootCmd.AddCommand(preauthkeysCmd) - preauthkeysCmd.PersistentFlags().StringP("user", "u", "", "User") - - preauthkeysCmd.PersistentFlags().StringP("namespace", "n", "", "User") - pakNamespaceFlag := preauthkeysCmd.PersistentFlags().Lookup("namespace") - pakNamespaceFlag.Deprecated = deprecateNamespaceMessage - pakNamespaceFlag.Hidden = true - - err := preauthkeysCmd.MarkPersistentFlagRequired("user") - if err != nil { - log.Fatal().Err(err).Msg("") - } preauthkeysCmd.AddCommand(listPreAuthKeys) preauthkeysCmd.AddCommand(createPreAuthKeyCmd) preauthkeysCmd.AddCommand(expirePreAuthKeyCmd) + preauthkeysCmd.AddCommand(deletePreAuthKeyCmd) createPreAuthKeyCmd.PersistentFlags(). Bool("reusable", false, "Make the preauthkey reusable") createPreAuthKeyCmd.PersistentFlags(). @@ -42,6 +32,9 @@ func init() { StringP("expiration", "e", DefaultPreAuthKeyExpiry, "Human-readable expiration of the key (e.g. 30m, 24h)") createPreAuthKeyCmd.Flags(). StringSlice("tags", []string{}, "Tags to automatically assign to node") + createPreAuthKeyCmd.PersistentFlags().Uint64P("user", "u", 0, "User identifier (ID)") + expirePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID") + deletePreAuthKeyCmd.PersistentFlags().Uint64P("id", "i", 0, "Authkey ID") } var preauthkeysCmd = &cobra.Command{ @@ -52,27 +45,16 @@ var preauthkeysCmd = &cobra.Command{ var listPreAuthKeys = &cobra.Command{ Use: "list", - Short: "List the preauthkeys for this user", + Short: "List all preauthkeys", Aliases: []string{"ls", "show"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetString("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return - } - - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - request := &v1.ListPreAuthKeysRequest{ - User: user, - } - - response, err := client.ListPreAuthKeys(ctx, request) + response, err := client.ListPreAuthKeys(ctx, &v1.ListPreAuthKeysRequest{}) if err != nil { ErrorOutput( err, @@ -85,20 +67,18 @@ var listPreAuthKeys = &cobra.Command{ if output != "" { SuccessOutput(response.GetPreAuthKeys(), "", output) - - return } tableData := pterm.TableData{ { "ID", - "Key", + "Key/Prefix", "Reusable", "Ephemeral", "Used", "Expiration", "Created", - "Tags", + "Owner", }, } for _, key := range response.GetPreAuthKeys() { @@ -107,30 +87,24 @@ var listPreAuthKeys = &cobra.Command{ expiration = ColourTime(key.GetExpiration().AsTime()) } - var reusable string - if key.GetEphemeral() { - reusable = "N/A" + var owner string + if len(key.GetAclTags()) > 0 { + owner = strings.Join(key.GetAclTags(), "\n") + } else if key.GetUser() != nil { + owner = key.GetUser().GetName() } else { - reusable = fmt.Sprintf("%v", key.GetReusable()) + owner = "-" } - aclTags := "" - - for _, tag := range key.GetAclTags() { - aclTags += "," + tag - } - - aclTags = strings.TrimLeft(aclTags, ",") - tableData = append(tableData, []string{ - key.GetId(), + strconv.FormatUint(key.GetId(), 10), key.GetKey(), - reusable, + strconv.FormatBool(key.GetReusable()), strconv.FormatBool(key.GetEphemeral()), strconv.FormatBool(key.GetUsed()), expiration, key.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), - aclTags, + owner, }) } @@ -141,36 +115,22 @@ var listPreAuthKeys = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } var createPreAuthKeyCmd = &cobra.Command{ Use: "create", - Short: "Creates a new preauthkey in the specified user", + Short: "Creates a new preauthkey", Aliases: []string{"c", "new"}, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetString("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) - - return - } - + user, _ := cmd.Flags().GetUint64("user") reusable, _ := cmd.Flags().GetBool("reusable") ephemeral, _ := cmd.Flags().GetBool("ephemeral") tags, _ := cmd.Flags().GetStringSlice("tags") - log.Trace(). - Bool("reusable", reusable). - Bool("ephemeral", ephemeral). - Str("user", user). - Msg("Preparing to create preauthkey") - request := &v1.CreatePreAuthKeyRequest{ User: user, Reusable: reusable, @@ -187,8 +147,6 @@ var createPreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Could not parse duration: %s\n", err), output, ) - - return } expiration := time.Now().UTC().Add(time.Duration(duration)) @@ -199,7 +157,7 @@ var createPreAuthKeyCmd = &cobra.Command{ request.Expiration = timestamppb.New(expiration) - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -210,8 +168,6 @@ var createPreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot create Pre Auth Key: %s\n", err), output, ) - - return } SuccessOutput(response.GetPreAuthKey(), response.GetPreAuthKey().GetKey(), output) @@ -219,32 +175,29 @@ var createPreAuthKeyCmd = &cobra.Command{ } var expirePreAuthKeyCmd = &cobra.Command{ - Use: "expire KEY", + Use: "expire", Short: "Expire a preauthkey", Aliases: []string{"revoke", "exp", "e"}, - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return errMissingParameter - } - - return nil - }, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - user, err := cmd.Flags().GetString("user") - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error getting user: %s", err), output) + id, _ := cmd.Flags().GetUint64("id") + + if id == 0 { + ErrorOutput( + errMissingParameter, + "Error: missing --id parameter", + output, + ) return } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() request := &v1.ExpirePreAuthKeyRequest{ - User: user, - Key: args[0], + Id: id, } response, err := client.ExpirePreAuthKey(ctx, request) @@ -254,10 +207,47 @@ var expirePreAuthKeyCmd = &cobra.Command{ fmt.Sprintf("Cannot expire Pre Auth Key: %s\n", err), output, ) - - return } SuccessOutput(response, "Key expired", output) }, } + +var deletePreAuthKeyCmd = &cobra.Command{ + Use: "delete", + Short: "Delete a preauthkey", + Aliases: []string{"del", "rm", "d"}, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + id, _ := cmd.Flags().GetUint64("id") + + if id == 0 { + ErrorOutput( + errMissingParameter, + "Error: missing --id parameter", + output, + ) + + return + } + + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + request := &v1.DeletePreAuthKeyRequest{ + Id: id, + } + + response, err := client.DeletePreAuthKey(ctx, request) + if err != nil { + ErrorOutput( + err, + fmt.Sprintf("Cannot delete Pre Auth Key: %s\n", err), + output, + ) + } + + SuccessOutput(response, "Key deleted", output) + }, +} diff --git a/cmd/headscale/cli/root.go b/cmd/headscale/cli/root.go index 40a9b18a..d7cdabb6 100644 --- a/cmd/headscale/cli/root.go +++ b/cmd/headscale/cli/root.go @@ -4,11 +4,14 @@ import ( "fmt" "os" "runtime" + "slices" + "strings" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog" "github.com/rs/zerolog/log" "github.com/spf13/cobra" + "github.com/spf13/viper" "github.com/tcnksm/go-latest" ) @@ -24,6 +27,11 @@ func init() { return } + if slices.Contains(os.Args, "policy") && slices.Contains(os.Args, "check") { + zerolog.SetGlobalLevel(zerolog.Disabled) + return + } + cobra.OnInitialize(initConfig) rootCmd.PersistentFlags(). StringVarP(&cfgFile, "config", "c", "", "config file (default is /etc/headscale/config.yaml)") @@ -49,45 +57,79 @@ func initConfig() { } } - cfg, err := types.GetHeadscaleConfig() - if err != nil { - log.Fatal().Caller().Err(err).Msg("Failed to get headscale configuration") - } - machineOutput := HasMachineOutputFlag() - zerolog.SetGlobalLevel(cfg.Log.Level) - // If the user has requested a "node" readable format, // then disable login so the output remains valid. if machineOutput { zerolog.SetGlobalLevel(zerolog.Disabled) } - if cfg.Log.Format == types.JSONLogFormat { + logFormat := viper.GetString("log.format") + if logFormat == types.JSONLogFormat { log.Logger = log.Output(os.Stdout) } - if !cfg.DisableUpdateCheck && !machineOutput { + disableUpdateCheck := viper.GetBool("disable_check_updates") + if !disableUpdateCheck && !machineOutput { + versionInfo := types.GetVersionInfo() if (runtime.GOOS == "linux" || runtime.GOOS == "darwin") && - Version != "dev" { + !versionInfo.Dirty { githubTag := &latest.GithubTag{ - Owner: "juanfont", - Repository: "headscale", + Owner: "juanfont", + Repository: "headscale", + TagFilterFunc: filterPreReleasesIfStable(func() string { return versionInfo.Version }), } - res, err := latest.Check(githubTag, Version) + res, err := latest.Check(githubTag, versionInfo.Version) if err == nil && res.Outdated { //nolint - fmt.Printf( + log.Warn().Msgf( "An updated version of Headscale has been found (%s vs. your current %s). Check it out https://github.com/juanfont/headscale/releases\n", res.Current, - Version, + versionInfo.Version, ) } } } } +var prereleases = []string{"alpha", "beta", "rc", "dev"} + +func isPreReleaseVersion(version string) bool { + for _, unstable := range prereleases { + if strings.Contains(version, unstable) { + return true + } + } + return false +} + +// filterPreReleasesIfStable returns a function that filters out +// pre-release tags if the current version is stable. +// If the current version is a pre-release, it does not filter anything. +// versionFunc is a function that returns the current version string, it is +// a func for testability. +func filterPreReleasesIfStable(versionFunc func() string) func(string) bool { + return func(tag string) bool { + version := versionFunc() + + // If we are on a pre-release version, then we do not filter anything + // as we want to recommend the user the latest pre-release. + if isPreReleaseVersion(version) { + return false + } + + // If we are on a stable release, filter out pre-releases. + for _, ignore := range prereleases { + if strings.Contains(tag, ignore) { + return true + } + } + + return false + } +} + var rootCmd = &cobra.Command{ Use: "headscale", Short: "headscale - a Tailscale control server", diff --git a/cmd/headscale/cli/root_test.go b/cmd/headscale/cli/root_test.go new file mode 100644 index 00000000..8d1b9c01 --- /dev/null +++ b/cmd/headscale/cli/root_test.go @@ -0,0 +1,293 @@ +package cli + +import ( + "testing" +) + +func TestFilterPreReleasesIfStable(t *testing.T) { + tests := []struct { + name string + currentVersion string + tag string + expectedFilter bool + description string + }{ + { + name: "stable version filters alpha tag", + currentVersion: "0.23.0", + tag: "v0.24.0-alpha.1", + expectedFilter: true, + description: "When on stable release, alpha tags should be filtered", + }, + { + name: "stable version filters beta tag", + currentVersion: "0.23.0", + tag: "v0.24.0-beta.2", + expectedFilter: true, + description: "When on stable release, beta tags should be filtered", + }, + { + name: "stable version filters rc tag", + currentVersion: "0.23.0", + tag: "v0.24.0-rc.1", + expectedFilter: true, + description: "When on stable release, rc tags should be filtered", + }, + { + name: "stable version allows stable tag", + currentVersion: "0.23.0", + tag: "v0.24.0", + expectedFilter: false, + description: "When on stable release, stable tags should not be filtered", + }, + { + name: "alpha version allows alpha tag", + currentVersion: "0.23.0-alpha.1", + tag: "v0.24.0-alpha.2", + expectedFilter: false, + description: "When on alpha release, alpha tags should not be filtered", + }, + { + name: "alpha version allows beta tag", + currentVersion: "0.23.0-alpha.1", + tag: "v0.24.0-beta.1", + expectedFilter: false, + description: "When on alpha release, beta tags should not be filtered", + }, + { + name: "alpha version allows rc tag", + currentVersion: "0.23.0-alpha.1", + tag: "v0.24.0-rc.1", + expectedFilter: false, + description: "When on alpha release, rc tags should not be filtered", + }, + { + name: "alpha version allows stable tag", + currentVersion: "0.23.0-alpha.1", + tag: "v0.24.0", + expectedFilter: false, + description: "When on alpha release, stable tags should not be filtered", + }, + { + name: "beta version allows alpha tag", + currentVersion: "0.23.0-beta.1", + tag: "v0.24.0-alpha.1", + expectedFilter: false, + description: "When on beta release, alpha tags should not be filtered", + }, + { + name: "beta version allows beta tag", + currentVersion: "0.23.0-beta.2", + tag: "v0.24.0-beta.3", + expectedFilter: false, + description: "When on beta release, beta tags should not be filtered", + }, + { + name: "beta version allows rc tag", + currentVersion: "0.23.0-beta.1", + tag: "v0.24.0-rc.1", + expectedFilter: false, + description: "When on beta release, rc tags should not be filtered", + }, + { + name: "beta version allows stable tag", + currentVersion: "0.23.0-beta.1", + tag: "v0.24.0", + expectedFilter: false, + description: "When on beta release, stable tags should not be filtered", + }, + { + name: "rc version allows alpha tag", + currentVersion: "0.23.0-rc.1", + tag: "v0.24.0-alpha.1", + expectedFilter: false, + description: "When on rc release, alpha tags should not be filtered", + }, + { + name: "rc version allows beta tag", + currentVersion: "0.23.0-rc.1", + tag: "v0.24.0-beta.1", + expectedFilter: false, + description: "When on rc release, beta tags should not be filtered", + }, + { + name: "rc version allows rc tag", + currentVersion: "0.23.0-rc.2", + tag: "v0.24.0-rc.3", + expectedFilter: false, + description: "When on rc release, rc tags should not be filtered", + }, + { + name: "rc version allows stable tag", + currentVersion: "0.23.0-rc.1", + tag: "v0.24.0", + expectedFilter: false, + description: "When on rc release, stable tags should not be filtered", + }, + { + name: "stable version with patch filters alpha", + currentVersion: "0.23.1", + tag: "v0.24.0-alpha.1", + expectedFilter: true, + description: "Stable version with patch number should filter alpha tags", + }, + { + name: "stable version with patch allows stable", + currentVersion: "0.23.1", + tag: "v0.24.0", + expectedFilter: false, + description: "Stable version with patch number should allow stable tags", + }, + { + name: "tag with alpha substring in version number", + currentVersion: "0.23.0", + tag: "v1.0.0-alpha.1", + expectedFilter: true, + description: "Tags with alpha in version string should be filtered on stable", + }, + { + name: "tag with beta substring in version number", + currentVersion: "0.23.0", + tag: "v1.0.0-beta.1", + expectedFilter: true, + description: "Tags with beta in version string should be filtered on stable", + }, + { + name: "tag with rc substring in version number", + currentVersion: "0.23.0", + tag: "v1.0.0-rc.1", + expectedFilter: true, + description: "Tags with rc in version string should be filtered on stable", + }, + { + name: "empty tag on stable version", + currentVersion: "0.23.0", + tag: "", + expectedFilter: false, + description: "Empty tags should not be filtered", + }, + { + name: "dev version allows all tags", + currentVersion: "0.23.0-dev", + tag: "v0.24.0-alpha.1", + expectedFilter: false, + description: "Dev versions should not filter any tags (pre-release allows all)", + }, + { + name: "stable version filters dev tag", + currentVersion: "0.23.0", + tag: "v0.24.0-dev", + expectedFilter: true, + description: "When on stable release, dev tags should be filtered", + }, + { + name: "dev version allows dev tag", + currentVersion: "0.23.0-dev", + tag: "v0.24.0-dev.1", + expectedFilter: false, + description: "When on dev release, dev tags should not be filtered", + }, + { + name: "dev version allows stable tag", + currentVersion: "0.23.0-dev", + tag: "v0.24.0", + expectedFilter: false, + description: "When on dev release, stable tags should not be filtered", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := filterPreReleasesIfStable(func() string { return tt.currentVersion })(tt.tag) + if result != tt.expectedFilter { + t.Errorf("%s: got %v, want %v\nDescription: %s\nCurrent version: %s, Tag: %s", + tt.name, + result, + tt.expectedFilter, + tt.description, + tt.currentVersion, + tt.tag, + ) + } + }) + } +} + +func TestIsPreReleaseVersion(t *testing.T) { + tests := []struct { + name string + version string + expected bool + description string + }{ + { + name: "stable version", + version: "0.23.0", + expected: false, + description: "Stable version should not be pre-release", + }, + { + name: "alpha version", + version: "0.23.0-alpha.1", + expected: true, + description: "Alpha version should be pre-release", + }, + { + name: "beta version", + version: "0.23.0-beta.1", + expected: true, + description: "Beta version should be pre-release", + }, + { + name: "rc version", + version: "0.23.0-rc.1", + expected: true, + description: "RC version should be pre-release", + }, + { + name: "version with alpha substring", + version: "0.23.0-alphabetical", + expected: true, + description: "Version containing 'alpha' should be pre-release", + }, + { + name: "version with beta substring", + version: "0.23.0-betamax", + expected: true, + description: "Version containing 'beta' should be pre-release", + }, + { + name: "dev version", + version: "0.23.0-dev", + expected: true, + description: "Dev version should be pre-release", + }, + { + name: "empty version", + version: "", + expected: false, + description: "Empty version should not be pre-release", + }, + { + name: "version with patch number", + version: "0.23.1", + expected: false, + description: "Stable version with patch should not be pre-release", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := isPreReleaseVersion(tt.version) + if result != tt.expected { + t.Errorf("%s: got %v, want %v\nDescription: %s\nVersion: %s", + tt.name, + result, + tt.expected, + tt.description, + tt.version, + ) + } + }) + } +} diff --git a/cmd/headscale/cli/routes.go b/cmd/headscale/cli/routes.go deleted file mode 100644 index 86ef295c..00000000 --- a/cmd/headscale/cli/routes.go +++ /dev/null @@ -1,298 +0,0 @@ -package cli - -import ( - "fmt" - "log" - "net/netip" - "strconv" - - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/pterm/pterm" - "github.com/spf13/cobra" - "google.golang.org/grpc/status" -) - -const ( - Base10 = 10 -) - -func init() { - rootCmd.AddCommand(routesCmd) - listRoutesCmd.Flags().Uint64P("identifier", "i", 0, "Node identifier (ID)") - routesCmd.AddCommand(listRoutesCmd) - - enableRouteCmd.Flags().Uint64P("route", "r", 0, "Route identifier (ID)") - err := enableRouteCmd.MarkFlagRequired("route") - if err != nil { - log.Fatalf(err.Error()) - } - routesCmd.AddCommand(enableRouteCmd) - - disableRouteCmd.Flags().Uint64P("route", "r", 0, "Route identifier (ID)") - err = disableRouteCmd.MarkFlagRequired("route") - if err != nil { - log.Fatalf(err.Error()) - } - routesCmd.AddCommand(disableRouteCmd) - - deleteRouteCmd.Flags().Uint64P("route", "r", 0, "Route identifier (ID)") - err = deleteRouteCmd.MarkFlagRequired("route") - if err != nil { - log.Fatalf(err.Error()) - } - routesCmd.AddCommand(deleteRouteCmd) -} - -var routesCmd = &cobra.Command{ - Use: "routes", - Short: "Manage the routes of Headscale", - Aliases: []string{"r", "route"}, -} - -var listRoutesCmd = &cobra.Command{ - Use: "list", - Short: "List all routes", - Aliases: []string{"ls", "show"}, - Run: func(cmd *cobra.Command, args []string) { - output, _ := cmd.Flags().GetString("output") - - machineID, err := cmd.Flags().GetUint64("identifier") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting machine id from flag: %s", err), - output, - ) - - return - } - - ctx, client, conn, cancel := getHeadscaleCLIClient() - defer cancel() - defer conn.Close() - - var routes []*v1.Route - - if machineID == 0 { - response, err := client.GetRoutes(ctx, &v1.GetRoutesRequest{}) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot get nodes: %s", status.Convert(err).Message()), - output, - ) - - return - } - - if output != "" { - SuccessOutput(response.GetRoutes(), "", output) - - return - } - - routes = response.GetRoutes() - } else { - response, err := client.GetNodeRoutes(ctx, &v1.GetNodeRoutesRequest{ - NodeId: machineID, - }) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot get routes for node %d: %s", machineID, status.Convert(err).Message()), - output, - ) - - return - } - - if output != "" { - SuccessOutput(response.GetRoutes(), "", output) - - return - } - - routes = response.GetRoutes() - } - - tableData := routesToPtables(routes) - if err != nil { - ErrorOutput(err, fmt.Sprintf("Error converting to table: %s", err), output) - - return - } - - err = pterm.DefaultTable.WithHasHeader().WithData(tableData).Render() - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Failed to render pterm table: %s", err), - output, - ) - - return - } - }, -} - -var enableRouteCmd = &cobra.Command{ - Use: "enable", - Short: "Set a route as enabled", - Long: `This command will make as enabled a given route.`, - Run: func(cmd *cobra.Command, args []string) { - output, _ := cmd.Flags().GetString("output") - - routeID, err := cmd.Flags().GetUint64("route") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting machine id from flag: %s", err), - output, - ) - - return - } - - ctx, client, conn, cancel := getHeadscaleCLIClient() - defer cancel() - defer conn.Close() - - response, err := client.EnableRoute(ctx, &v1.EnableRouteRequest{ - RouteId: routeID, - }) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot enable route %d: %s", routeID, status.Convert(err).Message()), - output, - ) - - return - } - - if output != "" { - SuccessOutput(response, "", output) - - return - } - }, -} - -var disableRouteCmd = &cobra.Command{ - Use: "disable", - Short: "Set as disabled a given route", - Long: `This command will make as disabled a given route.`, - Run: func(cmd *cobra.Command, args []string) { - output, _ := cmd.Flags().GetString("output") - - routeID, err := cmd.Flags().GetUint64("route") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting machine id from flag: %s", err), - output, - ) - - return - } - - ctx, client, conn, cancel := getHeadscaleCLIClient() - defer cancel() - defer conn.Close() - - response, err := client.DisableRoute(ctx, &v1.DisableRouteRequest{ - RouteId: routeID, - }) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot disable route %d: %s", routeID, status.Convert(err).Message()), - output, - ) - - return - } - - if output != "" { - SuccessOutput(response, "", output) - - return - } - }, -} - -var deleteRouteCmd = &cobra.Command{ - Use: "delete", - Short: "Delete a given route", - Long: `This command will delete a given route.`, - Run: func(cmd *cobra.Command, args []string) { - output, _ := cmd.Flags().GetString("output") - - routeID, err := cmd.Flags().GetUint64("route") - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Error getting machine id from flag: %s", err), - output, - ) - - return - } - - ctx, client, conn, cancel := getHeadscaleCLIClient() - defer cancel() - defer conn.Close() - - response, err := client.DeleteRoute(ctx, &v1.DeleteRouteRequest{ - RouteId: routeID, - }) - if err != nil { - ErrorOutput( - err, - fmt.Sprintf("Cannot delete route %d: %s", routeID, status.Convert(err).Message()), - output, - ) - - return - } - - if output != "" { - SuccessOutput(response, "", output) - - return - } - }, -} - -// routesToPtables converts the list of routes to a nice table. -func routesToPtables(routes []*v1.Route) pterm.TableData { - tableData := pterm.TableData{{"ID", "Node", "Prefix", "Advertised", "Enabled", "Primary"}} - - for _, route := range routes { - var isPrimaryStr string - prefix, err := netip.ParsePrefix(route.GetPrefix()) - if err != nil { - log.Printf("Error parsing prefix %s: %s", route.GetPrefix(), err) - - continue - } - if prefix == types.ExitRouteV4 || prefix == types.ExitRouteV6 { - isPrimaryStr = "-" - } else { - isPrimaryStr = strconv.FormatBool(route.GetIsPrimary()) - } - - tableData = append(tableData, - []string{ - strconv.FormatUint(route.GetId(), Base10), - route.GetNode().GetGivenName(), - route.GetPrefix(), - strconv.FormatBool(route.GetAdvertised()), - strconv.FormatBool(route.GetEnabled()), - isPrimaryStr, - }) - } - - return tableData -} diff --git a/cmd/headscale/cli/server.go b/cmd/headscale/cli/serve.go similarity index 50% rename from cmd/headscale/cli/server.go rename to cmd/headscale/cli/serve.go index a1d19600..8f05f851 100644 --- a/cmd/headscale/cli/server.go +++ b/cmd/headscale/cli/serve.go @@ -1,8 +1,13 @@ package cli import ( + "errors" + "fmt" + "net/http" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" + "github.com/tailscale/squibble" ) func init() { @@ -16,14 +21,20 @@ var serveCmd = &cobra.Command{ return nil }, Run: func(cmd *cobra.Command, args []string) { - app, err := getHeadscaleApp() + app, err := newHeadscaleServerWithConfig() if err != nil { + var squibbleErr squibble.ValidationError + if errors.As(err, &squibbleErr) { + fmt.Printf("SQLite schema failed to validate:\n") + fmt.Println(squibbleErr.Diff) + } + log.Fatal().Caller().Err(err).Msg("Error initializing") } err = app.Serve() - if err != nil { - log.Fatal().Caller().Err(err).Msg("Error starting server") + if err != nil && !errors.Is(err, http.ErrServerClosed) { + log.Fatal().Caller().Err(err).Msg("Headscale ran into an error and had to shut down.") } }, } diff --git a/cmd/headscale/cli/users.go b/cmd/headscale/cli/users.go index e6463d6f..9a816c78 100644 --- a/cmd/headscale/cli/users.go +++ b/cmd/headscale/cli/users.go @@ -3,21 +3,54 @@ package cli import ( "errors" "fmt" + "net/url" + "strconv" - survey "github.com/AlecAivazis/survey/v2" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/util" "github.com/pterm/pterm" "github.com/rs/zerolog/log" "github.com/spf13/cobra" "google.golang.org/grpc/status" ) +func usernameAndIDFlag(cmd *cobra.Command) { + cmd.Flags().Int64P("identifier", "i", -1, "User identifier (ID)") + cmd.Flags().StringP("name", "n", "", "Username") +} + +// usernameAndIDFromFlag returns the username and ID from the flags of the command. +// If both are empty, it will exit the program with an error. +func usernameAndIDFromFlag(cmd *cobra.Command) (uint64, string) { + username, _ := cmd.Flags().GetString("name") + identifier, _ := cmd.Flags().GetInt64("identifier") + if username == "" && identifier < 0 { + err := errors.New("--name or --identifier flag is required") + ErrorOutput( + err, + "Cannot rename user: "+status.Convert(err).Message(), + "", + ) + } + + return uint64(identifier), username +} + func init() { rootCmd.AddCommand(userCmd) userCmd.AddCommand(createUserCmd) + createUserCmd.Flags().StringP("display-name", "d", "", "Display name") + createUserCmd.Flags().StringP("email", "e", "", "Email") + createUserCmd.Flags().StringP("picture-url", "p", "", "Profile picture URL") userCmd.AddCommand(listUsersCmd) + usernameAndIDFlag(listUsersCmd) + listUsersCmd.Flags().StringP("email", "e", "", "Email") userCmd.AddCommand(destroyUserCmd) + usernameAndIDFlag(destroyUserCmd) userCmd.AddCommand(renameUserCmd) + usernameAndIDFlag(renameUserCmd) + renameUserCmd.Flags().StringP("new-name", "r", "", "New username") + renameNodeCmd.MarkFlagRequired("new-name") } var errMissingParameter = errors.New("missing parameters") @@ -44,7 +77,7 @@ var createUserCmd = &cobra.Command{ userName := args[0] - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() @@ -52,19 +85,36 @@ var createUserCmd = &cobra.Command{ request := &v1.CreateUserRequest{Name: userName} + if displayName, _ := cmd.Flags().GetString("display-name"); displayName != "" { + request.DisplayName = displayName + } + + if email, _ := cmd.Flags().GetString("email"); email != "" { + request.Email = email + } + + if pictureURL, _ := cmd.Flags().GetString("picture-url"); pictureURL != "" { + if _, err := url.Parse(pictureURL); err != nil { + ErrorOutput( + err, + fmt.Sprintf( + "Invalid Picture URL: %s", + err, + ), + output, + ) + } + request.PictureUrl = pictureURL + } + log.Trace().Interface("request", request).Msg("Sending CreateUser request") response, err := client.CreateUser(ctx, request) if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot create user: %s", - status.Convert(err).Message(), - ), + "Cannot create user: "+status.Convert(err).Message(), output, ) - - return } SuccessOutput(response.GetUser(), "User created", output) @@ -72,70 +122,61 @@ var createUserCmd = &cobra.Command{ } var destroyUserCmd = &cobra.Command{ - Use: "destroy NAME", + Use: "destroy --identifier ID or --name NAME", Short: "Destroys a user", Aliases: []string{"delete"}, - Args: func(cmd *cobra.Command, args []string) error { - if len(args) < 1 { - return errMissingParameter - } - - return nil - }, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - userName := args[0] - - request := &v1.GetUserRequest{ - Name: userName, + id, username := usernameAndIDFromFlag(cmd) + request := &v1.ListUsersRequest{ + Name: username, + Id: id, } - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - _, err := client.GetUser(ctx, request) + users, err := client.ListUsers(ctx, request) if err != nil { ErrorOutput( err, - fmt.Sprintf("Error: %s", status.Convert(err).Message()), + "Error: "+status.Convert(err).Message(), output, ) - - return } + if len(users.GetUsers()) != 1 { + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + ErrorOutput( + err, + "Error: "+status.Convert(err).Message(), + output, + ) + } + + user := users.GetUsers()[0] + confirm := false force, _ := cmd.Flags().GetBool("force") if !force { - prompt := &survey.Confirm{ - Message: fmt.Sprintf( - "Do you want to remove the user '%s' and any associated preauthkeys?", - userName, - ), - } - err := survey.AskOne(prompt, &confirm) - if err != nil { - return - } + confirm = util.YesNo(fmt.Sprintf( + "Do you want to remove the user %q (%d) and any associated preauthkeys?", + user.GetName(), user.GetId(), + )) } if confirm || force { - request := &v1.DeleteUserRequest{Name: userName} + request := &v1.DeleteUserRequest{Id: user.GetId()} response, err := client.DeleteUser(ctx, request) if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot destroy user: %s", - status.Convert(err).Message(), - ), + "Cannot destroy user: "+status.Convert(err).Message(), output, ) - - return } SuccessOutput(response, "User destroyed", output) } else { @@ -151,36 +192,48 @@ var listUsersCmd = &cobra.Command{ Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() request := &v1.ListUsersRequest{} + id, _ := cmd.Flags().GetInt64("identifier") + username, _ := cmd.Flags().GetString("name") + email, _ := cmd.Flags().GetString("email") + + // filter by one param at most + switch { + case id > 0: + request.Id = uint64(id) + case username != "": + request.Name = username + case email != "": + request.Email = email + } + response, err := client.ListUsers(ctx, request) if err != nil { ErrorOutput( err, - fmt.Sprintf("Cannot get users: %s", status.Convert(err).Message()), + "Cannot get users: "+status.Convert(err).Message(), output, ) - - return } if output != "" { SuccessOutput(response.GetUsers(), "", output) - - return } - tableData := pterm.TableData{{"ID", "Name", "Created"}} + tableData := pterm.TableData{{"ID", "Name", "Username", "Email", "Created"}} for _, user := range response.GetUsers() { tableData = append( tableData, []string{ - user.GetId(), + strconv.FormatUint(user.GetId(), 10), + user.GetDisplayName(), user.GetName(), + user.GetEmail(), user.GetCreatedAt().AsTime().Format("2006-01-02 15:04:05"), }, ) @@ -192,48 +245,59 @@ var listUsersCmd = &cobra.Command{ fmt.Sprintf("Failed to render pterm table: %s", err), output, ) - - return } }, } var renameUserCmd = &cobra.Command{ - Use: "rename OLD_NAME NEW_NAME", + Use: "rename", Short: "Renames a user", Aliases: []string{"mv"}, - Args: func(cmd *cobra.Command, args []string) error { - expectedArguments := 2 - if len(args) < expectedArguments { - return errMissingParameter - } - - return nil - }, Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - ctx, client, conn, cancel := getHeadscaleCLIClient() + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() defer cancel() defer conn.Close() - request := &v1.RenameUserRequest{ - OldName: args[0], - NewName: args[1], + id, username := usernameAndIDFromFlag(cmd) + listReq := &v1.ListUsersRequest{ + Name: username, + Id: id, } - response, err := client.RenameUser(ctx, request) + users, err := client.ListUsers(ctx, listReq) if err != nil { ErrorOutput( err, - fmt.Sprintf( - "Cannot rename user: %s", - status.Convert(err).Message(), - ), + "Error: "+status.Convert(err).Message(), output, ) + } - return + if len(users.GetUsers()) != 1 { + err := errors.New("Unable to determine user to delete, query returned multiple users, use ID") + ErrorOutput( + err, + "Error: "+status.Convert(err).Message(), + output, + ) + } + + newName, _ := cmd.Flags().GetString("new-name") + + renameReq := &v1.RenameUserRequest{ + OldId: id, + NewName: newName, + } + + response, err := client.RenameUser(ctx, renameReq) + if err != nil { + ErrorOutput( + err, + "Cannot rename user: "+status.Convert(err).Message(), + output, + ) } SuccessOutput(response.GetUser(), "User renamed", output) diff --git a/cmd/headscale/cli/utils.go b/cmd/headscale/cli/utils.go index a193d17d..0d0025d3 100644 --- a/cmd/headscale/cli/utils.go +++ b/cmd/headscale/cli/utils.go @@ -6,11 +6,9 @@ import ( "encoding/json" "fmt" "os" - "reflect" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol" - "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -25,40 +23,25 @@ const ( SocketWritePermissions = 0o666 ) -func getHeadscaleApp() (*hscontrol.Headscale, error) { - cfg, err := types.GetHeadscaleConfig() +func newHeadscaleServerWithConfig() (*hscontrol.Headscale, error) { + cfg, err := types.LoadServerConfig() if err != nil { return nil, fmt.Errorf( - "failed to load configuration while creating headscale instance: %w", + "loading configuration: %w", err, ) } app, err := hscontrol.NewHeadscale(cfg) if err != nil { - return nil, err - } - - // We are doing this here, as in the future could be cool to have it also hot-reload - - if cfg.ACL.PolicyPath != "" { - aclPath := util.AbsolutePathFromConfigPath(cfg.ACL.PolicyPath) - pol, err := policy.LoadACLPolicyFromPath(aclPath) - if err != nil { - log.Fatal(). - Str("path", aclPath). - Err(err). - Msg("Could not load the ACL policy") - } - - app.ACLPolicy = pol + return nil, fmt.Errorf("creating new headscale: %w", err) } return app, nil } -func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { - cfg, err := types.GetHeadscaleConfig() +func newHeadscaleCLIWithConfig() (context.Context, v1.HeadscaleServiceClient, *grpc.ClientConn, context.CancelFunc) { + cfg, err := types.LoadCLIConfig() if err != nil { log.Fatal(). Err(err). @@ -89,7 +72,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc. // Try to give the user better feedback if we cannot write to the headscale // socket. - socket, err := os.OpenFile(cfg.UnixSocket, os.O_WRONLY, SocketWritePermissions) //nolint + socket, err := os.OpenFile(cfg.UnixSocket, os.O_WRONLY, SocketWritePermissions) // nolint if err != nil { if os.IsPermission(err) { log.Fatal(). @@ -147,7 +130,7 @@ func getHeadscaleCLIClient() (context.Context, v1.HeadscaleServiceClient, *grpc. return ctx, client, conn, cancel } -func SuccessOutput(result interface{}, override string, outputFormat string) { +func output(result any, override string, outputFormat string) string { var jsonBytes []byte var err error switch outputFormat { @@ -167,22 +150,34 @@ func SuccessOutput(result interface{}, override string, outputFormat string) { log.Fatal().Err(err).Msg("failed to unmarshal output") } default: - //nolint - fmt.Println(override) - - return + // nolint + return override } - //nolint - fmt.Println(string(jsonBytes)) + return string(jsonBytes) } +// SuccessOutput prints the result to stdout and exits with status code 0. +func SuccessOutput(result any, override string, outputFormat string) { + fmt.Println(output(result, override, outputFormat)) + os.Exit(0) +} + +// ErrorOutput prints an error message to stderr and exits with status code 1. func ErrorOutput(errResult error, override string, outputFormat string) { type errOutput struct { Error string `json:"error"` } - SuccessOutput(errOutput{errResult.Error()}, override, outputFormat) + var errorMessage string + if errResult != nil { + errorMessage = errResult.Error() + } else { + errorMessage = override + } + + fmt.Fprintf(os.Stderr, "%s\n", output(errOutput{errorMessage}, override, outputFormat)) + os.Exit(1) } func HasMachineOutputFlag() bool { @@ -212,13 +207,3 @@ func (t tokenAuth) GetRequestMetadata( func (tokenAuth) RequireTransportSecurity() bool { return true } - -func contains[T string](ts []T, t T) bool { - for _, v := range ts { - if reflect.DeepEqual(v, t) { - return true - } - } - - return false -} diff --git a/cmd/headscale/cli/version.go b/cmd/headscale/cli/version.go index 2b440af3..df8a0be4 100644 --- a/cmd/headscale/cli/version.go +++ b/cmd/headscale/cli/version.go @@ -1,13 +1,13 @@ package cli import ( + "github.com/juanfont/headscale/hscontrol/types" "github.com/spf13/cobra" ) -var Version = "dev" - func init() { rootCmd.AddCommand(versionCmd) + versionCmd.Flags().StringP("output", "o", "", "Output format. Empty for human-readable, 'json', 'json-line' or 'yaml'") } var versionCmd = &cobra.Command{ @@ -16,6 +16,9 @@ var versionCmd = &cobra.Command{ Long: "The version of headscale.", Run: func(cmd *cobra.Command, args []string) { output, _ := cmd.Flags().GetString("output") - SuccessOutput(map[string]string{"version": Version}, Version, output) + + info := types.GetVersionInfo() + + SuccessOutput(info, info.String(), output) }, } diff --git a/cmd/headscale/headscale.go b/cmd/headscale/headscale.go index dfaf512f..fa17bf6d 100644 --- a/cmd/headscale/headscale.go +++ b/cmd/headscale/headscale.go @@ -4,27 +4,13 @@ import ( "os" "time" - "github.com/efekarakus/termcolor" + "github.com/jagottsicher/termcolor" "github.com/juanfont/headscale/cmd/headscale/cli" - "github.com/pkg/profile" "github.com/rs/zerolog" "github.com/rs/zerolog/log" ) func main() { - if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { - if profilePath, ok := os.LookupEnv("HEADSCALE_PROFILING_PATH"); ok { - err := os.MkdirAll(profilePath, os.ModePerm) - if err != nil { - log.Fatal().Err(err).Msg("failed to create profiling directory") - } - - defer profile.Start(profile.ProfilePath(profilePath)).Stop() - } else { - defer profile.Start().Stop() - } - } - var colors bool switch l := termcolor.SupportLevel(os.Stderr); l { case termcolor.Level16M: diff --git a/cmd/headscale/headscale_test.go b/cmd/headscale/headscale_test.go index 897e2537..2a9fbce6 100644 --- a/cmd/headscale/headscale_test.go +++ b/cmd/headscale/headscale_test.go @@ -4,40 +4,22 @@ import ( "io/fs" "os" "path/filepath" - "strings" "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/spf13/viper" - "gopkg.in/check.v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&Suite{}) - -type Suite struct{} - -func (s *Suite) SetUpSuite(c *check.C) { -} - -func (s *Suite) TearDownSuite(c *check.C) { -} - -func (*Suite) TestConfigFileLoading(c *check.C) { +func TestConfigFileLoading(t *testing.T) { tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - c.Fatal(err) - } + require.NoError(t, err) defer os.RemoveAll(tmpDir) path, err := os.Getwd() - if err != nil { - c.Fatal(err) - } + require.NoError(t, err) cfgFile := filepath.Join(tmpDir, "config.yaml") @@ -46,162 +28,54 @@ func (*Suite) TestConfigFileLoading(c *check.C) { filepath.Clean(path+"/../../config-example.yaml"), cfgFile, ) - if err != nil { - c.Fatal(err) - } + require.NoError(t, err) // Load example config, it should load without validation errors err = types.LoadConfig(cfgFile, true) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Test that config file was interpreted correctly - c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") - c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") - c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") - c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") - c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") - c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") - c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") - c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") - c.Assert( - util.GetFileMode("unix_socket_permission"), - check.Equals, - fs.FileMode(0o770), - ) - c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) + assert.Equal(t, "http://127.0.0.1:8080", viper.GetString("server_url")) + assert.Equal(t, "127.0.0.1:8080", viper.GetString("listen_addr")) + assert.Equal(t, "127.0.0.1:9090", viper.GetString("metrics_listen_addr")) + assert.Equal(t, "sqlite", viper.GetString("database.type")) + assert.Equal(t, "/var/lib/headscale/db.sqlite", viper.GetString("database.sqlite.path")) + assert.Empty(t, viper.GetString("tls_letsencrypt_hostname")) + assert.Equal(t, ":http", viper.GetString("tls_letsencrypt_listen")) + assert.Equal(t, "HTTP-01", viper.GetString("tls_letsencrypt_challenge_type")) + assert.Equal(t, fs.FileMode(0o770), util.GetFileMode("unix_socket_permission")) + assert.False(t, viper.GetBool("logtail.enabled")) } -func (*Suite) TestConfigLoading(c *check.C) { +func TestConfigLoading(t *testing.T) { tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - c.Fatal(err) - } + require.NoError(t, err) defer os.RemoveAll(tmpDir) path, err := os.Getwd() - if err != nil { - c.Fatal(err) - } + require.NoError(t, err) // Symlink the example config file err = os.Symlink( filepath.Clean(path+"/../../config-example.yaml"), filepath.Join(tmpDir, "config.yaml"), ) - if err != nil { - c.Fatal(err) - } + require.NoError(t, err) // Load example config, it should load without validation errors err = types.LoadConfig(tmpDir, false) - c.Assert(err, check.IsNil) + require.NoError(t, err) // Test that config file was interpreted correctly - c.Assert(viper.GetString("server_url"), check.Equals, "http://127.0.0.1:8080") - c.Assert(viper.GetString("listen_addr"), check.Equals, "127.0.0.1:8080") - c.Assert(viper.GetString("metrics_listen_addr"), check.Equals, "127.0.0.1:9090") - c.Assert(viper.GetString("db_type"), check.Equals, "sqlite3") - c.Assert(viper.GetString("db_path"), check.Equals, "/var/lib/headscale/db.sqlite") - c.Assert(viper.GetString("tls_letsencrypt_hostname"), check.Equals, "") - c.Assert(viper.GetString("tls_letsencrypt_listen"), check.Equals, ":http") - c.Assert(viper.GetString("tls_letsencrypt_challenge_type"), check.Equals, "HTTP-01") - c.Assert(viper.GetStringSlice("dns_config.nameservers")[0], check.Equals, "1.1.1.1") - c.Assert( - util.GetFileMode("unix_socket_permission"), - check.Equals, - fs.FileMode(0o770), - ) - c.Assert(viper.GetBool("logtail.enabled"), check.Equals, false) - c.Assert(viper.GetBool("randomize_client_port"), check.Equals, false) -} - -func (*Suite) TestDNSConfigLoading(c *check.C) { - tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - c.Fatal(err) - } - defer os.RemoveAll(tmpDir) - - path, err := os.Getwd() - if err != nil { - c.Fatal(err) - } - - // Symlink the example config file - err = os.Symlink( - filepath.Clean(path+"/../../config-example.yaml"), - filepath.Join(tmpDir, "config.yaml"), - ) - if err != nil { - c.Fatal(err) - } - - // Load example config, it should load without validation errors - err = types.LoadConfig(tmpDir, false) - c.Assert(err, check.IsNil) - - dnsConfig, baseDomain := types.GetDNSConfig() - - c.Assert(dnsConfig.Nameservers[0].String(), check.Equals, "1.1.1.1") - c.Assert(dnsConfig.Resolvers[0].Addr, check.Equals, "1.1.1.1") - c.Assert(dnsConfig.Proxied, check.Equals, true) - c.Assert(baseDomain, check.Equals, "example.com") -} - -func writeConfig(c *check.C, tmpDir string, configYaml []byte) { - // Populate a custom config file - configFile := filepath.Join(tmpDir, "config.yaml") - err := os.WriteFile(configFile, configYaml, 0o600) - if err != nil { - c.Fatalf("Couldn't write file %s", configFile) - } -} - -func (*Suite) TestTLSConfigValidation(c *check.C) { - tmpDir, err := os.MkdirTemp("", "headscale") - if err != nil { - c.Fatal(err) - } - // defer os.RemoveAll(tmpDir) - configYaml := []byte(`--- -tls_letsencrypt_hostname: example.com -tls_letsencrypt_challenge_type: "" -tls_cert_path: abc.pem -noise: - private_key_path: noise_private.key`) - writeConfig(c, tmpDir, configYaml) - - // Check configuration validation errors (1) - err = types.LoadConfig(tmpDir, false) - c.Assert(err, check.NotNil) - // check.Matches can not handle multiline strings - tmp := strings.ReplaceAll(err.Error(), "\n", "***") - c.Assert( - tmp, - check.Matches, - ".*Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both.*", - ) - c.Assert( - tmp, - check.Matches, - ".*Fatal config error: the only supported values for tls_letsencrypt_challenge_type are.*", - ) - c.Assert( - tmp, - check.Matches, - ".*Fatal config error: server_url must start with https:// or http://.*", - ) - - // Check configuration validation errors (2) - configYaml = []byte(`--- -noise: - private_key_path: noise_private.key -server_url: http://127.0.0.1:8080 -tls_letsencrypt_hostname: example.com -tls_letsencrypt_challenge_type: TLS-ALPN-01 -`) - writeConfig(c, tmpDir, configYaml) - err = types.LoadConfig(tmpDir, false) - c.Assert(err, check.IsNil) + assert.Equal(t, "http://127.0.0.1:8080", viper.GetString("server_url")) + assert.Equal(t, "127.0.0.1:8080", viper.GetString("listen_addr")) + assert.Equal(t, "127.0.0.1:9090", viper.GetString("metrics_listen_addr")) + assert.Equal(t, "sqlite", viper.GetString("database.type")) + assert.Equal(t, "/var/lib/headscale/db.sqlite", viper.GetString("database.sqlite.path")) + assert.Empty(t, viper.GetString("tls_letsencrypt_hostname")) + assert.Equal(t, ":http", viper.GetString("tls_letsencrypt_listen")) + assert.Equal(t, "HTTP-01", viper.GetString("tls_letsencrypt_challenge_type")) + assert.Equal(t, fs.FileMode(0o770), util.GetFileMode("unix_socket_permission")) + assert.False(t, viper.GetBool("logtail.enabled")) + assert.False(t, viper.GetBool("randomize_client_port")) } diff --git a/cmd/hi/README.md b/cmd/hi/README.md new file mode 100644 index 00000000..17324219 --- /dev/null +++ b/cmd/hi/README.md @@ -0,0 +1,6 @@ +# hi + +hi (headscale integration runner) is an entirely "vibe coded" wrapper around our +[integration test suite](../integration). It essentially runs the docker +commands for you with some added benefits of extracting resources like logs and +databases. diff --git a/cmd/hi/cleanup.go b/cmd/hi/cleanup.go new file mode 100644 index 00000000..7c5b5214 --- /dev/null +++ b/cmd/hi/cleanup.go @@ -0,0 +1,426 @@ +package main + +import ( + "context" + "fmt" + "log" + "os" + "path/filepath" + "strings" + "time" + + "github.com/cenkalti/backoff/v5" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/client" + "github.com/docker/docker/errdefs" +) + +// cleanupBeforeTest performs cleanup operations before running tests. +// Only removes stale (stopped/exited) test containers to avoid interfering with concurrent test runs. +func cleanupBeforeTest(ctx context.Context) error { + err := cleanupStaleTestContainers(ctx) + if err != nil { + return fmt.Errorf("failed to clean stale test containers: %w", err) + } + + if err := pruneDockerNetworks(ctx); err != nil { + return fmt.Errorf("failed to prune networks: %w", err) + } + + return nil +} + +// cleanupAfterTest removes the test container and all associated integration test containers for the run. +func cleanupAfterTest(ctx context.Context, cli *client.Client, containerID, runID string) error { + // Remove the main test container + err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{ + Force: true, + }) + if err != nil { + return fmt.Errorf("failed to remove test container: %w", err) + } + + // Clean up integration test containers for this run only + if runID != "" { + err := killTestContainersByRunID(ctx, runID) + if err != nil { + return fmt.Errorf("failed to clean up containers for run %s: %w", runID, err) + } + } + + return nil +} + +// killTestContainers terminates and removes all test containers. +func killTestContainers(ctx context.Context) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + containers, err := cli.ContainerList(ctx, container.ListOptions{ + All: true, + }) + if err != nil { + return fmt.Errorf("failed to list containers: %w", err) + } + + removed := 0 + for _, cont := range containers { + shouldRemove := false + for _, name := range cont.Names { + if strings.Contains(name, "headscale-test-suite") || + strings.Contains(name, "hs-") || + strings.Contains(name, "ts-") || + strings.Contains(name, "derp-") { + shouldRemove = true + break + } + } + + if shouldRemove { + // First kill the container if it's running + if cont.State == "running" { + _ = cli.ContainerKill(ctx, cont.ID, "KILL") + } + + // Then remove the container with retry logic + if removeContainerWithRetry(ctx, cli, cont.ID) { + removed++ + } + } + } + + if removed > 0 { + fmt.Printf("Removed %d test containers\n", removed) + } else { + fmt.Println("No test containers found to remove") + } + + return nil +} + +// killTestContainersByRunID terminates and removes all test containers for a specific run ID. +// This function filters containers by the hi.run-id label to only affect containers +// belonging to the specified test run, leaving other concurrent test runs untouched. +func killTestContainersByRunID(ctx context.Context, runID string) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + // Filter containers by hi.run-id label + containers, err := cli.ContainerList(ctx, container.ListOptions{ + All: true, + Filters: filters.NewArgs( + filters.Arg("label", "hi.run-id="+runID), + ), + }) + if err != nil { + return fmt.Errorf("failed to list containers for run %s: %w", runID, err) + } + + removed := 0 + + for _, cont := range containers { + // Kill the container if it's running + if cont.State == "running" { + _ = cli.ContainerKill(ctx, cont.ID, "KILL") + } + + // Remove the container with retry logic + if removeContainerWithRetry(ctx, cli, cont.ID) { + removed++ + } + } + + if removed > 0 { + fmt.Printf("Removed %d containers for run ID %s\n", removed, runID) + } + + return nil +} + +// cleanupStaleTestContainers removes stopped/exited test containers without affecting running tests. +// This is useful for cleaning up leftover containers from previous crashed or interrupted test runs +// without interfering with currently running concurrent tests. +func cleanupStaleTestContainers(ctx context.Context) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + // Only get stopped/exited containers + containers, err := cli.ContainerList(ctx, container.ListOptions{ + All: true, + Filters: filters.NewArgs( + filters.Arg("status", "exited"), + filters.Arg("status", "dead"), + ), + }) + if err != nil { + return fmt.Errorf("failed to list stopped containers: %w", err) + } + + removed := 0 + + for _, cont := range containers { + // Only remove containers that look like test containers + shouldRemove := false + + for _, name := range cont.Names { + if strings.Contains(name, "headscale-test-suite") || + strings.Contains(name, "hs-") || + strings.Contains(name, "ts-") || + strings.Contains(name, "derp-") { + shouldRemove = true + break + } + } + + if shouldRemove { + if removeContainerWithRetry(ctx, cli, cont.ID) { + removed++ + } + } + } + + if removed > 0 { + fmt.Printf("Removed %d stale test containers\n", removed) + } + + return nil +} + +const ( + containerRemoveInitialInterval = 100 * time.Millisecond + containerRemoveMaxElapsedTime = 2 * time.Second +) + +// removeContainerWithRetry attempts to remove a container with exponential backoff retry logic. +func removeContainerWithRetry(ctx context.Context, cli *client.Client, containerID string) bool { + expBackoff := backoff.NewExponentialBackOff() + expBackoff.InitialInterval = containerRemoveInitialInterval + + _, err := backoff.Retry(ctx, func() (struct{}, error) { + err := cli.ContainerRemove(ctx, containerID, container.RemoveOptions{ + Force: true, + }) + if err != nil { + return struct{}{}, err + } + + return struct{}{}, nil + }, backoff.WithBackOff(expBackoff), backoff.WithMaxElapsedTime(containerRemoveMaxElapsedTime)) + + return err == nil +} + +// pruneDockerNetworks removes unused Docker networks. +func pruneDockerNetworks(ctx context.Context) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + report, err := cli.NetworksPrune(ctx, filters.Args{}) + if err != nil { + return fmt.Errorf("failed to prune networks: %w", err) + } + + if len(report.NetworksDeleted) > 0 { + fmt.Printf("Removed %d unused networks\n", len(report.NetworksDeleted)) + } else { + fmt.Println("No unused networks found to remove") + } + + return nil +} + +// cleanOldImages removes test-related and old dangling Docker images. +func cleanOldImages(ctx context.Context) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + images, err := cli.ImageList(ctx, image.ListOptions{ + All: true, + }) + if err != nil { + return fmt.Errorf("failed to list images: %w", err) + } + + removed := 0 + for _, img := range images { + shouldRemove := false + for _, tag := range img.RepoTags { + if strings.Contains(tag, "hs-") || + strings.Contains(tag, "headscale-integration") || + strings.Contains(tag, "tailscale") { + shouldRemove = true + break + } + } + + if len(img.RepoTags) == 0 && time.Unix(img.Created, 0).Before(time.Now().Add(-7*24*time.Hour)) { + shouldRemove = true + } + + if shouldRemove { + _, err := cli.ImageRemove(ctx, img.ID, image.RemoveOptions{ + Force: true, + }) + if err == nil { + removed++ + } + } + } + + if removed > 0 { + fmt.Printf("Removed %d test images\n", removed) + } else { + fmt.Println("No test images found to remove") + } + + return nil +} + +// cleanCacheVolume removes the Docker volume used for Go module cache. +func cleanCacheVolume(ctx context.Context) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + volumeName := "hs-integration-go-cache" + err = cli.VolumeRemove(ctx, volumeName, true) + if err != nil { + if errdefs.IsNotFound(err) { + fmt.Printf("Go module cache volume not found: %s\n", volumeName) + } else if errdefs.IsConflict(err) { + fmt.Printf("Go module cache volume is in use and cannot be removed: %s\n", volumeName) + } else { + fmt.Printf("Failed to remove Go module cache volume %s: %v\n", volumeName, err) + } + } else { + fmt.Printf("Removed Go module cache volume: %s\n", volumeName) + } + + return nil +} + +// cleanupSuccessfulTestArtifacts removes artifacts from successful test runs to save disk space. +// This function removes large artifacts that are mainly useful for debugging failures: +// - Database dumps (.db files) +// - Profile data (pprof directories) +// - MapResponse data (mapresponses directories) +// - Prometheus metrics files +// +// It preserves: +// - Log files (.log) which are small and useful for verification. +func cleanupSuccessfulTestArtifacts(logsDir string, verbose bool) error { + entries, err := os.ReadDir(logsDir) + if err != nil { + return fmt.Errorf("failed to read logs directory: %w", err) + } + + var ( + removedFiles, removedDirs int + totalSize int64 + ) + + for _, entry := range entries { + name := entry.Name() + fullPath := filepath.Join(logsDir, name) + + if entry.IsDir() { + // Remove pprof and mapresponses directories (typically large) + // These directories contain artifacts from all containers in the test run + if name == "pprof" || name == "mapresponses" { + size, sizeErr := getDirSize(fullPath) + if sizeErr == nil { + totalSize += size + } + + err := os.RemoveAll(fullPath) + if err != nil { + if verbose { + log.Printf("Warning: failed to remove directory %s: %v", name, err) + } + } else { + removedDirs++ + + if verbose { + log.Printf("Removed directory: %s/", name) + } + } + } + } else { + // Only process test-related files (headscale and tailscale) + if !strings.HasPrefix(name, "hs-") && !strings.HasPrefix(name, "ts-") { + continue + } + + // Remove database, metrics, and status files, but keep logs + shouldRemove := strings.HasSuffix(name, ".db") || + strings.HasSuffix(name, "_metrics.txt") || + strings.HasSuffix(name, "_status.json") + + if shouldRemove { + info, infoErr := entry.Info() + if infoErr == nil { + totalSize += info.Size() + } + + err := os.Remove(fullPath) + if err != nil { + if verbose { + log.Printf("Warning: failed to remove file %s: %v", name, err) + } + } else { + removedFiles++ + + if verbose { + log.Printf("Removed file: %s", name) + } + } + } + } + } + + if removedFiles > 0 || removedDirs > 0 { + const bytesPerMB = 1024 * 1024 + log.Printf("Cleaned up %d files and %d directories (freed ~%.2f MB)", + removedFiles, removedDirs, float64(totalSize)/bytesPerMB) + } + + return nil +} + +// getDirSize calculates the total size of a directory. +func getDirSize(path string) (int64, error) { + var size int64 + + err := filepath.Walk(path, func(_ string, info os.FileInfo, err error) error { + if err != nil { + return err + } + + if !info.IsDir() { + size += info.Size() + } + + return nil + }) + + return size, err +} diff --git a/cmd/hi/docker.go b/cmd/hi/docker.go new file mode 100644 index 00000000..a6b94b25 --- /dev/null +++ b/cmd/hi/docker.go @@ -0,0 +1,767 @@ +package main + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "log" + "os" + "os/exec" + "path/filepath" + "strings" + "time" + + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/image" + "github.com/docker/docker/api/types/mount" + "github.com/docker/docker/client" + "github.com/docker/docker/pkg/stdcopy" + "github.com/juanfont/headscale/integration/dockertestutil" +) + +var ( + ErrTestFailed = errors.New("test failed") + ErrUnexpectedContainerWait = errors.New("unexpected end of container wait") + ErrNoDockerContext = errors.New("no docker context found") +) + +// runTestContainer executes integration tests in a Docker container. +func runTestContainer(ctx context.Context, config *RunConfig) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + runID := dockertestutil.GenerateRunID() + containerName := "headscale-test-suite-" + runID + logsDir := filepath.Join(config.LogsDir, runID) + + if config.Verbose { + log.Printf("Run ID: %s", runID) + log.Printf("Container name: %s", containerName) + log.Printf("Logs directory: %s", logsDir) + } + + absLogsDir, err := filepath.Abs(logsDir) + if err != nil { + return fmt.Errorf("failed to get absolute path for logs directory: %w", err) + } + + const dirPerm = 0o755 + if err := os.MkdirAll(absLogsDir, dirPerm); err != nil { + return fmt.Errorf("failed to create logs directory: %w", err) + } + + if config.CleanBefore { + if config.Verbose { + log.Printf("Running pre-test cleanup...") + } + if err := cleanupBeforeTest(ctx); err != nil && config.Verbose { + log.Printf("Warning: pre-test cleanup failed: %v", err) + } + } + + goTestCmd := buildGoTestCommand(config) + if config.Verbose { + log.Printf("Command: %s", strings.Join(goTestCmd, " ")) + } + + imageName := "golang:" + config.GoVersion + if err := ensureImageAvailable(ctx, cli, imageName, config.Verbose); err != nil { + return fmt.Errorf("failed to ensure image availability: %w", err) + } + + resp, err := createGoTestContainer(ctx, cli, config, containerName, absLogsDir, goTestCmd) + if err != nil { + return fmt.Errorf("failed to create container: %w", err) + } + + if config.Verbose { + log.Printf("Created container: %s", resp.ID) + } + + if err := cli.ContainerStart(ctx, resp.ID, container.StartOptions{}); err != nil { + return fmt.Errorf("failed to start container: %w", err) + } + + log.Printf("Starting test: %s", config.TestPattern) + log.Printf("Run ID: %s", runID) + log.Printf("Monitor with: docker logs -f %s", containerName) + log.Printf("Logs directory: %s", logsDir) + + // Start stats collection for container resource monitoring (if enabled) + var statsCollector *StatsCollector + if config.Stats { + var err error + statsCollector, err = NewStatsCollector() + if err != nil { + if config.Verbose { + log.Printf("Warning: failed to create stats collector: %v", err) + } + statsCollector = nil + } + + if statsCollector != nil { + defer statsCollector.Close() + + // Start stats collection immediately - no need for complex retry logic + // The new implementation monitors Docker events and will catch containers as they start + if err := statsCollector.StartCollection(ctx, runID, config.Verbose); err != nil { + if config.Verbose { + log.Printf("Warning: failed to start stats collection: %v", err) + } + } + defer statsCollector.StopCollection() + } + } + + exitCode, err := streamAndWait(ctx, cli, resp.ID) + + // Ensure all containers have finished and logs are flushed before extracting artifacts + if waitErr := waitForContainerFinalization(ctx, cli, resp.ID, config.Verbose); waitErr != nil && config.Verbose { + log.Printf("Warning: failed to wait for container finalization: %v", waitErr) + } + + // Extract artifacts from test containers before cleanup + if err := extractArtifactsFromContainers(ctx, resp.ID, logsDir, config.Verbose); err != nil && config.Verbose { + log.Printf("Warning: failed to extract artifacts from containers: %v", err) + } + + // Always list control files regardless of test outcome + listControlFiles(logsDir) + + // Print stats summary and check memory limits if enabled + if config.Stats && statsCollector != nil { + violations := statsCollector.PrintSummaryAndCheckLimits(config.HSMemoryLimit, config.TSMemoryLimit) + if len(violations) > 0 { + log.Printf("MEMORY LIMIT VIOLATIONS DETECTED:") + log.Printf("=================================") + for _, violation := range violations { + log.Printf("Container %s exceeded memory limit: %.1f MB > %.1f MB", + violation.ContainerName, violation.MaxMemoryMB, violation.LimitMB) + } + + return fmt.Errorf("test failed: %d container(s) exceeded memory limits", len(violations)) + } + } + + shouldCleanup := config.CleanAfter && (!config.KeepOnFailure || exitCode == 0) + if shouldCleanup { + if config.Verbose { + log.Printf("Running post-test cleanup for run %s...", runID) + } + + cleanErr := cleanupAfterTest(ctx, cli, resp.ID, runID) + + if cleanErr != nil && config.Verbose { + log.Printf("Warning: post-test cleanup failed: %v", cleanErr) + } + + // Clean up artifacts from successful tests to save disk space in CI + if exitCode == 0 { + if config.Verbose { + log.Printf("Test succeeded, cleaning up artifacts to save disk space...") + } + + cleanErr := cleanupSuccessfulTestArtifacts(logsDir, config.Verbose) + + if cleanErr != nil && config.Verbose { + log.Printf("Warning: artifact cleanup failed: %v", cleanErr) + } + } + } + + if err != nil { + return fmt.Errorf("test execution failed: %w", err) + } + + if exitCode != 0 { + return fmt.Errorf("%w: exit code %d", ErrTestFailed, exitCode) + } + + log.Printf("Test completed successfully!") + + return nil +} + +// buildGoTestCommand constructs the go test command arguments. +func buildGoTestCommand(config *RunConfig) []string { + cmd := []string{"go", "test", "./..."} + + if config.TestPattern != "" { + cmd = append(cmd, "-run", config.TestPattern) + } + + if config.FailFast { + cmd = append(cmd, "-failfast") + } + + cmd = append(cmd, "-timeout", config.Timeout.String()) + cmd = append(cmd, "-v") + + return cmd +} + +// createGoTestContainer creates a Docker container configured for running integration tests. +func createGoTestContainer(ctx context.Context, cli *client.Client, config *RunConfig, containerName, logsDir string, goTestCmd []string) (container.CreateResponse, error) { + pwd, err := os.Getwd() + if err != nil { + return container.CreateResponse{}, fmt.Errorf("failed to get working directory: %w", err) + } + + projectRoot := findProjectRoot(pwd) + + runID := dockertestutil.ExtractRunIDFromContainerName(containerName) + + env := []string{ + fmt.Sprintf("HEADSCALE_INTEGRATION_POSTGRES=%d", boolToInt(config.UsePostgres)), + "HEADSCALE_INTEGRATION_RUN_ID=" + runID, + } + + // Pass through CI environment variable for CI detection + if ci := os.Getenv("CI"); ci != "" { + env = append(env, "CI="+ci) + } + + // Pass through all HEADSCALE_INTEGRATION_* environment variables + for _, e := range os.Environ() { + if strings.HasPrefix(e, "HEADSCALE_INTEGRATION_") { + // Skip the ones we already set explicitly + if strings.HasPrefix(e, "HEADSCALE_INTEGRATION_POSTGRES=") || + strings.HasPrefix(e, "HEADSCALE_INTEGRATION_RUN_ID=") { + continue + } + + env = append(env, e) + } + } + + // Set GOCACHE to a known location (used by both bind mount and volume cases) + env = append(env, "GOCACHE=/cache/go-build") + + containerConfig := &container.Config{ + Image: "golang:" + config.GoVersion, + Cmd: goTestCmd, + Env: env, + WorkingDir: projectRoot + "/integration", + Tty: true, + Labels: map[string]string{ + "hi.run-id": runID, + "hi.test-type": "test-runner", + }, + } + + // Get the correct Docker socket path from the current context + dockerSocketPath := getDockerSocketPath() + + if config.Verbose { + log.Printf("Using Docker socket: %s", dockerSocketPath) + } + + binds := []string{ + fmt.Sprintf("%s:%s", projectRoot, projectRoot), + dockerSocketPath + ":/var/run/docker.sock", + logsDir + ":/tmp/control", + } + + // Use bind mounts for Go cache if provided via environment variables, + // otherwise fall back to Docker volumes for local development + var mounts []mount.Mount + + goCache := os.Getenv("HEADSCALE_INTEGRATION_GO_CACHE") + goBuildCache := os.Getenv("HEADSCALE_INTEGRATION_GO_BUILD_CACHE") + + if goCache != "" { + binds = append(binds, goCache+":/go") + } else { + mounts = append(mounts, mount.Mount{ + Type: mount.TypeVolume, + Source: "hs-integration-go-cache", + Target: "/go", + }) + } + + if goBuildCache != "" { + binds = append(binds, goBuildCache+":/cache/go-build") + } else { + mounts = append(mounts, mount.Mount{ + Type: mount.TypeVolume, + Source: "hs-integration-go-build-cache", + Target: "/cache/go-build", + }) + } + + hostConfig := &container.HostConfig{ + AutoRemove: false, // We'll remove manually for better control + Binds: binds, + Mounts: mounts, + } + + return cli.ContainerCreate(ctx, containerConfig, hostConfig, nil, nil, containerName) +} + +// streamAndWait streams container output and waits for completion. +func streamAndWait(ctx context.Context, cli *client.Client, containerID string) (int, error) { + out, err := cli.ContainerLogs(ctx, containerID, container.LogsOptions{ + ShowStdout: true, + ShowStderr: true, + Follow: true, + }) + if err != nil { + return -1, fmt.Errorf("failed to get container logs: %w", err) + } + defer out.Close() + + go func() { + _, _ = io.Copy(os.Stdout, out) + }() + + statusCh, errCh := cli.ContainerWait(ctx, containerID, container.WaitConditionNotRunning) + select { + case err := <-errCh: + if err != nil { + return -1, fmt.Errorf("error waiting for container: %w", err) + } + case status := <-statusCh: + return int(status.StatusCode), nil + } + + return -1, ErrUnexpectedContainerWait +} + +// waitForContainerFinalization ensures all test containers have properly finished and flushed their output. +func waitForContainerFinalization(ctx context.Context, cli *client.Client, testContainerID string, verbose bool) error { + // First, get all related test containers + containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) + if err != nil { + return fmt.Errorf("failed to list containers: %w", err) + } + + testContainers := getCurrentTestContainers(containers, testContainerID, verbose) + + // Wait for all test containers to reach a final state + maxWaitTime := 10 * time.Second + checkInterval := 500 * time.Millisecond + timeout := time.After(maxWaitTime) + ticker := time.NewTicker(checkInterval) + defer ticker.Stop() + + for { + select { + case <-timeout: + if verbose { + log.Printf("Timeout waiting for container finalization, proceeding with artifact extraction") + } + return nil + case <-ticker.C: + allFinalized := true + + for _, testCont := range testContainers { + inspect, err := cli.ContainerInspect(ctx, testCont.ID) + if err != nil { + if verbose { + log.Printf("Warning: failed to inspect container %s: %v", testCont.name, err) + } + continue + } + + // Check if container is in a final state + if !isContainerFinalized(inspect.State) { + allFinalized = false + if verbose { + log.Printf("Container %s still finalizing (state: %s)", testCont.name, inspect.State.Status) + } + + break + } + } + + if allFinalized { + if verbose { + log.Printf("All test containers finalized, ready for artifact extraction") + } + return nil + } + } + } +} + +// isContainerFinalized checks if a container has reached a final state where logs are flushed. +func isContainerFinalized(state *container.State) bool { + // Container is finalized if it's not running and has a finish time + return !state.Running && state.FinishedAt != "" +} + +// findProjectRoot locates the project root by finding the directory containing go.mod. +func findProjectRoot(startPath string) string { + current := startPath + for { + if _, err := os.Stat(filepath.Join(current, "go.mod")); err == nil { + return current + } + parent := filepath.Dir(current) + if parent == current { + return startPath + } + current = parent + } +} + +// boolToInt converts a boolean to an integer for environment variables. +func boolToInt(b bool) int { + if b { + return 1 + } + return 0 +} + +// DockerContext represents Docker context information. +type DockerContext struct { + Name string `json:"Name"` + Metadata map[string]any `json:"Metadata"` + Endpoints map[string]any `json:"Endpoints"` + Current bool `json:"Current"` +} + +// createDockerClient creates a Docker client with context detection. +func createDockerClient() (*client.Client, error) { + contextInfo, err := getCurrentDockerContext() + if err != nil { + return client.NewClientWithOpts(client.FromEnv, client.WithAPIVersionNegotiation()) + } + + var clientOpts []client.Opt + clientOpts = append(clientOpts, client.WithAPIVersionNegotiation()) + + if contextInfo != nil { + if endpoints, ok := contextInfo.Endpoints["docker"]; ok { + if endpointMap, ok := endpoints.(map[string]any); ok { + if host, ok := endpointMap["Host"].(string); ok { + if runConfig.Verbose { + log.Printf("Using Docker host from context '%s': %s", contextInfo.Name, host) + } + clientOpts = append(clientOpts, client.WithHost(host)) + } + } + } + } + + if len(clientOpts) == 1 { + clientOpts = append(clientOpts, client.FromEnv) + } + + return client.NewClientWithOpts(clientOpts...) +} + +// getCurrentDockerContext retrieves the current Docker context information. +func getCurrentDockerContext() (*DockerContext, error) { + cmd := exec.Command("docker", "context", "inspect") + output, err := cmd.Output() + if err != nil { + return nil, fmt.Errorf("failed to get docker context: %w", err) + } + + var contexts []DockerContext + if err := json.Unmarshal(output, &contexts); err != nil { + return nil, fmt.Errorf("failed to parse docker context: %w", err) + } + + if len(contexts) > 0 { + return &contexts[0], nil + } + + return nil, ErrNoDockerContext +} + +// getDockerSocketPath returns the correct Docker socket path for the current context. +func getDockerSocketPath() string { + // Always use the default socket path for mounting since Docker handles + // the translation to the actual socket (e.g., colima socket) internally + return "/var/run/docker.sock" +} + +// checkImageAvailableLocally checks if the specified Docker image is available locally. +func checkImageAvailableLocally(ctx context.Context, cli *client.Client, imageName string) (bool, error) { + _, _, err := cli.ImageInspectWithRaw(ctx, imageName) + if err != nil { + if client.IsErrNotFound(err) { + return false, nil + } + return false, fmt.Errorf("failed to inspect image %s: %w", imageName, err) + } + + return true, nil +} + +// ensureImageAvailable checks if the image is available locally first, then pulls if needed. +func ensureImageAvailable(ctx context.Context, cli *client.Client, imageName string, verbose bool) error { + // First check if image is available locally + available, err := checkImageAvailableLocally(ctx, cli, imageName) + if err != nil { + return fmt.Errorf("failed to check local image availability: %w", err) + } + + if available { + if verbose { + log.Printf("Image %s is available locally", imageName) + } + return nil + } + + // Image not available locally, try to pull it + if verbose { + log.Printf("Image %s not found locally, pulling...", imageName) + } + + reader, err := cli.ImagePull(ctx, imageName, image.PullOptions{}) + if err != nil { + return fmt.Errorf("failed to pull image %s: %w", imageName, err) + } + defer reader.Close() + + if verbose { + _, err = io.Copy(os.Stdout, reader) + if err != nil { + return fmt.Errorf("failed to read pull output: %w", err) + } + } else { + _, err = io.Copy(io.Discard, reader) + if err != nil { + return fmt.Errorf("failed to read pull output: %w", err) + } + log.Printf("Image %s pulled successfully", imageName) + } + + return nil +} + +// listControlFiles displays the headscale test artifacts created in the control logs directory. +func listControlFiles(logsDir string) { + entries, err := os.ReadDir(logsDir) + if err != nil { + log.Printf("Logs directory: %s", logsDir) + return + } + + var logFiles []string + var dataFiles []string + var dataDirs []string + + for _, entry := range entries { + name := entry.Name() + // Only show headscale (hs-*) files and directories + if !strings.HasPrefix(name, "hs-") { + continue + } + + if entry.IsDir() { + // Include directories (pprof, mapresponses) + if strings.Contains(name, "-pprof") || strings.Contains(name, "-mapresponses") { + dataDirs = append(dataDirs, name) + } + } else { + // Include files + switch { + case strings.HasSuffix(name, ".stderr.log") || strings.HasSuffix(name, ".stdout.log"): + logFiles = append(logFiles, name) + case strings.HasSuffix(name, ".db"): + dataFiles = append(dataFiles, name) + } + } + } + + log.Printf("Test artifacts saved to: %s", logsDir) + + if len(logFiles) > 0 { + log.Printf("Headscale logs:") + for _, file := range logFiles { + log.Printf(" %s", file) + } + } + + if len(dataFiles) > 0 || len(dataDirs) > 0 { + log.Printf("Headscale data:") + for _, file := range dataFiles { + log.Printf(" %s", file) + } + for _, dir := range dataDirs { + log.Printf(" %s/", dir) + } + } +} + +// extractArtifactsFromContainers collects container logs and files from the specific test run. +func extractArtifactsFromContainers(ctx context.Context, testContainerID, logsDir string, verbose bool) error { + cli, err := createDockerClient() + if err != nil { + return fmt.Errorf("failed to create Docker client: %w", err) + } + defer cli.Close() + + // List all containers + containers, err := cli.ContainerList(ctx, container.ListOptions{All: true}) + if err != nil { + return fmt.Errorf("failed to list containers: %w", err) + } + + // Get containers from the specific test run + currentTestContainers := getCurrentTestContainers(containers, testContainerID, verbose) + + extractedCount := 0 + for _, cont := range currentTestContainers { + // Extract container logs and tar files + if err := extractContainerArtifacts(ctx, cli, cont.ID, cont.name, logsDir, verbose); err != nil { + if verbose { + log.Printf("Warning: failed to extract artifacts from container %s (%s): %v", cont.name, cont.ID[:12], err) + } + } else { + if verbose { + log.Printf("Extracted artifacts from container %s (%s)", cont.name, cont.ID[:12]) + } + extractedCount++ + } + } + + if verbose && extractedCount > 0 { + log.Printf("Extracted artifacts from %d containers", extractedCount) + } + + return nil +} + +// testContainer represents a container from the current test run. +type testContainer struct { + ID string + name string +} + +// getCurrentTestContainers filters containers to only include those from the current test run. +func getCurrentTestContainers(containers []container.Summary, testContainerID string, verbose bool) []testContainer { + var testRunContainers []testContainer + + // Find the test container to get its run ID label + var runID string + for _, cont := range containers { + if cont.ID == testContainerID { + if cont.Labels != nil { + runID = cont.Labels["hi.run-id"] + } + break + } + } + + if runID == "" { + log.Printf("Error: test container %s missing required hi.run-id label", testContainerID[:12]) + return testRunContainers + } + + if verbose { + log.Printf("Looking for containers with run ID: %s", runID) + } + + // Find all containers with the same run ID + for _, cont := range containers { + for _, name := range cont.Names { + containerName := strings.TrimPrefix(name, "/") + if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") { + // Check if container has matching run ID label + if cont.Labels != nil && cont.Labels["hi.run-id"] == runID { + testRunContainers = append(testRunContainers, testContainer{ + ID: cont.ID, + name: containerName, + }) + if verbose { + log.Printf("Including container %s (run ID: %s)", containerName, runID) + } + } + + break + } + } + } + + return testRunContainers +} + +// extractContainerArtifacts saves logs and tar files from a container. +func extractContainerArtifacts(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { + // Ensure the logs directory exists + if err := os.MkdirAll(logsDir, 0o755); err != nil { + return fmt.Errorf("failed to create logs directory: %w", err) + } + + // Extract container logs + if err := extractContainerLogs(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { + return fmt.Errorf("failed to extract logs: %w", err) + } + + // Extract tar files for headscale containers only + if strings.HasPrefix(containerName, "hs-") { + if err := extractContainerFiles(ctx, cli, containerID, containerName, logsDir, verbose); err != nil { + if verbose { + log.Printf("Warning: failed to extract files from %s: %v", containerName, err) + } + // Don't fail the whole extraction if files are missing + } + } + + return nil +} + +// extractContainerLogs saves the stdout and stderr logs from a container to files. +func extractContainerLogs(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { + // Get container logs + logReader, err := cli.ContainerLogs(ctx, containerID, container.LogsOptions{ + ShowStdout: true, + ShowStderr: true, + Timestamps: false, + Follow: false, + Tail: "all", + }) + if err != nil { + return fmt.Errorf("failed to get container logs: %w", err) + } + defer logReader.Close() + + // Create log files following the headscale naming convention + stdoutPath := filepath.Join(logsDir, containerName+".stdout.log") + stderrPath := filepath.Join(logsDir, containerName+".stderr.log") + + // Create buffers to capture stdout and stderr separately + var stdoutBuf, stderrBuf bytes.Buffer + + // Demultiplex the Docker logs stream to separate stdout and stderr + _, err = stdcopy.StdCopy(&stdoutBuf, &stderrBuf, logReader) + if err != nil { + return fmt.Errorf("failed to demultiplex container logs: %w", err) + } + + // Write stdout logs + if err := os.WriteFile(stdoutPath, stdoutBuf.Bytes(), 0o644); err != nil { + return fmt.Errorf("failed to write stdout log: %w", err) + } + + // Write stderr logs + if err := os.WriteFile(stderrPath, stderrBuf.Bytes(), 0o644); err != nil { + return fmt.Errorf("failed to write stderr log: %w", err) + } + + if verbose { + log.Printf("Saved logs for %s: %s, %s", containerName, stdoutPath, stderrPath) + } + + return nil +} + +// extractContainerFiles extracts database file and directories from headscale containers. +// Note: The actual file extraction is now handled by the integration tests themselves +// via SaveProfile, SaveMapResponses, and SaveDatabase functions in hsic.go. +func extractContainerFiles(ctx context.Context, cli *client.Client, containerID, containerName, logsDir string, verbose bool) error { + // Files are now extracted directly by the integration tests + // This function is kept for potential future use or other file types + return nil +} diff --git a/cmd/hi/doctor.go b/cmd/hi/doctor.go new file mode 100644 index 00000000..8af6051f --- /dev/null +++ b/cmd/hi/doctor.go @@ -0,0 +1,374 @@ +package main + +import ( + "context" + "errors" + "fmt" + "log" + "os/exec" + "strings" +) + +var ErrSystemChecksFailed = errors.New("system checks failed") + +// DoctorResult represents the result of a single health check. +type DoctorResult struct { + Name string + Status string // "PASS", "FAIL", "WARN" + Message string + Suggestions []string +} + +// runDoctorCheck performs comprehensive pre-flight checks for integration testing. +func runDoctorCheck(ctx context.Context) error { + results := []DoctorResult{} + + // Check 1: Docker binary availability + results = append(results, checkDockerBinary()) + + // Check 2: Docker daemon connectivity + dockerResult := checkDockerDaemon(ctx) + results = append(results, dockerResult) + + // If Docker is available, run additional checks + if dockerResult.Status == "PASS" { + results = append(results, checkDockerContext(ctx)) + results = append(results, checkDockerSocket(ctx)) + results = append(results, checkGolangImage(ctx)) + } + + // Check 3: Go installation + results = append(results, checkGoInstallation()) + + // Check 4: Git repository + results = append(results, checkGitRepository()) + + // Check 5: Required files + results = append(results, checkRequiredFiles()) + + // Display results + displayDoctorResults(results) + + // Return error if any critical checks failed + for _, result := range results { + if result.Status == "FAIL" { + return fmt.Errorf("%w - see details above", ErrSystemChecksFailed) + } + } + + log.Printf("✅ All system checks passed - ready to run integration tests!") + + return nil +} + +// checkDockerBinary verifies Docker binary is available. +func checkDockerBinary() DoctorResult { + _, err := exec.LookPath("docker") + if err != nil { + return DoctorResult{ + Name: "Docker Binary", + Status: "FAIL", + Message: "Docker binary not found in PATH", + Suggestions: []string{ + "Install Docker: https://docs.docker.com/get-docker/", + "For macOS: consider using colima or Docker Desktop", + "Ensure docker is in your PATH", + }, + } + } + + return DoctorResult{ + Name: "Docker Binary", + Status: "PASS", + Message: "Docker binary found", + } +} + +// checkDockerDaemon verifies Docker daemon is running and accessible. +func checkDockerDaemon(ctx context.Context) DoctorResult { + cli, err := createDockerClient() + if err != nil { + return DoctorResult{ + Name: "Docker Daemon", + Status: "FAIL", + Message: fmt.Sprintf("Cannot create Docker client: %v", err), + Suggestions: []string{ + "Start Docker daemon/service", + "Check Docker Desktop is running (if using Docker Desktop)", + "For colima: run 'colima start'", + "Verify DOCKER_HOST environment variable if set", + }, + } + } + defer cli.Close() + + _, err = cli.Ping(ctx) + if err != nil { + return DoctorResult{ + Name: "Docker Daemon", + Status: "FAIL", + Message: fmt.Sprintf("Cannot ping Docker daemon: %v", err), + Suggestions: []string{ + "Ensure Docker daemon is running", + "Check Docker socket permissions", + "Try: docker info", + }, + } + } + + return DoctorResult{ + Name: "Docker Daemon", + Status: "PASS", + Message: "Docker daemon is running and accessible", + } +} + +// checkDockerContext verifies Docker context configuration. +func checkDockerContext(_ context.Context) DoctorResult { + contextInfo, err := getCurrentDockerContext() + if err != nil { + return DoctorResult{ + Name: "Docker Context", + Status: "WARN", + Message: "Could not detect Docker context, using default settings", + Suggestions: []string{ + "Check: docker context ls", + "Consider setting up a specific context if needed", + }, + } + } + + if contextInfo == nil { + return DoctorResult{ + Name: "Docker Context", + Status: "PASS", + Message: "Using default Docker context", + } + } + + return DoctorResult{ + Name: "Docker Context", + Status: "PASS", + Message: "Using Docker context: " + contextInfo.Name, + } +} + +// checkDockerSocket verifies Docker socket accessibility. +func checkDockerSocket(ctx context.Context) DoctorResult { + cli, err := createDockerClient() + if err != nil { + return DoctorResult{ + Name: "Docker Socket", + Status: "FAIL", + Message: fmt.Sprintf("Cannot access Docker socket: %v", err), + Suggestions: []string{ + "Check Docker socket permissions", + "Add user to docker group: sudo usermod -aG docker $USER", + "For colima: ensure socket is accessible", + }, + } + } + defer cli.Close() + + info, err := cli.Info(ctx) + if err != nil { + return DoctorResult{ + Name: "Docker Socket", + Status: "FAIL", + Message: fmt.Sprintf("Cannot get Docker info: %v", err), + Suggestions: []string{ + "Check Docker daemon status", + "Verify socket permissions", + }, + } + } + + return DoctorResult{ + Name: "Docker Socket", + Status: "PASS", + Message: fmt.Sprintf("Docker socket accessible (Server: %s)", info.ServerVersion), + } +} + +// checkGolangImage verifies the golang Docker image is available locally or can be pulled. +func checkGolangImage(ctx context.Context) DoctorResult { + cli, err := createDockerClient() + if err != nil { + return DoctorResult{ + Name: "Golang Image", + Status: "FAIL", + Message: "Cannot create Docker client for image check", + } + } + defer cli.Close() + + goVersion := detectGoVersion() + imageName := "golang:" + goVersion + + // First check if image is available locally + available, err := checkImageAvailableLocally(ctx, cli, imageName) + if err != nil { + return DoctorResult{ + Name: "Golang Image", + Status: "FAIL", + Message: fmt.Sprintf("Cannot check golang image %s: %v", imageName, err), + Suggestions: []string{ + "Check Docker daemon status", + "Try: docker images | grep golang", + }, + } + } + + if available { + return DoctorResult{ + Name: "Golang Image", + Status: "PASS", + Message: fmt.Sprintf("Golang image %s is available locally", imageName), + } + } + + // Image not available locally, try to pull it + err = ensureImageAvailable(ctx, cli, imageName, false) + if err != nil { + return DoctorResult{ + Name: "Golang Image", + Status: "FAIL", + Message: fmt.Sprintf("Golang image %s not available locally and cannot pull: %v", imageName, err), + Suggestions: []string{ + "Check internet connectivity", + "Verify Docker Hub access", + "Try: docker pull " + imageName, + "Or run tests offline if image was pulled previously", + }, + } + } + + return DoctorResult{ + Name: "Golang Image", + Status: "PASS", + Message: fmt.Sprintf("Golang image %s is now available", imageName), + } +} + +// checkGoInstallation verifies Go is installed and working. +func checkGoInstallation() DoctorResult { + _, err := exec.LookPath("go") + if err != nil { + return DoctorResult{ + Name: "Go Installation", + Status: "FAIL", + Message: "Go binary not found in PATH", + Suggestions: []string{ + "Install Go: https://golang.org/dl/", + "Ensure go is in your PATH", + }, + } + } + + cmd := exec.Command("go", "version") + output, err := cmd.Output() + if err != nil { + return DoctorResult{ + Name: "Go Installation", + Status: "FAIL", + Message: fmt.Sprintf("Cannot get Go version: %v", err), + } + } + + version := strings.TrimSpace(string(output)) + + return DoctorResult{ + Name: "Go Installation", + Status: "PASS", + Message: version, + } +} + +// checkGitRepository verifies we're in a git repository. +func checkGitRepository() DoctorResult { + cmd := exec.Command("git", "rev-parse", "--git-dir") + err := cmd.Run() + if err != nil { + return DoctorResult{ + Name: "Git Repository", + Status: "FAIL", + Message: "Not in a Git repository", + Suggestions: []string{ + "Run from within the headscale git repository", + "Clone the repository: git clone https://github.com/juanfont/headscale.git", + }, + } + } + + return DoctorResult{ + Name: "Git Repository", + Status: "PASS", + Message: "Running in Git repository", + } +} + +// checkRequiredFiles verifies required files exist. +func checkRequiredFiles() DoctorResult { + requiredFiles := []string{ + "go.mod", + "integration/", + "cmd/hi/", + } + + var missingFiles []string + for _, file := range requiredFiles { + cmd := exec.Command("test", "-e", file) + if err := cmd.Run(); err != nil { + missingFiles = append(missingFiles, file) + } + } + + if len(missingFiles) > 0 { + return DoctorResult{ + Name: "Required Files", + Status: "FAIL", + Message: "Missing required files: " + strings.Join(missingFiles, ", "), + Suggestions: []string{ + "Ensure you're in the headscale project root directory", + "Check that integration/ directory exists", + "Verify this is a complete headscale repository", + }, + } + } + + return DoctorResult{ + Name: "Required Files", + Status: "PASS", + Message: "All required files found", + } +} + +// displayDoctorResults shows the results in a formatted way. +func displayDoctorResults(results []DoctorResult) { + log.Printf("🔍 System Health Check Results") + log.Printf("================================") + + for _, result := range results { + var icon string + switch result.Status { + case "PASS": + icon = "✅" + case "WARN": + icon = "⚠️" + case "FAIL": + icon = "❌" + default: + icon = "❓" + } + + log.Printf("%s %s: %s", icon, result.Name, result.Message) + + if len(result.Suggestions) > 0 { + for _, suggestion := range result.Suggestions { + log.Printf(" 💡 %s", suggestion) + } + } + } + + log.Printf("================================") +} diff --git a/cmd/hi/main.go b/cmd/hi/main.go new file mode 100644 index 00000000..baecc6f3 --- /dev/null +++ b/cmd/hi/main.go @@ -0,0 +1,93 @@ +package main + +import ( + "context" + "os" + + "github.com/creachadair/command" + "github.com/creachadair/flax" +) + +var runConfig RunConfig + +func main() { + root := command.C{ + Name: "hi", + Help: "Headscale Integration test runner", + Commands: []*command.C{ + { + Name: "run", + Help: "Run integration tests", + Usage: "run [test-pattern] [flags]", + SetFlags: command.Flags(flax.MustBind, &runConfig), + Run: runIntegrationTest, + }, + { + Name: "doctor", + Help: "Check system requirements for running integration tests", + Run: func(env *command.Env) error { + return runDoctorCheck(env.Context()) + }, + }, + { + Name: "clean", + Help: "Clean Docker resources", + Commands: []*command.C{ + { + Name: "networks", + Help: "Prune unused Docker networks", + Run: func(env *command.Env) error { + return pruneDockerNetworks(env.Context()) + }, + }, + { + Name: "images", + Help: "Clean old test images", + Run: func(env *command.Env) error { + return cleanOldImages(env.Context()) + }, + }, + { + Name: "containers", + Help: "Kill all test containers", + Run: func(env *command.Env) error { + return killTestContainers(env.Context()) + }, + }, + { + Name: "cache", + Help: "Clean Go module cache volume", + Run: func(env *command.Env) error { + return cleanCacheVolume(env.Context()) + }, + }, + { + Name: "all", + Help: "Run all cleanup operations", + Run: func(env *command.Env) error { + return cleanAll(env.Context()) + }, + }, + }, + }, + command.HelpCommand(nil), + }, + } + + env := root.NewEnv(nil).MergeFlags(true) + command.RunOrFail(env, os.Args[1:]) +} + +func cleanAll(ctx context.Context) error { + if err := killTestContainers(ctx); err != nil { + return err + } + if err := pruneDockerNetworks(ctx); err != nil { + return err + } + if err := cleanOldImages(ctx); err != nil { + return err + } + + return cleanCacheVolume(ctx) +} diff --git a/cmd/hi/run.go b/cmd/hi/run.go new file mode 100644 index 00000000..1694399d --- /dev/null +++ b/cmd/hi/run.go @@ -0,0 +1,125 @@ +package main + +import ( + "errors" + "fmt" + "log" + "os" + "path/filepath" + "time" + + "github.com/creachadair/command" +) + +var ErrTestPatternRequired = errors.New("test pattern is required as first argument or use --test flag") + +type RunConfig struct { + TestPattern string `flag:"test,Test pattern to run"` + Timeout time.Duration `flag:"timeout,default=120m,Test timeout"` + FailFast bool `flag:"failfast,default=true,Stop on first test failure"` + UsePostgres bool `flag:"postgres,default=false,Use PostgreSQL instead of SQLite"` + GoVersion string `flag:"go-version,Go version to use (auto-detected from go.mod)"` + CleanBefore bool `flag:"clean-before,default=true,Clean stale resources before test"` + CleanAfter bool `flag:"clean-after,default=true,Clean resources after test"` + KeepOnFailure bool `flag:"keep-on-failure,default=false,Keep containers on test failure"` + LogsDir string `flag:"logs-dir,default=control_logs,Control logs directory"` + Verbose bool `flag:"verbose,default=false,Verbose output"` + Stats bool `flag:"stats,default=false,Collect and display container resource usage statistics"` + HSMemoryLimit float64 `flag:"hs-memory-limit,default=0,Fail test if any Headscale container exceeds this memory limit in MB (0 = disabled)"` + TSMemoryLimit float64 `flag:"ts-memory-limit,default=0,Fail test if any Tailscale container exceeds this memory limit in MB (0 = disabled)"` +} + +// runIntegrationTest executes the integration test workflow. +func runIntegrationTest(env *command.Env) error { + args := env.Args + if len(args) > 0 && runConfig.TestPattern == "" { + runConfig.TestPattern = args[0] + } + + if runConfig.TestPattern == "" { + return ErrTestPatternRequired + } + + if runConfig.GoVersion == "" { + runConfig.GoVersion = detectGoVersion() + } + + // Run pre-flight checks + if runConfig.Verbose { + log.Printf("Running pre-flight system checks...") + } + if err := runDoctorCheck(env.Context()); err != nil { + return fmt.Errorf("pre-flight checks failed: %w", err) + } + + if runConfig.Verbose { + log.Printf("Running test: %s", runConfig.TestPattern) + log.Printf("Go version: %s", runConfig.GoVersion) + log.Printf("Timeout: %s", runConfig.Timeout) + log.Printf("Use PostgreSQL: %t", runConfig.UsePostgres) + } + + return runTestContainer(env.Context(), &runConfig) +} + +// detectGoVersion reads the Go version from go.mod file. +func detectGoVersion() string { + goModPath := filepath.Join("..", "..", "go.mod") + + if _, err := os.Stat("go.mod"); err == nil { + goModPath = "go.mod" + } else if _, err := os.Stat("../../go.mod"); err == nil { + goModPath = "../../go.mod" + } + + content, err := os.ReadFile(goModPath) + if err != nil { + return "1.25" + } + + lines := splitLines(string(content)) + for _, line := range lines { + if len(line) > 3 && line[:3] == "go " { + version := line[3:] + if idx := indexOf(version, " "); idx != -1 { + version = version[:idx] + } + + return version + } + } + + return "1.25" +} + +// splitLines splits a string into lines without using strings.Split. +func splitLines(s string) []string { + var lines []string + var current string + + for _, char := range s { + if char == '\n' { + lines = append(lines, current) + current = "" + } else { + current += string(char) + } + } + + if current != "" { + lines = append(lines, current) + } + + return lines +} + +// indexOf finds the first occurrence of substr in s. +func indexOf(s, substr string) int { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return i + } + } + + return -1 +} diff --git a/cmd/hi/stats.go b/cmd/hi/stats.go new file mode 100644 index 00000000..b68215a6 --- /dev/null +++ b/cmd/hi/stats.go @@ -0,0 +1,471 @@ +package main + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "log" + "sort" + "strings" + "sync" + "time" + + "github.com/docker/docker/api/types" + "github.com/docker/docker/api/types/container" + "github.com/docker/docker/api/types/events" + "github.com/docker/docker/api/types/filters" + "github.com/docker/docker/client" +) + +// ContainerStats represents statistics for a single container. +type ContainerStats struct { + ContainerID string + ContainerName string + Stats []StatsSample + mutex sync.RWMutex +} + +// StatsSample represents a single stats measurement. +type StatsSample struct { + Timestamp time.Time + CPUUsage float64 // CPU usage percentage + MemoryMB float64 // Memory usage in MB +} + +// StatsCollector manages collection of container statistics. +type StatsCollector struct { + client *client.Client + containers map[string]*ContainerStats + stopChan chan struct{} + wg sync.WaitGroup + mutex sync.RWMutex + collectionStarted bool +} + +// NewStatsCollector creates a new stats collector instance. +func NewStatsCollector() (*StatsCollector, error) { + cli, err := createDockerClient() + if err != nil { + return nil, fmt.Errorf("failed to create Docker client: %w", err) + } + + return &StatsCollector{ + client: cli, + containers: make(map[string]*ContainerStats), + stopChan: make(chan struct{}), + }, nil +} + +// StartCollection begins monitoring all containers and collecting stats for hs- and ts- containers with matching run ID. +func (sc *StatsCollector) StartCollection(ctx context.Context, runID string, verbose bool) error { + sc.mutex.Lock() + defer sc.mutex.Unlock() + + if sc.collectionStarted { + return errors.New("stats collection already started") + } + + sc.collectionStarted = true + + // Start monitoring existing containers + sc.wg.Add(1) + go sc.monitorExistingContainers(ctx, runID, verbose) + + // Start Docker events monitoring for new containers + sc.wg.Add(1) + go sc.monitorDockerEvents(ctx, runID, verbose) + + if verbose { + log.Printf("Started container monitoring for run ID %s", runID) + } + + return nil +} + +// StopCollection stops all stats collection. +func (sc *StatsCollector) StopCollection() { + // Check if already stopped without holding lock + sc.mutex.RLock() + if !sc.collectionStarted { + sc.mutex.RUnlock() + return + } + sc.mutex.RUnlock() + + // Signal stop to all goroutines + close(sc.stopChan) + + // Wait for all goroutines to finish + sc.wg.Wait() + + // Mark as stopped + sc.mutex.Lock() + sc.collectionStarted = false + sc.mutex.Unlock() +} + +// monitorExistingContainers checks for existing containers that match our criteria. +func (sc *StatsCollector) monitorExistingContainers(ctx context.Context, runID string, verbose bool) { + defer sc.wg.Done() + + containers, err := sc.client.ContainerList(ctx, container.ListOptions{}) + if err != nil { + if verbose { + log.Printf("Failed to list existing containers: %v", err) + } + return + } + + for _, cont := range containers { + if sc.shouldMonitorContainer(cont, runID) { + sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose) + } + } +} + +// monitorDockerEvents listens for container start events and begins monitoring relevant containers. +func (sc *StatsCollector) monitorDockerEvents(ctx context.Context, runID string, verbose bool) { + defer sc.wg.Done() + + filter := filters.NewArgs() + filter.Add("type", "container") + filter.Add("event", "start") + + eventOptions := events.ListOptions{ + Filters: filter, + } + + events, errs := sc.client.Events(ctx, eventOptions) + + for { + select { + case <-sc.stopChan: + return + case <-ctx.Done(): + return + case event := <-events: + if event.Type == "container" && event.Action == "start" { + // Get container details + containerInfo, err := sc.client.ContainerInspect(ctx, event.ID) + if err != nil { + continue + } + + // Convert to types.Container format for consistency + cont := types.Container{ + ID: containerInfo.ID, + Names: []string{containerInfo.Name}, + Labels: containerInfo.Config.Labels, + } + + if sc.shouldMonitorContainer(cont, runID) { + sc.startStatsForContainer(ctx, cont.ID, cont.Names[0], verbose) + } + } + case err := <-errs: + if verbose { + log.Printf("Error in Docker events stream: %v", err) + } + return + } + } +} + +// shouldMonitorContainer determines if a container should be monitored. +func (sc *StatsCollector) shouldMonitorContainer(cont types.Container, runID string) bool { + // Check if it has the correct run ID label + if cont.Labels == nil || cont.Labels["hi.run-id"] != runID { + return false + } + + // Check if it's an hs- or ts- container + for _, name := range cont.Names { + containerName := strings.TrimPrefix(name, "/") + if strings.HasPrefix(containerName, "hs-") || strings.HasPrefix(containerName, "ts-") { + return true + } + } + + return false +} + +// startStatsForContainer begins stats collection for a specific container. +func (sc *StatsCollector) startStatsForContainer(ctx context.Context, containerID, containerName string, verbose bool) { + containerName = strings.TrimPrefix(containerName, "/") + + sc.mutex.Lock() + // Check if we're already monitoring this container + if _, exists := sc.containers[containerID]; exists { + sc.mutex.Unlock() + return + } + + sc.containers[containerID] = &ContainerStats{ + ContainerID: containerID, + ContainerName: containerName, + Stats: make([]StatsSample, 0), + } + sc.mutex.Unlock() + + if verbose { + log.Printf("Starting stats collection for container %s (%s)", containerName, containerID[:12]) + } + + sc.wg.Add(1) + go sc.collectStatsForContainer(ctx, containerID, verbose) +} + +// collectStatsForContainer collects stats for a specific container using Docker API streaming. +func (sc *StatsCollector) collectStatsForContainer(ctx context.Context, containerID string, verbose bool) { + defer sc.wg.Done() + + // Use Docker API streaming stats - much more efficient than CLI + statsResponse, err := sc.client.ContainerStats(ctx, containerID, true) + if err != nil { + if verbose { + log.Printf("Failed to get stats stream for container %s: %v", containerID[:12], err) + } + return + } + defer statsResponse.Body.Close() + + decoder := json.NewDecoder(statsResponse.Body) + var prevStats *container.Stats + + for { + select { + case <-sc.stopChan: + return + case <-ctx.Done(): + return + default: + var stats container.Stats + if err := decoder.Decode(&stats); err != nil { + // EOF is expected when container stops or stream ends + if err.Error() != "EOF" && verbose { + log.Printf("Failed to decode stats for container %s: %v", containerID[:12], err) + } + return + } + + // Calculate CPU percentage (only if we have previous stats) + var cpuPercent float64 + if prevStats != nil { + cpuPercent = calculateCPUPercent(prevStats, &stats) + } + + // Calculate memory usage in MB + memoryMB := float64(stats.MemoryStats.Usage) / (1024 * 1024) + + // Store the sample (skip first sample since CPU calculation needs previous stats) + if prevStats != nil { + // Get container stats reference without holding the main mutex + var containerStats *ContainerStats + var exists bool + + sc.mutex.RLock() + containerStats, exists = sc.containers[containerID] + sc.mutex.RUnlock() + + if exists && containerStats != nil { + containerStats.mutex.Lock() + containerStats.Stats = append(containerStats.Stats, StatsSample{ + Timestamp: time.Now(), + CPUUsage: cpuPercent, + MemoryMB: memoryMB, + }) + containerStats.mutex.Unlock() + } + } + + // Save current stats for next iteration + prevStats = &stats + } + } +} + +// calculateCPUPercent calculates CPU usage percentage from Docker stats. +func calculateCPUPercent(prevStats, stats *container.Stats) float64 { + // CPU calculation based on Docker's implementation + cpuDelta := float64(stats.CPUStats.CPUUsage.TotalUsage) - float64(prevStats.CPUStats.CPUUsage.TotalUsage) + systemDelta := float64(stats.CPUStats.SystemUsage) - float64(prevStats.CPUStats.SystemUsage) + + if systemDelta > 0 && cpuDelta >= 0 { + // Calculate CPU percentage: (container CPU delta / system CPU delta) * number of CPUs * 100 + numCPUs := float64(len(stats.CPUStats.CPUUsage.PercpuUsage)) + if numCPUs == 0 { + // Fallback: if PercpuUsage is not available, assume 1 CPU + numCPUs = 1.0 + } + + return (cpuDelta / systemDelta) * numCPUs * 100.0 + } + + return 0.0 +} + +// ContainerStatsSummary represents summary statistics for a container. +type ContainerStatsSummary struct { + ContainerName string + SampleCount int + CPU StatsSummary + Memory StatsSummary +} + +// MemoryViolation represents a container that exceeded the memory limit. +type MemoryViolation struct { + ContainerName string + MaxMemoryMB float64 + LimitMB float64 +} + +// StatsSummary represents min, max, and average for a metric. +type StatsSummary struct { + Min float64 + Max float64 + Average float64 +} + +// GetSummary returns a summary of collected statistics. +func (sc *StatsCollector) GetSummary() []ContainerStatsSummary { + // Take snapshot of container references without holding main lock long + sc.mutex.RLock() + containerRefs := make([]*ContainerStats, 0, len(sc.containers)) + for _, containerStats := range sc.containers { + containerRefs = append(containerRefs, containerStats) + } + sc.mutex.RUnlock() + + summaries := make([]ContainerStatsSummary, 0, len(containerRefs)) + + for _, containerStats := range containerRefs { + containerStats.mutex.RLock() + stats := make([]StatsSample, len(containerStats.Stats)) + copy(stats, containerStats.Stats) + containerName := containerStats.ContainerName + containerStats.mutex.RUnlock() + + if len(stats) == 0 { + continue + } + + summary := ContainerStatsSummary{ + ContainerName: containerName, + SampleCount: len(stats), + } + + // Calculate CPU stats + cpuValues := make([]float64, len(stats)) + memoryValues := make([]float64, len(stats)) + + for i, sample := range stats { + cpuValues[i] = sample.CPUUsage + memoryValues[i] = sample.MemoryMB + } + + summary.CPU = calculateStatsSummary(cpuValues) + summary.Memory = calculateStatsSummary(memoryValues) + + summaries = append(summaries, summary) + } + + // Sort by container name for consistent output + sort.Slice(summaries, func(i, j int) bool { + return summaries[i].ContainerName < summaries[j].ContainerName + }) + + return summaries +} + +// calculateStatsSummary calculates min, max, and average for a slice of values. +func calculateStatsSummary(values []float64) StatsSummary { + if len(values) == 0 { + return StatsSummary{} + } + + min := values[0] + max := values[0] + sum := 0.0 + + for _, value := range values { + if value < min { + min = value + } + if value > max { + max = value + } + sum += value + } + + return StatsSummary{ + Min: min, + Max: max, + Average: sum / float64(len(values)), + } +} + +// PrintSummary prints the statistics summary to the console. +func (sc *StatsCollector) PrintSummary() { + summaries := sc.GetSummary() + + if len(summaries) == 0 { + log.Printf("No container statistics collected") + return + } + + log.Printf("Container Resource Usage Summary:") + log.Printf("================================") + + for _, summary := range summaries { + log.Printf("Container: %s (%d samples)", summary.ContainerName, summary.SampleCount) + log.Printf(" CPU Usage: Min: %6.2f%% Max: %6.2f%% Avg: %6.2f%%", + summary.CPU.Min, summary.CPU.Max, summary.CPU.Average) + log.Printf(" Memory Usage: Min: %6.1f MB Max: %6.1f MB Avg: %6.1f MB", + summary.Memory.Min, summary.Memory.Max, summary.Memory.Average) + log.Printf("") + } +} + +// CheckMemoryLimits checks if any containers exceeded their memory limits. +func (sc *StatsCollector) CheckMemoryLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation { + if hsLimitMB <= 0 && tsLimitMB <= 0 { + return nil + } + + summaries := sc.GetSummary() + var violations []MemoryViolation + + for _, summary := range summaries { + var limitMB float64 + if strings.HasPrefix(summary.ContainerName, "hs-") { + limitMB = hsLimitMB + } else if strings.HasPrefix(summary.ContainerName, "ts-") { + limitMB = tsLimitMB + } else { + continue // Skip containers that don't match our patterns + } + + if limitMB > 0 && summary.Memory.Max > limitMB { + violations = append(violations, MemoryViolation{ + ContainerName: summary.ContainerName, + MaxMemoryMB: summary.Memory.Max, + LimitMB: limitMB, + }) + } + } + + return violations +} + +// PrintSummaryAndCheckLimits prints the statistics summary and returns memory violations if any. +func (sc *StatsCollector) PrintSummaryAndCheckLimits(hsLimitMB, tsLimitMB float64) []MemoryViolation { + sc.PrintSummary() + return sc.CheckMemoryLimits(hsLimitMB, tsLimitMB) +} + +// Close closes the stats collector and cleans up resources. +func (sc *StatsCollector) Close() error { + sc.StopCollection() + return sc.client.Close() +} diff --git a/cmd/mapresponses/main.go b/cmd/mapresponses/main.go new file mode 100644 index 00000000..5d7ad07d --- /dev/null +++ b/cmd/mapresponses/main.go @@ -0,0 +1,61 @@ +package main + +import ( + "encoding/json" + "fmt" + "os" + + "github.com/creachadair/command" + "github.com/creachadair/flax" + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/integration/integrationutil" +) + +type MapConfig struct { + Directory string `flag:"directory,Directory to read map responses from"` +} + +var mapConfig MapConfig + +func main() { + root := command.C{ + Name: "mapresponses", + Help: "MapResponses is a tool to map and compare map responses from a directory", + Commands: []*command.C{ + { + Name: "online", + Help: "", + Usage: "run [test-pattern] [flags]", + SetFlags: command.Flags(flax.MustBind, &mapConfig), + Run: runOnline, + }, + command.HelpCommand(nil), + }, + } + + env := root.NewEnv(nil).MergeFlags(true) + command.RunOrFail(env, os.Args[1:]) +} + +// runIntegrationTest executes the integration test workflow. +func runOnline(env *command.Env) error { + if mapConfig.Directory == "" { + return fmt.Errorf("directory is required") + } + + resps, err := mapper.ReadMapResponsesFromDirectory(mapConfig.Directory) + if err != nil { + return fmt.Errorf("reading map responses from directory: %w", err) + } + + expected := integrationutil.BuildExpectedOnlineMap(resps) + + out, err := json.MarshalIndent(expected, "", " ") + if err != nil { + return fmt.Errorf("marshaling expected online map: %w", err) + } + + os.Stderr.Write(out) + os.Stderr.Write([]byte("\n")) + return nil +} diff --git a/config-example.yaml b/config-example.yaml index 5105dcd8..dbb08202 100644 --- a/config-example.yaml +++ b/config-example.yaml @@ -18,10 +18,9 @@ server_url: http://127.0.0.1:8080 # listen_addr: 0.0.0.0:8080 listen_addr: 127.0.0.1:8080 -# Address to listen to /metrics, you may want -# to keep this endpoint private to your internal -# network -# +# Address to listen to /metrics and /debug, you may want +# to keep this endpoint private to your internal network +# Use an emty value to disable the metrics listener. metrics_listen_addr: 127.0.0.1:9090 # Address to listen for gRPC. @@ -43,9 +42,9 @@ grpc_allow_insecure: false # The Noise section includes specific configuration for the # TS2021 Noise protocol noise: - # The Noise private key is used to encrypt the - # traffic between headscale and Tailscale clients when - # using the new Noise-based protocol. + # The Noise private key is used to encrypt the traffic between headscale and + # Tailscale clients when using the new Noise-based protocol. A missing key + # will be automatically generated. private_key_path: /var/lib/headscale/noise_private.key # List of IP prefixes to allocate tailaddresses from. @@ -57,9 +56,16 @@ noise: # IPv6: https://github.com/tailscale/tailscale/blob/22ebb25e833264f58d7c3f534a8b166894a89536/net/tsaddr/tsaddr.go#LL81C52-L81C71 # IPv4: https://github.com/tailscale/tailscale/blob/22ebb25e833264f58d7c3f534a8b166894a89536/net/tsaddr/tsaddr.go#L33 # Any other range is NOT supported, and it will cause unexpected issues. -ip_prefixes: - - fd7a:115c:a1e0::/48 - - 100.64.0.0/10 +prefixes: + v4: 100.64.0.0/10 + v6: fd7a:115c:a1e0::/48 + + # Strategy used for allocation of IPs to nodes, available options: + # - sequential (default): assigns the next free IP from the previous given + # IP. A best-effort approach is used and Headscale might leave holes in the + # IP range or fill up existing holes in the IP range. + # - random: assigns the next free IP from a pseudo-random IP generator (crypto/rand). + allocation: sequential # DERP is a relay system that Tailscale uses when a direct # connection cannot be established. @@ -82,18 +88,29 @@ derp: region_code: "headscale" region_name: "Headscale Embedded DERP" + # Only allow clients associated with this server access + verify_clients: true + # Listens over UDP at the configured address for STUN connections - to help with NAT traversal. # When the embedded DERP server is enabled stun_listen_addr MUST be defined. # # For more details on how this works, check this great article: https://tailscale.com/blog/how-tailscale-works/ stun_listen_addr: "0.0.0.0:3478" - # Private key used to encrypt the traffic between headscale DERP - # and Tailscale clients. - # The private key file will be autogenerated if it's missing. - # + # Private key used to encrypt the traffic between headscale DERP and + # Tailscale clients. A missing key will be automatically generated. private_key_path: /var/lib/headscale/derp_server_private.key + # This flag can be used, so the DERP map entry for the embedded DERP server is not written automatically, + # it enables the creation of your very own DERP map entry using a locally available file with the parameter DERP.paths + # If you enable the DERP server and set this to false, it is required to add the DERP server to the DERP map using DERP.paths + automatically_add_embedded_derp_region: true + + # For better connection stability (especially when using an Exit-Node and DNS is not working), + # it is possible to optionally add the public IPv4 and IPv6 address to the Derp-Map using: + ipv4: 198.51.100.1 + ipv6: 2001:db8::1 + # List of externally available DERP maps encoded in JSON urls: - https://controlplane.tailscale.com/derpmap/default @@ -114,7 +131,7 @@ derp: auto_update_enabled: true # How often should we check for DERP updates? - update_frequency: 24h + update_frequency: 3h # Disables the automatic check for headscale updates on startup disable_check_updates: false @@ -122,30 +139,59 @@ disable_check_updates: false # Time before an inactive ephemeral node is deleted? ephemeral_node_inactivity_timeout: 30m -# Period to check for node updates within the tailnet. A value too low will severely affect -# CPU consumption of Headscale. A value too high (over 60s) will cause problems -# for the nodes, as they won't get updates or keep alive messages frequently enough. -# In case of doubts, do not touch the default 10s. -node_update_check_interval: 10s +database: + # Database type. Available options: sqlite, postgres + # Please note that using Postgres is highly discouraged as it is only supported for legacy reasons. + # All new development, testing and optimisations are done with SQLite in mind. + type: sqlite -# SQLite config -db_type: sqlite3 + # Enable debug mode. This setting requires the log.level to be set to "debug" or "trace". + debug: false -# For production: -db_path: /var/lib/headscale/db.sqlite + # GORM configuration settings. + gorm: + # Enable prepared statements. + prepare_stmt: true -# # Postgres config -# If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. -# db_type: postgres -# db_host: localhost -# db_port: 5432 -# db_name: headscale -# db_user: foo -# db_pass: bar + # Enable parameterized queries. + parameterized_queries: true -# If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need -# in the 'db_ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. -# db_ssl: false + # Skip logging "record not found" errors. + skip_err_record_not_found: true + + # Threshold for slow queries in milliseconds. + slow_threshold: 1000 + + # SQLite config + sqlite: + path: /var/lib/headscale/db.sqlite + + # Enable WAL mode for SQLite. This is recommended for production environments. + # https://www.sqlite.org/wal.html + write_ahead_log: true + + # Maximum number of WAL file frames before the WAL file is automatically checkpointed. + # https://www.sqlite.org/c3ref/wal_autocheckpoint.html + # Set to 0 to disable automatic checkpointing. + wal_autocheckpoint: 1000 + + # # Postgres config + # Please note that using Postgres is highly discouraged as it is only supported for legacy reasons. + # See database.type for more information. + # postgres: + # # If using a Unix socket to connect to Postgres, set the socket path in the 'host' field and leave 'port' blank. + # host: localhost + # port: 5432 + # name: headscale + # user: foo + # pass: bar + # max_open_conns: 10 + # max_idle_conns: 10 + # conn_max_idle_time_secs: 3600 + + # # If other 'sslmode' is required instead of 'require(true)' and 'disabled(false)', set the 'sslmode' you need + # # in the 'ssl' field. Refers to https://www.postgresql.org/docs/current/libpq-ssl.html Table 34.1. + # ssl: false ### TLS configuration # @@ -170,7 +216,7 @@ tls_letsencrypt_cache_dir: /var/lib/headscale/cache # Type of ACME challenge to use, currently supported types: # HTTP-01 or TLS-ALPN-01 -# See [docs/tls.md](docs/tls.md) for more information +# See: docs/ref/tls.md for more information tls_letsencrypt_challenge_type: HTTP-01 # When HTTP-01 challenge is chosen, letsencrypt must set up a # verification endpoint, and it will be listening on: @@ -182,14 +228,23 @@ tls_cert_path: "" tls_key_path: "" log: - # Output formatting for logs: text or json - format: text + # Valid log levels: panic, fatal, error, warn, info, debug, trace level: info -# Path to a file containg ACL policies. -# ACLs can be defined as YAML or HUJSON. -# https://tailscale.com/kb/1018/acls/ -acl_policy_path: "" + # Output formatting for logs: text or json + format: text + +## Policy +# headscale supports Tailscale's ACL policies. +# Please have a look to their KB to better +# understand the concepts: https://tailscale.com/kb/1018/acls/ +policy: + # The mode can be "file" or "database" that defines + # where the ACL policies are stored and read from. + mode: file + # If the mode is set to "file", the path to a + # HuJSON file containing ACL policies. + path: "" ## DNS # @@ -200,122 +255,158 @@ acl_policy_path: "" # - https://tailscale.com/kb/1081/magicdns/ # - https://tailscale.com/blog/2021-09-private-dns-with-magicdns/ # -dns_config: - # Whether to prefer using Headscale provided DNS or use local. +# Please note that for the DNS configuration to have any effect, +# clients must have the `--accept-dns=true` option enabled. This is the +# default for the Tailscale client. This option is enabled by default +# in the Tailscale client. +# +# Setting _any_ of the configuration and `--accept-dns=true` on the +# clients will integrate with the DNS manager on the client or +# overwrite /etc/resolv.conf. +# https://tailscale.com/kb/1235/resolv-conf +# +# If you want stop Headscale from managing the DNS configuration +# all the fields under `dns` should be set to empty values. +dns: + # Whether to use [MagicDNS](https://tailscale.com/kb/1081/magicdns/). + magic_dns: true + + # Defines the base domain to create the hostnames for MagicDNS. + # This domain _must_ be different from the server_url domain. + # `base_domain` must be a FQDN, without the trailing dot. + # The FQDN of the hosts will be + # `hostname.base_domain` (e.g., _myhost.example.com_). + base_domain: example.com + + # Whether to use the local DNS settings of a node or override the local DNS + # settings (default) and force the use of Headscale's DNS configuration. override_local_dns: true # List of DNS servers to expose to clients. nameservers: - - 1.1.1.1 + global: + - 1.1.1.1 + - 1.0.0.1 + - 2606:4700:4700::1111 + - 2606:4700:4700::1001 - # NextDNS (see https://tailscale.com/kb/1218/nextdns/). - # "abc123" is example NextDNS ID, replace with yours. - # - # With metadata sharing: - # nameservers: - # - https://dns.nextdns.io/abc123 - # - # Without metadata sharing: - # nameservers: - # - 2a07:a8c0::ab:c123 - # - 2a07:a8c1::ab:c123 + # NextDNS (see https://tailscale.com/kb/1218/nextdns/). + # "abc123" is example NextDNS ID, replace with yours. + # - https://dns.nextdns.io/abc123 - # Split DNS (see https://tailscale.com/kb/1054/dns/), - # list of search domains and the DNS to query for each one. - # - # restricted_nameservers: - # foo.bar.com: - # - 1.1.1.1 - # darp.headscale.net: - # - 1.1.1.1 - # - 8.8.8.8 + # Split DNS (see https://tailscale.com/kb/1054/dns/), + # a map of domains and which DNS server to use for each. + split: {} + # foo.bar.com: + # - 1.1.1.1 + # darp.headscale.net: + # - 1.1.1.1 + # - 8.8.8.8 - # Search domains to inject. - domains: [] + # Set custom DNS search domains. With MagicDNS enabled, + # your tailnet base_domain is always the first search domain. + search_domains: [] # Extra DNS records - # so far only A-records are supported (on the tailscale side) - # See https://github.com/juanfont/headscale/blob/main/docs/dns-records.md#Limitations - # extra_records: + # so far only A and AAAA records are supported (on the tailscale side) + # See: docs/ref/dns.md + extra_records: [] # - name: "grafana.myvpn.example.com" # type: "A" # value: "100.64.0.3" # # # you can also put it in one line # - { name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.3" } - - # Whether to use [MagicDNS](https://tailscale.com/kb/1081/magicdns/). - # Only works if there is at least a nameserver defined. - magic_dns: true - - # Defines the base domain to create the hostnames for MagicDNS. - # `base_domain` must be a FQDNs, without the trailing dot. - # The FQDN of the hosts will be - # `hostname.user.base_domain` (e.g., _myhost.myuser.example.com_). - base_domain: example.com + # + # Alternatively, extra DNS records can be loaded from a JSON file. + # Headscale processes this file on each change. + # extra_records_path: /var/lib/headscale/extra-records.json # Unix socket used for the CLI to connect without authentication # Note: for production you will want to set this to something like: unix_socket: /var/run/headscale/headscale.sock unix_socket_permission: "0770" -# -# headscale supports experimental OpenID connect support, -# it is still being tested and might have some bugs, please -# help us test it. + # OpenID Connect # oidc: +# # Block startup until the identity provider is available and healthy. # only_start_if_oidc_is_available: true +# +# # OpenID Connect Issuer URL from the identity provider # issuer: "https://your-oidc.issuer.com/path" +# +# # Client ID from the identity provider # client_id: "your-oidc-client-id" +# +# # Client secret generated by the identity provider +# # Note: client_secret and client_secret_path are mutually exclusive. # client_secret: "your-oidc-client-secret" # # Alternatively, set `client_secret_path` to read the secret from the file. # # It resolves environment variables, making integration to systemd's # # `LoadCredential` straightforward: # client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" -# # client_secret and client_secret_path are mutually exclusive. # -# # The amount of time from a node is authenticated with OpenID until it -# # expires and needs to reauthenticate. +# # The amount of time a node is authenticated with OpenID until it expires +# # and needs to reauthenticate. # # Setting the value to "0" will mean no expiry. # expiry: 180d # # # Use the expiry from the token received from OpenID when the user logged -# # in, this will typically lead to frequent need to reauthenticate and should -# # only been enabled if you know what you are doing. +# # in. This will typically lead to frequent need to reauthenticate and should +# # only be enabled if you know what you are doing. # # Note: enabling this will cause `oidc.expiry` to be ignored. # use_expiry_from_token: false # -# # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query -# # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". +# # The OIDC scopes to use, defaults to "openid", "profile" and "email". +# # Custom scopes can be configured as needed, be sure to always include the +# # required "openid" scope. +# scope: ["openid", "profile", "email"] # -# scope: ["openid", "profile", "email", "custom"] +# # Only verified email addresses are synchronized to the user profile by +# # default. Unverified emails may be allowed in case an identity provider +# # does not send the "email_verified: true" claim or email verification is +# # not required. +# email_verified_required: true +# +# # Provide custom key/value pairs which get sent to the identity provider's +# # authorization endpoint. # extra_params: # domain_hint: example.com # -# # List allowed principal domains and/or users. If an authenticated user's domain is not in this list, the -# # authentication request will be rejected. -# +# # Only accept users whose email domain is part of the allowed_domains list. # allowed_domains: # - example.com -# # Note: Groups from keycloak have a leading '/' -# allowed_groups: -# - /headscale +# +# # Only accept users whose email address is part of the allowed_users list. # allowed_users: # - alice@example.com # -# # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. -# # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` -# # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following -# user: `first-name.last-name.example.com` +# # Only accept users which are members of at least one group in the +# # allowed_groups list. +# allowed_groups: +# - /headscale # -# strip_email_domain: true +# # Optional: PKCE (Proof Key for Code Exchange) configuration +# # PKCE adds an additional layer of security to the OAuth 2.0 authorization code flow +# # by preventing authorization code interception attacks +# # See https://datatracker.ietf.org/doc/html/rfc7636 +# pkce: +# # Enable or disable PKCE support (default: false) +# enabled: false +# +# # PKCE method to use: +# # - plain: Use plain code verifier +# # - S256: Use SHA256 hashed code verifier (default, recommended) +# method: S256 # Logtail configuration -# Logtail is Tailscales logging and auditing infrastructure, it allows the control panel -# to instruct tailscale nodes to log their activity to a remote server. +# Logtail is Tailscales logging and auditing infrastructure, it allows the +# control panel to instruct tailscale nodes to log their activity to a remote +# server. To disable logging on the client side, please refer to: +# https://tailscale.com/kb/1011/log-mesh-traffic#opting-out-of-client-logging logtail: - # Enable logtail for this headscales clients. - # As there is currently no support for overriding the log server in headscale, this is + # Enable logtail for tailscale nodes of this Headscale instance. + # As there is currently no support for overriding the log server in Headscale, this is # disabled by default. Enabling this will make your clients send logs to Tailscale Inc. enabled: false @@ -323,3 +414,23 @@ logtail: # default static port 41641. This option is intended as a workaround for some buggy # firewall devices. See https://tailscale.com/kb/1181/firewalls/ for more information. randomize_client_port: false + +# Taildrop configuration +# Taildrop is the file sharing feature of Tailscale, allowing nodes to send files to each other. +# https://tailscale.com/kb/1106/taildrop/ +taildrop: + # Enable or disable Taildrop for all nodes. + # When enabled, nodes can send files to other nodes owned by the same user. + # Tagged devices and cross-user transfers are not permitted by Tailscale clients. + enabled: true +# Advanced performance tuning parameters. +# The defaults are carefully chosen and should rarely need adjustment. +# Only modify these if you have identified a specific performance issue. +# +# tuning: +# # NodeStore write batching configuration. +# # The NodeStore batches write operations before rebuilding peer relationships, +# # which is computationally expensive. Batching reduces rebuild frequency. +# # +# # node_store_batch_size: 100 +# # node_store_batch_timeout: 500ms diff --git a/derp-example.yaml b/derp-example.yaml index 732c4ba0..ea93427c 100644 --- a/derp-example.yaml +++ b/derp-example.yaml @@ -1,5 +1,6 @@ # If you plan to somehow use headscale, please deploy your own DERP infra: https://tailscale.com/kb/1118/custom-derp-servers/ regions: + 1: null # Disable DERP region with ID 1 900: regionid: 900 regioncode: custom @@ -7,9 +8,9 @@ regions: nodes: - name: 900a regionid: 900 - hostname: myderp.mydomain.no - ipv4: 123.123.123.123 - ipv6: "2604:a880:400:d1::828:b001" + hostname: myderp.example.com + ipv4: 198.51.100.1 + ipv6: 2001:db8::1 stunport: 0 stunonly: false derpport: 0 diff --git a/docs/about/clients.md b/docs/about/clients.md new file mode 100644 index 00000000..7cbe6a1b --- /dev/null +++ b/docs/about/clients.md @@ -0,0 +1,16 @@ +# Client and operating system support + +We aim to support the [**last 10 releases** of the Tailscale client](https://tailscale.com/changelog#client) on all +provided operating systems and platforms. Some platforms might require additional configuration to connect with +headscale. + +| OS | Supports headscale | +| ------- | ----------------------------------------------------------------------------------------------------- | +| Linux | Yes | +| OpenBSD | Yes | +| FreeBSD | Yes | +| Windows | Yes (see [docs](../usage/connect/windows.md) and `/windows` on your headscale for more information) | +| Android | Yes (see [docs](../usage/connect/android.md) for more information) | +| macOS | Yes (see [docs](../usage/connect/apple.md#macos) and `/apple` on your headscale for more information) | +| iOS | Yes (see [docs](../usage/connect/apple.md#ios) and `/apple` on your headscale for more information) | +| tvOS | Yes (see [docs](../usage/connect/apple.md#tvos) and `/apple` on your headscale for more information) | diff --git a/docs/about/contributing.md b/docs/about/contributing.md new file mode 100644 index 00000000..4eeeef13 --- /dev/null +++ b/docs/about/contributing.md @@ -0,0 +1,3 @@ +{% + include-markdown "../../CONTRIBUTING.md" +%} diff --git a/docs/about/faq.md b/docs/about/faq.md new file mode 100644 index 00000000..f1361590 --- /dev/null +++ b/docs/about/faq.md @@ -0,0 +1,177 @@ +# Frequently Asked Questions + +## What is the design goal of headscale? + +Headscale aims to implement a self-hosted, open source alternative to the +[Tailscale](https://tailscale.com/) control server. Headscale's goal is to +provide self-hosters and hobbyists with an open-source server they can use for +their projects and labs. It implements a narrow scope, a _single_ Tailscale +network (tailnet), suitable for a personal use, or a small open-source +organisation. + +## How can I contribute? + +Headscale is "Open Source, acknowledged contribution", this means that any +contribution will have to be discussed with the Maintainers before being submitted. + +Please see [Contributing](contributing.md) for more information. + +## Why is 'acknowledged contribution' the chosen model? + +Both maintainers have full-time jobs and families, and we want to avoid burnout. We also want to avoid frustration from contributors when their PRs are not accepted. + +We are more than happy to exchange emails, or to have dedicated calls before a PR is submitted. + +## When/Why is Feature X going to be implemented? + +We don't know. We might be working on it. If you're interested in contributing, please post a feature request about it. + +Please be aware that there are a number of reasons why we might not accept specific contributions: + +- It is not possible to implement the feature in a way that makes sense in a self-hosted environment. +- Given that we are reverse-engineering Tailscale to satisfy our own curiosity, we might be interested in implementing the feature ourselves. +- You are not sending unit and integration tests with it. + +## Do you support Y method of deploying headscale? + +We currently support deploying headscale using our binaries and the DEB packages. Visit our [installation guide using +official releases](../setup/install/official.md) for more information. + +In addition to that, you may use packages provided by the community or from distributions. Learn more in the +[installation guide using community packages](../setup/install/community.md). + +For convenience, we also [build container images with headscale](../setup/install/container.md). But **please be aware that +we don't officially support deploying headscale using Docker**. On our [Discord server](https://discord.gg/c84AZQhmpx) +we have a "docker-issues" channel where you can ask for Docker-specific help to the community. + +## What is the recommended update path? Can I skip multiple versions while updating? + +Please follow the steps outlined in the [upgrade guide](../setup/upgrade.md) to update your existing Headscale +installation. Its best to update from one stable version to the next (e.g. 0.24.0 → 0.25.1 → 0.26.1) in case +you are multiple releases behind. You should always pick the latest available patch release. + +Be sure to check the [changelog](https://github.com/juanfont/headscale/blob/main/CHANGELOG.md) for version specific +upgrade instructions and breaking changes. + +## Scaling / How many clients does Headscale support? + +It depends. As often stated, Headscale is not enterprise software and our focus +is homelabbers and self-hosters. Of course, we do not prevent people from using +it in a commercial/professional setting and often get questions about scaling. + +Please note that when Headscale is developed, performance is not part of the +consideration as the main audience is considered to be users with a modest +amount of devices. We focus on correctness and feature parity with Tailscale +SaaS over time. + +To understand if you might be able to use Headscale for your use case, I will +describe two scenarios in an effort to explain what is the central bottleneck +of Headscale: + +1. An environment with 1000 servers + + - they rarely "move" (change their endpoints) + - new nodes are added rarely + +2. An environment with 80 laptops/phones (end user devices) + + - nodes move often, e.g. switching from home to office + +Headscale calculates a map of all nodes that need to talk to each other, +creating this "world map" requires a lot of CPU time. When an event that +requires changes to this map happens, the whole "world" is recalculated, and a +new "world map" is created for every node in the network. + +This means that under certain conditions, Headscale can likely handle 100s +of devices (maybe more), if there is _little to no change_ happening in the +network. For example, in Scenario 1, the process of computing the world map is +extremely demanding due to the size of the network, but when the map has been +created and the nodes are not changing, the Headscale instance will likely +return to a very low resource usage until the next time there is an event +requiring the new map. + +In the case of Scenario 2, the process of computing the world map is less +demanding due to the smaller size of the network, however, the type of nodes +will likely change frequently, which would lead to a constant resource usage. + +Headscale will start to struggle when the two scenarios overlap, e.g. many nodes +with frequent changes will cause the resource usage to remain constantly high. +In the worst case scenario, the queue of nodes waiting for their map will grow +to a point where Headscale never will be able to catch up, and nodes will never +learn about the current state of the world. + +We expect that the performance will improve over time as we improve the code +base, but it is not a focus. In general, we will never make the tradeoff to make +things faster on the cost of less maintainable or readable code. We are a small +team and have to optimise for maintainability. + +## Which database should I use? + +We recommend the use of SQLite as database for headscale: + +- SQLite is simple to setup and easy to use +- It scales well for all of headscale's use cases +- Development and testing happens primarily on SQLite +- PostgreSQL is still supported, but is considered to be in "maintenance mode" + +The headscale project itself does not provide a tool to migrate from PostgreSQL to SQLite. Please have a look at [the +related tools documentation](../ref/integration/tools.md) for migration tooling provided by the community. + +The choice of database has little to no impact on the performance of the server, +see [Scaling / How many clients does Headscale support?](#scaling-how-many-clients-does-headscale-support) for understanding how Headscale spends its resources. + +## Why is my reverse proxy not working with headscale? + +We don't know. We don't use reverse proxies with headscale ourselves, so we don't have any experience with them. We have +[community documentation](../ref/integration/reverse-proxy.md) on how to configure various reverse proxies, and a +dedicated "reverse-proxy-issues" channel on our [Discord server](https://discord.gg/c84AZQhmpx) where you can ask for +help to the community. + +## Can I use headscale and tailscale on the same machine? + +Running headscale on a machine that is also in the tailnet can cause problems with subnet routers, traffic relay nodes, and MagicDNS. It might work, but it is not supported. + +## Why do two nodes see each other in their status, even if an ACL allows traffic only in one direction? + +A frequent use case is to allow traffic only from one node to another, but not the other way around. For example, the +workstation of an administrator should be able to connect to all nodes but the nodes themselves shouldn't be able to +connect back to the administrator's node. Why do all nodes see the administrator's workstation in the output of +`tailscale status`? + +This is essentially how Tailscale works. If traffic is allowed to flow in one direction, then both nodes see each other +in their output of `tailscale status`. Traffic is still filtered according to the ACL, with the exception of `tailscale +ping` which is always allowed in either direction. + +See also <https://tailscale.com/kb/1087/device-visibility>. + +## My policy is stored in the database and Headscale refuses to start due to an invalid policy. How can I recover? + +Headscale checks if the policy is valid during startup and refuses to start if it detects an error. The error message +indicates which part of the policy is invalid. Follow these steps to fix your policy: + +- Dump the policy to a file: `headscale policy get --bypass-grpc-and-access-database-directly > policy.json` +- Edit and fixup `policy.json`. Use the command `headscale policy check --file policy.json` to validate the policy. +- Load the modified policy: `headscale policy set --bypass-grpc-and-access-database-directly --file policy.json` +- Start Headscale as usual. + +!!! warning "Full server configuration required" + + The above commands to get/set the policy require a complete server configuration file including database settings. A + minimal config to [control Headscale via remote CLI](../ref/api.md#grpc) is not sufficient. You may use `headscale + -c /path/to/config.yaml` to specify the path to an alternative configuration file. + +## How can I avoid to send logs to Tailscale Inc? + +A Tailscale client [collects logs about its operation and connection attempts with other +clients](https://tailscale.com/kb/1011/log-mesh-traffic#client-logs) and sends them to a central log service operated by +Tailscale Inc. + +Headscale, by default, instructs clients to disable log submission to the central log service. This configuration is +applied by a client once it successfully connected with Headscale. See the configuration option `logtail.enabled` in the +[configuration file](../ref/configuration.md) for details. + +Alternatively, logging can also be disabled on the client side. This is independent of Headscale and opting out of +client logging disables log submission early during client startup. The configuration is operating system specific and +is usually achieved by setting the environment variable `TS_NO_LOGS_NO_SUPPORT=true` or by passing the flag +`--no-logs-no-support` to `tailscaled`. See +<https://tailscale.com/kb/1011/log-mesh-traffic#opting-out-of-client-logging> for details. diff --git a/docs/about/features.md b/docs/about/features.md new file mode 100644 index 00000000..83197d64 --- /dev/null +++ b/docs/about/features.md @@ -0,0 +1,38 @@ +# Features + +Headscale aims to implement a self-hosted, open source alternative to the Tailscale control server. Headscale's goal is +to provide self-hosters and hobbyists with an open-source server they can use for their projects and labs. This page +provides on overview of Headscale's feature and compatibility with the Tailscale control server: + +- [x] Full "base" support of Tailscale's features +- [x] Node registration + - [x] Interactive + - [x] Pre authenticated key +- [x] [DNS](../ref/dns.md) + - [x] [MagicDNS](https://tailscale.com/kb/1081/magicdns) + - [x] [Global and restricted nameservers (split DNS)](https://tailscale.com/kb/1054/dns#nameservers) + - [x] [search domains](https://tailscale.com/kb/1054/dns#search-domains) + - [x] [Extra DNS records (Headscale only)](../ref/dns.md#setting-extra-dns-records) +- [x] [Taildrop (File Sharing)](https://tailscale.com/kb/1106/taildrop) +- [x] [Tags](https://tailscale.com/kb/1068/tags) +- [x] [Routes](../ref/routes.md) + - [x] [Subnet routers](../ref/routes.md#subnet-router) + - [x] [Exit nodes](../ref/routes.md#exit-node) +- [x] Dual stack (IPv4 and IPv6) +- [x] Ephemeral nodes +- [x] Embedded [DERP server](../ref/derp.md) +- [x] Access control lists ([GitHub label "policy"](https://github.com/juanfont/headscale/labels/policy%20%F0%9F%93%9D)) + - [x] ACL management via API + - [x] Some [Autogroups](https://tailscale.com/kb/1396/targets#autogroups), currently: `autogroup:internet`, + `autogroup:nonroot`, `autogroup:member`, `autogroup:tagged`, `autogroup:self` + - [x] [Auto approvers](https://tailscale.com/kb/1337/acl-syntax#auto-approvers) for [subnet + routers](../ref/routes.md#automatically-approve-routes-of-a-subnet-router) and [exit + nodes](../ref/routes.md#automatically-approve-an-exit-node-with-auto-approvers) + - [x] [Tailscale SSH](https://tailscale.com/kb/1193/tailscale-ssh) +* [x] [Node registration using Single-Sign-On (OpenID Connect)](../ref/oidc.md) ([GitHub label "OIDC"](https://github.com/juanfont/headscale/labels/OIDC)) + - [x] Basic registration + - [x] Update user profile from identity provider + - [ ] OIDC groups cannot be used in ACLs +- [ ] [Funnel](https://tailscale.com/kb/1223/funnel) ([#1040](https://github.com/juanfont/headscale/issues/1040)) +- [ ] [Serve](https://tailscale.com/kb/1312/serve) ([#1234](https://github.com/juanfont/headscale/issues/1921)) +- [ ] [Network flow logs](https://tailscale.com/kb/1219/network-flow-logs) ([#1687](https://github.com/juanfont/headscale/issues/1687)) diff --git a/docs/about/help.md b/docs/about/help.md new file mode 100644 index 00000000..ec4adf6b --- /dev/null +++ b/docs/about/help.md @@ -0,0 +1,5 @@ +# Getting help + +Join our [Discord server](https://discord.gg/c84AZQhmpx) for announcements and community support. + +Please report bugs via [GitHub issues](https://github.com/juanfont/headscale/issues) diff --git a/docs/about/releases.md b/docs/about/releases.md new file mode 100644 index 00000000..a2d8f17a --- /dev/null +++ b/docs/about/releases.md @@ -0,0 +1,10 @@ +# Releases + +All headscale releases are available on the [GitHub release page](https://github.com/juanfont/headscale/releases). Those +releases are available as binaries for various platforms and architectures, packages for Debian based systems and source +code archives. Container images are available on [Docker Hub](https://hub.docker.com/r/headscale/headscale) and +[GitHub Container Registry](https://github.com/juanfont/headscale/pkgs/container/headscale). + +An Atom/RSS feed of headscale releases is available [here](https://github.com/juanfont/headscale/releases.atom). + +See the "announcements" channel on our [Discord server](https://discord.gg/c84AZQhmpx) for news about headscale. diff --git a/docs/about/sponsor.md b/docs/about/sponsor.md new file mode 100644 index 00000000..3fdb8e4b --- /dev/null +++ b/docs/about/sponsor.md @@ -0,0 +1,4 @@ +# Sponsor + +If you like to support the development of headscale, please consider a donation via +[ko-fi.com/headscale](https://ko-fi.com/headscale). Thank you! diff --git a/docs/android-client.md b/docs/android-client.md deleted file mode 100644 index d4f8129c..00000000 --- a/docs/android-client.md +++ /dev/null @@ -1,19 +0,0 @@ -# Connecting an Android client - -## Goal - -This documentation has the goal of showing how a user can use the official Android [Tailscale](https://tailscale.com) client with `headscale`. - -## Installation - -Install the official Tailscale Android client from the [Google Play Store](https://play.google.com/store/apps/details?id=com.tailscale.ipn) or [F-Droid](https://f-droid.org/packages/com.tailscale.ipn/). - -Ensure that the installed version is at least 1.30.0, as that is the first release to support custom URLs. - -## Configuring the headscale URL - -After opening the app, the kebab menu icon (three dots) on the top bar on the right must be repeatedly opened and closed until the _Change server_ option appears in the menu. This is where you can enter your headscale URL. - -A screen recording of this process can be seen in the `tailscale-android` PR which implemented this functionality: <https://github.com/tailscale/tailscale-android/pull/55> - -After saving and restarting the app, selecting the regular _Sign in_ option (non-SSO) should open up the headscale authentication page. diff --git a/docs/assets/favicon.png b/docs/assets/favicon.png new file mode 100644 index 00000000..4989810f Binary files /dev/null and b/docs/assets/favicon.png differ diff --git a/docs/images/headscale-acl-network.png b/docs/assets/images/headscale-acl-network.png similarity index 100% rename from docs/images/headscale-acl-network.png rename to docs/assets/images/headscale-acl-network.png diff --git a/docs/logo/headscale3-dots.pdf b/docs/assets/logo/headscale3-dots.pdf similarity index 100% rename from docs/logo/headscale3-dots.pdf rename to docs/assets/logo/headscale3-dots.pdf diff --git a/docs/logo/headscale3-dots.png b/docs/assets/logo/headscale3-dots.png similarity index 100% rename from docs/logo/headscale3-dots.png rename to docs/assets/logo/headscale3-dots.png diff --git a/docs/logo/headscale3-dots.svg b/docs/assets/logo/headscale3-dots.svg similarity index 97% rename from docs/logo/headscale3-dots.svg rename to docs/assets/logo/headscale3-dots.svg index 6a20973c..f7120395 100644 --- a/docs/logo/headscale3-dots.svg +++ b/docs/assets/logo/headscale3-dots.svg @@ -1 +1 @@ -<svg xmlns="http://www.w3.org/2000/svg" xml:space="preserve" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2" viewBox="0 0 1280 640"><circle cx="141.023" cy="338.36" r="117.472" style="fill:#f8b5cb" transform="matrix(.997276 0 0 1.00556 10.0024 -14.823)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 -3.15847 0)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 -3.15847 115.914)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 148.43 115.914)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 148.851 0)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.99196 0 0 1 3.36978 -10.2458)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.99196 0 0 1 255.633 -10.2458)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="matrix(-1 0 0 1 1857.19 0)"/></svg> \ No newline at end of file +<svg xmlns="http://www.w3.org/2000/svg" xml:space="preserve" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2" viewBox="0 0 1280 640"><circle cx="141.023" cy="338.36" r="117.472" style="fill:#f8b5cb" transform="matrix(.997276 0 0 1.00556 10.0024 -14.823)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 -3.15847 0)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 -3.15847 115.914)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 148.43 115.914)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(1.01749 0 0 1 148.851 0)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.99196 0 0 1 3.36978 -10.2458)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.99196 0 0 1 255.633 -10.2458)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="matrix(-1 0 0 1 1857.19 0)"/></svg> diff --git a/docs/logo/headscale3_header_stacked_left.pdf b/docs/assets/logo/headscale3_header_stacked_left.pdf similarity index 100% rename from docs/logo/headscale3_header_stacked_left.pdf rename to docs/assets/logo/headscale3_header_stacked_left.pdf diff --git a/docs/logo/headscale3_header_stacked_left.png b/docs/assets/logo/headscale3_header_stacked_left.png similarity index 100% rename from docs/logo/headscale3_header_stacked_left.png rename to docs/assets/logo/headscale3_header_stacked_left.png diff --git a/docs/logo/headscale3_header_stacked_left.svg b/docs/assets/logo/headscale3_header_stacked_left.svg similarity index 99% rename from docs/logo/headscale3_header_stacked_left.svg rename to docs/assets/logo/headscale3_header_stacked_left.svg index d00af00e..0c3702c6 100644 --- a/docs/logo/headscale3_header_stacked_left.svg +++ b/docs/assets/logo/headscale3_header_stacked_left.svg @@ -1 +1 @@ -<svg xmlns="http://www.w3.org/2000/svg" xml:space="preserve" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2" viewBox="0 0 1280 640"><path d="M.08 0v-.736h.068v.3C.203-.509.27-.545.347-.545c.029 0 .055.005.079.015.024.01.045.025.062.045.017.02.031.045.041.075.009.03.014.065.014.105V0H.475v-.289C.475-.352.464-.4.443-.433.422-.466.385-.483.334-.483c-.027 0-.052.006-.075.017C.236-.455.216-.439.2-.419c-.017.02-.029.044-.038.072-.009.028-.014.059-.014.093V0H.08Z" style="fill:#f8b5cb;fill-rule:nonzero" transform="translate(32.92220721 521.8022953) scale(235.3092)"/><path d="M.051-.264c0-.036.007-.071.02-.105.013-.034.031-.064.055-.09.023-.026.052-.047.086-.063.033-.015.071-.023.112-.023.039 0 .076.007.109.021.033.014.062.033.087.058.025.025.044.054.058.088.014.035.021.072.021.113v.005H.121c.001.031.007.059.018.084.01.025.024.047.042.065.018.019.04.033.065.043.025.01.052.015.082.015.026 0 .049-.003.069-.01.02-.007.038-.016.054-.028C.466-.102.48-.115.492-.13c.011-.015.022-.03.032-.046l.057.03C.556-.097.522-.058.48-.03.437-.001.387.013.328.013.284.013.245.006.21-.01.175-.024.146-.045.123-.07.1-.095.082-.125.07-.159.057-.192.051-.227.051-.264ZM.128-.32h.396C.51-.375.485-.416.449-.441.412-.466.371-.479.325-.479c-.048 0-.089.013-.123.039-.034.026-.059.066-.074.12Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(177.16674681 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.097h.067V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.012-.027.018-.056.018-.089 0-.031-.005-.059-.016-.086C.515-.375.501-.398.482-.417.462-.436.44-.452.415-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(327.76463481 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.302h.068V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.011-.027.017-.056.017-.089 0-.031-.005-.059-.016-.086C.514-.375.5-.398.481-.417.462-.436.439-.452.414-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(488.71612761 521.8022953) scale(235.3092)"/><path d="m.034-.062.043-.049c.017.019.035.034.054.044.018.01.037.015.057.015.013 0 .026-.002.038-.007.011-.004.021-.01.031-.018.009-.008.016-.017.021-.028.005-.011.008-.022.008-.035 0-.019-.005-.034-.014-.047C.263-.199.248-.21.229-.221.205-.234.183-.247.162-.259.14-.271.122-.284.107-.298.092-.311.08-.327.071-.344.062-.361.058-.381.058-.404c0-.021.004-.04.012-.058.007-.016.018-.031.031-.044.013-.013.028-.022.046-.029.018-.007.037-.01.057-.01.029 0 .056.006.079.019s.045.031.068.053l-.044.045C.291-.443.275-.456.258-.465.241-.474.221-.479.2-.479c-.022 0-.041.007-.056.02C.128-.445.12-.428.12-.408c0 .019.006.035.017.048.011.013.027.026.048.037.027.015.05.028.071.04.021.013.038.026.052.039.014.013.025.028.032.044.007.016.011.035.011.057 0 .021-.004.041-.011.059-.008.019-.019.036-.033.05-.014.015-.031.026-.05.035C.237.01.215.014.191.014c-.03 0-.059-.006-.086-.02C.077-.019.053-.037.034-.062Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(649.90292961 521.8022953) scale(235.3092)"/><path d="M.051-.266c0-.04.007-.077.022-.111.014-.034.034-.063.059-.089.025-.025.054-.044.089-.058.035-.014.072-.021.113-.021.051 0 .098.01.139.03.041.021.075.049.1.085l-.05.043C.498-.418.47-.441.439-.456.408-.471.372-.479.331-.479c-.03 0-.058.005-.083.016C.222-.452.2-.436.181-.418.162-.399.148-.376.137-.35c-.011.026-.016.054-.016.084 0 .031.005.06.016.086.011.027.025.049.044.068.019.019.041.034.067.044.025.011.053.016.084.016.077 0 .141-.03.191-.09l.051.04c-.028.036-.062.064-.103.085C.43.004.384.014.332.014.291.014.254.007.219-.008.184-.022.155-.042.13-.067.105-.092.086-.121.072-.156.058-.19.051-.227.051-.266Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(741.20289921 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.097h.067V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.012-.027.018-.056.018-.089 0-.031-.005-.059-.016-.086C.515-.375.501-.398.482-.417.462-.436.44-.452.415-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(884.27089281 521.8022953) scale(235.3092)"/><path d="M.066-.736h.068V0H.066z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(1045.22238561 521.8022953) scale(235.3092)"/><path d="M.051-.264c0-.036.007-.071.02-.105.013-.034.031-.064.055-.09.023-.026.052-.047.086-.063.033-.015.071-.023.112-.023.039 0 .076.007.109.021.033.014.062.033.087.058.025.025.044.054.058.088.014.035.021.072.021.113v.005H.121c.001.031.007.059.018.084.01.025.024.047.042.065.018.019.04.033.065.043.025.01.052.015.082.015.026 0 .049-.003.069-.01.02-.007.038-.016.054-.028C.466-.102.48-.115.492-.13c.011-.015.022-.03.032-.046l.057.03C.556-.097.522-.058.48-.03.437-.001.387.013.328.013.284.013.245.006.21-.01.175-.024.146-.045.123-.07.1-.095.082-.125.07-.159.057-.192.051-.227.051-.264ZM.128-.32h.396C.51-.375.485-.416.449-.441.412-.466.371-.479.325-.479c-.048 0-.089.013-.123.039-.034.026-.059.066-.074.12Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(1092.28422561 521.8022953) scale(235.3092)"/><circle cx="141.023" cy="338.36" r="117.472" style="fill:#f8b5cb" transform="matrix(.581302 0 0 .58613 40.06479894 12.59842153)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 32.39345942 21.2386)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 32.39345942 88.80371146)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 120.7528627 88.80371146)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 120.99825939 21.2386)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.5782 0 0 .58289 36.19871106 15.26642564)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.5782 0 0 .58289 183.24041937 15.26642564)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="translate(34.2345 21.2386) scale(.58289)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="matrix(-.58289 0 0 .58289 1116.7719791 21.2386)"/></svg> \ No newline at end of file +<svg xmlns="http://www.w3.org/2000/svg" xml:space="preserve" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2" viewBox="0 0 1280 640"><path d="M.08 0v-.736h.068v.3C.203-.509.27-.545.347-.545c.029 0 .055.005.079.015.024.01.045.025.062.045.017.02.031.045.041.075.009.03.014.065.014.105V0H.475v-.289C.475-.352.464-.4.443-.433.422-.466.385-.483.334-.483c-.027 0-.052.006-.075.017C.236-.455.216-.439.2-.419c-.017.02-.029.044-.038.072-.009.028-.014.059-.014.093V0H.08Z" style="fill:#f8b5cb;fill-rule:nonzero" transform="translate(32.92220721 521.8022953) scale(235.3092)"/><path d="M.051-.264c0-.036.007-.071.02-.105.013-.034.031-.064.055-.09.023-.026.052-.047.086-.063.033-.015.071-.023.112-.023.039 0 .076.007.109.021.033.014.062.033.087.058.025.025.044.054.058.088.014.035.021.072.021.113v.005H.121c.001.031.007.059.018.084.01.025.024.047.042.065.018.019.04.033.065.043.025.01.052.015.082.015.026 0 .049-.003.069-.01.02-.007.038-.016.054-.028C.466-.102.48-.115.492-.13c.011-.015.022-.03.032-.046l.057.03C.556-.097.522-.058.48-.03.437-.001.387.013.328.013.284.013.245.006.21-.01.175-.024.146-.045.123-.07.1-.095.082-.125.07-.159.057-.192.051-.227.051-.264ZM.128-.32h.396C.51-.375.485-.416.449-.441.412-.466.371-.479.325-.479c-.048 0-.089.013-.123.039-.034.026-.059.066-.074.12Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(177.16674681 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.097h.067V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.012-.027.018-.056.018-.089 0-.031-.005-.059-.016-.086C.515-.375.501-.398.482-.417.462-.436.44-.452.415-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(327.76463481 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.302h.068V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.011-.027.017-.056.017-.089 0-.031-.005-.059-.016-.086C.514-.375.5-.398.481-.417.462-.436.439-.452.414-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(488.71612761 521.8022953) scale(235.3092)"/><path d="m.034-.062.043-.049c.017.019.035.034.054.044.018.01.037.015.057.015.013 0 .026-.002.038-.007.011-.004.021-.01.031-.018.009-.008.016-.017.021-.028.005-.011.008-.022.008-.035 0-.019-.005-.034-.014-.047C.263-.199.248-.21.229-.221.205-.234.183-.247.162-.259.14-.271.122-.284.107-.298.092-.311.08-.327.071-.344.062-.361.058-.381.058-.404c0-.021.004-.04.012-.058.007-.016.018-.031.031-.044.013-.013.028-.022.046-.029.018-.007.037-.01.057-.01.029 0 .056.006.079.019s.045.031.068.053l-.044.045C.291-.443.275-.456.258-.465.241-.474.221-.479.2-.479c-.022 0-.041.007-.056.02C.128-.445.12-.428.12-.408c0 .019.006.035.017.048.011.013.027.026.048.037.027.015.05.028.071.04.021.013.038.026.052.039.014.013.025.028.032.044.007.016.011.035.011.057 0 .021-.004.041-.011.059-.008.019-.019.036-.033.05-.014.015-.031.026-.05.035C.237.01.215.014.191.014c-.03 0-.059-.006-.086-.02C.077-.019.053-.037.034-.062Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(649.90292961 521.8022953) scale(235.3092)"/><path d="M.051-.266c0-.04.007-.077.022-.111.014-.034.034-.063.059-.089.025-.025.054-.044.089-.058.035-.014.072-.021.113-.021.051 0 .098.01.139.03.041.021.075.049.1.085l-.05.043C.498-.418.47-.441.439-.456.408-.471.372-.479.331-.479c-.03 0-.058.005-.083.016C.222-.452.2-.436.181-.418.162-.399.148-.376.137-.35c-.011.026-.016.054-.016.084 0 .031.005.06.016.086.011.027.025.049.044.068.019.019.041.034.067.044.025.011.053.016.084.016.077 0 .141-.03.191-.09l.051.04c-.028.036-.062.064-.103.085C.43.004.384.014.332.014.291.014.254.007.219-.008.184-.022.155-.042.13-.067.105-.092.086-.121.072-.156.058-.19.051-.227.051-.266Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(741.20289921 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.097h.067V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.012-.027.018-.056.018-.089 0-.031-.005-.059-.016-.086C.515-.375.501-.398.482-.417.462-.436.44-.452.415-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(884.27089281 521.8022953) scale(235.3092)"/><path d="M.066-.736h.068V0H.066z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(1045.22238561 521.8022953) scale(235.3092)"/><path d="M.051-.264c0-.036.007-.071.02-.105.013-.034.031-.064.055-.09.023-.026.052-.047.086-.063.033-.015.071-.023.112-.023.039 0 .076.007.109.021.033.014.062.033.087.058.025.025.044.054.058.088.014.035.021.072.021.113v.005H.121c.001.031.007.059.018.084.01.025.024.047.042.065.018.019.04.033.065.043.025.01.052.015.082.015.026 0 .049-.003.069-.01.02-.007.038-.016.054-.028C.466-.102.48-.115.492-.13c.011-.015.022-.03.032-.046l.057.03C.556-.097.522-.058.48-.03.437-.001.387.013.328.013.284.013.245.006.21-.01.175-.024.146-.045.123-.07.1-.095.082-.125.07-.159.057-.192.051-.227.051-.264ZM.128-.32h.396C.51-.375.485-.416.449-.441.412-.466.371-.479.325-.479c-.048 0-.089.013-.123.039-.034.026-.059.066-.074.12Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(1092.28422561 521.8022953) scale(235.3092)"/><circle cx="141.023" cy="338.36" r="117.472" style="fill:#f8b5cb" transform="matrix(.581302 0 0 .58613 40.06479894 12.59842153)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 32.39345942 21.2386)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 32.39345942 88.80371146)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 120.7528627 88.80371146)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 120.99825939 21.2386)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.5782 0 0 .58289 36.19871106 15.26642564)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.5782 0 0 .58289 183.24041937 15.26642564)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="translate(34.2345 21.2386) scale(.58289)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="matrix(-.58289 0 0 .58289 1116.7719791 21.2386)"/></svg> diff --git a/docs/dns-records.md b/docs/dns-records.md deleted file mode 100644 index c5a07fe9..00000000 --- a/docs/dns-records.md +++ /dev/null @@ -1,90 +0,0 @@ -# Setting custom DNS records - -!!! warning "Community documentation" - - This page is not actively maintained by the headscale authors and is - written by community members. It is _not_ verified by `headscale` developers. - - **It might be outdated and it might miss necessary steps**. - -## Goal - -This documentation has the goal of showing how a user can set custom DNS records with `headscale`s magic dns. -An example use case is to serve apps on the same host via a reverse proxy like NGINX, in this case a Prometheus monitoring stack. This allows to nicely access the service with "http://grafana.myvpn.example.com" instead of the hostname and portnum combination "http://hostname-in-magic-dns.myvpn.example.com:3000". - -## Setup - -### 1. Change the configuration - -1. Change the `config.yaml` to contain the desired records like so: - -```yaml -dns_config: - ... - extra_records: - - name: "prometheus.myvpn.example.com" - type: "A" - value: "100.64.0.3" - - - name: "grafana.myvpn.example.com" - type: "A" - value: "100.64.0.3" - ... -``` - -2. Restart your headscale instance. - -Beware of the limitations listed later on! - -### 2. Verify that the records are set - -You can use a DNS querying tool of your choice on one of your hosts to verify that your newly set records are actually available in MagicDNS, here we used [`dig`](https://man.archlinux.org/man/dig.1.en): - -``` -$ dig grafana.myvpn.example.com - -; <<>> DiG 9.18.10 <<>> grafana.myvpn.example.com -;; global options: +cmd -;; Got answer: -;; ->>HEADER<<- opcode: QUERY, status: NOERROR, id: 44054 -;; flags: qr rd ra; QUERY: 1, ANSWER: 1, AUTHORITY: 0, ADDITIONAL: 1 - -;; OPT PSEUDOSECTION: -; EDNS: version: 0, flags:; udp: 65494 -;; QUESTION SECTION: -;grafana.myvpn.example.com. IN A - -;; ANSWER SECTION: -grafana.myvpn.example.com. 593 IN A 100.64.0.3 - -;; Query time: 0 msec -;; SERVER: 127.0.0.53#53(127.0.0.53) (UDP) -;; WHEN: Sat Dec 31 11:46:55 CET 2022 -;; MSG SIZE rcvd: 66 -``` - -### 3. Optional: Setup the reverse proxy - -The motivating example here was to be able to access internal monitoring services on the same host without specifying a port: - -``` -server { - listen 80; - listen [::]:80; - - server_name grafana.myvpn.example.com; - - location / { - proxy_pass http://localhost:3000; - proxy_set_header Host $http_host; - proxy_set_header X-Real-IP $remote_addr; - proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $scheme; - } - -} -``` - -## Limitations - -[Not all types of records are supported](https://github.com/tailscale/tailscale/blob/6edf357b96b28ee1be659a70232c0135b2ffedfd/ipn/ipnlocal/local.go#L2989-L3007), especially no CNAME records. diff --git a/docs/exit-node.md b/docs/exit-node.md deleted file mode 100644 index 898b7811..00000000 --- a/docs/exit-node.md +++ /dev/null @@ -1,49 +0,0 @@ -# Exit Nodes - -## On the node - -Register the node and make it advertise itself as an exit node: - -```console -$ sudo tailscale up --login-server https://my-server.com --advertise-exit-node -``` - -If the node is already registered, it can advertise exit capabilities like this: - -```console -$ sudo tailscale set --advertise-exit-node -``` - -To use a node as an exit node, IP forwarding must be enabled on the node. Check the official [Tailscale documentation](https://tailscale.com/kb/1019/subnets/?tab=linux#enable-ip-forwarding) for how to enable IP fowarding. - -## On the control server - -```console -$ # list nodes -$ headscale routes list -ID | Machine | Prefix | Advertised | Enabled | Primary -1 | | 0.0.0.0/0 | false | false | - -2 | | ::/0 | false | false | - -3 | phobos | 0.0.0.0/0 | true | false | - -4 | phobos | ::/0 | true | false | - -$ # enable routes for phobos -$ headscale routes enable -r 3 -$ headscale routes enable -r 4 -$ # Check node list again. The routes are now enabled. -$ headscale routes list -ID | Machine | Prefix | Advertised | Enabled | Primary -1 | | 0.0.0.0/0 | false | false | - -2 | | ::/0 | false | false | - -3 | phobos | 0.0.0.0/0 | true | true | - -4 | phobos | ::/0 | true | true | - -``` - -## On the client - -The exit node can now be used with: - -```console -$ sudo tailscale set --exit-node phobos -``` - -Check the official [Tailscale documentation](https://tailscale.com/kb/1103/exit-nodes/?q=exit#step-3-use-the-exit-node) for how to do it on your device. diff --git a/docs/faq.md b/docs/faq.md deleted file mode 100644 index 6331c54a..00000000 --- a/docs/faq.md +++ /dev/null @@ -1,53 +0,0 @@ ---- -hide: - - navigation ---- - -# Frequently Asked Questions - -## What is the design goal of headscale? - -`headscale` aims to implement a self-hosted, open source alternative to the [Tailscale](https://tailscale.com/) -control server. -`headscale`'s goal is to provide self-hosters and hobbyists with an open-source -server they can use for their projects and labs. -It implements a narrow scope, a _single_ Tailnet, suitable for a personal use, or a small -open-source organisation. - -## How can I contribute? - -Headscale is "Open Source, acknowledged contribution", this means that any -contribution will have to be discussed with the Maintainers before being submitted. - -Headscale is open to code contributions for bug fixes without discussion. - -If you find mistakes in the documentation, please also submit a fix to the documentation. - -## Why is 'acknowledged contribution' the chosen model? - -Both maintainers have full-time jobs and families, and we want to avoid burnout. We also want to avoid frustration from contributors when their PRs are not accepted. - -We are more than happy to exchange emails, or to have dedicated calls before a PR is submitted. - -## When/Why is Feature X going to be implemented? - -We don't know. We might be working on it. If you want to help, please send us a PR. - -Please be aware that there are a number of reasons why we might not accept specific contributions: - -- It is not possible to implement the feature in a way that makes sense in a self-hosted environment. -- Given that we are reverse-engineering Tailscale to satify our own curiosity, we might be interested in implementing the feature ourselves. -- You are not sending unit and integration tests with it. - -## Do you support Y method of deploying Headscale? - -We currently support deploying `headscale` using our binaries and the DEB packages. Both can be found in the -[GitHub releases page](https://github.com/juanfont/headscale/releases). - -In addition to that, there are semi-official RPM packages by the Fedora infra team https://copr.fedorainfracloud.org/coprs/jonathanspw/headscale/ - -For convenience, we also build Docker images with `headscale`. But **please be aware that we don't officially support deploying `headscale` using Docker**. We have a [Discord channel](https://discord.com/channels/896711691637780480/1070619770942148618) where you can ask for Docker-specific help to the community. - -## Why is my reverse proxy not working with Headscale? - -We don't know. We don't use reverse proxies with `headscale` ourselves, so we don't have any experience with them. We have [community documentation](https://headscale.net/reverse-proxy/) on how to configure various reverse proxies, and a dedicated [Discord channel](https://discord.com/channels/896711691637780480/1070619818346164324) where you can ask for help to the community. diff --git a/docs/glossary.md b/docs/glossary.md deleted file mode 100644 index f42941a6..00000000 --- a/docs/glossary.md +++ /dev/null @@ -1,6 +0,0 @@ -# Glossary - -| Term | Description | -| --------- | ------------------------------------------------------------------------------------------------------------------------------------------- | -| Machine | A machine is a single entity connected to `headscale`, typically an installation of Tailscale. Also known as **Node** | -| Namespace | A namespace was a logical grouping of machines "owned" by the same entity, in Tailscale, this is typically a User (This is now called user) | diff --git a/docs/iOS-client.md b/docs/iOS-client.md deleted file mode 100644 index 761dfcf0..00000000 --- a/docs/iOS-client.md +++ /dev/null @@ -1,30 +0,0 @@ -# Connecting an iOS client - -## Goal - -This documentation has the goal of showing how a user can use the official iOS [Tailscale](https://tailscale.com) client with `headscale`. - -## Installation - -Install the official Tailscale iOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037). - -Ensure that the installed version is at least 1.38.1, as that is the first release to support alternate control servers. - -## Configuring the headscale URL - -!!! info "Apple devices" - - An endpoint with information on how to connect your Apple devices - (currently macOS only) is available at `/apple` on your running instance. - -Ensure that the tailscale app is logged out before proceeding. - -Go to iOS settings, scroll down past game center and tv provider to the tailscale app and select it. The headscale URL can be entered into the _"ALTERNATE COORDINATION SERVER URL"_ box. - -> **Note** -> -> If the app was previously logged into tailscale, toggle on the _Reset Keychain_ switch. - -Restart the app by closing it from the iOS app switcher, open the app and select the regular _Sign in_ option (non-SSO), and it should open up to the headscale authentication page. - -Enter your credentials and log in. Headscale should now be working on your iOS device. diff --git a/docs/images/windows-registry.png b/docs/images/windows-registry.png deleted file mode 100644 index 1324ca6c..00000000 Binary files a/docs/images/windows-registry.png and /dev/null differ diff --git a/docs/index.md b/docs/index.md index d13339d8..890855b9 100644 --- a/docs/index.md +++ b/docs/index.md @@ -4,40 +4,34 @@ hide: - toc --- -# headscale +# Welcome to headscale -`headscale` is an open source, self-hosted implementation of the Tailscale control server. +Headscale is an open source, self-hosted implementation of the Tailscale control server. -This page contains the documentation for the latest version of headscale. Please also check our [FAQ](/faq/). +This page contains the documentation for the latest version of headscale. Please also check our [FAQ](./about/faq.md). -Join our [Discord](https://discord.gg/c84AZQhmpx) server for a chat and community support. +Join our [Discord server](https://discord.gg/c84AZQhmpx) for a chat and community support. ## Design goal -Headscale aims to implement a self-hosted, open source alternative to the Tailscale -control server. -Headscale's goal is to provide self-hosters and hobbyists with an open-source -server they can use for their projects and labs. -It implements a narrower scope, a single Tailnet, suitable for a personal use, or a small -open-source organisation. +Headscale aims to implement a self-hosted, open source alternative to the +[Tailscale](https://tailscale.com/) control server. Headscale's goal is to +provide self-hosters and hobbyists with an open-source server they can use for +their projects and labs. It implements a narrow scope, a _single_ Tailscale +network (tailnet), suitable for a personal use, or a small open-source +organisation. ## Supporting headscale -If you like `headscale` and find it useful, there is a sponsorship and donation -buttons available in the repo. +Please see [Sponsor](about/sponsor.md) for more information. ## Contributing Headscale is "Open Source, acknowledged contribution", this means that any contribution will have to be discussed with the Maintainers before being submitted. -This model has been chosen to reduce the risk of burnout by limiting the -maintenance overhead of reviewing and validating third-party code. - -Headscale is open to code contributions for bug fixes without discussion. - -If you find mistakes in the documentation, please submit a fix to the documentation. +Please see [Contributing](about/contributing.md) for more information. ## About -`headscale` is maintained by [Kristoffer Dalby](https://kradalby.no/) and [Juan Font](https://font.eu). +Headscale is maintained by [Kristoffer Dalby](https://kradalby.no/) and [Juan Font](https://font.eu). diff --git a/docs/oidc.md b/docs/oidc.md deleted file mode 100644 index 189d7cd7..00000000 --- a/docs/oidc.md +++ /dev/null @@ -1,172 +0,0 @@ -# Configuring Headscale to use OIDC authentication - -In order to authenticate users through a centralized solution one must enable the OIDC integration. - -Known limitations: - -- No dynamic ACL support -- OIDC groups cannot be used in ACLs - -## Basic configuration - -In your `config.yaml`, customize this to your liking: - -```yaml -oidc: - # Block further startup until the OIDC provider is healthy and available - only_start_if_oidc_is_available: true - # Specified by your OIDC provider - issuer: "https://your-oidc.issuer.com/path" - # Specified/generated by your OIDC provider - client_id: "your-oidc-client-id" - client_secret: "your-oidc-client-secret" - # alternatively, set `client_secret_path` to read the secret from the file. - # It resolves environment variables, making integration to systemd's - # `LoadCredential` straightforward: - #client_secret_path: "${CREDENTIALS_DIRECTORY}/oidc_client_secret" - - # Customize the scopes used in the OIDC flow, defaults to "openid", "profile" and "email" and add custom query - # parameters to the Authorize Endpoint request. Scopes default to "openid", "profile" and "email". - scope: ["openid", "profile", "email", "custom"] - # Optional: Passed on to the browser login request – used to tweak behaviour for the OIDC provider - extra_params: - domain_hint: example.com - - # Optional: List allowed principal domains and/or users. If an authenticated user's domain is not in this list, - # the authentication request will be rejected. - allowed_domains: - - example.com - # Optional. Note that groups from Keycloak have a leading '/'. - allowed_groups: - - /headscale - # Optional. - allowed_users: - - alice@example.com - - # If `strip_email_domain` is set to `true`, the domain part of the username email address will be removed. - # This will transform `first-name.last-name@example.com` to the user `first-name.last-name` - # If `strip_email_domain` is set to `false` the domain part will NOT be removed resulting to the following - # user: `first-name.last-name.example.com` - strip_email_domain: true -``` - -## Azure AD example - -In order to integrate Headscale with Azure Active Directory, we'll need to provision an App Registration with the correct scopes and redirect URI. Here with Terraform: - -```hcl -resource "azuread_application" "headscale" { - display_name = "Headscale" - - sign_in_audience = "AzureADMyOrg" - fallback_public_client_enabled = false - - required_resource_access { - // Microsoft Graph - resource_app_id = "00000003-0000-0000-c000-000000000000" - - resource_access { - // scope: profile - id = "14dad69e-099b-42c9-810b-d002981feec1" - type = "Scope" - } - resource_access { - // scope: openid - id = "37f7f235-527c-4136-accd-4a02d197296e" - type = "Scope" - } - resource_access { - // scope: email - id = "64a6cdd6-aab1-4aaf-94b8-3cc8405e90d0" - type = "Scope" - } - } - web { - # Points at your running Headscale instance - redirect_uris = ["https://headscale.example.com/oidc/callback"] - - implicit_grant { - access_token_issuance_enabled = false - id_token_issuance_enabled = true - } - } - - group_membership_claims = ["SecurityGroup"] - optional_claims { - # Expose group memberships - id_token { - name = "groups" - } - } -} - -resource "azuread_application_password" "headscale-application-secret" { - display_name = "Headscale Server" - application_object_id = azuread_application.headscale.object_id -} - -resource "azuread_service_principal" "headscale" { - application_id = azuread_application.headscale.application_id -} - -resource "azuread_service_principal_password" "headscale" { - service_principal_id = azuread_service_principal.headscale.id - end_date_relative = "44640h" -} - -output "headscale_client_id" { - value = azuread_application.headscale.application_id -} - -output "headscale_client_secret" { - value = azuread_application_password.headscale-application-secret.value -} -``` - -And in your Headscale `config.yaml`: - -```yaml -oidc: - issuer: "https://login.microsoftonline.com/<tenant-UUID>/v2.0" - client_id: "<client-id-from-terraform>" - client_secret: "<client-secret-from-terraform>" - - # Optional: add "groups" - scope: ["openid", "profile", "email"] - extra_params: - # Use your own domain, associated with Azure AD - domain_hint: example.com - # Optional: Force the Azure AD account picker - prompt: select_account -``` - -## Google OAuth Example - -In order to integrate Headscale with Google, you'll need to have a [Google Cloud Console](https://console.cloud.google.com) account. - -Google OAuth has a [verification process](https://support.google.com/cloud/answer/9110914?hl=en) if you need to have users authenticate who are outside of your domain. If you only need to authenticate users from your domain name (ie `@example.com`), you don't need to go through the verification process. - -However if you don't have a domain, or need to add users outside of your domain, you can manually add emails via Google Console. - -### Steps - -1. Go to [Google Console](https://console.cloud.google.com) and login or create an account if you don't have one. -2. Create a project (if you don't already have one). -3. On the left hand menu, go to `APIs and services` -> `Credentials` -4. Click `Create Credentials` -> `OAuth client ID` -5. Under `Application Type`, choose `Web Application` -6. For `Name`, enter whatever you like -7. Under `Authorised redirect URIs`, use `https://example.com/oidc/callback`, replacing example.com with your Headscale URL. -8. Click `Save` at the bottom of the form -9. Take note of the `Client ID` and `Client secret`, you can also download it for reference if you need it. -10. Edit your headscale config, under `oidc`, filling in your `client_id` and `client_secret`: - -```yaml -oidc: - issuer: "https://accounts.google.com" - client_id: "" - client_secret: "" - scope: ["openid", "profile", "email"] -``` - -You can also use `allowed_domains` and `allowed_users` to restrict the users who can authenticate. diff --git a/docs/packaging/README.md b/docs/packaging/README.md deleted file mode 100644 index c3a80893..00000000 --- a/docs/packaging/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Packaging - -We use [nFPM](https://nfpm.goreleaser.com/) for making `.deb`, `.rpm` and `.apk`. - -This folder contains files we need to package with these releases. diff --git a/docs/packaging/postinstall.sh b/docs/packaging/postinstall.sh deleted file mode 100644 index 7d934a9a..00000000 --- a/docs/packaging/postinstall.sh +++ /dev/null @@ -1,86 +0,0 @@ -#!/bin/sh -# Determine OS platform -# shellcheck source=/dev/null -. /etc/os-release - -HEADSCALE_EXE="/usr/bin/headscale" -BSD_HIER="" -HEADSCALE_RUN_DIR="/var/run/headscale" -HEADSCALE_USER="headscale" -HEADSCALE_GROUP="headscale" - -ensure_sudo() { - if [ "$(id -u)" = "0" ]; then - echo "Sudo permissions detected" - else - echo "No sudo permission detected, please run as sudo" - exit 1 - fi -} - -ensure_headscale_path() { - if [ ! -f "$HEADSCALE_EXE" ]; then - echo "headscale not in default path, exiting..." - exit 1 - fi - - printf "Found headscale %s\n" "$HEADSCALE_EXE" -} - -create_headscale_user() { - printf "PostInstall: Adding headscale user %s\n" "$HEADSCALE_USER" - useradd -s /bin/sh -c "headscale default user" headscale -} - -create_headscale_group() { - if command -V systemctl >/dev/null 2>&1; then - printf "PostInstall: Adding headscale group %s\n" "$HEADSCALE_GROUP" - groupadd "$HEADSCALE_GROUP" - - printf "PostInstall: Adding headscale user %s to group %s\n" "$HEADSCALE_USER" "$HEADSCALE_GROUP" - usermod -a -G "$HEADSCALE_GROUP" "$HEADSCALE_USER" - fi - - if [ "$ID" = "alpine" ]; then - printf "PostInstall: Adding headscale group %s\n" "$HEADSCALE_GROUP" - addgroup "$HEADSCALE_GROUP" - - printf "PostInstall: Adding headscale user %s to group %s\n" "$HEADSCALE_USER" "$HEADSCALE_GROUP" - addgroup "$HEADSCALE_USER" "$HEADSCALE_GROUP" - fi -} - -create_run_dir() { - printf "PostInstall: Creating headscale run directory \n" - mkdir -p "$HEADSCALE_RUN_DIR" - - printf "PostInstall: Modifying group ownership of headscale run directory \n" - chown "$HEADSCALE_USER":"$HEADSCALE_GROUP" "$HEADSCALE_RUN_DIR" -} - -summary() { - echo "----------------------------------------------------------------------" - echo " headscale package has been successfully installed." - echo "" - echo " Please follow the next steps to start the software:" - echo "" - echo " sudo systemctl enable headscale" - echo " sudo systemctl start headscale" - echo "" - echo " Configuration settings can be adjusted here:" - echo " ${BSD_HIER}/etc/headscale/config.yaml" - echo "" - echo "----------------------------------------------------------------------" -} - -# -# Main body of the script -# -{ - ensure_sudo - ensure_headscale_path - create_headscale_user - create_headscale_group - create_run_dir - summary -} diff --git a/docs/packaging/postremove.sh b/docs/packaging/postremove.sh deleted file mode 100644 index ed480bbf..00000000 --- a/docs/packaging/postremove.sh +++ /dev/null @@ -1,15 +0,0 @@ -#!/bin/sh -# Determine OS platform -# shellcheck source=/dev/null -. /etc/os-release - -if command -V systemctl >/dev/null 2>&1; then - echo "Stop and disable headscale service" - systemctl stop headscale >/dev/null 2>&1 || true - systemctl disable headscale >/dev/null 2>&1 || true - echo "Running daemon-reload" - systemctl daemon-reload || true -fi - -echo "Removing run directory" -rm -rf "/var/run/headscale.sock" diff --git a/docs/proposals/001-acls.md b/docs/proposals/001-acls.md deleted file mode 100644 index 8a02e836..00000000 --- a/docs/proposals/001-acls.md +++ /dev/null @@ -1,362 +0,0 @@ -# ACLs - -A key component of tailscale is the notion of Tailnet. This notion is hidden -but the implications that it have on how to use tailscale are not. - -For tailscale an [tailnet](https://tailscale.com/kb/1136/tailnet/) is the -following: - -> For personal users, you are a tailnet of many devices and one person. Each -> device gets a private Tailscale IP address in the CGNAT range and every -> device can talk directly to every other device, wherever they are on the -> internet. -> -> For businesses and organizations, a tailnet is many devices and many users. -> It can be based on your Microsoft Active Directory, your Google Workspace, a -> GitHub organization, Okta tenancy, or other identity provider namespace. All -> of the devices and users in your tailnet can be seen by the tailnet -> administrators in the Tailscale admin console. There you can apply -> tailnet-wide configuration, such as ACLs that affect visibility of devices -> inside your tailnet, DNS settings, and more. - -## Current implementation and issues - -Currently in headscale, the namespaces are used both as tailnet and users. The -issue is that if we want to use the ACL's we can't use both at the same time. - -Tailnet's cannot communicate with each others. So we can't have an ACL that -authorize tailnet (namespace) A to talk to tailnet (namespace) B. - -We also can't write ACLs based on the users (namespaces in headscale) since all -devices belong to the same user. - -With the current implementation the only ACL that we can user is to associate -each headscale IP to a host manually then write the ACLs according to this -manual mapping. - -```json -{ - "hosts": { - "host1": "100.64.0.1", - "server": "100.64.0.2" - }, - "acls": [ - { "action": "accept", "users": ["host1"], "ports": ["host2:80,443"] } - ] -} -``` - -While this works, it requires a lot of manual editing on the configuration and -to keep track of all devices IP address. - -## Proposition for a next implementation - -In order to ease the use of ACL's we need to split the tailnet and users -notion. - -A solution could be to consider a headscale server (in it's entirety) as a -tailnet. - -For personal users the default behavior could either allow all communications -between all namespaces (like tailscale) or dissallow all communications between -namespaces (current behavior). - -For businesses and organisations, viewing a headscale instance a single tailnet -would allow users (namespace) to talk to each other with the ACLs. As described -in tailscale's documentation [[1]], a server should be tagged and personnal -devices should be tied to a user. Translated in headscale's terms each user can -have multiple devices and all those devices should be in the same namespace. -The servers should be tagged and used as such. - -This implementation would render useless the sharing feature that is currently -implemented since an ACL could do the same. Simplifying to only one user -interface to do one thing is easier and less confusing for the users. - -To better suit the ACLs in this proposition, it's advised to consider that each -namespaces belong to one person. This person can have multiple devices, they -will all be considered as the same user in the ACLs. OIDC feature wouldn't need -to map people to namespace, just create a namespace if the person isn't -registered yet. - -As a sidenote, users would like to write ACLs as YAML. We should offer users -the ability to rules in either format (HuJSON or YAML). - -[1]: https://tailscale.com/kb/1068/acl-tags/ - -## Example - -Let's build an example use case for a small business (It may be the place where -ACL's are the most useful). - -We have a small company with a boss, an admin, two developper and an intern. - -The boss should have access to all servers but not to the users hosts. Admin -should also have access to all hosts except that their permissions should be -limited to maintaining the hosts (for example purposes). The developers can do -anything they want on dev hosts, but only watch on productions hosts. Intern -can only interact with the development servers. - -Each user have at least a device connected to the network and we have some -servers. - -- database.prod -- database.dev -- app-server1.prod -- app-server1.dev -- billing.internal - -### Current headscale implementation - -Let's create some namespaces - -```bash -headscale namespaces create prod -headscale namespaces create dev -headscale namespaces create internal -headscale namespaces create users - -headscale nodes register -n users boss-computer -headscale nodes register -n users admin1-computer -headscale nodes register -n users dev1-computer -headscale nodes register -n users dev1-phone -headscale nodes register -n users dev2-computer -headscale nodes register -n users intern1-computer - -headscale nodes register -n prod database -headscale nodes register -n prod app-server1 - -headscale nodes register -n dev database -headscale nodes register -n dev app-server1 - -headscale nodes register -n internal billing - -headscale nodes list -ID | Name | Namespace | IP address -1 | boss-computer | users | 100.64.0.1 -2 | admin1-computer | users | 100.64.0.2 -3 | dev1-computer | users | 100.64.0.3 -4 | dev1-phone | users | 100.64.0.4 -5 | dev2-computer | users | 100.64.0.5 -6 | intern1-computer | users | 100.64.0.6 -7 | database | prod | 100.64.0.7 -8 | app-server1 | prod | 100.64.0.8 -9 | database | dev | 100.64.0.9 -10 | app-server1 | dev | 100.64.0.10 -11 | internal | internal | 100.64.0.11 -``` - -In order to only allow the communications related to our description above we -need to add the following ACLs - -```json -{ - "hosts": { - "boss-computer": "100.64.0.1", - "admin1-computer": "100.64.0.2", - "dev1-computer": "100.64.0.3", - "dev1-phone": "100.64.0.4", - "dev2-computer": "100.64.0.5", - "intern1-computer": "100.64.0.6", - "prod-app-server1": "100.64.0.8" - }, - "groups": { - "group:dev": ["dev1-computer", "dev1-phone", "dev2-computer"], - "group:admin": ["admin1-computer"], - "group:boss": ["boss-computer"], - "group:intern": ["intern1-computer"] - }, - "acls": [ - // boss have access to all servers but no users hosts - { - "action": "accept", - "users": ["group:boss"], - "ports": ["prod:*", "dev:*", "internal:*"] - }, - - // admin have access to adminstration port (lets only consider port 22 here) - { - "action": "accept", - "users": ["group:admin"], - "ports": ["prod:22", "dev:22", "internal:22"] - }, - - // dev can do anything on dev servers and check access on prod servers - { - "action": "accept", - "users": ["group:dev"], - "ports": ["dev:*", "prod-app-server1:80,443"] - }, - - // interns only have access to port 80 and 443 on dev servers (lame internship) - { "action": "accept", "users": ["group:intern"], "ports": ["dev:80,443"] }, - - // users can access their own devices - { - "action": "accept", - "users": ["dev1-computer"], - "ports": ["dev1-phone:*"] - }, - { - "action": "accept", - "users": ["dev1-phone"], - "ports": ["dev1-computer:*"] - }, - - // internal namespace communications should still be allowed within the namespace - { "action": "accept", "users": ["dev"], "ports": ["dev:*"] }, - { "action": "accept", "users": ["prod"], "ports": ["prod:*"] }, - { "action": "accept", "users": ["internal"], "ports": ["internal:*"] } - ] -} -``` - -Since communications between namespace isn't possible we also have to share the -devices between the namespaces. - -```bash - -// add boss host to prod, dev and internal network -headscale nodes share -i 1 -n prod -headscale nodes share -i 1 -n dev -headscale nodes share -i 1 -n internal - -// add admin computer to prod, dev and internal network -headscale nodes share -i 2 -n prod -headscale nodes share -i 2 -n dev -headscale nodes share -i 2 -n internal - -// add all dev to prod and dev network -headscale nodes share -i 3 -n dev -headscale nodes share -i 4 -n dev -headscale nodes share -i 3 -n prod -headscale nodes share -i 4 -n prod -headscale nodes share -i 5 -n dev -headscale nodes share -i 5 -n prod - -headscale nodes share -i 6 -n dev -``` - -This fake network have not been tested but it should work. Operating it could -be quite tedious if the company grows. Each time a new user join we have to add -it to a group, and share it to the correct namespaces. If the user want -multiple devices we have to allow communication to each of them one by one. If -business conduct a change in the organisations we may have to rewrite all acls -and reorganise all namespaces. - -If we add servers in production we should also update the ACLs to allow dev -access to certain category of them (only app servers for example). - -### example based on the proposition in this document - -Let's create the namespaces - -```bash -headscale namespaces create boss -headscale namespaces create admin1 -headscale namespaces create dev1 -headscale namespaces create dev2 -headscale namespaces create intern1 -``` - -We don't need to create namespaces for the servers because the servers will be -tagged. When registering the servers we will need to add the flag -`--advertised-tags=tag:<tag1>,tag:<tag2>`, and the user (namespace) that is -registering the server should be allowed to do it. Since anyone can add tags to -a server they can register, the check of the tags is done on headscale server -and only valid tags are applied. A tag is valid if the namespace that is -registering it is allowed to do it. - -Here are the ACL's to implement the same permissions as above: - -```json -{ - // groups are simpler and only list the namespaces name - "groups": { - "group:boss": ["boss"], - "group:dev": ["dev1", "dev2"], - "group:admin": ["admin1"], - "group:intern": ["intern1"] - }, - "tagOwners": { - // the administrators can add servers in production - "tag:prod-databases": ["group:admin"], - "tag:prod-app-servers": ["group:admin"], - - // the boss can tag any server as internal - "tag:internal": ["group:boss"], - - // dev can add servers for dev purposes as well as admins - "tag:dev-databases": ["group:admin", "group:dev"], - "tag:dev-app-servers": ["group:admin", "group:dev"] - - // interns cannot add servers - }, - "acls": [ - // boss have access to all servers - { - "action": "accept", - "users": ["group:boss"], - "ports": [ - "tag:prod-databases:*", - "tag:prod-app-servers:*", - "tag:internal:*", - "tag:dev-databases:*", - "tag:dev-app-servers:*" - ] - }, - - // admin have only access to administrative ports of the servers - { - "action": "accept", - "users": ["group:admin"], - "ports": [ - "tag:prod-databases:22", - "tag:prod-app-servers:22", - "tag:internal:22", - "tag:dev-databases:22", - "tag:dev-app-servers:22" - ] - }, - - { - "action": "accept", - "users": ["group:dev"], - "ports": [ - "tag:dev-databases:*", - "tag:dev-app-servers:*", - "tag:prod-app-servers:80,443" - ] - }, - - // servers should be able to talk to database. Database should not be able to initiate connections to server - { - "action": "accept", - "users": ["tag:dev-app-servers"], - "ports": ["tag:dev-databases:5432"] - }, - { - "action": "accept", - "users": ["tag:prod-app-servers"], - "ports": ["tag:prod-databases:5432"] - }, - - // interns have access to dev-app-servers only in reading mode - { - "action": "accept", - "users": ["group:intern"], - "ports": ["tag:dev-app-servers:80,443"] - }, - - // we still have to allow internal namespaces communications since nothing guarantees that each user have their own namespaces. This could be talked over. - { "action": "accept", "users": ["boss"], "ports": ["boss:*"] }, - { "action": "accept", "users": ["dev1"], "ports": ["dev1:*"] }, - { "action": "accept", "users": ["dev2"], "ports": ["dev2:*"] }, - { "action": "accept", "users": ["admin1"], "ports": ["admin1:*"] }, - { "action": "accept", "users": ["intern1"], "ports": ["intern1:*"] } - ] -} -``` - -With this implementation, the sharing step is not necessary. Maintenance cost -of the ACL file is lower and less tedious (no need to map hostname and IP's -into it). diff --git a/docs/proposals/002-better-routing.md b/docs/proposals/002-better-routing.md deleted file mode 100644 index c56a38ff..00000000 --- a/docs/proposals/002-better-routing.md +++ /dev/null @@ -1,48 +0,0 @@ -# Better route management - -As of today, route management in Headscale is very basic and does not allow for much flexibility, including implementing subnet HA, 4via6 or more advanced features. We also have a number of bugs (e.g., routes exposed by ephemeral nodes) - -This proposal aims to improve the route management. - -## Current situation - -Routes advertised by the nodes are read from the Hostinfo struct. If approved from the the CLI or via autoApprovers, the route is added to the EnabledRoutes field in `Machine`. - -This means that the advertised routes are not persisted in the database, as Hostinfo is always replaced. In the same way, EnabledRoutes can get out of sync with the actual routes in the node. - -In case of colliding routes (i.e., subnets that are exposed from multiple nodes), we are currently just sending all of them in `PrimaryRoutes`... and hope for the best. (`PrimaryRoutes` is the field in `Node` used for subnet failover). - -## Proposal - -The core part is to create a new `Route` struct (and DB table), with the following fields: - -```go -type Route struct { - ID uint64 `gorm:"primary_key"` - - Machine *Machine - Prefix IPPrefix - - Advertised bool - Enabled bool - IsPrimary bool - - - CreatedAt *time.Time - UpdatedAt *time.Time - DeletedAt *time.Time -} -``` - -- The `Advertised` field is set to true if the route is being advertised by the node. It is set to false if the route is removed. This way we can indicate if a later enabled route has stopped being advertised. A similar behaviour happens in the Tailscale.com control panel. - -- The `Enabled` field is set to true if the route is enabled - via CLI or autoApprovers. - -- `IsPrimary` indicates if Headscale has selected this route as the primary route for that particular subnet. This allows us to implement subnet failover. This would be fully automatic if there is more than subnet routers advertising the same network - which is the behaviour of Tailscale.com. - -## Stuff to bear in mind - -- We need to make sure to migrate the current `EnabledRoutes` of `Machine` into the new table. -- When a node stops sharing a subnet, I reckon we should mark it both as not `Advertised` and not `Enabled`. Users should re-enable it if the node advertises it again. -- If only one subnet router is advertising a subnet, we should mark it as primary. -- Regarding subnet failover, the current behaviour of Tailscale.com is to perform the failover after 15 seconds from the node disconnecting from their control panel. I reckon we cannot do the same currently. Our maximum granularity is the keep alive period. diff --git a/docs/acls.md b/docs/ref/acls.md similarity index 50% rename from docs/acls.md rename to docs/ref/acls.md index 096dbea0..fff66715 100644 --- a/docs/acls.md +++ b/docs/ref/acls.md @@ -3,15 +3,44 @@ Headscale implements the same policy ACLs as Tailscale.com, adapted to the self- For instance, instead of referring to users when defining groups you must use users (which are the equivalent to user/logins in Tailscale.com). -Please check https://tailscale.com/kb/1018/acls/, and `./tests/acls/` in this repo for working examples. +Please check https://tailscale.com/kb/1018/acls/ for further information. When using ACL's the User borders are no longer applied. All machines whichever the User have the ability to communicate with other hosts as long as the ACL's permits this exchange. -## ACLs use case example +## ACL Setup -Let's build an example use case for a small business (It may be the place where +To enable and configure ACLs in Headscale, you need to specify the path to your ACL policy file in the `policy.path` key in `config.yaml`. + +Your ACL policy file must be formatted using [huJSON](https://github.com/tailscale/hujson). + +Info on how these policies are written can be found +[here](https://tailscale.com/kb/1018/acls/). + +Please reload or restart Headscale after updating the ACL file. Headscale may be reloaded either via its systemd service +(`sudo systemctl reload headscale`) or by sending a SIGHUP signal (`sudo kill -HUP $(pidof headscale)`) to the main +process. Headscale logs the result of ACL policy processing after each reload. + +## Simple Examples + +- [**Allow All**](https://tailscale.com/kb/1192/acl-samples#allow-all-default-acl): If you define an ACL file but completely omit the `"acls"` field from its content, Headscale will default to an "allow all" policy. This means all devices connected to your tailnet will be able to communicate freely with each other. + + ```json + {} + ``` + +- [**Deny All**](https://tailscale.com/kb/1192/acl-samples#deny-all): To prevent all communication within your tailnet, you can include an empty array for the `"acls"` field in your policy file. + + ```json + { + "acls": [] + } + ``` + +## Complex Example + +Let's build a more complex example use case for a small business (It may be the place where ACL's are the most useful). We have a small company with a boss, an admin, two developers and an intern. @@ -36,36 +65,26 @@ servers. - billing.internal - router.internal -![ACL implementation example](images/headscale-acl-network.png) +![ACL implementation example](../assets/images/headscale-acl-network.png) -## ACL setup - -Note: Users will be created automatically when users authenticate with the -Headscale server. - -ACLs could be written either on [huJSON](https://github.com/tailscale/hujson) -or YAML. Check the [test ACLs](../tests/acls) for further information. - -When registering the servers we will need to add the flag -`--advertise-tags=tag:<tag1>,tag:<tag2>`, and the user that is -registering the server should be allowed to do it. Since anyone can add tags to -a server they can register, the check of the tags is done on headscale server -and only valid tags are applied. A tag is valid if the user that is +When [registering the servers](../usage/getting-started.md#register-a-node) we +will need to add the flag `--advertise-tags=tag:<tag1>,tag:<tag2>`, and the user +that is registering the server should be allowed to do it. Since anyone can add +tags to a server they can register, the check of the tags is done on headscale +server and only valid tags are applied. A tag is valid if the user that is registering it is allowed to do it. -To use ACLs in headscale, you must edit your config.yaml file. In there you will find a `acl_policy_path: ""` parameter. This will need to point to your ACL file. More info on how these policies are written can be found [here](https://tailscale.com/kb/1018/acls/). - Here are the ACL's to implement the same permissions as above: -```json +```json title="acl.json" { // groups are collections of users having a common scope. A user can be in multiple groups // groups cannot be composed of groups "groups": { - "group:boss": ["boss"], - "group:dev": ["dev1", "dev2"], - "group:admin": ["admin1"], - "group:intern": ["intern1"] + "group:boss": ["boss@"], + "group:dev": ["dev1@", "dev2@"], + "group:admin": ["admin1@"], + "group:intern": ["intern1@"] }, // tagOwners in tailscale is an association between a TAG and the people allowed to set this TAG on a server. // This is documented [here](https://tailscale.com/kb/1068/acl-tags#defining-a-tag) @@ -88,7 +107,7 @@ Here are the ACL's to implement the same permissions as above: // to define a single host, use a /32 mask. You cannot use DNS entries here, // as they're prone to be hijacked by replacing their IP addresses. // see https://github.com/tailscale/tailscale/issues/3800 for more information. - "Hosts": { + "hosts": { "postgresql.internal": "10.20.0.2/32", "webservers.internal": "10.20.10.1/29" }, @@ -147,13 +166,11 @@ Here are the ACL's to implement the same permissions as above: }, // developers have access to the internal network through the router. // the internal network is composed of HTTPS endpoints and Postgresql - // database servers. There's an additional rule to allow traffic to be - // forwarded to the internal subnet, 10.20.0.0/16. See this issue - // https://github.com/juanfont/headscale/issues/502 + // database servers. { "action": "accept", "src": ["group:dev"], - "dst": ["10.20.0.0/16:443,5432", "router.internal:0"] + "dst": ["10.20.0.0/16:443,5432"] }, // servers should be able to talk to database in tcp/5432. Database should not be able to initiate connections to @@ -177,13 +194,94 @@ Here are the ACL's to implement the same permissions as above: "dst": ["tag:dev-app-servers:80,443"] }, - // We still have to allow internal users communications since nothing guarantees that each user have - // their own users. - { "action": "accept", "src": ["boss"], "dst": ["boss:*"] }, - { "action": "accept", "src": ["dev1"], "dst": ["dev1:*"] }, - { "action": "accept", "src": ["dev2"], "dst": ["dev2:*"] }, - { "action": "accept", "src": ["admin1"], "dst": ["admin1:*"] }, - { "action": "accept", "src": ["intern1"], "dst": ["intern1:*"] } + // Allow users to access their own devices using autogroup:self (see below for more details about performance impact) + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } ] } ``` + +## Autogroups + +Headscale supports several autogroups that automatically include users, destinations, or devices with specific properties. Autogroups provide a convenient way to write ACL rules without manually listing individual users or devices. + +### `autogroup:internet` + +Allows access to the internet through [exit nodes](routes.md#exit-node). Can only be used in ACL destinations. + +```json +{ + "action": "accept", + "src": ["group:users"], + "dst": ["autogroup:internet:*"] +} +``` + +### `autogroup:member` + +Includes all untagged devices. + +```json +{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["tag:prod-app-servers:80,443"] +} +``` + +### `autogroup:tagged` + +Includes all devices that have at least one tag. + +```json +{ + "action": "accept", + "src": ["autogroup:tagged"], + "dst": ["tag:monitoring:9090"] +} +``` + +### `autogroup:self` +**(EXPERIMENTAL)** + +!!! warning "The current implementation of `autogroup:self` is inefficient" + +Includes devices where the same user is authenticated on both the source and destination. Does not include tagged devices. Can only be used in ACL destinations. + +```json +{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] +} +``` +*Using `autogroup:self` may cause performance degradation on the Headscale coordinator server in large deployments, as filter rules must be compiled per-node rather than globally and the current implementation is not very efficient.* + +If you experience performance issues, consider using more specific ACL rules or limiting the use of `autogroup:self`. +```json +{ + // The following rules allow internal users to communicate with their + // own nodes in case autogroup:self is causing performance issues. + { "action": "accept", "src": ["boss@"], "dst": ["boss@:*"] }, + { "action": "accept", "src": ["dev1@"], "dst": ["dev1@:*"] }, + { "action": "accept", "src": ["dev2@"], "dst": ["dev2@:*"] }, + { "action": "accept", "src": ["admin1@"], "dst": ["admin1@:*"] }, + { "action": "accept", "src": ["intern1@"], "dst": ["intern1@:*"] } +} +``` + +### `autogroup:nonroot` + +Used in Tailscale SSH rules to allow access to any user except root. Can only be used in the `users` field of SSH rules. + +```json +{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self"], + "users": ["autogroup:nonroot"] +} +``` diff --git a/docs/ref/api.md b/docs/ref/api.md new file mode 100644 index 00000000..a99e679c --- /dev/null +++ b/docs/ref/api.md @@ -0,0 +1,129 @@ +# API +Headscale provides a [HTTP REST API](#rest-api) and a [gRPC interface](#grpc) which may be used to integrate a [web +interface](integration/web-ui.md), [remote control Headscale](#setup-remote-control) or provide a base for custom +integration and tooling. + +Both interfaces require a valid API key before use. To create an API key, log into your Headscale server and generate +one with the default expiration of 90 days: + +```shell +headscale apikeys create +``` + +Copy the output of the command and save it for later. Please note that you can not retrieve an API key again. If the API +key is lost, expire the old one, and create a new one. + +To list the API keys currently associated with the server: + +```shell +headscale apikeys list +``` + +and to expire an API key: + +```shell +headscale apikeys expire --prefix <PREFIX> +``` + +## REST API + +- API endpoint: `/api/v1`, e.g. `https://headscale.example.com/api/v1` +- Documentation: `/swagger`, e.g. `https://headscale.example.com/swagger` +- Headscale Version: `/version`, e.g. `https://headscale.example.com/version` +- Authenticate using HTTP Bearer authentication by sending the [API key](#api) with the HTTP `Authorization: Bearer + <API_KEY>` header. + +Start by [creating an API key](#api) and test it with the examples below. Read the API documentation provided by your +Headscale server at `/swagger` for details. + +=== "Get details for all users" + + ```console + curl -H "Authorization: Bearer <API_KEY>" \ + https://headscale.example.com/api/v1/user + ``` + +=== "Get details for user 'bob'" + + ```console + curl -H "Authorization: Bearer <API_KEY>" \ + https://headscale.example.com/api/v1/user?name=bob + ``` + +=== "Register a node" + + ```console + curl -H "Authorization: Bearer <API_KEY>" \ + -d user=<USER> -d key=<KEY> \ + https://headscale.example.com/api/v1/node/register + ``` + +## gRPC + +The gRPC interface can be used to control a Headscale instance from a remote machine with the `headscale` binary. + +### Prerequisite + +- A workstation to run `headscale` (any supported platform, e.g. Linux). +- A Headscale server with gRPC enabled. +- Connections to the gRPC port (default: `50443`) are allowed. +- Remote access requires an encrypted connection via TLS. +- An [API key](#api) to authenticate with the Headscale server. + +### Setup remote control + +1. Download the [`headscale` binary from GitHub's release page](https://github.com/juanfont/headscale/releases). Make + sure to use the same version as on the server. + +1. Put the binary somewhere in your `PATH`, e.g. `/usr/local/bin/headscale` + +1. Make `headscale` executable: `chmod +x /usr/local/bin/headscale` + +1. [Create an API key](#api) on the Headscale server. + +1. Provide the connection parameters for the remote Headscale server either via a minimal YAML configuration file or + via environment variables: + + === "Minimal YAML configuration file" + + ```yaml title="config.yaml" + cli: + address: <HEADSCALE_ADDRESS>:<PORT> + api_key: <API_KEY> + ``` + + === "Environment variables" + + ```shell + export HEADSCALE_CLI_ADDRESS="<HEADSCALE_ADDRESS>:<PORT>" + export HEADSCALE_CLI_API_KEY="<API_KEY>" + ``` + + This instructs the `headscale` binary to connect to a remote instance at `<HEADSCALE_ADDRESS>:<PORT>`, instead of + connecting to the local instance. + +1. Test the connection by listing all nodes: + + ```shell + headscale nodes list + ``` + + You should now be able to see a list of your nodes from your workstation, and you can + now control the Headscale server from your workstation. + +### Behind a proxy + +It's possible to run the gRPC remote endpoint behind a reverse proxy, like Nginx, and have it run on the _same_ port as Headscale. + +While this is _not a supported_ feature, an example on how this can be set up on +[NixOS is shown here](https://github.com/kradalby/dotfiles/blob/4489cdbb19cddfbfae82cd70448a38fde5a76711/machines/headscale.oracldn/headscale.nix#L61-L91). + +### Troubleshooting + +- Make sure you have the _same_ Headscale version on your server and workstation. +- Ensure that connections to the gRPC port are allowed. +- Verify that your TLS certificate is valid and trusted. +- If you don't have access to a trusted certificate (e.g. from Let's Encrypt), either: + - Add your self-signed certificate to the trust store of your OS _or_ + - Disable certificate verification by either setting `cli.insecure: true` in the configuration file or by setting + `HEADSCALE_CLI_INSECURE=1` via an environment variable. We do **not** recommend to disable certificate validation. diff --git a/docs/ref/configuration.md b/docs/ref/configuration.md new file mode 100644 index 00000000..18c8502f --- /dev/null +++ b/docs/ref/configuration.md @@ -0,0 +1,41 @@ +# Configuration + +- Headscale loads its configuration from a YAML file +- It searches for `config.yaml` in the following paths: + - `/etc/headscale` + - `$HOME/.headscale` + - the current working directory +- To load the configuration from a different path, use: + - the command line flag `-c`, `--config` + - the environment variable `HEADSCALE_CONFIG` +- Validate the configuration file with: `headscale configtest` + +!!! example "Get the [example configuration from the GitHub repository](https://github.com/juanfont/headscale/blob/main/config-example.yaml)" + + Always select the [same GitHub tag](https://github.com/juanfont/headscale/tags) as the released version you use to + ensure you have the correct example configuration. The `main` branch might contain unreleased changes. + + === "View on GitHub" + + * Development version: <https://github.com/juanfont/headscale/blob/main/config-example.yaml> + * Version {{ headscale.version }}: <https://github.com/juanfont/headscale/blob/v{{ headscale.version }}/config-example.yaml> + + === "Download with `wget`" + + ```shell + # Development version + wget -O config.yaml https://raw.githubusercontent.com/juanfont/headscale/main/config-example.yaml + + # Version {{ headscale.version }} + wget -O config.yaml https://raw.githubusercontent.com/juanfont/headscale/v{{ headscale.version }}/config-example.yaml + ``` + + === "Download with `curl`" + + ```shell + # Development version + curl -o config.yaml https://raw.githubusercontent.com/juanfont/headscale/main/config-example.yaml + + # Version {{ headscale.version }} + curl -o config.yaml https://raw.githubusercontent.com/juanfont/headscale/v{{ headscale.version }}/config-example.yaml + ``` diff --git a/docs/ref/debug.md b/docs/ref/debug.md new file mode 100644 index 00000000..f2899d69 --- /dev/null +++ b/docs/ref/debug.md @@ -0,0 +1,118 @@ +# Debugging and troubleshooting + +Headscale and Tailscale provide debug and introspection capabilities that can be helpful when things don't work as +expected. This page explains some debugging techniques to help pinpoint problems. + +Please also have a look at [Tailscale's Troubleshooting guide](https://tailscale.com/kb/1023/troubleshooting). It offers +a many tips and suggestions to troubleshoot common issues. + +## Tailscale + +The Tailscale client itself offers many commands to introspect its state as well as the state of the network: + +- [Check local network conditions](https://tailscale.com/kb/1080/cli#netcheck): `tailscale netcheck` +- [Get the client status](https://tailscale.com/kb/1080/cli#status): `tailscale status --json` +- [Get DNS status](https://tailscale.com/kb/1080/cli#dns): `tailscale dns status --all` +- Client logs: `tailscale debug daemon-logs` +- Client netmap: `tailscale debug netmap` +- Test DERP connection: `tailscale debug derp headscale` +- And many more, see: `tailscale debug --help` + +Many of the commands are helpful when trying to understand differences between Headscale and Tailscale SaaS. + +## Headscale + +### Application logging + +The log levels `debug` and `trace` can be useful to get more information from Headscale. + +```yaml hl_lines="3" +log: + # Valid log levels: panic, fatal, error, warn, info, debug, trace + level: debug +``` + +### Database logging + +The database debug mode logs all database queries. Enable it to see how Headscale interacts with its database. This also +requires the application log level to be set to either `debug` or `trace`. + +```yaml hl_lines="3 7" +database: + # Enable debug mode. This setting requires the log.level to be set to "debug" or "trace". + debug: false + +log: + # Valid log levels: panic, fatal, error, warn, info, debug, trace + level: debug +``` + +### Metrics and debug endpoint + +Headscale provides a metrics and debug endpoint. It allows to introspect different aspects such as: + +- Information about the Go runtime, memory usage and statistics +- Connected nodes and pending registrations +- Active ACLs, filters and SSH policy +- Current DERPMap +- Prometheus metrics + +!!! warning "Keep the metrics and debug endpoint private" + + The listen address and port can be configured with the `metrics_listen_addr` variable in the [configuration + file](./configuration.md). By default it listens on localhost, port 9090. + + Keep the metrics and debug endpoint private to your internal network and don't expose it to the Internet. + + The metrics and debug interface can be disabled completely by setting `metrics_listen_addr: null` in the + [configuration file](./configuration.md). + +Query metrics via <http://localhost:9090/metrics> and get an overview of available debug information via +<http://localhost:9090/debug/>. Metrics may be queried from outside localhost but the debug interface is subject to +additional protection despite listening on all interfaces. + +=== "Direct access" + + Access the debug interface directly on the server where Headscale is installed. + + ```console + curl http://localhost:9090/debug/ + ``` + +=== "SSH port forwarding" + + Use SSH port forwarding to forward Headscale's metrics and debug port to your device. + + ```console + ssh <HEADSCALE_SERVER> -L 9090:localhost:9090 + ``` + + Access the debug interface on your device by opening <http://localhost:9090/debug/> in your web browser. + +=== "Via debug key" + + The access control of the debug interface supports the use of a debug key. Traffic is accepted if the path to a + debug key is set via the environment variable `TS_DEBUG_KEY_PATH` and the debug key sent as value for `debugkey` + parameter with each request. + + ```console + openssl rand -hex 32 | tee debugkey.txt + export TS_DEBUG_KEY_PATH=debugkey.txt + headscale serve + ``` + + Access the debug interface on your device by opening `http://<IP_OF_HEADSCALE>:9090/debug/?debugkey=<DEBUG_KEY>` in + your web browser. The `debugkey` parameter must be sent with every request. + +=== "Via debug IP address" + + The debug endpoint expects traffic from localhost. A different debug IP address may be configured by setting the + `TS_ALLOW_DEBUG_IP` environment variable before starting Headscale. The debug IP address is ignored when the HTTP + header `X-Forwarded-For` is present. + + ```console + export TS_ALLOW_DEBUG_IP=192.168.0.10 # IP address of your device + headscale serve + ``` + + Access the debug interface on your device by opening `http://<IP_OF_HEADSCALE>:9090/debug/` in your web browser. diff --git a/docs/ref/derp.md b/docs/ref/derp.md new file mode 100644 index 00000000..45fc4119 --- /dev/null +++ b/docs/ref/derp.md @@ -0,0 +1,175 @@ +# DERP + +A [DERP (Designated Encrypted Relay for Packets) server](https://tailscale.com/kb/1232/derp-servers) is mainly used to +relay traffic between two nodes in case a direct connection can't be established. Headscale provides an embedded DERP +server to ensure seamless connectivity between nodes. + +## Configuration + +DERP related settings are configured within the `derp` section of the [configuration file](./configuration.md). The +following sections only use a few of the available settings, check the [example configuration](./configuration.md) for +all available configuration options. + +### Enable embedded DERP + +Headscale ships with an embedded DERP server which allows to run your own self-hosted DERP server easily. The embedded +DERP server is disabled by default and needs to be enabled. In addition, you should configure the public IPv4 and public +IPv6 address of your Headscale server for improved connection stability: + +```yaml title="config.yaml" hl_lines="3-5" +derp: + server: + enabled: true + ipv4: 198.51.100.1 + ipv6: 2001:db8::1 +``` + +Keep in mind that [additional ports are needed to run a DERP server](../setup/requirements.md#ports-in-use). Besides +relaying traffic, it also uses STUN (udp/3478) to help clients discover their public IP addresses and perform NAT +traversal. [Check DERP server connectivity](#check-derp-server-connectivity) to see if everything works. + +### Remove Tailscale's DERP servers + +Once enabled, Headscale's embedded DERP is added to the list of free-to-use [DERP +servers](https://tailscale.com/kb/1232/derp-servers) offered by Tailscale Inc. To only use Headscale's embedded DERP +server, disable the loading of the default DERP map: + +```yaml title="config.yaml" hl_lines="6" +derp: + server: + enabled: true + ipv4: 198.51.100.1 + ipv6: 2001:db8::1 + urls: [] +``` + +!!! warning "Single point of failure" + + Removing Tailscale's DERP servers means that there is now just a single DERP server available for clients. This is a + single point of failure and could hamper connectivity. + + [Check DERP server connectivity](#check-derp-server-connectivity) with your embedded DERP server before removing + Tailscale's DERP servers. + +### Customize DERP map + +The DERP map offered to clients can be customized with a [dedicated YAML-configuration +file](https://github.com/juanfont/headscale/blob/main/derp-example.yaml). This allows to modify previously loaded DERP +maps fetched via URL or to offer your own, custom DERP servers to nodes. + +=== "Remove specific DERP regions" + + The free-to-use [DERP servers](https://tailscale.com/kb/1232/derp-servers) are organized into regions via a region + ID. You can explicitly disable a specific region by setting its region ID to `null`. The following sample + `derp.yaml` disables the New York DERP region (which has the region ID 1): + + ```yaml title="derp.yaml" + regions: + 1: null + ``` + + Use the following configuration to serve the default DERP map (excluding New York) to nodes: + + ```yaml title="config.yaml" hl_lines="6 7" + derp: + server: + enabled: false + urls: + - https://controlplane.tailscale.com/derpmap/default + paths: + - /etc/headscale/derp.yaml + ``` + +=== "Provide custom DERP servers" + + The following sample `derp.yaml` references two custom regions (`custom-east` with ID 900 and `custom-west` with ID 901) + with one custom DERP server in each region. Each DERP server offers DERP relay via HTTPS on tcp/443, support for captive + portal checks via HTTP on tcp/80 and STUN on udp/3478. See the definitions of + [DERPMap](https://pkg.go.dev/tailscale.com/tailcfg#DERPMap), + [DERPRegion](https://pkg.go.dev/tailscale.com/tailcfg#DERPRegion) and + [DERPNode](https://pkg.go.dev/tailscale.com/tailcfg#DERPNode) for all available options. + + ```yaml title="derp.yaml" + regions: + 900: + regionid: 900 + regioncode: custom-east + regionname: My region (east) + nodes: + - name: 900a + regionid: 900 + hostname: derp900a.example.com + ipv4: 198.51.100.1 + ipv6: 2001:db8::1 + canport80: true + 901: + regionid: 901 + regioncode: custom-west + regionname: My Region (west) + nodes: + - name: 901a + regionid: 901 + hostname: derp901a.example.com + ipv4: 198.51.100.2 + ipv6: 2001:db8::2 + canport80: true + ``` + + Use the following configuration to only serve the two DERP servers from the above `derp.yaml`: + + ```yaml title="config.yaml" hl_lines="5 6" + derp: + server: + enabled: false + urls: [] + paths: + - /etc/headscale/derp.yaml + ``` + +Independent of the custom DERP map, you may choose to [enable the embedded DERP server and have it automatically added +to the custom DERP map](#enable-embedded-derp). + +### Verify clients + +Access to DERP serves can be restricted to nodes that are members of your Tailnet. Relay access is denied for unknown +clients. + +=== "Embedded DERP" + + Client verification is enabled by default. + + ```yaml title="config.yaml" hl_lines="3" + derp: + server: + verify_clients: true + ``` + +=== "3rd-party DERP" + + Tailscale's `derper` provides two parameters to configure client verification: + + - Use the `-verify-client-url` parameter of the `derper` and point it towards the `/verify` endpoint of your + Headscale server (e.g `https://headscale.example.com/verify`). The DERP server will query your Headscale instance + as soon as a client connects with it to ask whether access should be allowed or denied. Access is allowed if + Headscale knows about the connecting client and denied otherwise. + - The parameter `-verify-client-url-fail-open` controls what should happen when the DERP server can't reach the + Headscale instance. By default, it will allow access if Headscale is unreachable. + +## Check DERP server connectivity + +Any Tailscale client may be used to introspect the DERP map and to check for connectivity issues with DERP servers. + +- Display DERP map: `tailscale debug derp-map` +- Check connectivity with the embedded DERP[^1]:`tailscale debug derp headscale` + +Additional DERP related metrics and information is available via the [metrics and debug +endpoint](./debug.md#metrics-and-debug-endpoint). + +[^1]: + This assumes that the default region code of the [configuration file](./configuration.md) is used. + +## Limitations + +- The embedded DERP server can't be used for Tailscale's captive portal checks as it doesn't support the `/generate_204` + endpoint via HTTP on port tcp/80. +- There are no speed or throughput optimisations, the main purpose is to assist in node connectivity. diff --git a/docs/ref/dns.md b/docs/ref/dns.md new file mode 100644 index 00000000..409a903c --- /dev/null +++ b/docs/ref/dns.md @@ -0,0 +1,111 @@ +# DNS + +Headscale supports [most DNS features](../about/features.md) from Tailscale. DNS related settings can be configured +within the `dns` section of the [configuration file](./configuration.md). + +## Setting extra DNS records + +Headscale allows to set extra DNS records which are made available via +[MagicDNS](https://tailscale.com/kb/1081/magicdns). Extra DNS records can be configured either via static entries in the +[configuration file](./configuration.md) or from a JSON file that Headscale continuously watches for changes: + +- Use the `dns.extra_records` option in the [configuration file](./configuration.md) for entries that are static and + don't change while Headscale is running. Those entries are processed when Headscale is starting up and changes to the + configuration require a restart of Headscale. +- For dynamic DNS records that may be added, updated or removed while Headscale is running or DNS records that are + generated by scripts the option `dns.extra_records_path` in the [configuration file](./configuration.md) is useful. + Set it to the absolute path of the JSON file containing DNS records and Headscale processes this file as it detects + changes. + +An example use case is to serve multiple apps on the same host via a reverse proxy like NGINX, in this case a Prometheus +monitoring stack. This allows to nicely access the service with "http://grafana.myvpn.example.com" instead of the +hostname and port combination "http://hostname-in-magic-dns.myvpn.example.com:3000". + +!!! warning "Limitations" + + Currently, [only A and AAAA records are processed by Tailscale](https://github.com/tailscale/tailscale/blob/v1.86.5/ipn/ipnlocal/node_backend.go#L662). + +1. Configure extra DNS records using one of the available configuration options: + + === "Static entries, via `dns.extra_records`" + + ```yaml title="config.yaml" + dns: + ... + extra_records: + - name: "grafana.myvpn.example.com" + type: "A" + value: "100.64.0.3" + + - name: "prometheus.myvpn.example.com" + type: "A" + value: "100.64.0.3" + ... + ``` + + Restart your headscale instance. + + === "Dynamic entries, via `dns.extra_records_path`" + + ```json title="extra-records.json" + [ + { + "name": "grafana.myvpn.example.com", + "type": "A", + "value": "100.64.0.3" + }, + { + "name": "prometheus.myvpn.example.com", + "type": "A", + "value": "100.64.0.3" + } + ] + ``` + + Headscale picks up changes to the above JSON file automatically. + + !!! tip "Good to know" + + * The `dns.extra_records_path` option in the [configuration file](./configuration.md) needs to reference the + JSON file containing extra DNS records. + * Be sure to "sort keys" and produce a stable output in case you generate the JSON file with a script. + Headscale uses a checksum to detect changes to the file and a stable output avoids unnecessary processing. + +1. Verify that DNS records are properly set using the DNS querying tool of your choice: + + === "Query with dig" + + ```console + dig +short grafana.myvpn.example.com + 100.64.0.3 + ``` + + === "Query with drill" + + ```console + drill -Q grafana.myvpn.example.com + 100.64.0.3 + ``` + +1. Optional: Setup the reverse proxy + + The motivating example here was to be able to access internal monitoring services on the same host without + specifying a port, depicted as NGINX configuration snippet: + + ```nginx title="nginx.conf" + server { + listen 80; + listen [::]:80; + + server_name grafana.myvpn.example.com; + + location / { + proxy_pass http://localhost:3000; + proxy_set_header Host $http_host; + proxy_set_header X-Real-IP $remote_addr; + proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; + proxy_set_header X-Forwarded-Proto $scheme; + } + + } + ``` diff --git a/docs/reverse-proxy.md b/docs/ref/integration/reverse-proxy.md similarity index 73% rename from docs/reverse-proxy.md rename to docs/ref/integration/reverse-proxy.md index aab9f848..3586171f 100644 --- a/docs/reverse-proxy.md +++ b/docs/ref/integration/reverse-proxy.md @@ -3,7 +3,7 @@ !!! warning "Community documentation" This page is not actively maintained by the headscale authors and is - written by community members. It is _not_ verified by `headscale` developers. + written by community members. It is _not_ verified by headscale developers. **It might be outdated and it might miss necessary steps**. @@ -11,15 +11,19 @@ Running headscale behind a reverse proxy is useful when running multiple applica ### WebSockets -The reverse proxy MUST be configured to support WebSockets, as it is needed for clients running Tailscale v1.30+. +The reverse proxy MUST be configured to support WebSockets to communicate with Tailscale clients. -WebSockets support is required when using the headscale embedded DERP server. In this case, you will also need to expose the UDP port used for STUN (by default, udp/3478). Please check our [config-example.yaml](https://github.com/juanfont/headscale/blob/main/config-example.yaml). +WebSockets support is also required when using the Headscale [embedded DERP server](../derp.md). In this case, you will also need to expose the UDP port used for STUN (by default, udp/3478). Please check our [config-example.yaml](https://github.com/juanfont/headscale/blob/main/config-example.yaml). + +### Cloudflare + +Running headscale behind a cloudflare proxy or cloudflare tunnel is not supported and will not work as Cloudflare does not support WebSocket POSTs as required by the Tailscale protocol. See [this issue](https://github.com/juanfont/headscale/issues/1468) ### TLS Headscale can be configured not to use TLS, leaving it to the reverse proxy to handle. Add the following configuration values to your headscale config file. -```yaml +```yaml title="config.yaml" server_url: https://<YOUR_SERVER_NAME> # This should be the FQDN at which headscale will be served listen_addr: 0.0.0.0:8080 metrics_listen_addr: 0.0.0.0:9090 @@ -31,10 +35,9 @@ tls_key_path: "" The following example configuration can be used in your nginx setup, substituting values as necessary. `<IP:PORT>` should be the IP address and port where headscale is running. In most cases, this will be `http://localhost:8080`. -```Nginx +```nginx title="nginx.conf" map $http_upgrade $connection_upgrade { - default keep-alive; - 'websocket' upgrade; + default upgrade; '' close; } @@ -61,7 +64,7 @@ server { proxy_buffering off; proxy_set_header X-Real-IP $remote_addr; proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for; - proxy_set_header X-Forwarded-Proto $http_x_forwarded_proto; + proxy_set_header X-Forwarded-Proto $scheme; add_header Strict-Transport-Security "max-age=15552000; includeSubDomains" always; } } @@ -77,7 +80,7 @@ Sending local reply with details upgrade_failed ### Envoy -You need add a new upgrade_type named `tailscale-control-protocol`. [see detail](https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto#extensions-filters-network-http-connection-manager-v3-httpconnectionmanager-upgradeconfig) +You need to add a new upgrade_type named `tailscale-control-protocol`. [see details](https://www.envoyproxy.io/docs/envoy/latest/api-v3/extensions/filters/network/http_connection_manager/v3/http_connection_manager.proto#extensions-filters-network-http-connection-manager-v3-httpconnectionmanager-upgradeconfig) ### Istio @@ -110,21 +113,21 @@ spec: The following Caddyfile is all that is necessary to use Caddy as a reverse proxy for headscale, in combination with the `config.yaml` specifications above to disable headscale's built in TLS. Replace values as necessary - `<YOUR_SERVER_NAME>` should be the FQDN at which headscale will be served, and `<IP:PORT>` should be the IP address and port where headscale is running. In most cases, this will be `localhost:8080`. -``` +```none title="Caddyfile" <YOUR_SERVER_NAME> { reverse_proxy <IP:PORT> } ``` -Caddy v2 will [automatically](https://caddyserver.com/docs/automatic-https) provision a certficate for your domain/subdomain, force HTTPS, and proxy websockets - no further configuration is necessary. +Caddy v2 will [automatically](https://caddyserver.com/docs/automatic-https) provision a certificate for your domain/subdomain, force HTTPS, and proxy websockets - no further configuration is necessary. -For a slightly more complex configuration which utilizes Docker containers to manage Caddy, Headscale, and Headscale-UI, [Guru Computing's guide](https://blog.gurucomputing.com.au/smart-vpns-with-headscale/) is an excellent reference. +For a slightly more complex configuration which utilizes Docker containers to manage Caddy, headscale, and Headscale-UI, [Guru Computing's guide](https://blog.gurucomputing.com.au/smart-vpns-with-headscale/) is an excellent reference. ## Apache -The following minimal Apache config will proxy traffic to the Headscale instance on `<IP:PORT>`. Note that `upgrade=any` is required as a parameter for `ProxyPass` so that WebSockets traffic whose `Upgrade` header value is not equal to `WebSocket` (i. e. Tailscale Control Protocol) is forwarded correctly. See the [Apache docs](https://httpd.apache.org/docs/2.4/mod/mod_proxy_wstunnel.html) for more information on this. +The following minimal Apache config will proxy traffic to the headscale instance on `<IP:PORT>`. Note that `upgrade=any` is required as a parameter for `ProxyPass` so that WebSockets traffic whose `Upgrade` header value is not equal to `WebSocket` (i. e. Tailscale Control Protocol) is forwarded correctly. See the [Apache docs](https://httpd.apache.org/docs/2.4/mod/mod_proxy_wstunnel.html) for more information on this. -``` +```apache title="apache.conf" <VirtualHost *:443> ServerName <YOUR_SERVER_NAME> diff --git a/docs/ref/integration/tools.md b/docs/ref/integration/tools.md new file mode 100644 index 00000000..2cf7d619 --- /dev/null +++ b/docs/ref/integration/tools.md @@ -0,0 +1,22 @@ +# Tools related to headscale + +!!! warning "Community contributions" + + This page contains community contributions. The projects listed here are not + maintained by the headscale authors and are written by community members. + +This page collects third-party tools, client libraries, and scripts related to headscale. + +- [headscale-operator](https://github.com/infradohq/headscale-operator) - Headscale Kubernetes Operator +- [tailscale-manager](https://github.com/singlestore-labs/tailscale-manager) - Dynamically manage Tailscale route + advertisements +- [headscalebacktosqlite](https://github.com/bigbozza/headscalebacktosqlite) - Migrate headscale from PostgreSQL back to + SQLite +- [headscale-pf](https://github.com/YouSysAdmin/headscale-pf) - Populates user groups based on user groups in Jumpcloud + or Authentik +- [headscale-client-go](https://github.com/hibare/headscale-client-go) - A Go client implementation for the Headscale + HTTP API. +- [headscale-zabbix](https://github.com/dblanque/headscale-zabbix) - A Zabbix Monitoring Template for the Headscale + Service. +- [tailscale-exporter](https://github.com/adinhodovic/tailscale-exporter) - A Prometheus exporter for Headscale that + provides network-level metrics using the Headscale API. diff --git a/docs/ref/integration/web-ui.md b/docs/ref/integration/web-ui.md new file mode 100644 index 00000000..12238b94 --- /dev/null +++ b/docs/ref/integration/web-ui.md @@ -0,0 +1,24 @@ +# Web interfaces for headscale + +!!! warning "Community contributions" + + This page contains community contributions. The projects listed here are not + maintained by the headscale authors and are written by community members. + +Headscale doesn't provide a built-in web interface but users may pick one from the available options. + +- [headscale-ui](https://github.com/gurucomputing/headscale-ui) - A web frontend for the headscale Tailscale-compatible + coordination server +- [HeadscaleUi](https://github.com/simcu/headscale-ui) - A static headscale admin ui, no backend environment required +- [Headplane](https://github.com/tale/headplane) - An advanced Tailscale inspired frontend for headscale +- [headscale-admin](https://github.com/GoodiesHQ/headscale-admin) - Headscale-Admin is meant to be a simple, modern web + interface for headscale +- [ouroboros](https://github.com/yellowsink/ouroboros) - Ouroboros is designed for users to manage their own devices, + rather than for admins +- [unraid-headscale-admin](https://github.com/ich777/unraid-headscale-admin) - A simple headscale admin UI for Unraid, + it offers Local (`docker exec`) and API Mode +- [headscale-console](https://github.com/rickli-cloud/headscale-console) - WebAssembly-based client supporting SSH, VNC + and RDP with optional self-service capabilities +- [headscale-piying](https://github.com/wszgrcy/headscale-piying) - headscale web ui,support visual ACL configuration + +You can ask for support on our [Discord server](https://discord.gg/c84AZQhmpx) in the "web-interfaces" channel. diff --git a/docs/ref/oidc.md b/docs/ref/oidc.md new file mode 100644 index 00000000..f6ec1bcd --- /dev/null +++ b/docs/ref/oidc.md @@ -0,0 +1,334 @@ +# OpenID Connect + +Headscale supports authentication via external identity providers using OpenID Connect (OIDC). It features: + +- Auto configuration via OpenID Connect Discovery Protocol +- [Proof Key for Code Exchange (PKCE) code verification](#enable-pkce-recommended) +- [Authorization based on a user's domain, email address or group membership](#authorize-users-with-filters) +- Synchronization of [standard OIDC claims](#supported-oidc-claims) + +Please see [limitations](#limitations) for known issues and limitations. + +## Configuration + +OpenID requires configuration in Headscale and your identity provider: + +- Headscale: The `oidc` section of the Headscale [configuration](configuration.md) contains all available configuration + options along with a description and their default values. +- Identity provider: Please refer to the official documentation of your identity provider for specific instructions. + Additionally, there might be some useful hints in the [Identity provider specific + configuration](#identity-provider-specific-configuration) section below. + +### Basic configuration + +A basic configuration connects Headscale to an identity provider and typically requires: + +- OpenID Connect Issuer URL from the identity provider. Headscale uses the OpenID Connect Discovery Protocol 1.0 to + automatically obtain OpenID configuration parameters (example: `https://sso.example.com`). +- Client ID from the identity provider (example: `headscale`). +- Client secret generated by the identity provider (example: `generated-secret`). +- Redirect URI for your identity provider (example: `https://headscale.example.com/oidc/callback`). + +=== "Headscale" + + ```yaml + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + ``` + +=== "Identity provider" + + * Create a new confidential client (`Client ID`, `Client secret`) + * Add Headscale's OIDC callback URL as valid redirect URL: `https://headscale.example.com/oidc/callback` + * Configure additional parameters to improve user experience such as: name, description, logo, … + +### Enable PKCE (recommended) + +Proof Key for Code Exchange (PKCE) adds an additional layer of security to the OAuth 2.0 authorization code flow by +preventing authorization code interception attacks, see: <https://datatracker.ietf.org/doc/html/rfc7636>. PKCE is +recommended and needs to be configured for Headscale and the identity provider alike: + +=== "Headscale" + + ```yaml hl_lines="5-6" + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + pkce: + enabled: true + ``` + +=== "Identity provider" + + * Enable PKCE for the headscale client + * Set the PKCE challenge method to "S256" + +### Authorize users with filters + +Headscale allows to filter for allowed users based on their domain, email address or group membership. These filters can +be helpful to apply additional restrictions and control which users are allowed to join. Filters are disabled by +default, users are allowed to join once the authentication with the identity provider succeeds. In case multiple filters +are configured, a user needs to pass all of them. + +=== "Allowed domains" + + * Check the email domain of each authenticating user against the list of allowed domains and only authorize users + whose email domain matches `example.com`. + * A verified email address is required [unless email verification is disabled](#control-email-verification). + * Access allowed: `alice@example.com` + * Access denied: `bob@example.net` + + ```yaml hl_lines="5-6" + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + allowed_domains: + - "example.com" + ``` + +=== "Allowed users/emails" + + * Check the email address of each authenticating user against the list of allowed email addresses and only authorize + users whose email is part of the `allowed_users` list. + * A verified email address is required [unless email verification is disabled](#control-email-verification). + * Access allowed: `alice@example.com`, `bob@example.net` + * Access denied: `mallory@example.net` + + ```yaml hl_lines="5-7" + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + allowed_users: + - "alice@example.com" + - "bob@example.net" + ``` + +=== "Allowed groups" + + * Use the OIDC `groups` claim of each authenticating user to get their group membership and only authorize users + which are members in at least one of the referenced groups. + * Access allowed: users in the `headscale_users` group + * Access denied: users without groups, users with other groups + + ```yaml hl_lines="5-7" + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + scope: ["openid", "profile", "email", "groups"] + allowed_groups: + - "headscale_users" + ``` + +### Control email verification + +Headscale uses the `email` claim from the identity provider to synchronize the email address to its user profile. By +default, a user's email address is only synchronized when the identity provider reports the email address as verified +via the `email_verified: true` claim. + +Unverified emails may be allowed in case an identity provider does not send the `email_verified` claim or email +verification is not required. In that case, a user's email address is always synchronized to the user profile. + +```yaml hl_lines="5" +oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + email_verified_required: false +``` + +### Customize node expiration + +The node expiration is the amount of time a node is authenticated with OpenID Connect until it expires and needs to +reauthenticate. The default node expiration is 180 days. This can either be customized or set to the expiration from the +Access Token. + +=== "Customize node expiration" + + ```yaml hl_lines="5" + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + expiry: 30d # Use 0 to disable node expiration + ``` + +=== "Use expiration from Access Token" + + Please keep in mind that the Access Token is typically a short-lived token that expires within a few minutes. You + will have to configure token expiration in your identity provider to avoid frequent re-authentication. + + + ```yaml hl_lines="5" + oidc: + issuer: "https://sso.example.com" + client_id: "headscale" + client_secret: "generated-secret" + use_expiry_from_token: true + ``` + +!!! tip "Expire a node and force re-authentication" + + A node can be expired immediately via: + ```console + headscale node expire -i <NODE_ID> + ``` + +### Reference a user in the policy + +You may refer to users in the Headscale policy via: + +- Email address +- Username +- Provider identifier (only available in the database or from your identity provider) + +!!! note "A user identifier in the policy must contain a single `@`" + + The Headscale policy requires a single `@` to reference a user. If the username or provider identifier doesn't + already contain a single `@`, it needs to be appended at the end. For example: the username `ssmith` has to be + written as `ssmith@` to be correctly identified as user within the policy. + +!!! warning "Email address or username might be updated by users" + + Many identity providers allow users to update their own profile. Depending on the identity provider and its + configuration, the values for username or email address might change over time. This might have unexpected + consequences for Headscale where a policy might no longer work or a user might obtain more access by hijacking an + existing username or email address. + +## Supported OIDC claims + +Headscale uses [the standard OIDC claims](https://openid.net/specs/openid-connect-core-1_0.html#StandardClaims) to +populate and update its local user profile on each login. OIDC claims are read from the ID Token and from the UserInfo +endpoint. + +| Headscale profile | OIDC claim | Notes / examples | +| ------------------- | -------------------- | ------------------------------------------------------------------------------------------------- | +| email address | `email` | Only verified emails are synchronized, unless `email_verified_required: false` is configured | +| display name | `name` | eg: `Sam Smith` | +| username | `preferred_username` | Depends on identity provider, eg: `ssmith`, `ssmith@idp.example.com`, `\\example.com\ssmith` | +| profile picture | `picture` | URL to a profile picture or avatar | +| provider identifier | `iss`, `sub` | A stable and unique identifier for a user, typically a combination of `iss` and `sub` OIDC claims | +| | `groups` | [Only used to filter for allowed groups](#authorize-users-with-filters) | + +## Limitations + +- Support for OpenID Connect aims to be generic and vendor independent. It offers only limited support for quirks of + specific identity providers. +- OIDC groups cannot be used in ACLs. +- The username provided by the identity provider needs to adhere to this pattern: + - The username must be at least two characters long. + - It must only contain letters, digits, hyphens, dots, underscores, and up to a single `@`. + - The username must start with a letter. + +Please see the [GitHub label "OIDC"](https://github.com/juanfont/headscale/labels/OIDC) for OIDC related issues. + +## Identity provider specific configuration + +!!! warning "Third-party software and services" + + This section of the documentation is specific for third-party software and services. We recommend users read the + third-party documentation on how to configure and integrate an OIDC client. Please see the [Configuration + section](#configuration) for a description of Headscale's OIDC related configuration settings. + +Any identity provider with OpenID Connect support should "just work" with Headscale. The following identity providers +are known to work: + +- [Authelia](#authelia) +- [Authentik](#authentik) +- [Kanidm](#kanidm) +- [Keycloak](#keycloak) + +### Authelia + +Authelia is fully supported by Headscale. + +### Authentik + +- Authentik is fully supported by Headscale. +- [Headscale does not JSON Web Encryption](https://github.com/juanfont/headscale/issues/2446). Leave the field + `Encryption Key` in the providers section unset. + +### Google OAuth + +!!! warning "No username due to missing preferred_username" + + Google OAuth does not send the `preferred_username` claim when the scope `profile` is requested. The username in + Headscale will be blank/not set. + +In order to integrate Headscale with Google, you'll need to have a [Google Cloud +Console](https://console.cloud.google.com) account. + +Google OAuth has a [verification process](https://support.google.com/cloud/answer/9110914?hl=en) if you need to have +users authenticate who are outside of your domain. If you only need to authenticate users from your domain name (ie +`@example.com`), you don't need to go through the verification process. + +However if you don't have a domain, or need to add users outside of your domain, you can manually add emails via Google +Console. + +#### Steps + +1. Go to [Google Console](https://console.cloud.google.com) and login or create an account if you don't have one. +2. Create a project (if you don't already have one). +3. On the left hand menu, go to `APIs and services` -> `Credentials` +4. Click `Create Credentials` -> `OAuth client ID` +5. Under `Application Type`, choose `Web Application` +6. For `Name`, enter whatever you like +7. Under `Authorised redirect URIs`, add Headscale's OIDC callback URL: `https://headscale.example.com/oidc/callback` +8. Click `Save` at the bottom of the form +9. Take note of the `Client ID` and `Client secret`, you can also download it for reference if you need it. +10. [Configure Headscale following the "Basic configuration" steps](#basic-configuration). The issuer URL for Google + OAuth is: `https://accounts.google.com`. + +### Kanidm + +- Kanidm is fully supported by Headscale. +- Groups for the [allowed groups filter](#authorize-users-with-filters) need to be specified with their full SPN, for + example: `headscale_users@sso.example.com`. + +### Keycloak + +Keycloak is fully supported by Headscale. + +#### Additional configuration to use the allowed groups filter + +Keycloak has no built-in client scope for the OIDC `groups` claim. This extra configuration step is **only** needed if +you need to [authorize access based on group membership](#authorize-users-with-filters). + +- Create a new client scope `groups` for OpenID Connect: + - Configure a `Group Membership` mapper with name `groups` and the token claim name `groups`. + - Add the mapper to at least the UserInfo endpoint. +- Configure the new client scope for your Headscale client: + - Edit the Headscale client. + - Search for the client scope `group`. + - Add it with assigned type `Default`. +- [Configure the allowed groups in Headscale](#authorize-users-with-filters). How groups need to be specified depends on + Keycloak's `Full group path` option: + - `Full group path` is enabled: groups contain their full path, e.g. `/top/group1` + - `Full group path` is disabled: only the name of the group is used, e.g. `group1` + +### Microsoft Entra ID + +In order to integrate Headscale with Microsoft Entra ID, you'll need to provision an App Registration with the correct +scopes and redirect URI. + +[Configure Headscale following the "Basic configuration" steps](#basic-configuration). The issuer URL for Microsoft +Entra ID is: `https://login.microsoftonline.com/<tenant-UUID>/v2.0`. The following `extra_params` might be useful: + +- `domain_hint: example.com` to use your own domain +- `prompt: select_account` to force an account picker during login + +When using Microsoft Entra ID together with the [allowed groups filter](#authorize-users-with-filters), configure the +Headscale OIDC scope without the `groups` claim, for example: + +```yaml +oidc: + scope: ["openid", "profile", "email"] +``` + +Groups for the [allowed groups filter](#authorize-users-with-filters) need to be specified with their group ID(UUID) instead +of the group name. diff --git a/docs/ref/routes.md b/docs/ref/routes.md new file mode 100644 index 00000000..af8a3778 --- /dev/null +++ b/docs/ref/routes.md @@ -0,0 +1,307 @@ +# Routes + +Headscale supports route advertising and can be used to manage [subnet routers](https://tailscale.com/kb/1019/subnets) +and [exit nodes](https://tailscale.com/kb/1103/exit-nodes) for a tailnet. + +- [Subnet routers](#subnet-router) may be used to connect an existing network such as a virtual + private cloud or an on-premise network with your tailnet. Use a subnet router to access devices where Tailscale can't + be installed or to gradually rollout Tailscale. +- [Exit nodes](#exit-node) can be used to route all Internet traffic for another Tailscale + node. Use it to securely access the Internet on an untrusted Wi-Fi or to access online services that expect traffic + from a specific IP address. + +## Subnet router + +The setup of a subnet router requires double opt-in, once from a subnet router and once on the control server to allow +its use within the tailnet. Optionally, use [`autoApprovers` to automatically approve routes from a subnet +router](#automatically-approve-routes-of-a-subnet-router). + +### Setup a subnet router + +#### Configure a node as subnet router + +Register a node and advertise the routes it should handle as comma separated list: + +```console +$ sudo tailscale up --login-server <YOUR_HEADSCALE_URL> --advertise-routes=10.0.0.0/8,192.168.0.0/24 +``` + +If the node is already registered, it can advertise new routes or update previously announced routes with: + +```console +$ sudo tailscale set --advertise-routes=10.0.0.0/8,192.168.0.0/24 +``` + +Finally, [enable IP forwarding](#enable-ip-forwarding) to route traffic. + +#### Enable the subnet router on the control server + +The routes of a tailnet can be displayed with the `headscale nodes list-routes` command. A subnet router with the +hostname `myrouter` announced the IPv4 networks `10.0.0.0/8` and `192.168.0.0/24`. Those need to be approved before they +can be used. + +```console +$ headscale nodes list-routes +ID | Hostname | Approved | Available | Serving (Primary) +1 | myrouter | | 10.0.0.0/8 | + | | | 192.168.0.0/24 | +``` + +Approve all desired routes of a subnet router by specifying them as comma separated list: + +```console +$ headscale nodes approve-routes --identifier 1 --routes 10.0.0.0/8,192.168.0.0/24 +Node updated +``` + +The node `myrouter` can now route the IPv4 networks `10.0.0.0/8` and `192.168.0.0/24` for the tailnet. + +```console +$ headscale nodes list-routes +ID | Hostname | Approved | Available | Serving (Primary) +1 | myrouter | 10.0.0.0/8 | 10.0.0.0/8 | 10.0.0.0/8 + | | 192.168.0.0/24 | 192.168.0.0/24 | 192.168.0.0/24 +``` + +#### Use the subnet router + +To accept routes advertised by a subnet router on a node: + +```console +$ sudo tailscale set --accept-routes +``` + +Please refer to the official [Tailscale +documentation](https://tailscale.com/kb/1019/subnets#use-your-subnet-routes-from-other-devices) for how to use a subnet +router on different operating systems. + +### Restrict the use of a subnet router with ACL + +The routes announced by subnet routers are available to the nodes in a tailnet. By default, without an ACL enabled, all +nodes can accept and use such routes. Configure an ACL to explicitly manage who can use routes. + +The ACL snippet below defines three hosts, a subnet router `router`, a regular node `node` and `service.example.net` as +internal service that can be reached via a route on the subnet router `router`. It allows the node `node` to access +`service.example.net` on port 80 and 443 which is reachable via the subnet router. Access to the subnet router itself is +denied. + +```json title="Access the routes of a subnet router without the subnet router itself" +{ + "hosts": { + // the router is not referenced but announces 192.168.0.0/24" + "router": "100.64.0.1/32", + "node": "100.64.0.2/32", + "service.example.net": "192.168.0.1/32" + }, + "acls": [ + { + "action": "accept", + "src": ["node"], + "dst": ["service.example.net:80,443"] + } + ] +} +``` + +### Automatically approve routes of a subnet router + +The initial setup of a subnet router usually requires manual approval of their announced routes on the control server +before they can be used by a node in a tailnet. Headscale supports the `autoApprovers` section of an ACL to automate the +approval of routes served with a subnet router. + +The ACL snippet below defines the tag `tag:router` owned by the user `alice`. This tag is used for `routes` in the +`autoApprovers` section. The IPv4 route `192.168.0.0/24` is automatically approved once announced by a subnet router +that advertises the tag `tag:router`. + +```json title="Subnet routers tagged with tag:router are automatically approved" +{ + "tagOwners": { + "tag:router": ["alice@"] + }, + "autoApprovers": { + "routes": { + "192.168.0.0/24": ["tag:router"] + } + }, + "acls": [ + // more rules + ] +} +``` + +Advertise the route `192.168.0.0/24` from a subnet router that also advertises the tag `tag:router` when joining the tailnet: + +```console +$ sudo tailscale up --login-server <YOUR_HEADSCALE_URL> --advertise-tags tag:router --advertise-routes 192.168.0.0/24 +``` + +Please see the [official Tailscale documentation](https://tailscale.com/kb/1337/acl-syntax#autoapprovers) for more +information on auto approvers. + +## Exit node + +The setup of an exit node requires double opt-in, once from an exit node and once on the control server to allow its use +within the tailnet. Optionally, use [`autoApprovers` to automatically approve an exit +node](#automatically-approve-an-exit-node-with-auto-approvers). + +### Setup an exit node + +#### Configure a node as exit node + +Register a node and make it advertise itself as an exit node: + +```console +$ sudo tailscale up --login-server <YOUR_HEADSCALE_URL> --advertise-exit-node +``` + +If the node is already registered, it can advertise exit capabilities like this: + +```console +$ sudo tailscale set --advertise-exit-node +``` + +Finally, [enable IP forwarding](#enable-ip-forwarding) to route traffic. + +#### Enable the exit node on the control server + +The routes of a tailnet can be displayed with the `headscale nodes list-routes` command. An exit node can be recognized +by its announced routes: `0.0.0.0/0` for IPv4 and `::/0` for IPv6. The exit node with the hostname `myexit` is already +available, but needs to be approved: + +```console +$ headscale nodes list-routes +ID | Hostname | Approved | Available | Serving (Primary) +1 | myexit | | 0.0.0.0/0 | + | | | ::/0 | +``` + +For exit nodes, it is sufficient to approve either the IPv4 or IPv6 route. The other will be approved automatically. + +```console +$ headscale nodes approve-routes --identifier 1 --routes 0.0.0.0/0 +Node updated +``` + +The node `myexit` is now approved as exit node for the tailnet: + +```console +$ headscale nodes list-routes +ID | Hostname | Approved | Available | Serving (Primary) +1 | myexit | 0.0.0.0/0 | 0.0.0.0/0 | 0.0.0.0/0 + | | ::/0 | ::/0 | ::/0 +``` + +#### Use the exit node + +The exit node can now be used on a node with: + +```console +$ sudo tailscale set --exit-node myexit +``` + +Please refer to the official [Tailscale documentation](https://tailscale.com/kb/1103/exit-nodes#use-the-exit-node) for +how to use an exit node on different operating systems. + +### Restrict the use of an exit node with ACL + +An exit node is offered to all nodes in a tailnet. By default, without an ACL enabled, all nodes in a tailnet can select +and use an exit node. Configure `autogroup:internet` in an ACL rule to restrict who can use _any_ of the available exit +nodes. + +```json title="Example use of autogroup:internet" +{ + "acls": [ + { + "action": "accept", + "src": ["..."], + "dst": ["autogroup:internet:*"] + } + ] +} +``` + +### Restrict access to exit nodes per user or group + +A user can use _any_ of the available exit nodes with `autogroup:internet`. Alternatively, the ACL snippet below assigns +each user a specific exit node while hiding all other exit nodes. The user `alice` can only use exit node `exit1` while +user `bob` can only use exit node `exit2`. + +```json title="Assign each user a dedicated exit node" +{ + "hosts": { + "exit1": "100.64.0.1/32", + "exit2": "100.64.0.2/32" + }, + "acls": [ + { + "action": "accept", + "src": ["alice@"], + "dst": ["exit1:*"] + }, + { + "action": "accept", + "src": ["bob@"], + "dst": ["exit2:*"] + } + ] +} +``` + +!!! warning + + - The above implementation is Headscale specific and will likely be removed once [support for + `via`](https://github.com/juanfont/headscale/issues/2409) is available. + - Beware that a user can also connect to any port of the exit node itself. + +### Automatically approve an exit node with auto approvers + +The initial setup of an exit node usually requires manual approval on the control server before it can be used by a node +in a tailnet. Headscale supports the `autoApprovers` section of an ACL to automate the approval of a new exit node as +soon as it joins the tailnet. + +The ACL snippet below defines the tag `tag:exit` owned by the user `alice`. This tag is used for `exitNode` in the +`autoApprovers` section. A new exit node that advertises the tag `tag:exit` is automatically approved: + +```json title="Exit nodes tagged with tag:exit are automatically approved" +{ + "tagOwners": { + "tag:exit": ["alice@"] + }, + "autoApprovers": { + "exitNode": ["tag:exit"] + }, + "acls": [ + // more rules + ] +} +``` + +Advertise a node as exit node and also advertise the tag `tag:exit` when joining the tailnet: + +```console +$ sudo tailscale up --login-server <YOUR_HEADSCALE_URL> --advertise-tags tag:exit --advertise-exit-node +``` + +Please see the [official Tailscale documentation](https://tailscale.com/kb/1337/acl-syntax#autoapprovers) for more +information on auto approvers. + +## High availability + +Headscale has limited support for high availability routing. Multiple subnet routers with overlapping routes or multiple +exit nodes can be used to provide high availability for users. If one router node goes offline, another one can serve +the same routes to clients. Please see the official [Tailscale documentation on high +availability](https://tailscale.com/kb/1115/high-availability#subnet-router-high-availability) for details. + +!!! bug + + In certain situations it might take up to 16 minutes for Headscale to detect a node as offline. A failover node + might not be selected fast enough, if such a node is used as subnet router or exit node causing service + interruptions for clients. See [issue 2129](https://github.com/juanfont/headscale/issues/2129) for more information. + +## Troubleshooting + +### Enable IP forwarding + +A subnet router or exit node is routing traffic on behalf of other nodes and thus requires IP forwarding. Check the +official [Tailscale documentation](https://tailscale.com/kb/1019/subnets/?tab=linux#enable-ip-forwarding) for how to +enable IP forwarding. diff --git a/docs/ref/tls.md b/docs/ref/tls.md new file mode 100644 index 00000000..527646b4 --- /dev/null +++ b/docs/ref/tls.md @@ -0,0 +1,78 @@ +# Running the service via TLS (optional) + +## Bring your own certificate + +Headscale can be configured to expose its web service via TLS. To configure the certificate and key file manually, set the `tls_cert_path` and `tls_key_path` configuration parameters. If the path is relative, it will be interpreted as relative to the directory the configuration file was read from. + +```yaml title="config.yaml" +tls_cert_path: "" +tls_key_path: "" +``` + +The certificate should contain the full chain, else some clients, like the Tailscale Android client, will reject it. + +## Let's Encrypt / ACME + +To get a certificate automatically via [Let's Encrypt](https://letsencrypt.org/), set `tls_letsencrypt_hostname` to the desired certificate hostname. This name must resolve to the IP address(es) headscale is reachable on (i.e., it must correspond to the `server_url` configuration parameter). The certificate and Let's Encrypt account credentials will be stored in the directory configured in `tls_letsencrypt_cache_dir`. If the path is relative, it will be interpreted as relative to the directory the configuration file was read from. + +```yaml title="config.yaml" +tls_letsencrypt_hostname: "" +tls_letsencrypt_listen: ":http" +tls_letsencrypt_cache_dir: ".cache" +tls_letsencrypt_challenge_type: HTTP-01 +``` + +### Challenge types + +Headscale only supports two values for `tls_letsencrypt_challenge_type`: `HTTP-01` (default) and `TLS-ALPN-01`. + +#### HTTP-01 + +For `HTTP-01`, headscale must be reachable on port 80 for the Let's Encrypt automated validation, in addition to whatever port is configured in `listen_addr`. By default, headscale listens on port 80 on all local IPs for Let's Encrypt automated validation. + +If you need to change the ip and/or port used by headscale for the Let's Encrypt validation process, set `tls_letsencrypt_listen` to the appropriate value. This can be handy if you are running headscale as a non-root user (or can't run `setcap`). Keep in mind, however, that Let's Encrypt will _only_ connect to port 80 for the validation callback, so if you change `tls_letsencrypt_listen` you will also need to configure something else (e.g. a firewall rule) to forward the traffic from port 80 to the ip:port combination specified in `tls_letsencrypt_listen`. + +#### TLS-ALPN-01 + +For `TLS-ALPN-01`, headscale listens on the ip:port combination defined in `listen_addr`. Let's Encrypt will _only_ connect to port 443 for the validation callback, so if `listen_addr` is not set to port 443, something else (e.g. a firewall rule) will be required to forward the traffic from port 443 to the ip:port combination specified in `listen_addr`. + +### Technical description + +Headscale uses [autocert](https://pkg.go.dev/golang.org/x/crypto/acme/autocert), a Golang library providing [ACME protocol](https://en.wikipedia.org/wiki/Automatic_Certificate_Management_Environment) verification, to facilitate certificate renewals via [Let's Encrypt](https://letsencrypt.org/about/). Certificates will be renewed automatically, and the following can be expected: + +- Certificates provided from Let's Encrypt have a validity of 3 months from date issued. +- Renewals are only attempted by headscale when 30 days or less remains until certificate expiry. +- Renewal attempts by autocert are triggered at a random interval of 30-60 minutes. +- No log output is generated when renewals are skipped, or successful. + +#### Checking certificate expiry + +If you want to validate that certificate renewal completed successfully, this can be done either manually, or through external monitoring software. Two examples of doing this manually: + +1. Open the URL for your headscale server in your browser of choice, and manually inspecting the expiry date of the certificate you receive. +2. Or, check remotely from CLI using `openssl`: + +```console +$ openssl s_client -servername [hostname] -connect [hostname]:443 | openssl x509 -noout -dates +(...) +notBefore=Feb 8 09:48:26 2024 GMT +notAfter=May 8 09:48:25 2024 GMT +``` + +#### Log output from the autocert library + +As these log lines are from the autocert library, they are not strictly generated by headscale itself. + +```plaintext +acme/autocert: missing server name +``` + +Likely caused by an incoming connection that does not specify a hostname, for example a `curl` request directly against the IP of the server, or an unexpected hostname. + +```plaintext +acme/autocert: host "[foo]" not configured in HostWhitelist +``` + +Similarly to the above, this likely indicates an invalid incoming request for an incorrect hostname, commonly just the IP itself. + +The source code for autocert can be found [here](https://cs.opensource.google/go/x/crypto/+/refs/tags/v0.19.0:acme/autocert/autocert.go) diff --git a/docs/remote-cli.md b/docs/remote-cli.md deleted file mode 100644 index 96a6333a..00000000 --- a/docs/remote-cli.md +++ /dev/null @@ -1,100 +0,0 @@ -# Controlling `headscale` with remote CLI - -## Prerequisit - -- A workstation to run `headscale` (could be Linux, macOS, other supported platforms) -- A `headscale` server (version `0.13.0` or newer) -- Access to create API keys (local access to the `headscale` server) -- `headscale` _must_ be served over TLS/HTTPS - - Remote access does _not_ support unencrypted traffic. -- Port `50443` must be open in the firewall (or port overriden by `grpc_listen_addr` option) - -## Goal - -This documentation has the goal of showing a user how-to set control a `headscale` instance -from a remote machine with the `headscale` command line binary. - -## Create an API key - -We need to create an API key to authenticate our remote `headscale` when using it from our workstation. - -To create a API key, log into your `headscale` server and generate a key: - -```shell -headscale apikeys create --expiration 90d -``` - -Copy the output of the command and save it for later. Please note that you can not retrieve a key again, -if the key is lost, expire the old one, and create a new key. - -To list the keys currently assosicated with the server: - -```shell -headscale apikeys list -``` - -and to expire a key: - -```shell -headscale apikeys expire --prefix "<PREFIX>" -``` - -## Download and configure `headscale` - -1. Download the latest [`headscale` binary from GitHub's release page](https://github.com/juanfont/headscale/releases): - -2. Put the binary somewhere in your `PATH`, e.g. `/usr/local/bin/headscale` - -3. Make `headscale` executable: - -```shell -chmod +x /usr/local/bin/headscale -``` - -4. Configure the CLI through Environment Variables - -```shell -export HEADSCALE_CLI_ADDRESS="<HEADSCALE ADDRESS>:<PORT>" -export HEADSCALE_CLI_API_KEY="<API KEY FROM PREVIOUS STAGE>" -``` - -for example: - -```shell -export HEADSCALE_CLI_ADDRESS="headscale.example.com:50443" -export HEADSCALE_CLI_API_KEY="abcde12345" -``` - -This will tell the `headscale` binary to connect to a remote instance, instead of looking -for a local instance (which is what it does on the server). - -The API key is needed to make sure that your are allowed to access the server. The key is _not_ -needed when running directly on the server, as the connection is local. - -5. Test the connection - -Let us run the headscale command to verify that we can connect by listing our nodes: - -```shell -headscale nodes list -``` - -You should now be able to see a list of your nodes from your workstation, and you can -now control the `headscale` server from your workstation. - -## Behind a proxy - -It is possible to run the gRPC remote endpoint behind a reverse proxy, like Nginx, and have it run on the _same_ port as `headscale`. - -While this is _not a supported_ feature, an example on how this can be set up on -[NixOS is shown here](https://github.com/kradalby/dotfiles/blob/4489cdbb19cddfbfae82cd70448a38fde5a76711/machines/headscale.oracldn/headscale.nix#L61-L91). - -## Troubleshooting - -Checklist: - -- Make sure you have the _same_ `headscale` version on your server and workstation -- Make sure you use version `0.13.0` or newer. -- Verify that your TLS certificate is valid and trusted - - If you do not have access to a trusted certificate (e.g. from Let's Encrypt), add your self signed certificate to the trust store of your OS or - - Set `HEADSCALE_CLI_INSECURE` to 0 in your environement diff --git a/docs/requirements.txt b/docs/requirements.txt index 32bd08c1..65174cd4 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,5 +1,6 @@ -cairosvg~=2.7.1 -mkdocs-material~=9.4.14 -mkdocs-minify-plugin~=0.7.1 -pillow~=10.1.0 - +mike~=2.1 +mkdocs-include-markdown-plugin~=7.1 +mkdocs-macros-plugin~=1.3 +mkdocs-material[imaging]~=9.5 +mkdocs-minify-plugin~=0.7 +mkdocs-redirects~=1.2 diff --git a/docs/running-headscale-container.md b/docs/running-headscale-container.md deleted file mode 100644 index c2663581..00000000 --- a/docs/running-headscale-container.md +++ /dev/null @@ -1,171 +0,0 @@ -# Running headscale in a container - -!!! warning "Community documentation" - - This page is not actively maintained by the headscale authors and is - written by community members. It is _not_ verified by `headscale` developers. - - **It might be outdated and it might miss necessary steps**. - -## Goal - -This documentation has the goal of showing a user how-to set up and run `headscale` in a container. -[Docker](https://www.docker.com) is used as the reference container implementation, but there is no reason that it should -not work with alternatives like [Podman](https://podman.io). The Docker image can be found on Docker Hub [here](https://hub.docker.com/r/headscale/headscale). - -## Configure and run `headscale` - -1. Prepare a directory on the host Docker node in your directory of choice, used to hold `headscale` configuration and the [SQLite](https://www.sqlite.org/) database: - -```shell -mkdir -p ./headscale/config -cd ./headscale -``` - -2. Create an empty SQlite datebase in the headscale directory: - -```shell -touch ./config/db.sqlite -``` - -3. **(Strongly Recommended)** Download a copy of the [example configuration](https://github.com/juanfont/headscale/blob/main/config-example.yaml) from the headscale repository. - -Using wget: - -```shell -wget -O ./config/config.yaml https://raw.githubusercontent.com/juanfont/headscale/main/config-example.yaml -``` - -Using curl: - -```shell -curl https://raw.githubusercontent.com/juanfont/headscale/main/config-example.yaml -o ./config/config.yaml -``` - -**(Advanced)** If you would like to hand craft a config file **instead** of downloading the example config file, create a blank `headscale` configuration in the headscale directory to edit: - -```shell -touch ./config/config.yaml -``` - -Modify the config file to your preferences before launching Docker container. -Here are some settings that you likely want: - -```yaml -# Change to your hostname or host IP -server_url: http://your-host-name:8080 -# Listen to 0.0.0.0 so it's accessible outside the container -metrics_listen_addr: 0.0.0.0:9090 -# The default /var/lib/headscale path is not writable in the container -private_key_path: /etc/headscale/private.key -# The default /var/lib/headscale path is not writable in the container -noise: - private_key_path: /etc/headscale/noise_private.key -# The default /var/lib/headscale path is not writable in the container -db_type: sqlite3 -db_path: /etc/headscale/db.sqlite -``` - -4. Start the headscale server while working in the host headscale directory: - -```shell -docker run \ - --name headscale \ - --detach \ - --volume $(pwd)/config:/etc/headscale/ \ - --publish 127.0.0.1:8080:8080 \ - --publish 127.0.0.1:9090:9090 \ - headscale/headscale:<VERSION> \ - headscale serve - -``` - -Note: use `0.0.0.0:8080:8080` instead of `127.0.0.1:8080:8080` if you want to expose the container externally. - -This command will mount `config/` under `/etc/headscale`, forward port 8080 out of the container so the -`headscale` instance becomes available and then detach so headscale runs in the background. - -5. Verify `headscale` is running: - -Follow the container logs: - -```shell -docker logs --follow headscale -``` - -Verify running containers: - -```shell -docker ps -``` - -Verify `headscale` is available: - -```shell -curl http://127.0.0.1:9090/metrics -``` - -6. Create a user ([tailnet](https://tailscale.com/kb/1136/tailnet/)): - -```shell -docker exec headscale \ - headscale users create myfirstuser -``` - -### Register a machine (normal login) - -On a client machine, execute the `tailscale` login command: - -```shell -tailscale up --login-server YOUR_HEADSCALE_URL -``` - -To register a machine when running `headscale` in a container, take the headscale command and pass it to the container: - -```shell -docker exec headscale \ - headscale --user myfirstuser nodes register --key <YOU_+MACHINE_KEY> -``` - -### Register machine using a pre authenticated key - -Generate a key using the command line: - -```shell -docker exec headscale \ - headscale --user myfirstuser preauthkeys create --reusable --expiration 24h -``` - -This will return a pre-authenticated key that can be used to connect a node to `headscale` during the `tailscale` command: - -```shell -tailscale up --login-server <YOUR_HEADSCALE_URL> --authkey <YOUR_AUTH_KEY> -``` - -## Debugging headscale running in Docker - -The `headscale/headscale` Docker container is based on a "distroless" image that does not contain a shell or any other debug tools. If you need to debug your application running in the Docker container, you can use the `-debug` variant, for example `headscale/headscale:x.x.x-debug`. - -### Running the debug Docker container - -To run the debug Docker container, use the exact same commands as above, but replace `headscale/headscale:x.x.x` with `headscale/headscale:x.x.x-debug` (`x.x.x` is the version of headscale). The two containers are compatible with each other, so you can alternate between them. - -### Executing commands in the debug container - -The default command in the debug container is to run `headscale`, which is located at `/bin/headscale` inside the container. - -Additionally, the debug container includes a minimalist Busybox shell. - -To launch a shell in the container, use: - -``` -docker run -it headscale/headscale:x.x.x-debug sh -``` - -You can also execute commands directly, such as `ls /bin` in this example: - -``` -docker run headscale/headscale:x.x.x-debug ls /bin -``` - -Using `docker exec` allows you to run commands in an existing container. diff --git a/docs/running-headscale-linux-manual.md b/docs/running-headscale-linux-manual.md deleted file mode 100644 index 03619d7a..00000000 --- a/docs/running-headscale-linux-manual.md +++ /dev/null @@ -1,198 +0,0 @@ -# Running headscale on Linux - -## Note: Outdated and "advanced" - -This documentation is considered the "legacy"/advanced/manual version of the documentation, you most likely do not -want to use this documentation and rather look at the distro specific documentation (TODO LINK)[]. - -## Goal - -This documentation has the goal of showing a user how-to set up and run `headscale` on Linux. -In additional to the "get up and running section", there is an optional [SystemD section](#running-headscale-in-the-background-with-systemd) -describing how to make `headscale` run properly in a server environment. - -## Configure and run `headscale` - -1. Download the latest [`headscale` binary from GitHub's release page](https://github.com/juanfont/headscale/releases): - -```shell -wget --output-document=/usr/local/bin/headscale \ - https://github.com/juanfont/headscale/releases/download/v<HEADSCALE VERSION>/headscale_<HEADSCALE VERSION>_linux_<ARCH> -``` - -2. Make `headscale` executable: - -```shell -chmod +x /usr/local/bin/headscale -``` - -3. Prepare a directory to hold `headscale` configuration and the [SQLite](https://www.sqlite.org/) database: - -```shell -# Directory for configuration - -mkdir -p /etc/headscale - -# Directory for Database, and other variable data (like certificates) -mkdir -p /var/lib/headscale -# or if you create a headscale user: -useradd \ - --create-home \ - --home-dir /var/lib/headscale/ \ - --system \ - --user-group \ - --shell /usr/bin/nologin \ - headscale -``` - -4. Create an empty SQLite database: - -```shell -touch /var/lib/headscale/db.sqlite -``` - -5. Create a `headscale` configuration: - -```shell -touch /etc/headscale/config.yaml -``` - -**(Strongly Recommended)** Download a copy of the [example configuration][config-example.yaml](https://github.com/juanfont/headscale/blob/main/config-example.yaml) from the headscale repository. - -6. Start the headscale server: - -```shell -headscale serve -``` - -This command will start `headscale` in the current terminal session. - ---- - -To continue the tutorial, open a new terminal and let it run in the background. -Alternatively use terminal emulators like [tmux](https://github.com/tmux/tmux) or [screen](https://www.gnu.org/software/screen/). - -To run `headscale` in the background, please follow the steps in the [SystemD section](#running-headscale-in-the-background-with-systemd) before continuing. - -7. Verify `headscale` is running: - -Verify `headscale` is available: - -```shell -curl http://127.0.0.1:9090/metrics -``` - -8. Create a user ([tailnet](https://tailscale.com/kb/1136/tailnet/)): - -```shell -headscale users create myfirstuser -``` - -### Register a machine (normal login) - -On a client machine, execute the `tailscale` login command: - -```shell -tailscale up --login-server YOUR_HEADSCALE_URL -``` - -Register the machine: - -```shell -headscale --user myfirstuser nodes register --key <YOUR_MACHINE_KEY> -``` - -### Register machine using a pre authenticated key - -Generate a key using the command line: - -```shell -headscale --user myfirstuser preauthkeys create --reusable --expiration 24h -``` - -This will return a pre-authenticated key that can be used to connect a node to `headscale` during the `tailscale` command: - -```shell -tailscale up --login-server <YOUR_HEADSCALE_URL> --authkey <YOUR_AUTH_KEY> -``` - -## Running `headscale` in the background with SystemD - -:warning: **Deprecated**: This part is very outdated and you should use the [pre-packaged Headscale for this](./running-headscale-linux.md - -This section demonstrates how to run `headscale` as a service in the background with [SystemD](https://www.freedesktop.org/wiki/Software/systemd/). -This should work on most modern Linux distributions. - -1. Create a SystemD service configuration at `/etc/systemd/system/headscale.service` containing: - -```systemd -[Unit] -Description=headscale controller -After=syslog.target -After=network.target - -[Service] -Type=simple -User=headscale -Group=headscale -ExecStart=/usr/local/bin/headscale serve -Restart=always -RestartSec=5 - -# Optional security enhancements -NoNewPrivileges=yes -PrivateTmp=yes -ProtectSystem=strict -ProtectHome=yes -WorkingDirectory=/var/lib/headscale -ReadWritePaths=/var/lib/headscale /var/run/headscale -AmbientCapabilities=CAP_NET_BIND_SERVICE -RuntimeDirectory=headscale - -[Install] -WantedBy=multi-user.target -``` - -Note that when running as the headscale user ensure that, either you add your current user to the headscale group: - -```shell -usermod -a -G headscale current_user -``` - -or run all headscale commands as the headscale user: - -```shell -su - headscale -``` - -2. In `/etc/headscale/config.yaml`, override the default `headscale` unix socket with path that is writable by the `headscale` user or group: - -```yaml -unix_socket: /var/run/headscale/headscale.sock -``` - -3. Reload SystemD to load the new configuration file: - -```shell -systemctl daemon-reload -``` - -4. Enable and start the new `headscale` service: - -```shell -systemctl enable --now headscale -``` - -5. Verify the headscale service: - -```shell -systemctl status headscale -``` - -Verify `headscale` is available: - -```shell -curl http://127.0.0.1:9090/metrics -``` - -`headscale` will now run in the background and start at boot. diff --git a/docs/running-headscale-linux.md b/docs/running-headscale-linux.md deleted file mode 100644 index 66ccc3d3..00000000 --- a/docs/running-headscale-linux.md +++ /dev/null @@ -1,95 +0,0 @@ -# Running headscale on Linux - -## Requirements - -- Ubuntu 20.04 or newer, Debian 11 or newer. - -## Goal - -Get Headscale up and running. - -This includes running Headscale with SystemD. - -## Migrating from manual install - -If you are migrating from the old manual install, the best thing would be to remove -the files installed by following [the guide in reverse](./running-headscale-linux-manual.md). - -You should _not_ delete the database (`/var/headscale/db.sqlite`) and the -configuration (`/etc/headscale/config.yaml`). - -## Installation - -1. Download the lastest Headscale package for your platform (`.deb` for Ubuntu and Debian) from [Headscale's releases page](https://github.com/juanfont/headscale/releases): - -```shell -wget --output-document=headscale.deb \ - https://github.com/juanfont/headscale/releases/download/v<HEADSCALE VERSION>/headscale_<HEADSCALE VERSION>_linux_<ARCH>.deb -``` - -2. Install Headscale: - -```shell -sudo dpkg --install headscale.deb -``` - -3. Enable Headscale service, this will start Headscale at boot: - -```shell -sudo systemctl enable headscale -``` - -4. Configure Headscale by editing the configuration file: - -```shell -nano /etc/headscale/config.yaml -``` - -5. Start Headscale: - -```shell -sudo systemctl start headscale -``` - -6. Check that Headscale is running as intended: - -```shell -systemctl status headscale -``` - -## Using Headscale - -### Create a user - -```shell -headscale users create myfirstuser -``` - -### Register a machine (normal login) - -On a client machine, run the `tailscale` login command: - -```shell -tailscale up --login-server <YOUR_HEADSCALE_URL> -``` - -Register the machine: - -```shell -headscale --user myfirstuser nodes register --key <YOUR_MACHINE_KEY> -``` - -### Register machine using a pre authenticated key - -Generate a key using the command line: - -```shell -headscale --user myfirstuser preauthkeys create --reusable --expiration 24h -``` - -This will return a pre-authenticated key that is used to -connect a node to `headscale` during the `tailscale` command: - -```shell -tailscale up --login-server <YOUR_HEADSCALE_URL> --authkey <YOUR_AUTH_KEY> -``` diff --git a/docs/running-headscale-openbsd.md b/docs/running-headscale-openbsd.md deleted file mode 100644 index b76c9135..00000000 --- a/docs/running-headscale-openbsd.md +++ /dev/null @@ -1,209 +0,0 @@ -# Running headscale on OpenBSD - -!!! warning "Community documentation" - - This page is not actively maintained by the headscale authors and is - written by community members. It is _not_ verified by `headscale` developers. - - **It might be outdated and it might miss necessary steps**. - -## Goal - -This documentation has the goal of showing a user how-to install and run `headscale` on OpenBSD 7.1. -In additional to the "get up and running section", there is an optional [rc.d section](#running-headscale-in-the-background-with-rcd) -describing how to make `headscale` run properly in a server environment. - -## Install `headscale` - -1. Install from ports (Not Recommend) - - As of OpenBSD 7.2, there's a headscale in ports collection, however, it's severely outdated(v0.12.4). - You can install it via `pkg_add headscale`. - -2. Install from source on OpenBSD 7.2 - -```shell -# Install prerequistes -pkg_add go - -git clone https://github.com/juanfont/headscale.git - -cd headscale - -# optionally checkout a release -# option a. you can find offical relase at https://github.com/juanfont/headscale/releases/latest -# option b. get latest tag, this may be a beta release -latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) - -git checkout $latestTag - -go build -ldflags="-s -w -X github.com/juanfont/headscale/cmd/headscale/cli.Version=$latestTag" github.com/juanfont/headscale - -# make it executable -chmod a+x headscale - -# copy it to /usr/local/sbin -cp headscale /usr/local/sbin -``` - -3. Install from source via cross compile - -```shell -# Install prerequistes -# 1. go v1.20+: headscale newer than 0.21 needs go 1.20+ to compile -# 2. gmake: Makefile in the headscale repo is written in GNU make syntax - -git clone https://github.com/juanfont/headscale.git - -cd headscale - -# optionally checkout a release -# option a. you can find offical relase at https://github.com/juanfont/headscale/releases/latest -# option b. get latest tag, this may be a beta release -latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) - -git checkout $latestTag - -make build GOOS=openbsd - -# copy headscale to openbsd machine and put it in /usr/local/sbin -``` - -## Configure and run `headscale` - -1. Prepare a directory to hold `headscale` configuration and the [SQLite](https://www.sqlite.org/) database: - -```shell -# Directory for configuration - -mkdir -p /etc/headscale - -# Directory for Database, and other variable data (like certificates) -mkdir -p /var/lib/headscale -``` - -2. Create an empty SQLite database: - -```shell -touch /var/lib/headscale/db.sqlite -``` - -3. Create a `headscale` configuration: - -```shell -touch /etc/headscale/config.yaml -``` - -**(Strongly Recommended)** Download a copy of the [example configuration][config-example.yaml](https://github.com/juanfont/headscale/blob/main/config-example.yaml) from the headscale repository. - -4. Start the headscale server: - -```shell -headscale serve -``` - -This command will start `headscale` in the current terminal session. - ---- - -To continue the tutorial, open a new terminal and let it run in the background. -Alternatively use terminal emulators like [tmux](https://github.com/tmux/tmux). - -To run `headscale` in the background, please follow the steps in the [rc.d section](#running-headscale-in-the-background-with-rcd) before continuing. - -5. Verify `headscale` is running: - -Verify `headscale` is available: - -```shell -curl http://127.0.0.1:9090/metrics -``` - -6. Create a user ([tailnet](https://tailscale.com/kb/1136/tailnet/)): - -```shell -headscale users create myfirstuser -``` - -### Register a machine (normal login) - -On a client machine, execute the `tailscale` login command: - -```shell -tailscale up --login-server YOUR_HEADSCALE_URL -``` - -Register the machine: - -```shell -headscale --user myfirstuser nodes register --key <YOU_+MACHINE_KEY> -``` - -### Register machine using a pre authenticated key - -Generate a key using the command line: - -```shell -headscale --user myfirstuser preauthkeys create --reusable --expiration 24h -``` - -This will return a pre-authenticated key that can be used to connect a node to `headscale` during the `tailscale` command: - -```shell -tailscale up --login-server <YOUR_HEADSCALE_URL> --authkey <YOUR_AUTH_KEY> -``` - -## Running `headscale` in the background with rc.d - -This section demonstrates how to run `headscale` as a service in the background with [rc.d](https://man.openbsd.org/rc.d). - -1. Create a rc.d service at `/etc/rc.d/headscale` containing: - -```shell -#!/bin/ksh - -daemon="/usr/local/sbin/headscale" -daemon_logger="daemon.info" -daemon_user="root" -daemon_flags="serve" -daemon_timeout=60 - -. /etc/rc.d/rc.subr - -rc_bg=YES -rc_reload=NO - -rc_cmd $1 -``` - -2. `/etc/rc.d/headscale` needs execute permission: - -```shell -chmod a+x /etc/rc.d/headscale -``` - -3. Start `headscale` service: - -```shell -rcctl start headscale -``` - -4. Make `headscale` service start at boot: - -```shell -rcctl enable headscale -``` - -5. Verify the headscale service: - -```shell -rcctl check headscale -``` - -Verify `headscale` is available: - -```shell -curl http://127.0.0.1:9090/metrics -``` - -`headscale` will now run in the background and start at boot. diff --git a/docs/setup/install/community.md b/docs/setup/install/community.md new file mode 100644 index 00000000..f67725cd --- /dev/null +++ b/docs/setup/install/community.md @@ -0,0 +1,55 @@ +# Community packages + +Several Linux distributions and community members provide packages for headscale. Those packages may be used instead of +the [official releases](./official.md) provided by the headscale maintainers. Such packages offer improved integration +for their targeted operating system and usually: + +- setup a dedicated local user account to run headscale +- provide a default configuration +- install headscale as system service + +!!! warning "Community packages might be outdated" + + The packages mentioned on this page might be outdated or unmaintained. Use the [official releases](./official.md) to + get the current stable version or to test pre-releases. + + [![Packaging status](https://repology.org/badge/vertical-allrepos/headscale.svg)](https://repology.org/project/headscale/versions) + +## Arch Linux + +Arch Linux offers a package for headscale, install via: + +```shell +pacman -S headscale +``` + +The [AUR package `headscale-git`](https://aur.archlinux.org/packages/headscale-git) can be used to build the current +development version. + +## Fedora, RHEL, CentOS + +A third-party repository for various RPM based distributions is available at: +<https://copr.fedorainfracloud.org/coprs/jonathanspw/headscale/>. The site provides detailed setup and installation +instructions. + +## Nix, NixOS + +A Nix package is available as: `headscale`. See the [NixOS package site for installation +details](https://search.nixos.org/packages?show=headscale). + +## Gentoo + +```shell +emerge --ask net-vpn/headscale +``` + +Gentoo specific documentation is available [here](https://wiki.gentoo.org/wiki/User:Maffblaster/Drafts/Headscale). + +## OpenBSD + +Headscale is available in ports. The port installs headscale as system service with `rc.d` and provides usage +instructions upon installation. + +```shell +pkg_add headscale +``` diff --git a/docs/setup/install/container.md b/docs/setup/install/container.md new file mode 100644 index 00000000..dca22537 --- /dev/null +++ b/docs/setup/install/container.md @@ -0,0 +1,125 @@ +# Running headscale in a container + +!!! warning "Community documentation" + + This page is not actively maintained by the headscale authors and is + written by community members. It is _not_ verified by headscale developers. + + **It might be outdated and it might miss necessary steps**. + +This documentation has the goal of showing a user how-to set up and run headscale in a container. A container runtime +such as [Docker](https://www.docker.com) or [Podman](https://podman.io) is required. The container image can be found on +[Docker Hub](https://hub.docker.com/r/headscale/headscale) and [GitHub Container +Registry](https://github.com/juanfont/headscale/pkgs/container/headscale). The container image URLs are: + +- [Docker Hub](https://hub.docker.com/r/headscale/headscale): `docker.io/headscale/headscale:<VERSION>` +- [GitHub Container Registry](https://github.com/juanfont/headscale/pkgs/container/headscale): + `ghcr.io/juanfont/headscale:<VERSION>` + +## Configure and run headscale + +1. Create a directory on the container host to store headscale's [configuration](../../ref/configuration.md) and the [SQLite](https://www.sqlite.org/) database: + + ```shell + mkdir -p ./headscale/{config,lib} + cd ./headscale + ``` + +1. Download the example configuration for your chosen version and save it as: `$(pwd)/config/config.yaml`. Adjust the + configuration to suit your local environment. See [Configuration](../../ref/configuration.md) for details. + +1. Start headscale from within the previously created `./headscale` directory: + + ```shell + docker run \ + --name headscale \ + --detach \ + --read-only \ + --tmpfs /var/run/headscale \ + --volume "$(pwd)/config:/etc/headscale:ro" \ + --volume "$(pwd)/lib:/var/lib/headscale" \ + --publish 127.0.0.1:8080:8080 \ + --publish 127.0.0.1:9090:9090 \ + --health-cmd "CMD headscale health" \ + docker.io/headscale/headscale:<VERSION> \ + serve + ``` + + Note: use `0.0.0.0:8080:8080` instead of `127.0.0.1:8080:8080` if you want to expose the container externally. + + This command mounts the local directories inside the container, forwards port 8080 and 9090 out of the container so + the headscale instance becomes available and then detaches so headscale runs in the background. + + A similar configuration for `docker-compose`: + + ```yaml title="docker-compose.yaml" + services: + headscale: + image: docker.io/headscale/headscale:<VERSION> + restart: unless-stopped + container_name: headscale + read_only: true + tmpfs: + - /var/run/headscale + ports: + - "127.0.0.1:8080:8080" + - "127.0.0.1:9090:9090" + volumes: + # Please set <HEADSCALE_PATH> to the absolute path + # of the previously created headscale directory. + - <HEADSCALE_PATH>/config:/etc/headscale:ro + - <HEADSCALE_PATH>/lib:/var/lib/headscale + command: serve + healthcheck: + test: ["CMD", "headscale", "health"] + ``` + +1. Verify headscale is running: + + Follow the container logs: + + ```shell + docker logs --follow headscale + ``` + + Verify running containers: + + ```shell + docker ps + ``` + + Verify headscale is available: + + ```shell + curl http://127.0.0.1:8080/health + ``` + +Continue on the [getting started page](../../usage/getting-started.md) to register your first machine. + +## Debugging headscale running in Docker + +The Headscale container image is based on a "distroless" image that does not contain a shell or any other debug tools. If you need to debug headscale running in the Docker container, you can use the `-debug` variant, for example `docker.io/headscale/headscale:x.x.x-debug`. + +### Running the debug Docker container + +To run the debug Docker container, use the exact same commands as above, but replace `docker.io/headscale/headscale:x.x.x` with `docker.io/headscale/headscale:x.x.x-debug` (`x.x.x` is the version of headscale). The two containers are compatible with each other, so you can alternate between them. + +### Executing commands in the debug container + +The default command in the debug container is to run `headscale`, which is located at `/ko-app/headscale` inside the container. + +Additionally, the debug container includes a minimalist Busybox shell. + +To launch a shell in the container, use: + +```shell +docker run -it docker.io/headscale/headscale:x.x.x-debug sh +``` + +You can also execute commands directly, such as `ls /ko-app` in this example: + +```shell +docker run docker.io/headscale/headscale:x.x.x-debug ls /ko-app +``` + +Using `docker exec -it` allows you to run commands in an existing container. diff --git a/docs/setup/install/official.md b/docs/setup/install/official.md new file mode 100644 index 00000000..56fd0c9c --- /dev/null +++ b/docs/setup/install/official.md @@ -0,0 +1,121 @@ +# Official releases + +Official releases for headscale are available as binaries for various platforms and DEB packages for Debian and Ubuntu. +Both are available on the [GitHub releases page](https://github.com/juanfont/headscale/releases). + +## Using packages for Debian/Ubuntu (recommended) + +It is recommended to use our DEB packages to install headscale on a Debian based system as those packages configure a +local user to run headscale, provide a default configuration and ship with a systemd service file. Supported +distributions are Ubuntu 22.04 or newer, Debian 12 or newer. + +1. Download the [latest headscale package](https://github.com/juanfont/headscale/releases/latest) for your platform (`.deb` for Ubuntu and Debian). + + ```shell + HEADSCALE_VERSION="" # See above URL for latest version, e.g. "X.Y.Z" (NOTE: do not add the "v" prefix!) + HEADSCALE_ARCH="" # Your system architecture, e.g. "amd64" + wget --output-document=headscale.deb \ + "https://github.com/juanfont/headscale/releases/download/v${HEADSCALE_VERSION}/headscale_${HEADSCALE_VERSION}_linux_${HEADSCALE_ARCH}.deb" + ``` + +1. Install headscale: + + ```shell + sudo apt install ./headscale.deb + ``` + +1. [Configure headscale by editing the configuration file](../../ref/configuration.md): + + ```shell + sudo nano /etc/headscale/config.yaml + ``` + +1. Enable and start the headscale service: + + ```shell + sudo systemctl enable --now headscale + ``` + +1. Verify that headscale is running as intended: + + ```shell + sudo systemctl status headscale + ``` + +Continue on the [getting started page](../../usage/getting-started.md) to register your first machine. + +## Using standalone binaries (advanced) + +!!! warning "Advanced" + + This installation method is considered advanced as one needs to take care of the local user and the systemd + service themselves. If possible, use the [DEB packages](#using-packages-for-debianubuntu-recommended) or a + [community package](./community.md) instead. + +This section describes the installation of headscale according to the [Requirements and +assumptions](../requirements.md#assumptions). Headscale is run by a dedicated local user and the service itself is +managed by systemd. + +1. Download the latest [`headscale` binary from GitHub's release page](https://github.com/juanfont/headscale/releases): + + ```shell + sudo wget --output-document=/usr/bin/headscale \ + https://github.com/juanfont/headscale/releases/download/v<HEADSCALE VERSION>/headscale_<HEADSCALE VERSION>_linux_<ARCH> + ``` + +1. Make `headscale` executable: + + ```shell + sudo chmod +x /usr/bin/headscale + ``` + +1. Add a dedicated local user to run headscale: + + ```shell + sudo useradd \ + --create-home \ + --home-dir /var/lib/headscale/ \ + --system \ + --user-group \ + --shell /usr/sbin/nologin \ + headscale + ``` + +1. Download the example configuration for your chosen version and save it as: `/etc/headscale/config.yaml`. Adjust the + configuration to suit your local environment. See [Configuration](../../ref/configuration.md) for details. + + ```shell + sudo mkdir -p /etc/headscale + sudo nano /etc/headscale/config.yaml + ``` + +1. Copy [headscale's systemd service file](https://github.com/juanfont/headscale/blob/main/packaging/systemd/headscale.service) + to `/etc/systemd/system/headscale.service` and adjust it to suit your local setup. The following parameters likely need + to be modified: `ExecStart`, `WorkingDirectory`, `ReadWritePaths`. + +1. In `/etc/headscale/config.yaml`, override the default `headscale` unix socket with a path that is writable by the + `headscale` user or group: + + ```yaml title="config.yaml" + unix_socket: /var/run/headscale/headscale.sock + ``` + +1. Reload systemd to load the new configuration file: + + ```shell + systemctl daemon-reload + ``` + +1. Enable and start the new headscale service: + + ```shell + systemctl enable --now headscale + ``` + +1. Verify that headscale is running as intended: + + ```shell + systemctl status headscale + ``` + +Continue on the [getting started page](../../usage/getting-started.md) to register your first machine. diff --git a/docs/setup/install/source.md b/docs/setup/install/source.md new file mode 100644 index 00000000..b46931af --- /dev/null +++ b/docs/setup/install/source.md @@ -0,0 +1,63 @@ +# Build from source + +!!! warning "Community documentation" + + This page is not actively maintained by the headscale authors and is + written by community members. It is _not_ verified by headscale developers. + + **It might be outdated and it might miss necessary steps**. + +Headscale can be built from source using the latest version of [Go](https://golang.org) and [Buf](https://buf.build) +(Protobuf generator). See the [Contributing section in the GitHub +README](https://github.com/juanfont/headscale#contributing) for more information. + +## OpenBSD + +### Install from source + +```shell +# Install prerequisites +pkg_add go git + +git clone https://github.com/juanfont/headscale.git + +cd headscale + +# optionally checkout a release +# option a. you can find official release at https://github.com/juanfont/headscale/releases/latest +# option b. get latest tag, this may be a beta release +latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) + +git checkout $latestTag + +go build -ldflags="-s -w -X github.com/juanfont/headscale/hscontrol/types.Version=$latestTag" -X github.com/juanfont/headscale/hscontrol/types.GitCommitHash=HASH" github.com/juanfont/headscale + +# make it executable +chmod a+x headscale + +# copy it to /usr/local/sbin +cp headscale /usr/local/sbin +``` + +### Install from source via cross compile + +```shell +# Install prerequisites +# 1. go v1.20+: headscale newer than 0.21 needs go 1.20+ to compile +# 2. gmake: Makefile in the headscale repo is written in GNU make syntax + +git clone https://github.com/juanfont/headscale.git + +cd headscale + +# optionally checkout a release +# option a. you can find official release at https://github.com/juanfont/headscale/releases/latest +# option b. get latest tag, this may be a beta release +latestTag=$(git describe --tags `git rev-list --tags --max-count=1`) + +git checkout $latestTag + +make build GOOS=openbsd + +# copy headscale to openbsd machine and put it in /usr/local/sbin +``` diff --git a/docs/setup/requirements.md b/docs/setup/requirements.md new file mode 100644 index 00000000..ae1ea660 --- /dev/null +++ b/docs/setup/requirements.md @@ -0,0 +1,52 @@ +# Requirements + +Headscale should just work as long as the following requirements are met: + +- A server with a public IP address for headscale. A dual-stack setup with a public IPv4 and a public IPv6 address is + recommended. +- Headscale is served via HTTPS on port 443[^1] and [may use additional ports](#ports-in-use). +- A reasonably modern Linux or BSD based operating system. +- A dedicated local user account to run headscale. +- A little bit of command line knowledge to configure and operate headscale. + +## Ports in use + +The ports in use vary with the intended scenario and enabled features. Some of the listed ports may be changed via the +[configuration file](../ref/configuration.md) but we recommend to stick with the default values. + +- tcp/80 + - Expose publicly: yes + - HTTP, used by Let's Encrypt to verify ownership via the HTTP-01 challenge. + - Only required if the built-in Let's Enrypt client with the HTTP-01 challenge is used. See [TLS](../ref/tls.md) for + details. +- tcp/443 + - Expose publicly: yes + - HTTPS, required to make Headscale available to Tailscale clients[^1] + - Required if the [embedded DERP server](../ref/derp.md) is enabled +- udp/3478 + - Expose publicly: yes + - STUN, required if the [embedded DERP server](../ref/derp.md) is enabled +- tcp/50443 + - Expose publicly: yes + - Only required if the gRPC interface is used to [remote-control Headscale](../ref/api.md#grpc). +- tcp/9090 + - Expose publicly: no + - [Metrics and debug endpoint](../ref/debug.md#metrics-and-debug-endpoint) + +## Assumptions + +The headscale documentation and the provided examples are written with a few assumptions in mind: + +- Headscale is running as system service via a dedicated local user `headscale`. +- The [configuration](../ref/configuration.md) is loaded from `/etc/headscale/config.yaml`. +- SQLite is used as database. +- The data directory for headscale (used for private keys, ACLs, SQLite database, …) is located in `/var/lib/headscale`. +- URLs and values that need to be replaced by the user are either denoted as `<VALUE_TO_CHANGE>` or use placeholder + values such as `headscale.example.com`. + +Please adjust to your local environment accordingly. + +[^1]: + The Tailscale client assumes HTTPS on port 443 in certain situations. Serving headscale either via HTTP or via HTTPS + on a port other than 443 is possible but sticking with HTTPS on port 443 is strongly recommended for production + setups. See [issue 2164](https://github.com/juanfont/headscale/issues/2164) for more information. diff --git a/docs/setup/upgrade.md b/docs/setup/upgrade.md new file mode 100644 index 00000000..9c72eb4f --- /dev/null +++ b/docs/setup/upgrade.md @@ -0,0 +1,10 @@ +# Upgrade an existing installation + +Update an existing headscale installation to a new version: + +- Read the announcement on the [GitHub releases](https://github.com/juanfont/headscale/releases) page for the new + version. It lists the changes of the release along with possible breaking changes. +- **Create a backup of your database.** +- Update headscale to the new version, preferably by following the same installation method. +- Compare and update the [configuration](../ref/configuration.md) file. +- Restart headscale. diff --git a/docs/tls.md b/docs/tls.md deleted file mode 100644 index 557cdf01..00000000 --- a/docs/tls.md +++ /dev/null @@ -1,31 +0,0 @@ -# Running the service via TLS (optional) - -## Let's Encrypt / ACME - -To get a certificate automatically via [Let's Encrypt](https://letsencrypt.org/), set `tls_letsencrypt_hostname` to the desired certificate hostname. This name must resolve to the IP address(es) headscale is reachable on (i.e., it must correspond to the `server_url` configuration parameter). The certificate and Let's Encrypt account credentials will be stored in the directory configured in `tls_letsencrypt_cache_dir`. If the path is relative, it will be interpreted as relative to the directory the configuration file was read from. The certificate will automatically be renewed as needed. - -```yaml -tls_letsencrypt_hostname: "" -tls_letsencrypt_listen: ":http" -tls_letsencrypt_cache_dir: ".cache" -tls_letsencrypt_challenge_type: HTTP-01 -``` - -### Challenge type HTTP-01 - -The default challenge type `HTTP-01` requires that headscale is reachable on port 80 for the Let's Encrypt automated validation, in addition to whatever port is configured in `listen_addr`. By default, headscale listens on port 80 on all local IPs for Let's Encrypt automated validation. - -If you need to change the ip and/or port used by headscale for the Let's Encrypt validation process, set `tls_letsencrypt_listen` to the appropriate value. This can be handy if you are running headscale as a non-root user (or can't run `setcap`). Keep in mind, however, that Let's Encrypt will _only_ connect to port 80 for the validation callback, so if you change `tls_letsencrypt_listen` you will also need to configure something else (e.g. a firewall rule) to forward the traffic from port 80 to the ip:port combination specified in `tls_letsencrypt_listen`. - -### Challenge type TLS-ALPN-01 - -Alternatively, `tls_letsencrypt_challenge_type` can be set to `TLS-ALPN-01`. In this configuration, headscale listens on the ip:port combination defined in `listen_addr`. Let's Encrypt will _only_ connect to port 443 for the validation callback, so if `listen_addr` is not set to port 443, something else (e.g. a firewall rule) will be required to forward the traffic from port 443 to the ip:port combination specified in `listen_addr`. - -## Bring your own certificate - -headscale can also be configured to expose its web service via TLS. To configure the certificate and key file manually, set the `tls_cert_path` and `tls_cert_path` configuration parameters. If the path is relative, it will be interpreted as relative to the directory the configuration file was read from. - -```yaml -tls_cert_path: "" -tls_key_path: "" -``` diff --git a/docs/usage/connect/android.md b/docs/usage/connect/android.md new file mode 100644 index 00000000..b6fa3a66 --- /dev/null +++ b/docs/usage/connect/android.md @@ -0,0 +1,28 @@ +# Connecting an Android client + +This documentation has the goal of showing how a user can use the official Android [Tailscale](https://tailscale.com) client with headscale. + +## Installation + +Install the official Tailscale Android client from the [Google Play Store](https://play.google.com/store/apps/details?id=com.tailscale.ipn) or [F-Droid](https://f-droid.org/packages/com.tailscale.ipn/). + +## Connect via normal, interactive login + +- Open the app and select the settings menu in the upper-right corner +- Tap on `Accounts` +- In the kebab menu icon (three dots) in the upper-right corner select `Use an alternate server` +- Enter your server URL (e.g `https://headscale.example.com`) and follow the instructions +- The client connects automatically as soon as the node registration is complete on headscale. Until then, nothing is + visible in the server logs. + +## Connect using a preauthkey + +- Open the app and select the settings menu in the upper-right corner +- Tap on `Accounts` +- In the kebab menu icon (three dots) in the upper-right corner select `Use an alternate server` +- Enter your server URL (e.g `https://headscale.example.com`). If login prompts open, close it and continue +- Open the settings menu in the upper-right corner +- Tap on `Accounts` +- In the kebab menu icon (three dots) in the upper-right corner select `Use an auth key` +- Enter your [preauthkey generated from headscale](../getting-started.md#using-a-preauthkey) +- If needed, tap `Log in` on the main screen. You should now be connected to your headscale. diff --git a/docs/usage/connect/apple.md b/docs/usage/connect/apple.md new file mode 100644 index 00000000..d3a96688 --- /dev/null +++ b/docs/usage/connect/apple.md @@ -0,0 +1,65 @@ +# Connecting an Apple client + +This documentation has the goal of showing how a user can use the official iOS and macOS [Tailscale](https://tailscale.com) clients with headscale. + +!!! info "Instructions on your headscale instance" + + An endpoint with information on how to connect your Apple device + is also available at `/apple` on your running instance. + +## iOS + +### Installation + +Install the official Tailscale iOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037). + +### Configuring the headscale URL + +- Open the Tailscale app +- Click the account icon in the top-right corner and select `Log in…`. +- Tap the top-right options menu button and select `Use custom coordination server`. +- Enter your instance url (e.g `https://headscale.example.com`) +- Enter your credentials and log in. Headscale should now be working on your iOS device. + +## macOS + +### Installation + +Choose one of the available [Tailscale clients for macOS](https://tailscale.com/kb/1065/macos-variants) and install it. + +### Configuring the headscale URL + +#### Command line + +Use Tailscale's login command to connect with your headscale instance (e.g `https://headscale.example.com`): + +``` +tailscale login --login-server <YOUR_HEADSCALE_URL> +``` + +#### GUI + +- Option + Click the Tailscale icon in the menu and hover over the Debug menu +- Under `Custom Login Server`, select `Add Account...` +- Enter the URL of your headscale instance (e.g `https://headscale.example.com`) and press `Add Account` +- Follow the login procedure in the browser + +## tvOS + +### Installation + +Install the official Tailscale tvOS client from the [App Store](https://apps.apple.com/app/tailscale/id1470499037). + +!!! danger + + **Don't** open the Tailscale App after installation! + +### Configuring the headscale URL + +- Open Settings (the Apple tvOS settings) > Apps > Tailscale +- Under `ALTERNATE COORDINATION SERVER URL`, select `URL` +- Enter the URL of your headscale instance (e.g `https://headscale.example.com`) and press `OK` +- Return to the tvOS Home screen +- Open Tailscale +- Click the button `Install VPN configuration` and confirm the appearing popup by clicking the `Allow` button +- Scan the QR code and follow the login procedure diff --git a/docs/usage/connect/windows.md b/docs/usage/connect/windows.md new file mode 100644 index 00000000..2d073981 --- /dev/null +++ b/docs/usage/connect/windows.md @@ -0,0 +1,59 @@ +# Connecting a Windows client + +This documentation has the goal of showing how a user can use the official Windows [Tailscale](https://tailscale.com) client with headscale. + +!!! info "Instructions on your headscale instance" + + An endpoint with information on how to connect your Windows device + is also available at `/windows` on your running instance. + +## Installation + +Download the [Official Windows Client](https://tailscale.com/download/windows) and install it. + +## Configuring the headscale URL + +Open a Command Prompt or Powershell and use Tailscale's login command to connect with your headscale instance (e.g +`https://headscale.example.com`): + +``` +tailscale login --login-server <YOUR_HEADSCALE_URL> +``` + +Follow the instructions in the opened browser window to finish the configuration. + +## Troubleshooting + +### Unattended mode + +By default, Tailscale's Windows client is only running when the user is logged in. If you want to keep Tailscale running +all the time, please enable "Unattended mode": + +- Click on the Tailscale tray icon and select `Preferences` +- Enable `Run unattended` +- Confirm the "Unattended mode" message + +See also [Keep Tailscale running when I'm not logged in to my computer](https://tailscale.com/kb/1088/run-unattended) + +### Failing node registration + +If you are seeing repeated messages like: + +``` +[GIN] 2022/02/10 - 16:39:34 | 200 | 1.105306ms | 127.0.0.1 | POST "/machine/redacted" +``` + +in your headscale output, turn on `DEBUG` logging and look for: + +``` +2022-02-11T00:59:29Z DBG Machine registration has expired. Sending a authurl to register machine=redacted +``` + +This typically means that the registry keys above was not set appropriately. + +To reset and try again, it is important to do the following: + +1. Shut down the Tailscale service (or the client running in the tray) +2. Delete Tailscale Application data folder, located at `C:\Users\<USERNAME>\AppData\Local\Tailscale` and try to connect again. +3. Ensure the Windows node is deleted from headscale (to ensure fresh setup) +4. Start Tailscale on the Windows machine and retry the login. diff --git a/docs/usage/getting-started.md b/docs/usage/getting-started.md new file mode 100644 index 00000000..a69d89a3 --- /dev/null +++ b/docs/usage/getting-started.md @@ -0,0 +1,152 @@ +# Getting started + +This page helps you get started with headscale and provides a few usage examples for the headscale command line tool +`headscale`. + +!!! note "Prerequisites" + + * Headscale is installed and running as system service. Read the [setup section](../setup/requirements.md) for + installation instructions. + * The configuration file exists and is adjusted to suit your environment, see + [Configuration](../ref/configuration.md) for details. + * Headscale is reachable from the Internet. Verify this by visiting the health endpoint: + https://headscale.example.com/health + * The Tailscale client is installed, see [Client and operating system support](../about/clients.md) for more + information. + +## Getting help + +The `headscale` command line tool provides built-in help. To show available commands along with their arguments and +options, run: + +=== "Native" + + ```shell + # Show help + headscale help + + # Show help for a specific command + headscale <COMMAND> --help + ``` + +=== "Container" + + ```shell + # Show help + docker exec -it headscale \ + headscale help + + # Show help for a specific command + docker exec -it headscale \ + headscale <COMMAND> --help + ``` + +!!! note "Manage headscale from another local user" + + By default only the user `headscale` or `root` will have the necessary permissions to access the unix socket + (`/var/run/headscale/headscale.sock`) that is used to communicate with the service. In order to be able to + communicate with the headscale service you have to make sure the unix socket is accessible by the user that runs + the commands. In general you can achieve this by any of the following methods: + + * using `sudo` + * run the commands as user `headscale` + * add your user to the `headscale` group + + To verify you can run the following command using your preferred method: + + ```shell + headscale users list + ``` + +## Manage headscale users + +In headscale, a node (also known as machine or device) is always assigned to a +headscale user. Such a headscale user may have many nodes assigned to them and +can be managed with the `headscale users` command. Invoke the built-in help for +more information: `headscale users --help`. + +### Create a headscale user + +=== "Native" + + ```shell + headscale users create <USER> + ``` + +=== "Container" + + ```shell + docker exec -it headscale \ + headscale users create <USER> + ``` + +### List existing headscale users + +=== "Native" + + ```shell + headscale users list + ``` + +=== "Container" + + ```shell + docker exec -it headscale \ + headscale users list + ``` + +## Register a node + +One has to register a node first to use headscale as coordination with Tailscale. The following examples work for the +Tailscale client on Linux/BSD operating systems. Alternatively, follow the instructions to connect +[Android](connect/android.md), [Apple](connect/apple.md) or [Windows](connect/windows.md) devices. + +### Normal, interactive login + +On a client machine, run the `tailscale up` command and provide the FQDN of your headscale instance as argument: + +```shell +tailscale up --login-server <YOUR_HEADSCALE_URL> +``` + +Usually, a browser window with further instructions is opened and contains the value for `<YOUR_MACHINE_KEY>`. Approve +and register the node on your headscale server: + +=== "Native" + + ```shell + headscale nodes register --user <USER> --key <YOUR_MACHINE_KEY> + ``` + +=== "Container" + + ```shell + docker exec -it headscale \ + headscale nodes register --user <USER> --key <YOUR_MACHINE_KEY> + ``` + +### Using a preauthkey + +It is also possible to generate a preauthkey and register a node non-interactively. First, generate a preauthkey on the +headscale instance. By default, the key is valid for one hour and can only be used once (see `headscale preauthkeys +--help` for other options): + +=== "Native" + + ```shell + headscale preauthkeys create --user <USER_ID> + ``` + +=== "Container" + + ```shell + docker exec -it headscale \ + headscale preauthkeys create --user <USER_ID> + ``` + +The command returns the preauthkey on success which is used to connect a node to the headscale instance via the +`tailscale up` command: + +```shell +tailscale up --login-server <YOUR_HEADSCALE_URL> --authkey <YOUR_AUTH_KEY> +``` diff --git a/docs/web-ui.md b/docs/web-ui.md deleted file mode 100644 index d018666e..00000000 --- a/docs/web-ui.md +++ /dev/null @@ -1,14 +0,0 @@ -# Headscale web interface - -!!! warning "Community contributions" - - This page contains community contributions. The projects listed here are not - maintained by the Headscale authors and are written by community members. - -| Name | Repository Link | Description | Status | -| --------------- | ------------------------------------------------------- | ------------------------------------------------------------------------- | ------ | -| headscale-webui | [Github](https://github.com/ifargle/headscale-webui) | A simple Headscale web UI for small-scale deployments. | Alpha | -| headscale-ui | [Github](https://github.com/gurucomputing/headscale-ui) | A web frontend for the headscale Tailscale-compatible coordination server | Alpha | -| HeadscaleUi | [GitHub](https://github.com/simcu/headscale-ui) | A static headscale admin ui, no backend enviroment required | Alpha | - -You can ask for support on our dedicated [Discord channel](https://discord.com/channels/896711691637780480/1105842846386356294). diff --git a/docs/windows-client.md b/docs/windows-client.md deleted file mode 100644 index fcb8c0e5..00000000 --- a/docs/windows-client.md +++ /dev/null @@ -1,59 +0,0 @@ -# Connecting a Windows client - -## Goal - -This documentation has the goal of showing how a user can use the official Windows [Tailscale](https://tailscale.com) client with `headscale`. - -## Add registry keys - -To make the Windows client behave as expected and to run well with `headscale`, two registry keys **must** be set: - -- `HKLM:\SOFTWARE\Tailscale IPN\UnattendedMode` must be set to `always` as a `string` type, to allow Tailscale to run properly in the background -- `HKLM:\SOFTWARE\Tailscale IPN\LoginURL` must be set to `<YOUR HEADSCALE URL>` as a `string` type, to ensure Tailscale contacts the correct control server. - -You can set these using the Windows Registry Editor: - -![windows-registry](./images/windows-registry.png) - -Or via the following Powershell commands (right click Powershell icon and select "Run as administrator"): - -``` -New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always -New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value https://YOUR-HEADSCALE-URL -``` - -The Tailscale Windows client has been observed to reset its configuration on logout/reboot and these two keys [resolves that issue](https://github.com/tailscale/tailscale/issues/2798). - -For a guide on how to edit registry keys, [check out Computer Hope](https://www.computerhope.com/issues/ch001348.htm). - -## Installation - -Download the [Official Windows Client](https://tailscale.com/download/windows) and install it. - -When the installation has finished, start Tailscale and log in (you might have to click the icon in the system tray). - -The log in should open a browser Window and direct you to your `headscale` instance. - -## Troubleshooting - -If you are seeing repeated messages like: - -``` -[GIN] 2022/02/10 - 16:39:34 | 200 | 1.105306ms | 127.0.0.1 | POST "/machine/redacted" -``` - -in your `headscale` output, turn on `DEBUG` logging and look for: - -``` -2022-02-11T00:59:29Z DBG Machine registration has expired. Sending a authurl to register machine=redacted -``` - -This typically means that the registry keys above was not set appropriately. - -To reset and try again, it is important to do the following: - -1. Ensure the registry keys from the previous guide is correctly set. -2. Shut down the Tailscale service (or the client running in the tray) -3. Delete Tailscale Application data folder, located at `C:\Users\<USERNAME>\AppData\Local\Tailscale` and try to connect again. -4. Ensure the Windows node is deleted from headscale (to ensure fresh setup) -5. Start Tailscale on the windows machine and retry the login. diff --git a/examples/README.md b/examples/README.md deleted file mode 100644 index f9e85ff3..00000000 --- a/examples/README.md +++ /dev/null @@ -1,5 +0,0 @@ -# Examples - -This directory contains examples on how to run `headscale` on different platforms. - -All examples are provided by the community and they are not verified by the `headscale` authors. diff --git a/examples/kustomize/.gitignore b/examples/kustomize/.gitignore deleted file mode 100644 index 229058d2..00000000 --- a/examples/kustomize/.gitignore +++ /dev/null @@ -1,2 +0,0 @@ -/**/site -/**/secrets diff --git a/examples/kustomize/README.md b/examples/kustomize/README.md deleted file mode 100644 index cc57f147..00000000 --- a/examples/kustomize/README.md +++ /dev/null @@ -1,100 +0,0 @@ -# Deploying headscale on Kubernetes - -**Note:** This is contributed by the community and not verified by the headscale authors. - -This directory contains [Kustomize](https://kustomize.io) templates that deploy -headscale in various configurations. - -These templates currently support Rancher k3s. Other clusters may require -adaptation, especially around volume claims and ingress. - -Commands below assume this directory is your current working directory. - -# Generate secrets and site configuration - -Run `./init.bash` to generate keys, passwords, and site configuration files. - -Edit `base/site/public.env`, changing `public-hostname` to the public DNS name -that will be used for your headscale deployment. - -Set `public-proto` to "https" if you're planning to use TLS & Let's Encrypt. - -Configure DERP servers by editing `base/site/derp.yaml` if needed. - -# Add the image to the registry - -You'll somehow need to get `headscale:latest` into your cluster image registry. - -An easy way to do this with k3s: - -- Reconfigure k3s to use docker instead of containerd (`k3s server --docker`) -- `docker build -t headscale:latest ..` from here - -# Create the namespace - -If it doesn't already exist, `kubectl create ns headscale`. - -# Deploy headscale - -## sqlite - -`kubectl -n headscale apply -k ./sqlite` - -## postgres - -`kubectl -n headscale apply -k ./postgres` - -# TLS & Let's Encrypt - -Test a staging certificate with your configured DNS name and Let's Encrypt. - -`kubectl -n headscale apply -k ./staging-tls` - -Replace with a production certificate. - -`kubectl -n headscale apply -k ./production-tls` - -## Static / custom TLS certificates - -Only Let's Encrypt is supported. If you need other TLS settings, modify or patch the ingress. - -# Administration - -Use the wrapper script to remotely operate headscale to perform administrative -tasks like creating namespaces, authkeys, etc. - -``` -[c@nix-slate:~/Projects/headscale/k8s]$ ./headscale.bash - -headscale is an open source implementation of the Tailscale control server - -https://github.com/juanfont/headscale - -Usage: - headscale [command] - -Available Commands: - help Help about any command - namespace Manage the namespaces of headscale - node Manage the nodes of headscale - preauthkey Handle the preauthkeys in headscale - routes Manage the routes of headscale - serve Launches the headscale server - version Print the version. - -Flags: - -h, --help help for headscale - -o, --output string Output format. Empty for human-readable, 'json' or 'json-line' - -Use "headscale [command] --help" for more information about a command. - -``` - -# TODO / Ideas - -- Interpolate `email:` option to the ClusterIssuer from site configuration. - This probably needs to be done with a transformer, kustomize vars don't seem to work. -- Add kustomize examples for cloud-native ingress, load balancer -- CockroachDB for the backend -- DERP server deployment -- Tor hidden service diff --git a/examples/kustomize/base/configmap.yaml b/examples/kustomize/base/configmap.yaml deleted file mode 100644 index 0ac2d563..00000000 --- a/examples/kustomize/base/configmap.yaml +++ /dev/null @@ -1,9 +0,0 @@ -apiVersion: v1 -kind: ConfigMap -metadata: - name: headscale-config -data: - server_url: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) - listen_addr: "0.0.0.0:8080" - metrics_listen_addr: "127.0.0.1:9090" - ephemeral_node_inactivity_timeout: "30m" diff --git a/examples/kustomize/base/ingress.yaml b/examples/kustomize/base/ingress.yaml deleted file mode 100644 index 51da3427..00000000 --- a/examples/kustomize/base/ingress.yaml +++ /dev/null @@ -1,18 +0,0 @@ -apiVersion: networking.k8s.io/v1 -kind: Ingress -metadata: - name: headscale - annotations: - kubernetes.io/ingress.class: traefik -spec: - rules: - - host: $(PUBLIC_HOSTNAME) - http: - paths: - - backend: - service: - name: headscale - port: - number: 8080 - path: / - pathType: Prefix diff --git a/examples/kustomize/base/kustomization.yaml b/examples/kustomize/base/kustomization.yaml deleted file mode 100644 index 93278f7d..00000000 --- a/examples/kustomize/base/kustomization.yaml +++ /dev/null @@ -1,42 +0,0 @@ -namespace: headscale -resources: - - configmap.yaml - - ingress.yaml - - service.yaml -generatorOptions: - disableNameSuffixHash: true -configMapGenerator: - - name: headscale-site - files: - - derp.yaml=site/derp.yaml - envs: - - site/public.env - - name: headscale-etc - literals: - - config.json={} -secretGenerator: - - name: headscale - files: - - secrets/private-key -vars: - - name: PUBLIC_PROTO - objRef: - kind: ConfigMap - name: headscale-site - apiVersion: v1 - fieldRef: - fieldPath: data.public-proto - - name: PUBLIC_HOSTNAME - objRef: - kind: ConfigMap - name: headscale-site - apiVersion: v1 - fieldRef: - fieldPath: data.public-hostname - - name: CONTACT_EMAIL - objRef: - kind: ConfigMap - name: headscale-site - apiVersion: v1 - fieldRef: - fieldPath: data.contact-email diff --git a/examples/kustomize/base/service.yaml b/examples/kustomize/base/service.yaml deleted file mode 100644 index 39e67253..00000000 --- a/examples/kustomize/base/service.yaml +++ /dev/null @@ -1,13 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: headscale - labels: - app: headscale -spec: - selector: - app: headscale - ports: - - name: http - targetPort: http - port: 8080 diff --git a/examples/kustomize/headscale.bash b/examples/kustomize/headscale.bash deleted file mode 100755 index 66bfe92c..00000000 --- a/examples/kustomize/headscale.bash +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -set -eu -exec kubectl -n headscale exec -ti pod/headscale-0 -- /go/bin/headscale "$@" diff --git a/examples/kustomize/init.bash b/examples/kustomize/init.bash deleted file mode 100755 index e5b7965c..00000000 --- a/examples/kustomize/init.bash +++ /dev/null @@ -1,22 +0,0 @@ -#!/usr/bin/env bash -set -eux -cd $(dirname $0) - -umask 022 -mkdir -p base/site/ -[ ! -e base/site/public.env ] && ( - cat >base/site/public.env <<EOF -public-hostname=localhost -public-proto=http -contact-email=headscale@example.com -EOF -) -[ ! -e base/site/derp.yaml ] && cp ../derp.yaml base/site/derp.yaml - -umask 077 -mkdir -p base/secrets/ -[ ! -e base/secrets/private-key ] && ( - wg genkey > base/secrets/private-key -) -mkdir -p postgres/secrets/ -[ ! -e postgres/secrets/password ] && (head -c 32 /dev/urandom | base64 -w0 > postgres/secrets/password) diff --git a/examples/kustomize/install-cert-manager.bash b/examples/kustomize/install-cert-manager.bash deleted file mode 100755 index 1a5ecacb..00000000 --- a/examples/kustomize/install-cert-manager.bash +++ /dev/null @@ -1,3 +0,0 @@ -#!/usr/bin/env bash -set -eux -kubectl apply -f https://github.com/jetstack/cert-manager/releases/download/v1.4.0/cert-manager.yaml diff --git a/examples/kustomize/postgres/deployment.yaml b/examples/kustomize/postgres/deployment.yaml deleted file mode 100644 index 1dd88b41..00000000 --- a/examples/kustomize/postgres/deployment.yaml +++ /dev/null @@ -1,81 +0,0 @@ -apiVersion: apps/v1 -kind: Deployment -metadata: - name: headscale -spec: - replicas: 2 - selector: - matchLabels: - app: headscale - template: - metadata: - labels: - app: headscale - spec: - containers: - - name: headscale - image: "headscale:latest" - imagePullPolicy: IfNotPresent - command: ["/go/bin/headscale", "serve"] - env: - - name: SERVER_URL - value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) - - name: LISTEN_ADDR - valueFrom: - configMapKeyRef: - name: headscale-config - key: listen_addr - - name: METRICS_LISTEN_ADDR - valueFrom: - configMapKeyRef: - name: headscale-config - key: metrics_listen_addr - - name: DERP_MAP_PATH - value: /vol/config/derp.yaml - - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT - valueFrom: - configMapKeyRef: - name: headscale-config - key: ephemeral_node_inactivity_timeout - - name: DB_TYPE - value: postgres - - name: DB_HOST - value: postgres.headscale.svc.cluster.local - - name: DB_PORT - value: "5432" - - name: DB_USER - value: headscale - - name: DB_PASS - valueFrom: - secretKeyRef: - name: postgresql - key: password - - name: DB_NAME - value: headscale - ports: - - name: http - protocol: TCP - containerPort: 8080 - livenessProbe: - tcpSocket: - port: http - initialDelaySeconds: 30 - timeoutSeconds: 5 - periodSeconds: 15 - volumeMounts: - - name: config - mountPath: /vol/config - - name: secret - mountPath: /vol/secret - - name: etc - mountPath: /etc/headscale - volumes: - - name: config - configMap: - name: headscale-site - - name: etc - configMap: - name: headscale-etc - - name: secret - secret: - secretName: headscale diff --git a/examples/kustomize/postgres/kustomization.yaml b/examples/kustomize/postgres/kustomization.yaml deleted file mode 100644 index e732e3b9..00000000 --- a/examples/kustomize/postgres/kustomization.yaml +++ /dev/null @@ -1,13 +0,0 @@ -namespace: headscale -bases: - - ../base -resources: - - deployment.yaml - - postgres-service.yaml - - postgres-statefulset.yaml -generatorOptions: - disableNameSuffixHash: true -secretGenerator: - - name: postgresql - files: - - secrets/password diff --git a/examples/kustomize/postgres/postgres-service.yaml b/examples/kustomize/postgres/postgres-service.yaml deleted file mode 100644 index 6252e7f9..00000000 --- a/examples/kustomize/postgres/postgres-service.yaml +++ /dev/null @@ -1,13 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: postgres - labels: - app: postgres -spec: - selector: - app: postgres - ports: - - name: postgres - targetPort: postgres - port: 5432 diff --git a/examples/kustomize/postgres/postgres-statefulset.yaml b/examples/kustomize/postgres/postgres-statefulset.yaml deleted file mode 100644 index b81c9bf0..00000000 --- a/examples/kustomize/postgres/postgres-statefulset.yaml +++ /dev/null @@ -1,49 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - name: postgres -spec: - serviceName: postgres - replicas: 1 - selector: - matchLabels: - app: postgres - template: - metadata: - labels: - app: postgres - spec: - containers: - - name: postgres - image: "postgres:13" - imagePullPolicy: IfNotPresent - env: - - name: POSTGRES_PASSWORD - valueFrom: - secretKeyRef: - name: postgresql - key: password - - name: POSTGRES_USER - value: headscale - ports: - - name: postgres - protocol: TCP - containerPort: 5432 - livenessProbe: - tcpSocket: - port: 5432 - initialDelaySeconds: 30 - timeoutSeconds: 5 - periodSeconds: 15 - volumeMounts: - - name: pgdata - mountPath: /var/lib/postgresql/data - volumeClaimTemplates: - - metadata: - name: pgdata - spec: - storageClassName: local-path - accessModes: ["ReadWriteOnce"] - resources: - requests: - storage: 1Gi diff --git a/examples/kustomize/production-tls/ingress-patch.yaml b/examples/kustomize/production-tls/ingress-patch.yaml deleted file mode 100644 index 9e6177fb..00000000 --- a/examples/kustomize/production-tls/ingress-patch.yaml +++ /dev/null @@ -1,11 +0,0 @@ -kind: Ingress -metadata: - name: headscale - annotations: - cert-manager.io/cluster-issuer: letsencrypt-production - traefik.ingress.kubernetes.io/router.tls: "true" -spec: - tls: - - hosts: - - $(PUBLIC_HOSTNAME) - secretName: production-cert diff --git a/examples/kustomize/production-tls/kustomization.yaml b/examples/kustomize/production-tls/kustomization.yaml deleted file mode 100644 index d3147f5f..00000000 --- a/examples/kustomize/production-tls/kustomization.yaml +++ /dev/null @@ -1,9 +0,0 @@ -namespace: headscale -bases: - - ../base -resources: - - production-issuer.yaml -patches: - - path: ingress-patch.yaml - target: - kind: Ingress diff --git a/examples/kustomize/production-tls/production-issuer.yaml b/examples/kustomize/production-tls/production-issuer.yaml deleted file mode 100644 index f436090b..00000000 --- a/examples/kustomize/production-tls/production-issuer.yaml +++ /dev/null @@ -1,16 +0,0 @@ -apiVersion: cert-manager.io/v1 -kind: ClusterIssuer -metadata: - name: letsencrypt-production -spec: - acme: - # TODO: figure out how to get kustomize to interpolate this, or use a transformer - #email: $(CONTACT_EMAIL) - server: https://acme-v02.api.letsencrypt.org/directory - privateKeySecretRef: - # Secret resource used to store the account's private key. - name: letsencrypt-production-acc-key - solvers: - - http01: - ingress: - class: traefik diff --git a/examples/kustomize/sqlite/kustomization.yaml b/examples/kustomize/sqlite/kustomization.yaml deleted file mode 100644 index ca799419..00000000 --- a/examples/kustomize/sqlite/kustomization.yaml +++ /dev/null @@ -1,5 +0,0 @@ -namespace: headscale -bases: - - ../base -resources: - - statefulset.yaml diff --git a/examples/kustomize/sqlite/statefulset.yaml b/examples/kustomize/sqlite/statefulset.yaml deleted file mode 100644 index 2321d39d..00000000 --- a/examples/kustomize/sqlite/statefulset.yaml +++ /dev/null @@ -1,82 +0,0 @@ -apiVersion: apps/v1 -kind: StatefulSet -metadata: - name: headscale -spec: - serviceName: headscale - replicas: 1 - selector: - matchLabels: - app: headscale - template: - metadata: - labels: - app: headscale - spec: - containers: - - name: headscale - image: "headscale:latest" - imagePullPolicy: IfNotPresent - command: ["/go/bin/headscale", "serve"] - env: - - name: SERVER_URL - value: $(PUBLIC_PROTO)://$(PUBLIC_HOSTNAME) - - name: LISTEN_ADDR - valueFrom: - configMapKeyRef: - name: headscale-config - key: listen_addr - - name: METRICS_LISTEN_ADDR - valueFrom: - configMapKeyRef: - name: headscale-config - key: metrics_listen_addr - - name: DERP_MAP_PATH - value: /vol/config/derp.yaml - - name: EPHEMERAL_NODE_INACTIVITY_TIMEOUT - valueFrom: - configMapKeyRef: - name: headscale-config - key: ephemeral_node_inactivity_timeout - - name: DB_TYPE - value: sqlite3 - - name: DB_PATH - value: /vol/data/db.sqlite - ports: - - name: http - protocol: TCP - containerPort: 8080 - livenessProbe: - tcpSocket: - port: http - initialDelaySeconds: 30 - timeoutSeconds: 5 - periodSeconds: 15 - volumeMounts: - - name: config - mountPath: /vol/config - - name: data - mountPath: /vol/data - - name: secret - mountPath: /vol/secret - - name: etc - mountPath: /etc/headscale - volumes: - - name: config - configMap: - name: headscale-site - - name: etc - configMap: - name: headscale-etc - - name: secret - secret: - secretName: headscale - volumeClaimTemplates: - - metadata: - name: data - spec: - storageClassName: local-path - accessModes: ["ReadWriteOnce"] - resources: - requests: - storage: 1Gi diff --git a/examples/kustomize/staging-tls/ingress-patch.yaml b/examples/kustomize/staging-tls/ingress-patch.yaml deleted file mode 100644 index 5a1daf0c..00000000 --- a/examples/kustomize/staging-tls/ingress-patch.yaml +++ /dev/null @@ -1,11 +0,0 @@ -kind: Ingress -metadata: - name: headscale - annotations: - cert-manager.io/cluster-issuer: letsencrypt-staging - traefik.ingress.kubernetes.io/router.tls: "true" -spec: - tls: - - hosts: - - $(PUBLIC_HOSTNAME) - secretName: staging-cert diff --git a/examples/kustomize/staging-tls/kustomization.yaml b/examples/kustomize/staging-tls/kustomization.yaml deleted file mode 100644 index 0900c583..00000000 --- a/examples/kustomize/staging-tls/kustomization.yaml +++ /dev/null @@ -1,9 +0,0 @@ -namespace: headscale -bases: - - ../base -resources: - - staging-issuer.yaml -patches: - - path: ingress-patch.yaml - target: - kind: Ingress diff --git a/examples/kustomize/staging-tls/staging-issuer.yaml b/examples/kustomize/staging-tls/staging-issuer.yaml deleted file mode 100644 index cf290415..00000000 --- a/examples/kustomize/staging-tls/staging-issuer.yaml +++ /dev/null @@ -1,16 +0,0 @@ -apiVersion: cert-manager.io/v1 -kind: ClusterIssuer -metadata: - name: letsencrypt-staging -spec: - acme: - # TODO: figure out how to get kustomize to interpolate this, or use a transformer - #email: $(CONTACT_EMAIL) - server: https://acme-staging-v02.api.letsencrypt.org/directory - privateKeySecretRef: - # Secret resource used to store the account's private key. - name: letsencrypt-staging-acc-key - solvers: - - http01: - ingress: - class: traefik diff --git a/flake.lock b/flake.lock index 91bd5055..9f77e322 100644 --- a/flake.lock +++ b/flake.lock @@ -5,11 +5,11 @@ "systems": "systems" }, "locked": { - "lastModified": 1701680307, - "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", "owner": "numtide", "repo": "flake-utils", - "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", "type": "github" }, "original": { @@ -20,11 +20,11 @@ }, "nixpkgs": { "locked": { - "lastModified": 1701998057, - "narHash": "sha256-gAJGhcTO9cso7XDfAScXUlPcva427AUT2q02qrmXPdo=", + "lastModified": 1768875095, + "narHash": "sha256-dYP3DjiL7oIiiq3H65tGIXXIT1Waiadmv93JS0sS+8A=", "owner": "NixOS", "repo": "nixpkgs", - "rev": "09dc04054ba2ff1f861357d0e7e76d021b273cd7", + "rev": "ed142ab1b3a092c4d149245d0c4126a5d7ea00b0", "type": "github" }, "original": { diff --git a/flake.nix b/flake.nix index 638bf7c1..3b5cff09 100644 --- a/flake.nix +++ b/flake.nix @@ -6,170 +6,232 @@ flake-utils.url = "github:numtide/flake-utils"; }; - outputs = { - self, - nixpkgs, - flake-utils, - ... - }: let - headscaleVersion = - if (self ? shortRev) - then self.shortRev - else "dev"; - in + outputs = + { self + , nixpkgs + , flake-utils + , ... + }: + let + headscaleVersion = self.shortRev or self.dirtyShortRev; + commitHash = self.rev or self.dirtyRev; + in { - overlay = _: prev: let - pkgs = nixpkgs.legacyPackages.${prev.system}; - in rec { - headscale = pkgs.buildGo121Module rec { - pname = "headscale"; - version = headscaleVersion; - src = pkgs.lib.cleanSource self; - - # Only run unit tests when testing a build - checkFlags = ["-short"]; - - # When updating go.mod or go.sum, a new sha will need to be calculated, - # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-u9AmJguQ5dnJpfhOeLN43apvMHuraOrJhvlEIp9RoIc="; - - ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; - }; - - golines = pkgs.buildGoModule rec { - pname = "golines"; - version = "0.11.0"; - - src = pkgs.fetchFromGitHub { - owner = "segmentio"; - repo = "golines"; - rev = "v${version}"; - sha256 = "sha256-2K9KAg8iSubiTbujyFGN3yggrL+EDyeUCs9OOta/19A="; - }; - - vendorHash = "sha256-rxYuzn4ezAxaeDhxd8qdOzt+CKYIh03A9zKNdzILq18="; - - nativeBuildInputs = [pkgs.installShellFiles]; - }; - - golangci-lint = prev.golangci-lint.override { - # Override https://github.com/NixOS/nixpkgs/pull/166801 which changed this - # to buildGo118Module because it does not build on Darwin. - inherit (prev) buildGoModule; - }; - - protoc-gen-grpc-gateway = pkgs.buildGoModule rec { - pname = "grpc-gateway"; - version = "2.14.0"; - - src = pkgs.fetchFromGitHub { - owner = "grpc-ecosystem"; - repo = "grpc-gateway"; - rev = "v${version}"; - sha256 = "sha256-lnNdsDCpeSHtl2lC1IhUw11t3cnGF+37qSM7HDvKLls="; - }; - - vendorHash = "sha256-dGdnDuRbwg8fU7uB5GaHEWa/zI3w06onqjturvooJQA="; - - nativeBuildInputs = [pkgs.installShellFiles]; - - subPackages = ["protoc-gen-grpc-gateway" "protoc-gen-openapiv2"]; - }; + # NixOS module + nixosModules = rec { + headscale = import ./nix/module.nix; + default = headscale; }; + + overlays.default = _: prev: + let + pkgs = nixpkgs.legacyPackages.${prev.stdenv.hostPlatform.system}; + buildGo = pkgs.buildGo125Module; + vendorHash = "sha256-dWsDgI5K+8mFw4PA5gfFBPCSqBJp5RcZzm0ML1+HsWw="; + in + { + headscale = buildGo { + pname = "headscale"; + version = headscaleVersion; + src = pkgs.lib.cleanSource self; + + # Only run unit tests when testing a build + checkFlags = [ "-short" ]; + + # When updating go.mod or go.sum, a new sha will need to be calculated, + # update this if you have a mismatch after doing a change to those files. + inherit vendorHash; + + subPackages = [ "cmd/headscale" ]; + + meta = { + mainProgram = "headscale"; + }; + }; + + hi = buildGo { + pname = "hi"; + version = headscaleVersion; + src = pkgs.lib.cleanSource self; + + checkFlags = [ "-short" ]; + inherit vendorHash; + + subPackages = [ "cmd/hi" ]; + }; + + protoc-gen-grpc-gateway = buildGo rec { + pname = "grpc-gateway"; + version = "2.27.4"; + + src = pkgs.fetchFromGitHub { + owner = "grpc-ecosystem"; + repo = "grpc-gateway"; + rev = "v${version}"; + sha256 = "sha256-4bhEQTVV04EyX/qJGNMIAQDcMWcDVr1tFkEjBHpc2CA="; + }; + + vendorHash = "sha256-ohZW/uPdt08Y2EpIQ2yeyGSjV9O58+QbQQqYrs6O8/g="; + + nativeBuildInputs = [ pkgs.installShellFiles ]; + + subPackages = [ "protoc-gen-grpc-gateway" "protoc-gen-openapiv2" ]; + }; + + protobuf-language-server = buildGo rec { + pname = "protobuf-language-server"; + version = "1cf777d"; + + src = pkgs.fetchFromGitHub { + owner = "lasorda"; + repo = "protobuf-language-server"; + rev = "1cf777de4d35a6e493a689e3ca1a6183ce3206b6"; + sha256 = "sha256-9MkBQPxr/TDr/sNz/Sk7eoZwZwzdVbE5u6RugXXk5iY="; + }; + + vendorHash = "sha256-4nTpKBe7ekJsfQf+P6edT/9Vp2SBYbKz1ITawD3bhkI="; + + subPackages = [ "." ]; + }; + + # Upstream does not override buildGoModule properly, + # importing a specific module, so comment out for now. + # golangci-lint = prev.golangci-lint.override { + # buildGoModule = buildGo; + # }; + # golangci-lint-langserver = prev.golangci-lint.override { + # buildGoModule = buildGo; + # }; + + # The package uses buildGo125Module, not the convention. + # goreleaser = prev.goreleaser.override { + # buildGoModule = buildGo; + # }; + + gotestsum = prev.gotestsum.override { + buildGoModule = buildGo; + }; + + gotests = prev.gotests.override { + buildGoModule = buildGo; + }; + + gofumpt = prev.gofumpt.override { + buildGoModule = buildGo; + }; + + # gopls = prev.gopls.override { + # buildGoModule = buildGo; + # }; + }; } // flake-utils.lib.eachDefaultSystem - (system: let - pkgs = import nixpkgs { - overlays = [self.overlay]; - inherit system; - }; - buildDeps = with pkgs; [git go_1_21 gnumake]; - devDeps = with pkgs; - buildDeps - ++ [ - golangci-lint - golines - nodePackages.prettier - goreleaser - nfpm - gotestsum - gotests - ksh + (system: + let + pkgs = import nixpkgs { + overlays = [ self.overlays.default ]; + inherit system; + }; + buildDeps = with pkgs; [ git go_1_25 gnumake ]; + devDeps = with pkgs; + buildDeps + ++ [ + golangci-lint + golangci-lint-langserver + golines + nodePackages.prettier + nixpkgs-fmt + goreleaser + nfpm + gotestsum + gotests + gofumpt + gopls + ksh + ko + yq-go + ripgrep + postgresql + prek - # 'dot' is needed for pprof graphs - # go tool pprof -http=: <source> - graphviz + # 'dot' is needed for pprof graphs + # go tool pprof -http=: <source> + graphviz - # Protobuf dependencies - protobuf - protoc-gen-go - protoc-gen-go-grpc - protoc-gen-grpc-gateway - buf - clang-tools # clang-format - ]; + # Protobuf dependencies + protobuf + protoc-gen-go + protoc-gen-go-grpc + protoc-gen-grpc-gateway + buf + clang-tools # clang-format + protobuf-language-server + ] + ++ lib.optional pkgs.stdenv.isLinux [ traceroute ]; - # Add entry to build a docker image with headscale - # caveat: only works on Linux - # - # Usage: - # nix build .#headscale-docker - # docker load < result - headscale-docker = pkgs.dockerTools.buildLayeredImage { - name = "headscale"; - tag = headscaleVersion; - contents = [pkgs.headscale]; - config.Entrypoint = [(pkgs.headscale + "/bin/headscale")]; - }; - in rec { - # `nix develop` - devShell = pkgs.mkShell { - buildInputs = devDeps; + # Add entry to build a docker image with headscale + # caveat: only works on Linux + # + # Usage: + # nix build .#headscale-docker + # docker load < result + headscale-docker = pkgs.dockerTools.buildLayeredImage { + name = "headscale"; + tag = headscaleVersion; + contents = [ pkgs.headscale ]; + config.Entrypoint = [ (pkgs.headscale + "/bin/headscale") ]; + }; + in + { + # `nix develop` + devShells.default = pkgs.mkShell { + buildInputs = + devDeps + ++ [ + (pkgs.writeShellScriptBin + "nix-vendor-sri" + '' + set -eu - shellHook = '' - export PATH="$PWD/result/bin:$PATH" + OUT=$(mktemp -d -t nar-hash-XXXXXX) + rm -rf "$OUT" - mkdir -p ./ignored - export HEADSCALE_PRIVATE_KEY_PATH="./ignored/private.key" - export HEADSCALE_NOISE_PRIVATE_KEY_PATH="./ignored/noise_private.key" - export HEADSCALE_DB_PATH="./ignored/db.sqlite" - export HEADSCALE_TLS_LETSENCRYPT_CACHE_DIR="./ignored/cache" - export HEADSCALE_UNIX_SOCKET="./ignored/headscale.sock" - ''; - }; + go mod vendor -o "$OUT" + go run tailscale.com/cmd/nardump --sri "$OUT" + rm -rf "$OUT" + '') - # `nix build` - packages = with pkgs; { - inherit headscale; - inherit headscale-docker; - }; - defaultPackage = pkgs.headscale; - - # `nix run` - apps.headscale = flake-utils.lib.mkApp { - drv = packages.headscale; - }; - apps.default = apps.headscale; - - checks = { - format = - pkgs.runCommand "check-format" - { - buildInputs = with pkgs; [ - gnumake - nixpkgs-fmt - golangci-lint - nodePackages.prettier - golines - clang-tools + (pkgs.writeShellScriptBin + "go-mod-update-all" + '' + cat go.mod | ${pkgs.silver-searcher}/bin/ag "\t" | ${pkgs.silver-searcher}/bin/ag -v indirect | ${pkgs.gawk}/bin/awk '{print $1}' | ${pkgs.findutils}/bin/xargs go get -u + go mod tidy + '') ]; - } '' - ${pkgs.nixpkgs-fmt}/bin/nixpkgs-fmt ${./.} - ${pkgs.golangci-lint}/bin/golangci-lint run --fix --timeout 10m - ${pkgs.nodePackages.prettier}/bin/prettier --write '**/**.{ts,js,md,yaml,yml,sass,css,scss,html}' - ${pkgs.golines}/bin/golines --max-len=88 --base-formatter=gofumpt -w ${./.} - ${pkgs.clang-tools}/bin/clang-format -style="{BasedOnStyle: Google, IndentWidth: 4, AlignConsecutiveDeclarations: true, AlignConsecutiveAssignments: true, ColumnLimit: 0}" -i ${./.} + + shellHook = '' + export PATH="$PWD/result/bin:$PATH" + export CGO_ENABLED=0 ''; - }; - }); + }; + + # `nix build` + packages = with pkgs; { + inherit headscale; + inherit headscale-docker; + default = headscale; + }; + + # `nix run` + apps.headscale = flake-utils.lib.mkApp { + drv = pkgs.headscale; + }; + apps.default = flake-utils.lib.mkApp { + drv = pkgs.headscale; + }; + + checks = { + headscale = pkgs.testers.nixosTest (import ./nix/tests/headscale.nix); + }; + }); } diff --git a/gen/go/headscale/v1/apikey.pb.go b/gen/go/headscale/v1/apikey.pb.go index 3e1ebd9c..0c855738 100644 --- a/gen/go/headscale/v1/apikey.pb.go +++ b/gen/go/headscale/v1/apikey.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: headscale/v1/apikey.proto @@ -12,6 +12,7 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -22,24 +23,21 @@ const ( ) type ApiKey struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Prefix string `protobuf:"bytes,2,opt,name=prefix,proto3" json:"prefix,omitempty"` + Expiration *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=expiration,proto3" json:"expiration,omitempty"` + CreatedAt *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + LastSeen *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` unknownFields protoimpl.UnknownFields - - Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` - Prefix string `protobuf:"bytes,2,opt,name=prefix,proto3" json:"prefix,omitempty"` - Expiration *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=expiration,proto3" json:"expiration,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` - LastSeen *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ApiKey) Reset() { *x = ApiKey{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ApiKey) String() string { @@ -50,7 +48,7 @@ func (*ApiKey) ProtoMessage() {} func (x *ApiKey) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -101,20 +99,17 @@ func (x *ApiKey) GetLastSeen() *timestamppb.Timestamp { } type CreateApiKeyRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Expiration *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=expiration,proto3" json:"expiration,omitempty"` unknownFields protoimpl.UnknownFields - - Expiration *timestamppb.Timestamp `protobuf:"bytes,1,opt,name=expiration,proto3" json:"expiration,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CreateApiKeyRequest) Reset() { *x = CreateApiKeyRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreateApiKeyRequest) String() string { @@ -125,7 +120,7 @@ func (*CreateApiKeyRequest) ProtoMessage() {} func (x *CreateApiKeyRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -148,20 +143,17 @@ func (x *CreateApiKeyRequest) GetExpiration() *timestamppb.Timestamp { } type CreateApiKeyResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + ApiKey string `protobuf:"bytes,1,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` unknownFields protoimpl.UnknownFields - - ApiKey string `protobuf:"bytes,1,opt,name=api_key,json=apiKey,proto3" json:"api_key,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CreateApiKeyResponse) Reset() { *x = CreateApiKeyResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreateApiKeyResponse) String() string { @@ -172,7 +164,7 @@ func (*CreateApiKeyResponse) ProtoMessage() {} func (x *CreateApiKeyResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -195,20 +187,18 @@ func (x *CreateApiKeyResponse) GetApiKey() string { } type ExpireApiKeyRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Prefix string `protobuf:"bytes,1,opt,name=prefix,proto3" json:"prefix,omitempty"` + Id uint64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields - - Prefix string `protobuf:"bytes,1,opt,name=prefix,proto3" json:"prefix,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ExpireApiKeyRequest) Reset() { *x = ExpireApiKeyRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExpireApiKeyRequest) String() string { @@ -219,7 +209,7 @@ func (*ExpireApiKeyRequest) ProtoMessage() {} func (x *ExpireApiKeyRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -241,19 +231,24 @@ func (x *ExpireApiKeyRequest) GetPrefix() string { return "" } +func (x *ExpireApiKeyRequest) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + type ExpireApiKeyResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExpireApiKeyResponse) Reset() { *x = ExpireApiKeyResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExpireApiKeyResponse) String() string { @@ -264,7 +259,7 @@ func (*ExpireApiKeyResponse) ProtoMessage() {} func (x *ExpireApiKeyResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -280,18 +275,16 @@ func (*ExpireApiKeyResponse) Descriptor() ([]byte, []int) { } type ListApiKeysRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ListApiKeysRequest) Reset() { *x = ListApiKeysRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListApiKeysRequest) String() string { @@ -302,7 +295,7 @@ func (*ListApiKeysRequest) ProtoMessage() {} func (x *ListApiKeysRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -318,20 +311,17 @@ func (*ListApiKeysRequest) Descriptor() ([]byte, []int) { } type ListApiKeysResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + ApiKeys []*ApiKey `protobuf:"bytes,1,rep,name=api_keys,json=apiKeys,proto3" json:"api_keys,omitempty"` unknownFields protoimpl.UnknownFields - - ApiKeys []*ApiKey `protobuf:"bytes,1,rep,name=api_keys,json=apiKeys,proto3" json:"api_keys,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ListApiKeysResponse) Reset() { *x = ListApiKeysResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_apikey_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_apikey_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListApiKeysResponse) String() string { @@ -342,7 +332,7 @@ func (*ListApiKeysResponse) ProtoMessage() {} func (x *ListApiKeysResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_apikey_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -364,66 +354,140 @@ func (x *ListApiKeysResponse) GetApiKeys() []*ApiKey { return nil } +type DeleteApiKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Prefix string `protobuf:"bytes,1,opt,name=prefix,proto3" json:"prefix,omitempty"` + Id uint64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteApiKeyRequest) Reset() { + *x = DeleteApiKeyRequest{} + mi := &file_headscale_v1_apikey_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteApiKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteApiKeyRequest) ProtoMessage() {} + +func (x *DeleteApiKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_apikey_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteApiKeyRequest.ProtoReflect.Descriptor instead. +func (*DeleteApiKeyRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_apikey_proto_rawDescGZIP(), []int{7} +} + +func (x *DeleteApiKeyRequest) GetPrefix() string { + if x != nil { + return x.Prefix + } + return "" +} + +func (x *DeleteApiKeyRequest) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +type DeleteApiKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeleteApiKeyResponse) Reset() { + *x = DeleteApiKeyResponse{} + mi := &file_headscale_v1_apikey_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeleteApiKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeleteApiKeyResponse) ProtoMessage() {} + +func (x *DeleteApiKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_apikey_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeleteApiKeyResponse.ProtoReflect.Descriptor instead. +func (*DeleteApiKeyResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_apikey_proto_rawDescGZIP(), []int{8} +} + var File_headscale_v1_apikey_proto protoreflect.FileDescriptor -var file_headscale_v1_apikey_proto_rawDesc = []byte{ - 0x0a, 0x19, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x61, - 0x70, 0x69, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xe0, 0x01, 0x0a, 0x06, 0x41, - 0x70, 0x69, 0x4b, 0x65, 0x79, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x3a, 0x0a, - 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x03, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x65, - 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, - 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, - 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, - 0x65, 0x64, 0x41, 0x74, 0x12, 0x37, 0x0a, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, - 0x6e, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, - 0x61, 0x6d, 0x70, 0x52, 0x08, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x65, 0x65, 0x6e, 0x22, 0x51, 0x0a, - 0x13, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x3a, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, - 0x6f, 0x6e, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, - 0x22, 0x2f, 0x0a, 0x14, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x17, 0x0a, 0x07, 0x61, 0x70, 0x69, 0x5f, - 0x6b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x61, 0x70, 0x69, 0x4b, 0x65, - 0x79, 0x22, 0x2d, 0x0a, 0x13, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x65, 0x66, - 0x69, 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, - 0x22, 0x16, 0x0a, 0x14, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x14, 0x0a, 0x12, 0x4c, 0x69, 0x73, 0x74, - 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x46, - 0x0a, 0x13, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2f, 0x0a, 0x08, 0x61, 0x70, 0x69, 0x5f, 0x6b, 0x65, 0x79, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x14, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x52, 0x07, 0x61, - 0x70, 0x69, 0x4b, 0x65, 0x79, 0x73, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, - 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, 0x75, 0x61, 0x6e, 0x66, 0x6f, 0x6e, 0x74, 0x2f, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x67, 0x6f, 0x2f, 0x76, - 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +const file_headscale_v1_apikey_proto_rawDesc = "" + + "\n" + + "\x19headscale/v1/apikey.proto\x12\fheadscale.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"\xe0\x01\n" + + "\x06ApiKey\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x12\x16\n" + + "\x06prefix\x18\x02 \x01(\tR\x06prefix\x12:\n" + + "\n" + + "expiration\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\n" + + "expiration\x129\n" + + "\n" + + "created_at\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\tcreatedAt\x127\n" + + "\tlast_seen\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\blastSeen\"Q\n" + + "\x13CreateApiKeyRequest\x12:\n" + + "\n" + + "expiration\x18\x01 \x01(\v2\x1a.google.protobuf.TimestampR\n" + + "expiration\"/\n" + + "\x14CreateApiKeyResponse\x12\x17\n" + + "\aapi_key\x18\x01 \x01(\tR\x06apiKey\"=\n" + + "\x13ExpireApiKeyRequest\x12\x16\n" + + "\x06prefix\x18\x01 \x01(\tR\x06prefix\x12\x0e\n" + + "\x02id\x18\x02 \x01(\x04R\x02id\"\x16\n" + + "\x14ExpireApiKeyResponse\"\x14\n" + + "\x12ListApiKeysRequest\"F\n" + + "\x13ListApiKeysResponse\x12/\n" + + "\bapi_keys\x18\x01 \x03(\v2\x14.headscale.v1.ApiKeyR\aapiKeys\"=\n" + + "\x13DeleteApiKeyRequest\x12\x16\n" + + "\x06prefix\x18\x01 \x01(\tR\x06prefix\x12\x0e\n" + + "\x02id\x18\x02 \x01(\x04R\x02id\"\x16\n" + + "\x14DeleteApiKeyResponseB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( file_headscale_v1_apikey_proto_rawDescOnce sync.Once - file_headscale_v1_apikey_proto_rawDescData = file_headscale_v1_apikey_proto_rawDesc + file_headscale_v1_apikey_proto_rawDescData []byte ) func file_headscale_v1_apikey_proto_rawDescGZIP() []byte { file_headscale_v1_apikey_proto_rawDescOnce.Do(func() { - file_headscale_v1_apikey_proto_rawDescData = protoimpl.X.CompressGZIP(file_headscale_v1_apikey_proto_rawDescData) + file_headscale_v1_apikey_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_apikey_proto_rawDesc), len(file_headscale_v1_apikey_proto_rawDesc))) }) return file_headscale_v1_apikey_proto_rawDescData } -var file_headscale_v1_apikey_proto_msgTypes = make([]protoimpl.MessageInfo, 7) -var file_headscale_v1_apikey_proto_goTypes = []interface{}{ +var file_headscale_v1_apikey_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_headscale_v1_apikey_proto_goTypes = []any{ (*ApiKey)(nil), // 0: headscale.v1.ApiKey (*CreateApiKeyRequest)(nil), // 1: headscale.v1.CreateApiKeyRequest (*CreateApiKeyResponse)(nil), // 2: headscale.v1.CreateApiKeyResponse @@ -431,13 +495,15 @@ var file_headscale_v1_apikey_proto_goTypes = []interface{}{ (*ExpireApiKeyResponse)(nil), // 4: headscale.v1.ExpireApiKeyResponse (*ListApiKeysRequest)(nil), // 5: headscale.v1.ListApiKeysRequest (*ListApiKeysResponse)(nil), // 6: headscale.v1.ListApiKeysResponse - (*timestamppb.Timestamp)(nil), // 7: google.protobuf.Timestamp + (*DeleteApiKeyRequest)(nil), // 7: headscale.v1.DeleteApiKeyRequest + (*DeleteApiKeyResponse)(nil), // 8: headscale.v1.DeleteApiKeyResponse + (*timestamppb.Timestamp)(nil), // 9: google.protobuf.Timestamp } var file_headscale_v1_apikey_proto_depIdxs = []int32{ - 7, // 0: headscale.v1.ApiKey.expiration:type_name -> google.protobuf.Timestamp - 7, // 1: headscale.v1.ApiKey.created_at:type_name -> google.protobuf.Timestamp - 7, // 2: headscale.v1.ApiKey.last_seen:type_name -> google.protobuf.Timestamp - 7, // 3: headscale.v1.CreateApiKeyRequest.expiration:type_name -> google.protobuf.Timestamp + 9, // 0: headscale.v1.ApiKey.expiration:type_name -> google.protobuf.Timestamp + 9, // 1: headscale.v1.ApiKey.created_at:type_name -> google.protobuf.Timestamp + 9, // 2: headscale.v1.ApiKey.last_seen:type_name -> google.protobuf.Timestamp + 9, // 3: headscale.v1.CreateApiKeyRequest.expiration:type_name -> google.protobuf.Timestamp 0, // 4: headscale.v1.ListApiKeysResponse.api_keys:type_name -> headscale.v1.ApiKey 5, // [5:5] is the sub-list for method output_type 5, // [5:5] is the sub-list for method input_type @@ -451,99 +517,13 @@ func file_headscale_v1_apikey_proto_init() { if File_headscale_v1_apikey_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_headscale_v1_apikey_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ApiKey); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_apikey_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateApiKeyRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_apikey_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateApiKeyResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_apikey_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExpireApiKeyRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_apikey_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExpireApiKeyResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_apikey_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListApiKeysRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_apikey_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListApiKeysResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_apikey_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_apikey_proto_rawDesc), len(file_headscale_v1_apikey_proto_rawDesc)), NumEnums: 0, - NumMessages: 7, + NumMessages: 9, NumExtensions: 0, NumServices: 0, }, @@ -552,7 +532,6 @@ func file_headscale_v1_apikey_proto_init() { MessageInfos: file_headscale_v1_apikey_proto_msgTypes, }.Build() File_headscale_v1_apikey_proto = out.File - file_headscale_v1_apikey_proto_rawDesc = nil file_headscale_v1_apikey_proto_goTypes = nil file_headscale_v1_apikey_proto_depIdxs = nil } diff --git a/gen/go/headscale/v1/device.pb.go b/gen/go/headscale/v1/device.pb.go index 1de3084f..e2362b05 100644 --- a/gen/go/headscale/v1/device.pb.go +++ b/gen/go/headscale/v1/device.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: headscale/v1/device.proto @@ -12,6 +12,7 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -22,21 +23,18 @@ const ( ) type Latency struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + LatencyMs float32 `protobuf:"fixed32,1,opt,name=latency_ms,json=latencyMs,proto3" json:"latency_ms,omitempty"` + Preferred bool `protobuf:"varint,2,opt,name=preferred,proto3" json:"preferred,omitempty"` unknownFields protoimpl.UnknownFields - - LatencyMs float32 `protobuf:"fixed32,1,opt,name=latency_ms,json=latencyMs,proto3" json:"latency_ms,omitempty"` - Preferred bool `protobuf:"varint,2,opt,name=preferred,proto3" json:"preferred,omitempty"` + sizeCache protoimpl.SizeCache } func (x *Latency) Reset() { *x = Latency{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Latency) String() string { @@ -47,7 +45,7 @@ func (*Latency) ProtoMessage() {} func (x *Latency) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -77,25 +75,22 @@ func (x *Latency) GetPreferred() bool { } type ClientSupports struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + HairPinning bool `protobuf:"varint,1,opt,name=hair_pinning,json=hairPinning,proto3" json:"hair_pinning,omitempty"` + Ipv6 bool `protobuf:"varint,2,opt,name=ipv6,proto3" json:"ipv6,omitempty"` + Pcp bool `protobuf:"varint,3,opt,name=pcp,proto3" json:"pcp,omitempty"` + Pmp bool `protobuf:"varint,4,opt,name=pmp,proto3" json:"pmp,omitempty"` + Udp bool `protobuf:"varint,5,opt,name=udp,proto3" json:"udp,omitempty"` + Upnp bool `protobuf:"varint,6,opt,name=upnp,proto3" json:"upnp,omitempty"` unknownFields protoimpl.UnknownFields - - HairPinning bool `protobuf:"varint,1,opt,name=hair_pinning,json=hairPinning,proto3" json:"hair_pinning,omitempty"` - Ipv6 bool `protobuf:"varint,2,opt,name=ipv6,proto3" json:"ipv6,omitempty"` - Pcp bool `protobuf:"varint,3,opt,name=pcp,proto3" json:"pcp,omitempty"` - Pmp bool `protobuf:"varint,4,opt,name=pmp,proto3" json:"pmp,omitempty"` - Udp bool `protobuf:"varint,5,opt,name=udp,proto3" json:"udp,omitempty"` - Upnp bool `protobuf:"varint,6,opt,name=upnp,proto3" json:"upnp,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ClientSupports) Reset() { *x = ClientSupports{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ClientSupports) String() string { @@ -106,7 +101,7 @@ func (*ClientSupports) ProtoMessage() {} func (x *ClientSupports) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -164,24 +159,21 @@ func (x *ClientSupports) GetUpnp() bool { } type ClientConnectivity struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Endpoints []string `protobuf:"bytes,1,rep,name=endpoints,proto3" json:"endpoints,omitempty"` - Derp string `protobuf:"bytes,2,opt,name=derp,proto3" json:"derp,omitempty"` - MappingVariesByDestIp bool `protobuf:"varint,3,opt,name=mapping_varies_by_dest_ip,json=mappingVariesByDestIp,proto3" json:"mapping_varies_by_dest_ip,omitempty"` - Latency map[string]*Latency `protobuf:"bytes,4,rep,name=latency,proto3" json:"latency,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` - ClientSupports *ClientSupports `protobuf:"bytes,5,opt,name=client_supports,json=clientSupports,proto3" json:"client_supports,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Endpoints []string `protobuf:"bytes,1,rep,name=endpoints,proto3" json:"endpoints,omitempty"` + Derp string `protobuf:"bytes,2,opt,name=derp,proto3" json:"derp,omitempty"` + MappingVariesByDestIp bool `protobuf:"varint,3,opt,name=mapping_varies_by_dest_ip,json=mappingVariesByDestIp,proto3" json:"mapping_varies_by_dest_ip,omitempty"` + Latency map[string]*Latency `protobuf:"bytes,4,rep,name=latency,proto3" json:"latency,omitempty" protobuf_key:"bytes,1,opt,name=key" protobuf_val:"bytes,2,opt,name=value"` + ClientSupports *ClientSupports `protobuf:"bytes,5,opt,name=client_supports,json=clientSupports,proto3" json:"client_supports,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ClientConnectivity) Reset() { *x = ClientConnectivity{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ClientConnectivity) String() string { @@ -192,7 +184,7 @@ func (*ClientConnectivity) ProtoMessage() {} func (x *ClientConnectivity) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -243,20 +235,17 @@ func (x *ClientConnectivity) GetClientSupports() *ClientSupports { } type GetDeviceRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetDeviceRequest) Reset() { *x = GetDeviceRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetDeviceRequest) String() string { @@ -267,7 +256,7 @@ func (*GetDeviceRequest) ProtoMessage() {} func (x *GetDeviceRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -290,10 +279,7 @@ func (x *GetDeviceRequest) GetId() string { } type GetDeviceResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - + state protoimpl.MessageState `protogen:"open.v1"` Addresses []string `protobuf:"bytes,1,rep,name=addresses,proto3" json:"addresses,omitempty"` Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` User string `protobuf:"bytes,3,opt,name=user,proto3" json:"user,omitempty"` @@ -314,15 +300,15 @@ type GetDeviceResponse struct { EnabledRoutes []string `protobuf:"bytes,18,rep,name=enabled_routes,json=enabledRoutes,proto3" json:"enabled_routes,omitempty"` AdvertisedRoutes []string `protobuf:"bytes,19,rep,name=advertised_routes,json=advertisedRoutes,proto3" json:"advertised_routes,omitempty"` ClientConnectivity *ClientConnectivity `protobuf:"bytes,20,opt,name=client_connectivity,json=clientConnectivity,proto3" json:"client_connectivity,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetDeviceResponse) Reset() { *x = GetDeviceResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetDeviceResponse) String() string { @@ -333,7 +319,7 @@ func (*GetDeviceResponse) ProtoMessage() {} func (x *GetDeviceResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -489,20 +475,17 @@ func (x *GetDeviceResponse) GetClientConnectivity() *ClientConnectivity { } type DeleteDeviceRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + sizeCache protoimpl.SizeCache } func (x *DeleteDeviceRequest) Reset() { *x = DeleteDeviceRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DeleteDeviceRequest) String() string { @@ -513,7 +496,7 @@ func (*DeleteDeviceRequest) ProtoMessage() {} func (x *DeleteDeviceRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -536,18 +519,16 @@ func (x *DeleteDeviceRequest) GetId() string { } type DeleteDeviceResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *DeleteDeviceResponse) Reset() { *x = DeleteDeviceResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DeleteDeviceResponse) String() string { @@ -558,7 +539,7 @@ func (*DeleteDeviceResponse) ProtoMessage() {} func (x *DeleteDeviceResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -574,20 +555,17 @@ func (*DeleteDeviceResponse) Descriptor() ([]byte, []int) { } type GetDeviceRoutesRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetDeviceRoutesRequest) Reset() { *x = GetDeviceRoutesRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetDeviceRoutesRequest) String() string { @@ -598,7 +576,7 @@ func (*GetDeviceRoutesRequest) ProtoMessage() {} func (x *GetDeviceRoutesRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[7] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -621,21 +599,18 @@ func (x *GetDeviceRoutesRequest) GetId() string { } type GetDeviceRoutesResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - EnabledRoutes []string `protobuf:"bytes,1,rep,name=enabled_routes,json=enabledRoutes,proto3" json:"enabled_routes,omitempty"` - AdvertisedRoutes []string `protobuf:"bytes,2,rep,name=advertised_routes,json=advertisedRoutes,proto3" json:"advertised_routes,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + EnabledRoutes []string `protobuf:"bytes,1,rep,name=enabled_routes,json=enabledRoutes,proto3" json:"enabled_routes,omitempty"` + AdvertisedRoutes []string `protobuf:"bytes,2,rep,name=advertised_routes,json=advertisedRoutes,proto3" json:"advertised_routes,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetDeviceRoutesResponse) Reset() { *x = GetDeviceRoutesResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetDeviceRoutesResponse) String() string { @@ -646,7 +621,7 @@ func (*GetDeviceRoutesResponse) ProtoMessage() {} func (x *GetDeviceRoutesResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[8] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -676,21 +651,18 @@ func (x *GetDeviceRoutesResponse) GetAdvertisedRoutes() []string { } type EnableDeviceRoutesRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` + Routes []string `protobuf:"bytes,2,rep,name=routes,proto3" json:"routes,omitempty"` unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - Routes []string `protobuf:"bytes,2,rep,name=routes,proto3" json:"routes,omitempty"` + sizeCache protoimpl.SizeCache } func (x *EnableDeviceRoutesRequest) Reset() { *x = EnableDeviceRoutesRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *EnableDeviceRoutesRequest) String() string { @@ -701,7 +673,7 @@ func (*EnableDeviceRoutesRequest) ProtoMessage() {} func (x *EnableDeviceRoutesRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[9] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -731,21 +703,18 @@ func (x *EnableDeviceRoutesRequest) GetRoutes() []string { } type EnableDeviceRoutesResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - EnabledRoutes []string `protobuf:"bytes,1,rep,name=enabled_routes,json=enabledRoutes,proto3" json:"enabled_routes,omitempty"` - AdvertisedRoutes []string `protobuf:"bytes,2,rep,name=advertised_routes,json=advertisedRoutes,proto3" json:"advertised_routes,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + EnabledRoutes []string `protobuf:"bytes,1,rep,name=enabled_routes,json=enabledRoutes,proto3" json:"enabled_routes,omitempty"` + AdvertisedRoutes []string `protobuf:"bytes,2,rep,name=advertised_routes,json=advertisedRoutes,proto3" json:"advertised_routes,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *EnableDeviceRoutesResponse) Reset() { *x = EnableDeviceRoutesResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_device_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_device_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *EnableDeviceRoutesResponse) String() string { @@ -756,7 +725,7 @@ func (*EnableDeviceRoutesResponse) ProtoMessage() {} func (x *EnableDeviceRoutesResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_device_proto_msgTypes[10] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -787,145 +756,86 @@ func (x *EnableDeviceRoutesResponse) GetAdvertisedRoutes() []string { var File_headscale_v1_device_proto protoreflect.FileDescriptor -var file_headscale_v1_device_proto_rawDesc = []byte{ - 0x0a, 0x19, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x64, - 0x65, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x46, 0x0a, 0x07, 0x4c, 0x61, - 0x74, 0x65, 0x6e, 0x63, 0x79, 0x12, 0x1d, 0x0a, 0x0a, 0x6c, 0x61, 0x74, 0x65, 0x6e, 0x63, 0x79, - 0x5f, 0x6d, 0x73, 0x18, 0x01, 0x20, 0x01, 0x28, 0x02, 0x52, 0x09, 0x6c, 0x61, 0x74, 0x65, 0x6e, - 0x63, 0x79, 0x4d, 0x73, 0x12, 0x1c, 0x0a, 0x09, 0x70, 0x72, 0x65, 0x66, 0x65, 0x72, 0x72, 0x65, - 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x70, 0x72, 0x65, 0x66, 0x65, 0x72, 0x72, - 0x65, 0x64, 0x22, 0x91, 0x01, 0x0a, 0x0e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x53, 0x75, 0x70, - 0x70, 0x6f, 0x72, 0x74, 0x73, 0x12, 0x21, 0x0a, 0x0c, 0x68, 0x61, 0x69, 0x72, 0x5f, 0x70, 0x69, - 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0b, 0x68, 0x61, 0x69, - 0x72, 0x50, 0x69, 0x6e, 0x6e, 0x69, 0x6e, 0x67, 0x12, 0x12, 0x0a, 0x04, 0x69, 0x70, 0x76, 0x36, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x69, 0x70, 0x76, 0x36, 0x12, 0x10, 0x0a, 0x03, - 0x70, 0x63, 0x70, 0x18, 0x03, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x63, 0x70, 0x12, 0x10, - 0x0a, 0x03, 0x70, 0x6d, 0x70, 0x18, 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x70, 0x6d, 0x70, - 0x12, 0x10, 0x0a, 0x03, 0x75, 0x64, 0x70, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x03, 0x75, - 0x64, 0x70, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x70, 0x6e, 0x70, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x04, 0x75, 0x70, 0x6e, 0x70, 0x22, 0xe3, 0x02, 0x0a, 0x12, 0x43, 0x6c, 0x69, 0x65, 0x6e, - 0x74, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x12, 0x1c, 0x0a, - 0x09, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x09, 0x65, 0x6e, 0x64, 0x70, 0x6f, 0x69, 0x6e, 0x74, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x64, - 0x65, 0x72, 0x70, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x64, 0x65, 0x72, 0x70, 0x12, - 0x38, 0x0a, 0x19, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x5f, 0x76, 0x61, 0x72, 0x69, 0x65, - 0x73, 0x5f, 0x62, 0x79, 0x5f, 0x64, 0x65, 0x73, 0x74, 0x5f, 0x69, 0x70, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x15, 0x6d, 0x61, 0x70, 0x70, 0x69, 0x6e, 0x67, 0x56, 0x61, 0x72, 0x69, 0x65, - 0x73, 0x42, 0x79, 0x44, 0x65, 0x73, 0x74, 0x49, 0x70, 0x12, 0x47, 0x0a, 0x07, 0x6c, 0x61, 0x74, - 0x65, 0x6e, 0x63, 0x79, 0x18, 0x04, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x2d, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, 0x74, - 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x2e, 0x4c, 0x61, 0x74, - 0x65, 0x6e, 0x63, 0x79, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x52, 0x07, 0x6c, 0x61, 0x74, 0x65, 0x6e, - 0x63, 0x79, 0x12, 0x45, 0x0a, 0x0f, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x73, 0x75, 0x70, - 0x70, 0x6f, 0x72, 0x74, 0x73, 0x18, 0x05, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1c, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6c, 0x69, 0x65, 0x6e, - 0x74, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x52, 0x0e, 0x63, 0x6c, 0x69, 0x65, 0x6e, - 0x74, 0x53, 0x75, 0x70, 0x70, 0x6f, 0x72, 0x74, 0x73, 0x1a, 0x51, 0x0a, 0x0c, 0x4c, 0x61, 0x74, - 0x65, 0x6e, 0x63, 0x79, 0x45, 0x6e, 0x74, 0x72, 0x79, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x2b, 0x0a, 0x05, 0x76, - 0x61, 0x6c, 0x75, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x15, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x61, 0x74, 0x65, 0x6e, 0x63, - 0x79, 0x52, 0x05, 0x76, 0x61, 0x6c, 0x75, 0x65, 0x3a, 0x02, 0x38, 0x01, 0x22, 0x22, 0x0a, 0x10, - 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, - 0x22, 0xa0, 0x06, 0x0a, 0x11, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1c, 0x0a, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, - 0x73, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x61, 0x64, 0x64, 0x72, 0x65, - 0x73, 0x73, 0x65, 0x73, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x02, 0x69, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x1a, 0x0a, 0x08, - 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x05, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, - 0x68, 0x6f, 0x73, 0x74, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x63, 0x6c, 0x69, 0x65, - 0x6e, 0x74, 0x5f, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x18, 0x06, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x0d, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x56, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x12, - 0x29, 0x0a, 0x10, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x5f, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, - 0x62, 0x6c, 0x65, 0x18, 0x07, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0f, 0x75, 0x70, 0x64, 0x61, 0x74, - 0x65, 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x6f, 0x73, - 0x18, 0x08, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x6f, 0x73, 0x12, 0x34, 0x0a, 0x07, 0x63, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x64, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, - 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x07, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, - 0x12, 0x37, 0x0a, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, 0x6e, 0x18, 0x0a, 0x20, - 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, - 0x08, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x65, 0x65, 0x6e, 0x12, 0x2e, 0x0a, 0x13, 0x6b, 0x65, 0x79, - 0x5f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x5f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, - 0x18, 0x0b, 0x20, 0x01, 0x28, 0x08, 0x52, 0x11, 0x6b, 0x65, 0x79, 0x45, 0x78, 0x70, 0x69, 0x72, - 0x79, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x34, 0x0a, 0x07, 0x65, 0x78, 0x70, - 0x69, 0x72, 0x65, 0x73, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x07, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x73, 0x12, - 0x1e, 0x0a, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x18, 0x0d, 0x20, - 0x01, 0x28, 0x08, 0x52, 0x0a, 0x61, 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x12, - 0x1f, 0x0a, 0x0b, 0x69, 0x73, 0x5f, 0x65, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, 0x18, 0x0e, - 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x69, 0x73, 0x45, 0x78, 0x74, 0x65, 0x72, 0x6e, 0x61, 0x6c, - 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, - 0x0f, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x4b, 0x65, - 0x79, 0x12, 0x19, 0x0a, 0x08, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x10, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x3e, 0x0a, 0x1b, - 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x73, 0x5f, 0x69, 0x6e, 0x63, 0x6f, 0x6d, 0x69, 0x6e, 0x67, 0x5f, - 0x63, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x18, 0x11, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x19, 0x62, 0x6c, 0x6f, 0x63, 0x6b, 0x73, 0x49, 0x6e, 0x63, 0x6f, 0x6d, 0x69, 0x6e, - 0x67, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x12, 0x25, 0x0a, 0x0e, - 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x12, - 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, - 0x64, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x13, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, - 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, 0x64, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x12, 0x51, 0x0a, 0x13, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x5f, 0x63, 0x6f, 0x6e, 0x6e, 0x65, - 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x18, 0x14, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x20, 0x2e, - 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x6c, 0x69, - 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x76, 0x69, 0x74, 0x79, 0x52, - 0x12, 0x63, 0x6c, 0x69, 0x65, 0x6e, 0x74, 0x43, 0x6f, 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x76, - 0x69, 0x74, 0x79, 0x22, 0x25, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x44, 0x65, 0x76, - 0x69, 0x63, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x22, 0x16, 0x0a, 0x14, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x22, 0x28, 0x0a, 0x16, 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x22, 0x6d, 0x0a, 0x17, - 0x47, 0x65, 0x74, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x65, 0x6e, 0x61, 0x62, 0x6c, - 0x65, 0x64, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, - 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x2b, - 0x0a, 0x11, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, 0x64, 0x5f, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x10, 0x61, 0x64, 0x76, 0x65, 0x72, - 0x74, 0x69, 0x73, 0x65, 0x64, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x43, 0x0a, 0x19, 0x45, - 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x22, 0x70, 0x0a, 0x1a, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x44, 0x65, 0x76, 0x69, 0x63, 0x65, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x25, - 0x0a, 0x0e, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x18, 0x01, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0d, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x2b, 0x0a, 0x11, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, - 0x73, 0x65, 0x64, 0x5f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, - 0x52, 0x10, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, 0x64, 0x52, 0x6f, 0x75, 0x74, - 0x65, 0x73, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x6a, 0x75, 0x61, 0x6e, 0x66, 0x6f, 0x6e, 0x74, 0x2f, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x67, 0x6f, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +const file_headscale_v1_device_proto_rawDesc = "" + + "\n" + + "\x19headscale/v1/device.proto\x12\fheadscale.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"F\n" + + "\aLatency\x12\x1d\n" + + "\n" + + "latency_ms\x18\x01 \x01(\x02R\tlatencyMs\x12\x1c\n" + + "\tpreferred\x18\x02 \x01(\bR\tpreferred\"\x91\x01\n" + + "\x0eClientSupports\x12!\n" + + "\fhair_pinning\x18\x01 \x01(\bR\vhairPinning\x12\x12\n" + + "\x04ipv6\x18\x02 \x01(\bR\x04ipv6\x12\x10\n" + + "\x03pcp\x18\x03 \x01(\bR\x03pcp\x12\x10\n" + + "\x03pmp\x18\x04 \x01(\bR\x03pmp\x12\x10\n" + + "\x03udp\x18\x05 \x01(\bR\x03udp\x12\x12\n" + + "\x04upnp\x18\x06 \x01(\bR\x04upnp\"\xe3\x02\n" + + "\x12ClientConnectivity\x12\x1c\n" + + "\tendpoints\x18\x01 \x03(\tR\tendpoints\x12\x12\n" + + "\x04derp\x18\x02 \x01(\tR\x04derp\x128\n" + + "\x19mapping_varies_by_dest_ip\x18\x03 \x01(\bR\x15mappingVariesByDestIp\x12G\n" + + "\alatency\x18\x04 \x03(\v2-.headscale.v1.ClientConnectivity.LatencyEntryR\alatency\x12E\n" + + "\x0fclient_supports\x18\x05 \x01(\v2\x1c.headscale.v1.ClientSupportsR\x0eclientSupports\x1aQ\n" + + "\fLatencyEntry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12+\n" + + "\x05value\x18\x02 \x01(\v2\x15.headscale.v1.LatencyR\x05value:\x028\x01\"\"\n" + + "\x10GetDeviceRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\"\xa0\x06\n" + + "\x11GetDeviceResponse\x12\x1c\n" + + "\taddresses\x18\x01 \x03(\tR\taddresses\x12\x0e\n" + + "\x02id\x18\x02 \x01(\tR\x02id\x12\x12\n" + + "\x04user\x18\x03 \x01(\tR\x04user\x12\x12\n" + + "\x04name\x18\x04 \x01(\tR\x04name\x12\x1a\n" + + "\bhostname\x18\x05 \x01(\tR\bhostname\x12%\n" + + "\x0eclient_version\x18\x06 \x01(\tR\rclientVersion\x12)\n" + + "\x10update_available\x18\a \x01(\bR\x0fupdateAvailable\x12\x0e\n" + + "\x02os\x18\b \x01(\tR\x02os\x124\n" + + "\acreated\x18\t \x01(\v2\x1a.google.protobuf.TimestampR\acreated\x127\n" + + "\tlast_seen\x18\n" + + " \x01(\v2\x1a.google.protobuf.TimestampR\blastSeen\x12.\n" + + "\x13key_expiry_disabled\x18\v \x01(\bR\x11keyExpiryDisabled\x124\n" + + "\aexpires\x18\f \x01(\v2\x1a.google.protobuf.TimestampR\aexpires\x12\x1e\n" + + "\n" + + "authorized\x18\r \x01(\bR\n" + + "authorized\x12\x1f\n" + + "\vis_external\x18\x0e \x01(\bR\n" + + "isExternal\x12\x1f\n" + + "\vmachine_key\x18\x0f \x01(\tR\n" + + "machineKey\x12\x19\n" + + "\bnode_key\x18\x10 \x01(\tR\anodeKey\x12>\n" + + "\x1bblocks_incoming_connections\x18\x11 \x01(\bR\x19blocksIncomingConnections\x12%\n" + + "\x0eenabled_routes\x18\x12 \x03(\tR\renabledRoutes\x12+\n" + + "\x11advertised_routes\x18\x13 \x03(\tR\x10advertisedRoutes\x12Q\n" + + "\x13client_connectivity\x18\x14 \x01(\v2 .headscale.v1.ClientConnectivityR\x12clientConnectivity\"%\n" + + "\x13DeleteDeviceRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\"\x16\n" + + "\x14DeleteDeviceResponse\"(\n" + + "\x16GetDeviceRoutesRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\"m\n" + + "\x17GetDeviceRoutesResponse\x12%\n" + + "\x0eenabled_routes\x18\x01 \x03(\tR\renabledRoutes\x12+\n" + + "\x11advertised_routes\x18\x02 \x03(\tR\x10advertisedRoutes\"C\n" + + "\x19EnableDeviceRoutesRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\tR\x02id\x12\x16\n" + + "\x06routes\x18\x02 \x03(\tR\x06routes\"p\n" + + "\x1aEnableDeviceRoutesResponse\x12%\n" + + "\x0eenabled_routes\x18\x01 \x03(\tR\renabledRoutes\x12+\n" + + "\x11advertised_routes\x18\x02 \x03(\tR\x10advertisedRoutesB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( file_headscale_v1_device_proto_rawDescOnce sync.Once - file_headscale_v1_device_proto_rawDescData = file_headscale_v1_device_proto_rawDesc + file_headscale_v1_device_proto_rawDescData []byte ) func file_headscale_v1_device_proto_rawDescGZIP() []byte { file_headscale_v1_device_proto_rawDescOnce.Do(func() { - file_headscale_v1_device_proto_rawDescData = protoimpl.X.CompressGZIP(file_headscale_v1_device_proto_rawDescData) + file_headscale_v1_device_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_device_proto_rawDesc), len(file_headscale_v1_device_proto_rawDesc))) }) return file_headscale_v1_device_proto_rawDescData } var file_headscale_v1_device_proto_msgTypes = make([]protoimpl.MessageInfo, 12) -var file_headscale_v1_device_proto_goTypes = []interface{}{ +var file_headscale_v1_device_proto_goTypes = []any{ (*Latency)(nil), // 0: headscale.v1.Latency (*ClientSupports)(nil), // 1: headscale.v1.ClientSupports (*ClientConnectivity)(nil), // 2: headscale.v1.ClientConnectivity @@ -960,145 +870,11 @@ func file_headscale_v1_device_proto_init() { if File_headscale_v1_device_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_headscale_v1_device_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Latency); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ClientSupports); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ClientConnectivity); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetDeviceRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetDeviceResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteDeviceRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteDeviceResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetDeviceRoutesRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetDeviceRoutesResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EnableDeviceRoutesRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_device_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EnableDeviceRoutesResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_device_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_device_proto_rawDesc), len(file_headscale_v1_device_proto_rawDesc)), NumEnums: 0, NumMessages: 12, NumExtensions: 0, @@ -1109,7 +885,6 @@ func file_headscale_v1_device_proto_init() { MessageInfos: file_headscale_v1_device_proto_msgTypes, }.Build() File_headscale_v1_device_proto = out.File - file_headscale_v1_device_proto_rawDesc = nil file_headscale_v1_device_proto_goTypes = nil file_headscale_v1_device_proto_depIdxs = nil } diff --git a/gen/go/headscale/v1/headscale.pb.go b/gen/go/headscale/v1/headscale.pb.go index 9917fd81..3d16778c 100644 --- a/gen/go/headscale/v1/headscale.pb.go +++ b/gen/go/headscale/v1/headscale.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: headscale/v1/headscale.proto @@ -11,6 +11,8 @@ import ( protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" + sync "sync" + unsafe "unsafe" ) const ( @@ -20,315 +22,243 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) +type HealthRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HealthRequest) Reset() { + *x = HealthRequest{} + mi := &file_headscale_v1_headscale_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HealthRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthRequest) ProtoMessage() {} + +func (x *HealthRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_headscale_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthRequest.ProtoReflect.Descriptor instead. +func (*HealthRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_headscale_proto_rawDescGZIP(), []int{0} +} + +type HealthResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + DatabaseConnectivity bool `protobuf:"varint,1,opt,name=database_connectivity,json=databaseConnectivity,proto3" json:"database_connectivity,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *HealthResponse) Reset() { + *x = HealthResponse{} + mi := &file_headscale_v1_headscale_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *HealthResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HealthResponse) ProtoMessage() {} + +func (x *HealthResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_headscale_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HealthResponse.ProtoReflect.Descriptor instead. +func (*HealthResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_headscale_proto_rawDescGZIP(), []int{1} +} + +func (x *HealthResponse) GetDatabaseConnectivity() bool { + if x != nil { + return x.DatabaseConnectivity + } + return false +} + var File_headscale_v1_headscale_proto protoreflect.FileDescriptor -var file_headscale_v1_headscale_proto_rawDesc = []byte{ - 0x0a, 0x1c, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, - 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1c, 0x67, 0x6f, - 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x61, 0x6e, 0x6e, 0x6f, 0x74, 0x61, 0x74, - 0x69, 0x6f, 0x6e, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x68, 0x65, 0x61, 0x64, - 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x1a, 0x1d, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, - 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x1a, 0x17, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, - 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x19, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x19, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x61, 0x70, 0x69, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x32, 0x85, 0x17, 0x0a, 0x10, 0x48, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x53, - 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x63, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x55, 0x73, 0x65, - 0x72, 0x12, 0x1c, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x47, 0x65, 0x74, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x1d, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x47, - 0x65, 0x74, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, - 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x15, 0x12, 0x13, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, - 0x75, 0x73, 0x65, 0x72, 0x2f, 0x7b, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x12, 0x68, 0x0a, 0x0a, 0x43, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, - 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, - 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x17, 0x82, 0xd3, - 0xe4, 0x93, 0x02, 0x11, 0x3a, 0x01, 0x2a, 0x22, 0x0c, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, - 0x2f, 0x75, 0x73, 0x65, 0x72, 0x12, 0x82, 0x01, 0x0a, 0x0a, 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, - 0x55, 0x73, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, - 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x31, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x2b, 0x22, - 0x29, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x2f, 0x7b, 0x6f, - 0x6c, 0x64, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x2f, 0x72, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x2f, - 0x7b, 0x6e, 0x65, 0x77, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x12, 0x6c, 0x0a, 0x0a, 0x44, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x12, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x55, 0x73, - 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x68, 0x65, 0x61, 0x64, - 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x55, - 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1b, 0x82, 0xd3, 0xe4, - 0x93, 0x02, 0x15, 0x2a, 0x13, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x73, 0x65, - 0x72, 0x2f, 0x7b, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x12, 0x62, 0x0a, 0x09, 0x4c, 0x69, 0x73, 0x74, - 0x55, 0x73, 0x65, 0x72, 0x73, 0x12, 0x1e, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x55, 0x73, 0x65, 0x72, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x55, 0x73, 0x65, 0x72, 0x73, 0x52, 0x65, - 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x14, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x0e, 0x12, 0x0c, - 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x12, 0x80, 0x01, 0x0a, - 0x10, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, - 0x79, 0x12, 0x25, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x22, 0x1d, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x17, 0x3a, 0x01, 0x2a, 0x22, 0x12, 0x2f, 0x61, 0x70, - 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, 0x65, 0x79, 0x12, - 0x87, 0x01, 0x0a, 0x10, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x4b, 0x65, 0x79, 0x12, 0x25, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, - 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, - 0x68, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x26, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x69, 0x72, - 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x24, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1e, 0x3a, 0x01, 0x2a, 0x22, 0x19, - 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, - 0x65, 0x79, 0x2f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x12, 0x7a, 0x0a, 0x0f, 0x4c, 0x69, 0x73, - 0x74, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x24, 0x2e, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x25, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, - 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1a, 0x82, 0xd3, 0xe4, 0x93, 0x02, - 0x14, 0x12, 0x12, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x70, 0x72, 0x65, 0x61, 0x75, - 0x74, 0x68, 0x6b, 0x65, 0x79, 0x12, 0x7d, 0x0a, 0x0f, 0x44, 0x65, 0x62, 0x75, 0x67, 0x43, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x24, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x62, 0x75, 0x67, 0x43, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x25, - 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, - 0x62, 0x75, 0x67, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1d, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x17, 0x3a, 0x01, 0x2a, - 0x22, 0x12, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x64, 0x65, 0x62, 0x75, 0x67, 0x2f, - 0x6e, 0x6f, 0x64, 0x65, 0x12, 0x66, 0x0a, 0x07, 0x47, 0x65, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x12, - 0x1c, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x47, - 0x65, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, - 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, - 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x1e, 0x82, 0xd3, - 0xe4, 0x93, 0x02, 0x18, 0x12, 0x16, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, - 0x64, 0x65, 0x2f, 0x7b, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x7d, 0x12, 0x6e, 0x0a, 0x07, - 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x73, 0x12, 0x1c, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x73, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1d, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x26, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x20, 0x3a, 0x01, 0x2a, 0x22, - 0x1b, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2f, 0x7b, 0x6e, - 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x74, 0x61, 0x67, 0x73, 0x12, 0x74, 0x0a, 0x0c, - 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x21, 0x2e, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x67, 0x69, - 0x73, 0x74, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x22, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, - 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x1d, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x17, 0x22, 0x15, 0x2f, 0x61, 0x70, - 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2f, 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, - 0x65, 0x72, 0x12, 0x6f, 0x0a, 0x0a, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, - 0x12, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, - 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x1a, 0x20, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x22, 0x1e, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x18, 0x2a, 0x16, 0x2f, 0x61, 0x70, - 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2f, 0x7b, 0x6e, 0x6f, 0x64, 0x65, 0x5f, - 0x69, 0x64, 0x7d, 0x12, 0x76, 0x0a, 0x0a, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x4e, 0x6f, 0x64, - 0x65, 0x12, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, - 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, - 0x31, 0x2e, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x25, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1f, 0x22, 0x1d, 0x2f, 0x61, - 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2f, 0x7b, 0x6e, 0x6f, 0x64, 0x65, - 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x12, 0x81, 0x01, 0x0a, 0x0a, - 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x1f, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, - 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x20, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x65, 0x6e, 0x61, 0x6d, - 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x30, 0x82, - 0xd3, 0xe4, 0x93, 0x02, 0x2a, 0x22, 0x28, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, - 0x6f, 0x64, 0x65, 0x2f, 0x7b, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x72, 0x65, - 0x6e, 0x61, 0x6d, 0x65, 0x2f, 0x7b, 0x6e, 0x65, 0x77, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x7d, 0x12, - 0x62, 0x0a, 0x09, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x73, 0x12, 0x1e, 0x2e, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x4e, 0x6f, 0x64, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1f, 0x2e, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x4e, 0x6f, 0x64, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x14, 0x82, - 0xd3, 0xe4, 0x93, 0x02, 0x0e, 0x12, 0x0c, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, - 0x6f, 0x64, 0x65, 0x12, 0x6e, 0x0a, 0x08, 0x4d, 0x6f, 0x76, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x12, - 0x1d, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4d, - 0x6f, 0x76, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1e, - 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4d, 0x6f, - 0x76, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x23, - 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1d, 0x22, 0x1b, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, - 0x6e, 0x6f, 0x64, 0x65, 0x2f, 0x7b, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x75, - 0x73, 0x65, 0x72, 0x12, 0x64, 0x0a, 0x09, 0x47, 0x65, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x12, 0x1e, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, - 0x47, 0x65, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x1a, 0x1f, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, - 0x47, 0x65, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x16, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x10, 0x12, 0x0e, 0x2f, 0x61, 0x70, 0x69, 0x2f, - 0x76, 0x31, 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x7c, 0x0a, 0x0b, 0x45, 0x6e, 0x61, - 0x62, 0x6c, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x20, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x28, 0x82, - 0xd3, 0xe4, 0x93, 0x02, 0x22, 0x22, 0x20, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x5f, 0x69, 0x64, 0x7d, - 0x2f, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x80, 0x01, 0x0a, 0x0c, 0x44, 0x69, 0x73, 0x61, - 0x62, 0x6c, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x21, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x52, - 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x29, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x23, 0x22, 0x21, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, - 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x5f, 0x69, - 0x64, 0x7d, 0x2f, 0x64, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x7f, 0x0a, 0x0d, 0x47, 0x65, - 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x22, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x47, 0x65, 0x74, 0x4e, 0x6f, - 0x64, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, - 0x23, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x47, - 0x65, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x25, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1f, 0x12, 0x1d, 0x2f, 0x61, - 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2f, 0x7b, 0x6e, 0x6f, 0x64, 0x65, - 0x5f, 0x69, 0x64, 0x7d, 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x12, 0x75, 0x0a, 0x0b, 0x44, - 0x65, 0x6c, 0x65, 0x74, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x20, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, 0x2e, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x44, 0x65, 0x6c, 0x65, - 0x74, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x21, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x1b, 0x2a, 0x19, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, - 0x2f, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x2f, 0x7b, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x5f, 0x69, - 0x64, 0x7d, 0x12, 0x70, 0x0a, 0x0c, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x70, 0x69, 0x4b, - 0x65, 0x79, 0x12, 0x21, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, - 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x19, 0x82, 0xd3, 0xe4, 0x93, 0x02, - 0x13, 0x3a, 0x01, 0x2a, 0x22, 0x0e, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, 0x61, 0x70, - 0x69, 0x6b, 0x65, 0x79, 0x12, 0x77, 0x0a, 0x0c, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x41, 0x70, - 0x69, 0x4b, 0x65, 0x79, 0x12, 0x21, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, - 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x22, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x41, 0x70, 0x69, - 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x20, 0x82, 0xd3, 0xe4, - 0x93, 0x02, 0x1a, 0x3a, 0x01, 0x2a, 0x22, 0x15, 0x2f, 0x61, 0x70, 0x69, 0x2f, 0x76, 0x31, 0x2f, - 0x61, 0x70, 0x69, 0x6b, 0x65, 0x79, 0x2f, 0x65, 0x78, 0x70, 0x69, 0x72, 0x65, 0x12, 0x6a, 0x0a, - 0x0b, 0x4c, 0x69, 0x73, 0x74, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x73, 0x12, 0x20, 0x2e, 0x68, - 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, 0x73, 0x74, - 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x21, - 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x69, - 0x73, 0x74, 0x41, 0x70, 0x69, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, - 0x65, 0x22, 0x16, 0x82, 0xd3, 0xe4, 0x93, 0x02, 0x10, 0x12, 0x0e, 0x2f, 0x61, 0x70, 0x69, 0x2f, - 0x76, 0x31, 0x2f, 0x61, 0x70, 0x69, 0x6b, 0x65, 0x79, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, - 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, 0x75, 0x61, 0x6e, 0x66, 0x6f, 0x6e, 0x74, - 0x2f, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x67, - 0x6f, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +const file_headscale_v1_headscale_proto_rawDesc = "" + + "\n" + + "\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" + + "\rHealthRequest\"E\n" + + "\x0eHealthResponse\x123\n" + + "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\x8c\x17\n" + + "\x10HeadscaleService\x12h\n" + + "\n" + + "CreateUser\x12\x1f.headscale.v1.CreateUserRequest\x1a .headscale.v1.CreateUserResponse\"\x17\x82\xd3\xe4\x93\x02\x11:\x01*\"\f/api/v1/user\x12\x80\x01\n" + + "\n" + + "RenameUser\x12\x1f.headscale.v1.RenameUserRequest\x1a .headscale.v1.RenameUserResponse\"/\x82\xd3\xe4\x93\x02)\"'/api/v1/user/{old_id}/rename/{new_name}\x12j\n" + + "\n" + + "DeleteUser\x12\x1f.headscale.v1.DeleteUserRequest\x1a .headscale.v1.DeleteUserResponse\"\x19\x82\xd3\xe4\x93\x02\x13*\x11/api/v1/user/{id}\x12b\n" + + "\tListUsers\x12\x1e.headscale.v1.ListUsersRequest\x1a\x1f.headscale.v1.ListUsersResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\x12\f/api/v1/user\x12\x80\x01\n" + + "\x10CreatePreAuthKey\x12%.headscale.v1.CreatePreAuthKeyRequest\x1a&.headscale.v1.CreatePreAuthKeyResponse\"\x1d\x82\xd3\xe4\x93\x02\x17:\x01*\"\x12/api/v1/preauthkey\x12\x87\x01\n" + + "\x10ExpirePreAuthKey\x12%.headscale.v1.ExpirePreAuthKeyRequest\x1a&.headscale.v1.ExpirePreAuthKeyResponse\"$\x82\xd3\xe4\x93\x02\x1e:\x01*\"\x19/api/v1/preauthkey/expire\x12}\n" + + "\x10DeletePreAuthKey\x12%.headscale.v1.DeletePreAuthKeyRequest\x1a&.headscale.v1.DeletePreAuthKeyResponse\"\x1a\x82\xd3\xe4\x93\x02\x14*\x12/api/v1/preauthkey\x12z\n" + + "\x0fListPreAuthKeys\x12$.headscale.v1.ListPreAuthKeysRequest\x1a%.headscale.v1.ListPreAuthKeysResponse\"\x1a\x82\xd3\xe4\x93\x02\x14\x12\x12/api/v1/preauthkey\x12}\n" + + "\x0fDebugCreateNode\x12$.headscale.v1.DebugCreateNodeRequest\x1a%.headscale.v1.DebugCreateNodeResponse\"\x1d\x82\xd3\xe4\x93\x02\x17:\x01*\"\x12/api/v1/debug/node\x12f\n" + + "\aGetNode\x12\x1c.headscale.v1.GetNodeRequest\x1a\x1d.headscale.v1.GetNodeResponse\"\x1e\x82\xd3\xe4\x93\x02\x18\x12\x16/api/v1/node/{node_id}\x12n\n" + + "\aSetTags\x12\x1c.headscale.v1.SetTagsRequest\x1a\x1d.headscale.v1.SetTagsResponse\"&\x82\xd3\xe4\x93\x02 :\x01*\"\x1b/api/v1/node/{node_id}/tags\x12\x96\x01\n" + + "\x11SetApprovedRoutes\x12&.headscale.v1.SetApprovedRoutesRequest\x1a'.headscale.v1.SetApprovedRoutesResponse\"0\x82\xd3\xe4\x93\x02*:\x01*\"%/api/v1/node/{node_id}/approve_routes\x12t\n" + + "\fRegisterNode\x12!.headscale.v1.RegisterNodeRequest\x1a\".headscale.v1.RegisterNodeResponse\"\x1d\x82\xd3\xe4\x93\x02\x17\"\x15/api/v1/node/register\x12o\n" + + "\n" + + "DeleteNode\x12\x1f.headscale.v1.DeleteNodeRequest\x1a .headscale.v1.DeleteNodeResponse\"\x1e\x82\xd3\xe4\x93\x02\x18*\x16/api/v1/node/{node_id}\x12v\n" + + "\n" + + "ExpireNode\x12\x1f.headscale.v1.ExpireNodeRequest\x1a .headscale.v1.ExpireNodeResponse\"%\x82\xd3\xe4\x93\x02\x1f\"\x1d/api/v1/node/{node_id}/expire\x12\x81\x01\n" + + "\n" + + "RenameNode\x12\x1f.headscale.v1.RenameNodeRequest\x1a .headscale.v1.RenameNodeResponse\"0\x82\xd3\xe4\x93\x02*\"(/api/v1/node/{node_id}/rename/{new_name}\x12b\n" + + "\tListNodes\x12\x1e.headscale.v1.ListNodesRequest\x1a\x1f.headscale.v1.ListNodesResponse\"\x14\x82\xd3\xe4\x93\x02\x0e\x12\f/api/v1/node\x12\x80\x01\n" + + "\x0fBackfillNodeIPs\x12$.headscale.v1.BackfillNodeIPsRequest\x1a%.headscale.v1.BackfillNodeIPsResponse\" \x82\xd3\xe4\x93\x02\x1a\"\x18/api/v1/node/backfillips\x12p\n" + + "\fCreateApiKey\x12!.headscale.v1.CreateApiKeyRequest\x1a\".headscale.v1.CreateApiKeyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\"\x0e/api/v1/apikey\x12w\n" + + "\fExpireApiKey\x12!.headscale.v1.ExpireApiKeyRequest\x1a\".headscale.v1.ExpireApiKeyResponse\" \x82\xd3\xe4\x93\x02\x1a:\x01*\"\x15/api/v1/apikey/expire\x12j\n" + + "\vListApiKeys\x12 .headscale.v1.ListApiKeysRequest\x1a!.headscale.v1.ListApiKeysResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/apikey\x12v\n" + + "\fDeleteApiKey\x12!.headscale.v1.DeleteApiKeyRequest\x1a\".headscale.v1.DeleteApiKeyResponse\"\x1f\x82\xd3\xe4\x93\x02\x19*\x17/api/v1/apikey/{prefix}\x12d\n" + + "\tGetPolicy\x12\x1e.headscale.v1.GetPolicyRequest\x1a\x1f.headscale.v1.GetPolicyResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/policy\x12g\n" + + "\tSetPolicy\x12\x1e.headscale.v1.SetPolicyRequest\x1a\x1f.headscale.v1.SetPolicyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\x1a\x0e/api/v1/policy\x12[\n" + + "\x06Health\x12\x1b.headscale.v1.HealthRequest\x1a\x1c.headscale.v1.HealthResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/healthB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" + +var ( + file_headscale_v1_headscale_proto_rawDescOnce sync.Once + file_headscale_v1_headscale_proto_rawDescData []byte +) + +func file_headscale_v1_headscale_proto_rawDescGZIP() []byte { + file_headscale_v1_headscale_proto_rawDescOnce.Do(func() { + file_headscale_v1_headscale_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_headscale_proto_rawDesc), len(file_headscale_v1_headscale_proto_rawDesc))) + }) + return file_headscale_v1_headscale_proto_rawDescData } -var file_headscale_v1_headscale_proto_goTypes = []interface{}{ - (*GetUserRequest)(nil), // 0: headscale.v1.GetUserRequest - (*CreateUserRequest)(nil), // 1: headscale.v1.CreateUserRequest - (*RenameUserRequest)(nil), // 2: headscale.v1.RenameUserRequest - (*DeleteUserRequest)(nil), // 3: headscale.v1.DeleteUserRequest - (*ListUsersRequest)(nil), // 4: headscale.v1.ListUsersRequest - (*CreatePreAuthKeyRequest)(nil), // 5: headscale.v1.CreatePreAuthKeyRequest - (*ExpirePreAuthKeyRequest)(nil), // 6: headscale.v1.ExpirePreAuthKeyRequest - (*ListPreAuthKeysRequest)(nil), // 7: headscale.v1.ListPreAuthKeysRequest - (*DebugCreateNodeRequest)(nil), // 8: headscale.v1.DebugCreateNodeRequest - (*GetNodeRequest)(nil), // 9: headscale.v1.GetNodeRequest - (*SetTagsRequest)(nil), // 10: headscale.v1.SetTagsRequest - (*RegisterNodeRequest)(nil), // 11: headscale.v1.RegisterNodeRequest - (*DeleteNodeRequest)(nil), // 12: headscale.v1.DeleteNodeRequest - (*ExpireNodeRequest)(nil), // 13: headscale.v1.ExpireNodeRequest - (*RenameNodeRequest)(nil), // 14: headscale.v1.RenameNodeRequest - (*ListNodesRequest)(nil), // 15: headscale.v1.ListNodesRequest - (*MoveNodeRequest)(nil), // 16: headscale.v1.MoveNodeRequest - (*GetRoutesRequest)(nil), // 17: headscale.v1.GetRoutesRequest - (*EnableRouteRequest)(nil), // 18: headscale.v1.EnableRouteRequest - (*DisableRouteRequest)(nil), // 19: headscale.v1.DisableRouteRequest - (*GetNodeRoutesRequest)(nil), // 20: headscale.v1.GetNodeRoutesRequest - (*DeleteRouteRequest)(nil), // 21: headscale.v1.DeleteRouteRequest - (*CreateApiKeyRequest)(nil), // 22: headscale.v1.CreateApiKeyRequest - (*ExpireApiKeyRequest)(nil), // 23: headscale.v1.ExpireApiKeyRequest - (*ListApiKeysRequest)(nil), // 24: headscale.v1.ListApiKeysRequest - (*GetUserResponse)(nil), // 25: headscale.v1.GetUserResponse - (*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse - (*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse - (*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse - (*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse - (*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse - (*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse - (*ListPreAuthKeysResponse)(nil), // 32: headscale.v1.ListPreAuthKeysResponse - (*DebugCreateNodeResponse)(nil), // 33: headscale.v1.DebugCreateNodeResponse - (*GetNodeResponse)(nil), // 34: headscale.v1.GetNodeResponse - (*SetTagsResponse)(nil), // 35: headscale.v1.SetTagsResponse - (*RegisterNodeResponse)(nil), // 36: headscale.v1.RegisterNodeResponse - (*DeleteNodeResponse)(nil), // 37: headscale.v1.DeleteNodeResponse - (*ExpireNodeResponse)(nil), // 38: headscale.v1.ExpireNodeResponse - (*RenameNodeResponse)(nil), // 39: headscale.v1.RenameNodeResponse - (*ListNodesResponse)(nil), // 40: headscale.v1.ListNodesResponse - (*MoveNodeResponse)(nil), // 41: headscale.v1.MoveNodeResponse - (*GetRoutesResponse)(nil), // 42: headscale.v1.GetRoutesResponse - (*EnableRouteResponse)(nil), // 43: headscale.v1.EnableRouteResponse - (*DisableRouteResponse)(nil), // 44: headscale.v1.DisableRouteResponse - (*GetNodeRoutesResponse)(nil), // 45: headscale.v1.GetNodeRoutesResponse - (*DeleteRouteResponse)(nil), // 46: headscale.v1.DeleteRouteResponse - (*CreateApiKeyResponse)(nil), // 47: headscale.v1.CreateApiKeyResponse - (*ExpireApiKeyResponse)(nil), // 48: headscale.v1.ExpireApiKeyResponse - (*ListApiKeysResponse)(nil), // 49: headscale.v1.ListApiKeysResponse +var file_headscale_v1_headscale_proto_msgTypes = make([]protoimpl.MessageInfo, 2) +var file_headscale_v1_headscale_proto_goTypes = []any{ + (*HealthRequest)(nil), // 0: headscale.v1.HealthRequest + (*HealthResponse)(nil), // 1: headscale.v1.HealthResponse + (*CreateUserRequest)(nil), // 2: headscale.v1.CreateUserRequest + (*RenameUserRequest)(nil), // 3: headscale.v1.RenameUserRequest + (*DeleteUserRequest)(nil), // 4: headscale.v1.DeleteUserRequest + (*ListUsersRequest)(nil), // 5: headscale.v1.ListUsersRequest + (*CreatePreAuthKeyRequest)(nil), // 6: headscale.v1.CreatePreAuthKeyRequest + (*ExpirePreAuthKeyRequest)(nil), // 7: headscale.v1.ExpirePreAuthKeyRequest + (*DeletePreAuthKeyRequest)(nil), // 8: headscale.v1.DeletePreAuthKeyRequest + (*ListPreAuthKeysRequest)(nil), // 9: headscale.v1.ListPreAuthKeysRequest + (*DebugCreateNodeRequest)(nil), // 10: headscale.v1.DebugCreateNodeRequest + (*GetNodeRequest)(nil), // 11: headscale.v1.GetNodeRequest + (*SetTagsRequest)(nil), // 12: headscale.v1.SetTagsRequest + (*SetApprovedRoutesRequest)(nil), // 13: headscale.v1.SetApprovedRoutesRequest + (*RegisterNodeRequest)(nil), // 14: headscale.v1.RegisterNodeRequest + (*DeleteNodeRequest)(nil), // 15: headscale.v1.DeleteNodeRequest + (*ExpireNodeRequest)(nil), // 16: headscale.v1.ExpireNodeRequest + (*RenameNodeRequest)(nil), // 17: headscale.v1.RenameNodeRequest + (*ListNodesRequest)(nil), // 18: headscale.v1.ListNodesRequest + (*BackfillNodeIPsRequest)(nil), // 19: headscale.v1.BackfillNodeIPsRequest + (*CreateApiKeyRequest)(nil), // 20: headscale.v1.CreateApiKeyRequest + (*ExpireApiKeyRequest)(nil), // 21: headscale.v1.ExpireApiKeyRequest + (*ListApiKeysRequest)(nil), // 22: headscale.v1.ListApiKeysRequest + (*DeleteApiKeyRequest)(nil), // 23: headscale.v1.DeleteApiKeyRequest + (*GetPolicyRequest)(nil), // 24: headscale.v1.GetPolicyRequest + (*SetPolicyRequest)(nil), // 25: headscale.v1.SetPolicyRequest + (*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse + (*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse + (*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse + (*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse + (*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse + (*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse + (*DeletePreAuthKeyResponse)(nil), // 32: headscale.v1.DeletePreAuthKeyResponse + (*ListPreAuthKeysResponse)(nil), // 33: headscale.v1.ListPreAuthKeysResponse + (*DebugCreateNodeResponse)(nil), // 34: headscale.v1.DebugCreateNodeResponse + (*GetNodeResponse)(nil), // 35: headscale.v1.GetNodeResponse + (*SetTagsResponse)(nil), // 36: headscale.v1.SetTagsResponse + (*SetApprovedRoutesResponse)(nil), // 37: headscale.v1.SetApprovedRoutesResponse + (*RegisterNodeResponse)(nil), // 38: headscale.v1.RegisterNodeResponse + (*DeleteNodeResponse)(nil), // 39: headscale.v1.DeleteNodeResponse + (*ExpireNodeResponse)(nil), // 40: headscale.v1.ExpireNodeResponse + (*RenameNodeResponse)(nil), // 41: headscale.v1.RenameNodeResponse + (*ListNodesResponse)(nil), // 42: headscale.v1.ListNodesResponse + (*BackfillNodeIPsResponse)(nil), // 43: headscale.v1.BackfillNodeIPsResponse + (*CreateApiKeyResponse)(nil), // 44: headscale.v1.CreateApiKeyResponse + (*ExpireApiKeyResponse)(nil), // 45: headscale.v1.ExpireApiKeyResponse + (*ListApiKeysResponse)(nil), // 46: headscale.v1.ListApiKeysResponse + (*DeleteApiKeyResponse)(nil), // 47: headscale.v1.DeleteApiKeyResponse + (*GetPolicyResponse)(nil), // 48: headscale.v1.GetPolicyResponse + (*SetPolicyResponse)(nil), // 49: headscale.v1.SetPolicyResponse } var file_headscale_v1_headscale_proto_depIdxs = []int32{ - 0, // 0: headscale.v1.HeadscaleService.GetUser:input_type -> headscale.v1.GetUserRequest - 1, // 1: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest - 2, // 2: headscale.v1.HeadscaleService.RenameUser:input_type -> headscale.v1.RenameUserRequest - 3, // 3: headscale.v1.HeadscaleService.DeleteUser:input_type -> headscale.v1.DeleteUserRequest - 4, // 4: headscale.v1.HeadscaleService.ListUsers:input_type -> headscale.v1.ListUsersRequest - 5, // 5: headscale.v1.HeadscaleService.CreatePreAuthKey:input_type -> headscale.v1.CreatePreAuthKeyRequest - 6, // 6: headscale.v1.HeadscaleService.ExpirePreAuthKey:input_type -> headscale.v1.ExpirePreAuthKeyRequest - 7, // 7: headscale.v1.HeadscaleService.ListPreAuthKeys:input_type -> headscale.v1.ListPreAuthKeysRequest - 8, // 8: headscale.v1.HeadscaleService.DebugCreateNode:input_type -> headscale.v1.DebugCreateNodeRequest - 9, // 9: headscale.v1.HeadscaleService.GetNode:input_type -> headscale.v1.GetNodeRequest - 10, // 10: headscale.v1.HeadscaleService.SetTags:input_type -> headscale.v1.SetTagsRequest - 11, // 11: headscale.v1.HeadscaleService.RegisterNode:input_type -> headscale.v1.RegisterNodeRequest - 12, // 12: headscale.v1.HeadscaleService.DeleteNode:input_type -> headscale.v1.DeleteNodeRequest - 13, // 13: headscale.v1.HeadscaleService.ExpireNode:input_type -> headscale.v1.ExpireNodeRequest - 14, // 14: headscale.v1.HeadscaleService.RenameNode:input_type -> headscale.v1.RenameNodeRequest - 15, // 15: headscale.v1.HeadscaleService.ListNodes:input_type -> headscale.v1.ListNodesRequest - 16, // 16: headscale.v1.HeadscaleService.MoveNode:input_type -> headscale.v1.MoveNodeRequest - 17, // 17: headscale.v1.HeadscaleService.GetRoutes:input_type -> headscale.v1.GetRoutesRequest - 18, // 18: headscale.v1.HeadscaleService.EnableRoute:input_type -> headscale.v1.EnableRouteRequest - 19, // 19: headscale.v1.HeadscaleService.DisableRoute:input_type -> headscale.v1.DisableRouteRequest - 20, // 20: headscale.v1.HeadscaleService.GetNodeRoutes:input_type -> headscale.v1.GetNodeRoutesRequest - 21, // 21: headscale.v1.HeadscaleService.DeleteRoute:input_type -> headscale.v1.DeleteRouteRequest - 22, // 22: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest - 23, // 23: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest - 24, // 24: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest - 25, // 25: headscale.v1.HeadscaleService.GetUser:output_type -> headscale.v1.GetUserResponse - 26, // 26: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse - 27, // 27: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse - 28, // 28: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse - 29, // 29: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse - 30, // 30: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse - 31, // 31: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse - 32, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse - 33, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse - 34, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse - 35, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse - 36, // 36: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse - 37, // 37: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse - 38, // 38: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse - 39, // 39: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse - 40, // 40: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse - 41, // 41: headscale.v1.HeadscaleService.MoveNode:output_type -> headscale.v1.MoveNodeResponse - 42, // 42: headscale.v1.HeadscaleService.GetRoutes:output_type -> headscale.v1.GetRoutesResponse - 43, // 43: headscale.v1.HeadscaleService.EnableRoute:output_type -> headscale.v1.EnableRouteResponse - 44, // 44: headscale.v1.HeadscaleService.DisableRoute:output_type -> headscale.v1.DisableRouteResponse - 45, // 45: headscale.v1.HeadscaleService.GetNodeRoutes:output_type -> headscale.v1.GetNodeRoutesResponse - 46, // 46: headscale.v1.HeadscaleService.DeleteRoute:output_type -> headscale.v1.DeleteRouteResponse - 47, // 47: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse - 48, // 48: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse - 49, // 49: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse + 2, // 0: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest + 3, // 1: headscale.v1.HeadscaleService.RenameUser:input_type -> headscale.v1.RenameUserRequest + 4, // 2: headscale.v1.HeadscaleService.DeleteUser:input_type -> headscale.v1.DeleteUserRequest + 5, // 3: headscale.v1.HeadscaleService.ListUsers:input_type -> headscale.v1.ListUsersRequest + 6, // 4: headscale.v1.HeadscaleService.CreatePreAuthKey:input_type -> headscale.v1.CreatePreAuthKeyRequest + 7, // 5: headscale.v1.HeadscaleService.ExpirePreAuthKey:input_type -> headscale.v1.ExpirePreAuthKeyRequest + 8, // 6: headscale.v1.HeadscaleService.DeletePreAuthKey:input_type -> headscale.v1.DeletePreAuthKeyRequest + 9, // 7: headscale.v1.HeadscaleService.ListPreAuthKeys:input_type -> headscale.v1.ListPreAuthKeysRequest + 10, // 8: headscale.v1.HeadscaleService.DebugCreateNode:input_type -> headscale.v1.DebugCreateNodeRequest + 11, // 9: headscale.v1.HeadscaleService.GetNode:input_type -> headscale.v1.GetNodeRequest + 12, // 10: headscale.v1.HeadscaleService.SetTags:input_type -> headscale.v1.SetTagsRequest + 13, // 11: headscale.v1.HeadscaleService.SetApprovedRoutes:input_type -> headscale.v1.SetApprovedRoutesRequest + 14, // 12: headscale.v1.HeadscaleService.RegisterNode:input_type -> headscale.v1.RegisterNodeRequest + 15, // 13: headscale.v1.HeadscaleService.DeleteNode:input_type -> headscale.v1.DeleteNodeRequest + 16, // 14: headscale.v1.HeadscaleService.ExpireNode:input_type -> headscale.v1.ExpireNodeRequest + 17, // 15: headscale.v1.HeadscaleService.RenameNode:input_type -> headscale.v1.RenameNodeRequest + 18, // 16: headscale.v1.HeadscaleService.ListNodes:input_type -> headscale.v1.ListNodesRequest + 19, // 17: headscale.v1.HeadscaleService.BackfillNodeIPs:input_type -> headscale.v1.BackfillNodeIPsRequest + 20, // 18: headscale.v1.HeadscaleService.CreateApiKey:input_type -> headscale.v1.CreateApiKeyRequest + 21, // 19: headscale.v1.HeadscaleService.ExpireApiKey:input_type -> headscale.v1.ExpireApiKeyRequest + 22, // 20: headscale.v1.HeadscaleService.ListApiKeys:input_type -> headscale.v1.ListApiKeysRequest + 23, // 21: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest + 24, // 22: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest + 25, // 23: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest + 0, // 24: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest + 26, // 25: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse + 27, // 26: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse + 28, // 27: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse + 29, // 28: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse + 30, // 29: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse + 31, // 30: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse + 32, // 31: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse + 33, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse + 34, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse + 35, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse + 36, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse + 37, // 36: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse + 38, // 37: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse + 39, // 38: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse + 40, // 39: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse + 41, // 40: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse + 42, // 41: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse + 43, // 42: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse + 44, // 43: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse + 45, // 44: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse + 46, // 45: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse + 47, // 46: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse + 48, // 47: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse + 49, // 48: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse + 1, // 49: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse 25, // [25:50] is the sub-list for method output_type 0, // [0:25] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name @@ -344,23 +274,23 @@ func file_headscale_v1_headscale_proto_init() { file_headscale_v1_user_proto_init() file_headscale_v1_preauthkey_proto_init() file_headscale_v1_node_proto_init() - file_headscale_v1_routes_proto_init() file_headscale_v1_apikey_proto_init() + file_headscale_v1_policy_proto_init() type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_headscale_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_headscale_proto_rawDesc), len(file_headscale_v1_headscale_proto_rawDesc)), NumEnums: 0, - NumMessages: 0, + NumMessages: 2, NumExtensions: 0, NumServices: 1, }, GoTypes: file_headscale_v1_headscale_proto_goTypes, DependencyIndexes: file_headscale_v1_headscale_proto_depIdxs, + MessageInfos: file_headscale_v1_headscale_proto_msgTypes, }.Build() File_headscale_v1_headscale_proto = out.File - file_headscale_v1_headscale_proto_rawDesc = nil file_headscale_v1_headscale_proto_goTypes = nil file_headscale_v1_headscale_proto_depIdxs = nil } diff --git a/gen/go/headscale/v1/headscale.pb.gw.go b/gen/go/headscale/v1/headscale.pb.gw.go index 6acf8fc3..ab851614 100644 --- a/gen/go/headscale/v1/headscale.pb.gw.go +++ b/gen/go/headscale/v1/headscale.pb.gw.go @@ -10,6 +10,7 @@ package v1 import ( "context" + "errors" "io" "net/http" @@ -24,1169 +25,900 @@ import ( ) // Suppress "imported and not used" errors -var _ codes.Code -var _ io.Reader -var _ status.Status -var _ = runtime.String -var _ = utilities.NewDoubleArray -var _ = metadata.Join - -func request_HeadscaleService_GetUser_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetUserRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["name"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "name") - } - - protoReq.Name, err = runtime.String(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "name", err) - } - - msg, err := client.GetUser(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_HeadscaleService_GetUser_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetUserRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["name"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "name") - } - - protoReq.Name, err = runtime.String(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "name", err) - } - - msg, err := server.GetUser(ctx, &protoReq) - return msg, metadata, err - -} +var ( + _ codes.Code + _ io.Reader + _ status.Status + _ = errors.New + _ = runtime.String + _ = utilities.NewDoubleArray + _ = metadata.Join +) func request_HeadscaleService_CreateUser_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CreateUserRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq CreateUserRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.CreateUser(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_CreateUser_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CreateUserRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq CreateUserRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.CreateUser(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_RenameUser_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq RenameUserRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq RenameUserRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["old_name"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["old_id"] if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "old_name") + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "old_id") } - - protoReq.OldName, err = runtime.String(val) + protoReq.OldId, err = runtime.Uint64(val) if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "old_name", err) + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "old_id", err) } - val, ok = pathParams["new_name"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "new_name") } - protoReq.NewName, err = runtime.String(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "new_name", err) } - msg, err := client.RenameUser(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_RenameUser_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq RenameUserRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq RenameUserRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["old_name"] + val, ok := pathParams["old_id"] if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "old_name") + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "old_id") } - - protoReq.OldName, err = runtime.String(val) + protoReq.OldId, err = runtime.Uint64(val) if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "old_name", err) + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "old_id", err) } - val, ok = pathParams["new_name"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "new_name") } - protoReq.NewName, err = runtime.String(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "new_name", err) } - msg, err := server.RenameUser(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_DeleteUser_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteUserRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq DeleteUserRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["name"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["id"] if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "name") + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "id") } - - protoReq.Name, err = runtime.String(val) + protoReq.Id, err = runtime.Uint64(val) if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "name", err) + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "id", err) } - msg, err := client.DeleteUser(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_DeleteUser_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteUserRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq DeleteUserRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["name"] + val, ok := pathParams["id"] if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "name") + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "id") } - - protoReq.Name, err = runtime.String(val) + protoReq.Id, err = runtime.Uint64(val) if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "name", err) + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "id", err) } - msg, err := server.DeleteUser(ctx, &protoReq) return msg, metadata, err - } -func request_HeadscaleService_ListUsers_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListUsersRequest - var metadata runtime.ServerMetadata +var filter_HeadscaleService_ListUsers_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +func request_HeadscaleService_ListUsers_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq ListUsersRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ListUsers_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } msg, err := client.ListUsers(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ListUsers_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListUsersRequest - var metadata runtime.ServerMetadata - + var ( + protoReq ListUsersRequest + metadata runtime.ServerMetadata + ) + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ListUsers_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } msg, err := server.ListUsers(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_CreatePreAuthKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CreatePreAuthKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq CreatePreAuthKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.CreatePreAuthKey(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_CreatePreAuthKey_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CreatePreAuthKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq CreatePreAuthKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.CreatePreAuthKey(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_ExpirePreAuthKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ExpirePreAuthKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq ExpirePreAuthKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.ExpirePreAuthKey(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ExpirePreAuthKey_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ExpirePreAuthKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq ExpirePreAuthKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.ExpirePreAuthKey(ctx, &protoReq) return msg, metadata, err - } -var ( - filter_HeadscaleService_ListPreAuthKeys_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} -) - -func request_HeadscaleService_ListPreAuthKeys_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListPreAuthKeysRequest - var metadata runtime.ServerMetadata +var filter_HeadscaleService_DeletePreAuthKey_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +func request_HeadscaleService_DeletePreAuthKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq DeletePreAuthKeyRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ListPreAuthKeys_0); err != nil { + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_DeletePreAuthKey_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } + msg, err := client.DeletePreAuthKey(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} +func local_request_HeadscaleService_DeletePreAuthKey_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq DeletePreAuthKeyRequest + metadata runtime.ServerMetadata + ) + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_DeletePreAuthKey_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.DeletePreAuthKey(ctx, &protoReq) + return msg, metadata, err +} + +func request_HeadscaleService_ListPreAuthKeys_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq ListPreAuthKeysRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.ListPreAuthKeys(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ListPreAuthKeys_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListPreAuthKeysRequest - var metadata runtime.ServerMetadata - - if err := req.ParseForm(); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ListPreAuthKeys_0); err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) - } - + var ( + protoReq ListPreAuthKeysRequest + metadata runtime.ServerMetadata + ) msg, err := server.ListPreAuthKeys(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_DebugCreateNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DebugCreateNodeRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq DebugCreateNodeRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.DebugCreateNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_DebugCreateNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DebugCreateNodeRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq DebugCreateNodeRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.DebugCreateNode(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_GetNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq GetNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - msg, err := client.GetNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_GetNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq GetNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - msg, err := server.GetNode(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_SetTags_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq SetTagsRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq SetTagsRequest + metadata runtime.ServerMetadata + err error + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["node_id"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - msg, err := client.SetTags(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_SetTags_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq SetTagsRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq SetTagsRequest + metadata runtime.ServerMetadata + err error + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["node_id"] + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - msg, err := server.SetTags(ctx, &protoReq) return msg, metadata, err - } -var ( - filter_HeadscaleService_RegisterNode_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} -) +func request_HeadscaleService_SetApprovedRoutes_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq SetApprovedRoutesRequest + metadata runtime.ServerMetadata + err error + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["node_id"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") + } + protoReq.NodeId, err = runtime.Uint64(val) + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) + } + msg, err := client.SetApprovedRoutes(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_SetApprovedRoutes_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq SetApprovedRoutesRequest + metadata runtime.ServerMetadata + err error + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + val, ok := pathParams["node_id"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") + } + protoReq.NodeId, err = runtime.Uint64(val) + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) + } + msg, err := server.SetApprovedRoutes(ctx, &protoReq) + return msg, metadata, err +} + +var filter_HeadscaleService_RegisterNode_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} func request_HeadscaleService_RegisterNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq RegisterNodeRequest - var metadata runtime.ServerMetadata - + var ( + protoReq RegisterNodeRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_RegisterNode_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := client.RegisterNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_RegisterNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq RegisterNodeRequest - var metadata runtime.ServerMetadata - + var ( + protoReq RegisterNodeRequest + metadata runtime.ServerMetadata + ) if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_RegisterNode_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.RegisterNode(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_DeleteNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq DeleteNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - msg, err := client.DeleteNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_DeleteNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq DeleteNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - msg, err := server.DeleteNode(ctx, &protoReq) return msg, metadata, err - } +var filter_HeadscaleService_ExpireNode_0 = &utilities.DoubleArray{Encoding: map[string]int{"node_id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} + func request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ExpireNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq ExpireNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ExpireNode_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } msg, err := client.ExpireNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ExpireNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ExpireNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq ExpireNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ExpireNode_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } msg, err := server.ExpireNode(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_RenameNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq RenameNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq RenameNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - val, ok = pathParams["new_name"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "new_name") } - protoReq.NewName, err = runtime.String(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "new_name", err) } - msg, err := client.RenameNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_RenameNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq RenameNodeRequest - var metadata runtime.ServerMetadata - var ( - val string - ok bool - err error - _ = err + protoReq RenameNodeRequest + metadata runtime.ServerMetadata + err error ) - - val, ok = pathParams["node_id"] + val, ok := pathParams["node_id"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") } - protoReq.NodeId, err = runtime.Uint64(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) } - val, ok = pathParams["new_name"] if !ok { return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "new_name") } - protoReq.NewName, err = runtime.String(val) if err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "new_name", err) } - msg, err := server.RenameNode(ctx, &protoReq) return msg, metadata, err - } -var ( - filter_HeadscaleService_ListNodes_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} -) +var filter_HeadscaleService_ListNodes_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} func request_HeadscaleService_ListNodes_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListNodesRequest - var metadata runtime.ServerMetadata - + var ( + protoReq ListNodesRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ListNodes_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := client.ListNodes(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ListNodes_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListNodesRequest - var metadata runtime.ServerMetadata - + var ( + protoReq ListNodesRequest + metadata runtime.ServerMetadata + ) if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_ListNodes_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.ListNodes(ctx, &protoReq) return msg, metadata, err - } -var ( - filter_HeadscaleService_MoveNode_0 = &utilities.DoubleArray{Encoding: map[string]int{"node_id": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} -) - -func request_HeadscaleService_MoveNode_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq MoveNodeRequest - var metadata runtime.ServerMetadata +var filter_HeadscaleService_BackfillNodeIPs_0 = &utilities.DoubleArray{Encoding: map[string]int{}, Base: []int(nil), Check: []int(nil)} +func request_HeadscaleService_BackfillNodeIPs_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var ( - val string - ok bool - err error - _ = err + protoReq BackfillNodeIPsRequest + metadata runtime.ServerMetadata ) - - val, ok = pathParams["node_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) } - - protoReq.NodeId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) - } - if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_MoveNode_0); err != nil { + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_BackfillNodeIPs_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - - msg, err := client.MoveNode(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + msg, err := client.BackfillNodeIPs(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } -func local_request_HeadscaleService_MoveNode_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq MoveNodeRequest - var metadata runtime.ServerMetadata - +func local_request_HeadscaleService_BackfillNodeIPs_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var ( - val string - ok bool - err error - _ = err + protoReq BackfillNodeIPsRequest + metadata runtime.ServerMetadata ) - - val, ok = pathParams["node_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") - } - - protoReq.NodeId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) - } - if err := req.ParseForm(); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_MoveNode_0); err != nil { + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_BackfillNodeIPs_0); err != nil { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - - msg, err := server.MoveNode(ctx, &protoReq) + msg, err := server.BackfillNodeIPs(ctx, &protoReq) return msg, metadata, err - -} - -func request_HeadscaleService_GetRoutes_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetRoutesRequest - var metadata runtime.ServerMetadata - - msg, err := client.GetRoutes(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_HeadscaleService_GetRoutes_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetRoutesRequest - var metadata runtime.ServerMetadata - - msg, err := server.GetRoutes(ctx, &protoReq) - return msg, metadata, err - -} - -func request_HeadscaleService_EnableRoute_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq EnableRouteRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["route_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "route_id") - } - - protoReq.RouteId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "route_id", err) - } - - msg, err := client.EnableRoute(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_HeadscaleService_EnableRoute_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq EnableRouteRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["route_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "route_id") - } - - protoReq.RouteId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "route_id", err) - } - - msg, err := server.EnableRoute(ctx, &protoReq) - return msg, metadata, err - -} - -func request_HeadscaleService_DisableRoute_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DisableRouteRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["route_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "route_id") - } - - protoReq.RouteId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "route_id", err) - } - - msg, err := client.DisableRoute(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_HeadscaleService_DisableRoute_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DisableRouteRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["route_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "route_id") - } - - protoReq.RouteId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "route_id", err) - } - - msg, err := server.DisableRoute(ctx, &protoReq) - return msg, metadata, err - -} - -func request_HeadscaleService_GetNodeRoutes_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetNodeRoutesRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["node_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") - } - - protoReq.NodeId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) - } - - msg, err := client.GetNodeRoutes(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_HeadscaleService_GetNodeRoutes_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq GetNodeRoutesRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["node_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "node_id") - } - - protoReq.NodeId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "node_id", err) - } - - msg, err := server.GetNodeRoutes(ctx, &protoReq) - return msg, metadata, err - -} - -func request_HeadscaleService_DeleteRoute_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteRouteRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["route_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "route_id") - } - - protoReq.RouteId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "route_id", err) - } - - msg, err := client.DeleteRoute(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) - return msg, metadata, err - -} - -func local_request_HeadscaleService_DeleteRoute_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq DeleteRouteRequest - var metadata runtime.ServerMetadata - - var ( - val string - ok bool - err error - _ = err - ) - - val, ok = pathParams["route_id"] - if !ok { - return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "route_id") - } - - protoReq.RouteId, err = runtime.Uint64(val) - if err != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "route_id", err) - } - - msg, err := server.DeleteRoute(ctx, &protoReq) - return msg, metadata, err - } func request_HeadscaleService_CreateApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CreateApiKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq CreateApiKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.CreateApiKey(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_CreateApiKey_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq CreateApiKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq CreateApiKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.CreateApiKey(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_ExpireApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ExpireApiKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq ExpireApiKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.ExpireApiKey(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ExpireApiKey_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ExpireApiKeyRequest - var metadata runtime.ServerMetadata - - newReader, berr := utilities.IOReaderFactory(req.Body) - if berr != nil { - return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", berr) - } - if err := marshaler.NewDecoder(newReader()).Decode(&protoReq); err != nil && err != io.EOF { + var ( + protoReq ExpireApiKeyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) } - msg, err := server.ExpireApiKey(ctx, &protoReq) return msg, metadata, err - } func request_HeadscaleService_ListApiKeys_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListApiKeysRequest - var metadata runtime.ServerMetadata - + var ( + protoReq ListApiKeysRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } msg, err := client.ListApiKeys(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) return msg, metadata, err - } func local_request_HeadscaleService_ListApiKeys_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { - var protoReq ListApiKeysRequest - var metadata runtime.ServerMetadata - + var ( + protoReq ListApiKeysRequest + metadata runtime.ServerMetadata + ) msg, err := server.ListApiKeys(ctx, &protoReq) return msg, metadata, err +} +var filter_HeadscaleService_DeleteApiKey_0 = &utilities.DoubleArray{Encoding: map[string]int{"prefix": 0}, Base: []int{1, 1, 0}, Check: []int{0, 1, 2}} + +func request_HeadscaleService_DeleteApiKey_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq DeleteApiKeyRequest + metadata runtime.ServerMetadata + err error + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + val, ok := pathParams["prefix"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "prefix") + } + protoReq.Prefix, err = runtime.String(val) + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "prefix", err) + } + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_DeleteApiKey_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := client.DeleteApiKey(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_DeleteApiKey_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq DeleteApiKeyRequest + metadata runtime.ServerMetadata + err error + ) + val, ok := pathParams["prefix"] + if !ok { + return nil, metadata, status.Errorf(codes.InvalidArgument, "missing parameter %s", "prefix") + } + protoReq.Prefix, err = runtime.String(val) + if err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "type mismatch, parameter: %s, error: %v", "prefix", err) + } + if err := req.ParseForm(); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if err := runtime.PopulateQueryParameters(&protoReq, req.Form, filter_HeadscaleService_DeleteApiKey_0); err != nil { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.DeleteApiKey(ctx, &protoReq) + return msg, metadata, err +} + +func request_HeadscaleService_GetPolicy_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetPolicyRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.GetPolicy(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_GetPolicy_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq GetPolicyRequest + metadata runtime.ServerMetadata + ) + msg, err := server.GetPolicy(ctx, &protoReq) + return msg, metadata, err +} + +func request_HeadscaleService_SetPolicy_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq SetPolicyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.SetPolicy(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_SetPolicy_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq SetPolicyRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.SetPolicy(ctx, &protoReq) + return msg, metadata, err +} + +func request_HeadscaleService_Health_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq HealthRequest + metadata runtime.ServerMetadata + ) + if req.Body != nil { + _, _ = io.Copy(io.Discard, req.Body) + } + msg, err := client.Health(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_Health_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq HealthRequest + metadata runtime.ServerMetadata + ) + msg, err := server.Health(ctx, &protoReq) + return msg, metadata, err } // RegisterHeadscaleServiceHandlerServer registers the http handlers for service HeadscaleService to "mux". // UnaryRPC :call HeadscaleServiceServer directly. // StreamingRPC :currently unsupported pending https://github.com/grpc/grpc-go/issues/906. // Note that using this registration option will cause many gRPC library features to stop working. Consider using RegisterHeadscaleServiceHandlerFromEndpoint instead. +// GRPC interceptors will not work for this type of registration. To use interceptors, you must use the "runtime.WithMiddlewares" option in the "runtime.NewServeMux" call. func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.ServeMux, server HeadscaleServiceServer) error { - - mux.Handle("GET", pattern_HeadscaleService_GetUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetUser", runtime.WithHTTPPathPattern("/api/v1/user/{name}")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_HeadscaleService_GetUser_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_GetUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_CreateUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateUser", runtime.WithHTTPPathPattern("/api/v1/user")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateUser", runtime.WithHTTPPathPattern("/api/v1/user")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1198,20 +930,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_CreateUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_RenameUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_RenameUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameUser", runtime.WithHTTPPathPattern("/api/v1/user/{old_name}/rename/{new_name}")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameUser", runtime.WithHTTPPathPattern("/api/v1/user/{old_id}/rename/{new_name}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1223,20 +950,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_RenameUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("DELETE", pattern_HeadscaleService_DeleteUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeleteUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteUser", runtime.WithHTTPPathPattern("/api/v1/user/{name}")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteUser", runtime.WithHTTPPathPattern("/api/v1/user/{id}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1248,20 +970,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_DeleteUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListUsers_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListUsers_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListUsers", runtime.WithHTTPPathPattern("/api/v1/user")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListUsers", runtime.WithHTTPPathPattern("/api/v1/user")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1273,20 +990,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListUsers_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_CreatePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_CreatePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreatePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreatePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1298,20 +1010,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_CreatePreAuthKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_ExpirePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_ExpirePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpirePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey/expire")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpirePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey/expire")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1323,20 +1030,35 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ExpirePreAuthKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListPreAuthKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeletePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListPreAuthKeys", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeletePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_DeletePreAuthKey_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_DeletePreAuthKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListPreAuthKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListPreAuthKeys", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1348,20 +1070,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListPreAuthKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_DebugCreateNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_DebugCreateNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DebugCreateNode", runtime.WithHTTPPathPattern("/api/v1/debug/node")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DebugCreateNode", runtime.WithHTTPPathPattern("/api/v1/debug/node")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1373,20 +1090,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_DebugCreateNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_GetNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_GetNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1398,20 +1110,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_GetNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_SetTags_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_SetTags_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetTags", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/tags")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetTags", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/tags")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1423,20 +1130,35 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_SetTags_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_RegisterNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_SetApprovedRoutes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RegisterNode", runtime.WithHTTPPathPattern("/api/v1/node/register")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetApprovedRoutes", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/approve_routes")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_SetApprovedRoutes_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_SetApprovedRoutes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_RegisterNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RegisterNode", runtime.WithHTTPPathPattern("/api/v1/node/register")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1448,20 +1170,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_RegisterNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("DELETE", pattern_HeadscaleService_DeleteNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeleteNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1473,20 +1190,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_DeleteNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_ExpireNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_ExpireNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/expire")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/expire")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1498,20 +1210,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ExpireNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_RenameNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_RenameNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/rename/{new_name}")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/rename/{new_name}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1523,20 +1230,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_RenameNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListNodes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListNodes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListNodes", runtime.WithHTTPPathPattern("/api/v1/node")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListNodes", runtime.WithHTTPPathPattern("/api/v1/node")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1548,170 +1250,35 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListNodes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_MoveNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_BackfillNodeIPs_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/MoveNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/user")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/BackfillNodeIPs", runtime.WithHTTPPathPattern("/api/v1/node/backfillips")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := local_request_HeadscaleService_MoveNode_0(annotatedContext, inboundMarshaler, server, req, pathParams) + resp, md, err := local_request_HeadscaleService_BackfillNodeIPs_0(annotatedContext, inboundMarshaler, server, req, pathParams) md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - - forward_HeadscaleService_MoveNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - + forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) - - mux.Handle("GET", pattern_HeadscaleService_GetRoutes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetRoutes", runtime.WithHTTPPathPattern("/api/v1/routes")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_HeadscaleService_GetRoutes_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_GetRoutes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_EnableRoute_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/EnableRoute", runtime.WithHTTPPathPattern("/api/v1/routes/{route_id}/enable")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_HeadscaleService_EnableRoute_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_EnableRoute_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_DisableRoute_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DisableRoute", runtime.WithHTTPPathPattern("/api/v1/routes/{route_id}/disable")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_HeadscaleService_DisableRoute_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_DisableRoute_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("GET", pattern_HeadscaleService_GetNodeRoutes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetNodeRoutes", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/routes")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_HeadscaleService_GetNodeRoutes_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_GetNodeRoutes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("DELETE", pattern_HeadscaleService_DeleteRoute_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteRoute", runtime.WithHTTPPathPattern("/api/v1/routes/{route_id}")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := local_request_HeadscaleService_DeleteRoute_0(annotatedContext, inboundMarshaler, server, req, pathParams) - md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_DeleteRoute_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - var stream runtime.ServerTransportStream - ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1723,20 +1290,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_CreateApiKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_ExpireApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_ExpireApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey/expire")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey/expire")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1748,20 +1310,15 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ExpireApiKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListApiKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListApiKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() var stream runtime.ServerTransportStream ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListApiKeys", runtime.WithHTTPPathPattern("/api/v1/apikey")) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListApiKeys", runtime.WithHTTPPathPattern("/api/v1/apikey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1773,9 +1330,87 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListApiKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - + }) + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeleteApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey/{prefix}")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_DeleteApiKey_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_DeleteApiKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_HeadscaleService_GetPolicy_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetPolicy", runtime.WithHTTPPathPattern("/api/v1/policy")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_GetPolicy_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_GetPolicy_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPut, pattern_HeadscaleService_SetPolicy_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetPolicy", runtime.WithHTTPPathPattern("/api/v1/policy")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_SetPolicy_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_SetPolicy_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_HeadscaleService_Health_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/Health", runtime.WithHTTPPathPattern("/api/v1/health")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_Health_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_Health_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) return nil @@ -1784,25 +1419,24 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser // RegisterHeadscaleServiceHandlerFromEndpoint is same as RegisterHeadscaleServiceHandler but // automatically dials to "endpoint" and closes the connection when "ctx" gets done. func RegisterHeadscaleServiceHandlerFromEndpoint(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) (err error) { - conn, err := grpc.Dial(endpoint, opts...) + conn, err := grpc.NewClient(endpoint, opts...) if err != nil { return err } defer func() { if err != nil { if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) } return } go func() { <-ctx.Done() if cerr := conn.Close(); cerr != nil { - grpclog.Infof("Failed to close conn to %s: %v", endpoint, cerr) + grpclog.Errorf("Failed to close conn to %s: %v", endpoint, cerr) } }() }() - return RegisterHeadscaleServiceHandler(ctx, mux, conn) } @@ -1816,38 +1450,13 @@ func RegisterHeadscaleServiceHandler(ctx context.Context, mux *runtime.ServeMux, // to "mux". The handlers forward requests to the grpc endpoint over the given implementation of "HeadscaleServiceClient". // Note: the gRPC framework executes interceptors within the gRPC handler. If the passed in "HeadscaleServiceClient" // doesn't go through the normal gRPC flow (creating a gRPC client etc.) then it will be up to the passed in -// "HeadscaleServiceClient" to call the correct interceptors. +// "HeadscaleServiceClient" to call the correct interceptors. This client ignores the HTTP middlewares. func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.ServeMux, client HeadscaleServiceClient) error { - - mux.Handle("GET", pattern_HeadscaleService_GetUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetUser", runtime.WithHTTPPathPattern("/api/v1/user/{name}")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_HeadscaleService_GetUser_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_GetUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_CreateUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateUser", runtime.WithHTTPPathPattern("/api/v1/user")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateUser", runtime.WithHTTPPathPattern("/api/v1/user")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1858,18 +1467,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_CreateUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_RenameUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_RenameUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameUser", runtime.WithHTTPPathPattern("/api/v1/user/{old_name}/rename/{new_name}")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameUser", runtime.WithHTTPPathPattern("/api/v1/user/{old_id}/rename/{new_name}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1880,18 +1484,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_RenameUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("DELETE", pattern_HeadscaleService_DeleteUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeleteUser_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteUser", runtime.WithHTTPPathPattern("/api/v1/user/{name}")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteUser", runtime.WithHTTPPathPattern("/api/v1/user/{id}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1902,18 +1501,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_DeleteUser_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListUsers_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListUsers_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListUsers", runtime.WithHTTPPathPattern("/api/v1/user")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListUsers", runtime.WithHTTPPathPattern("/api/v1/user")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1924,18 +1518,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListUsers_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_CreatePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_CreatePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreatePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreatePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1946,18 +1535,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_CreatePreAuthKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_ExpirePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_ExpirePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpirePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey/expire")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpirePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey/expire")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1968,18 +1552,30 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ExpirePreAuthKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListPreAuthKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeletePreAuthKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListPreAuthKeys", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeletePreAuthKey", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_DeletePreAuthKey_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_DeletePreAuthKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListPreAuthKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListPreAuthKeys", runtime.WithHTTPPathPattern("/api/v1/preauthkey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -1990,18 +1586,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListPreAuthKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_DebugCreateNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_DebugCreateNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DebugCreateNode", runtime.WithHTTPPathPattern("/api/v1/debug/node")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DebugCreateNode", runtime.WithHTTPPathPattern("/api/v1/debug/node")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2012,18 +1603,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_DebugCreateNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_GetNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_GetNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2034,18 +1620,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_GetNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_SetTags_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_SetTags_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetTags", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/tags")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetTags", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/tags")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2056,18 +1637,30 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_SetTags_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_RegisterNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_SetApprovedRoutes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RegisterNode", runtime.WithHTTPPathPattern("/api/v1/node/register")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetApprovedRoutes", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/approve_routes")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_SetApprovedRoutes_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_SetApprovedRoutes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_RegisterNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RegisterNode", runtime.WithHTTPPathPattern("/api/v1/node/register")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2078,18 +1671,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_RegisterNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("DELETE", pattern_HeadscaleService_DeleteNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeleteNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2100,18 +1688,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_DeleteNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_ExpireNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_ExpireNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/expire")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/expire")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2122,18 +1705,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ExpireNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_RenameNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_RenameNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/rename/{new_name}")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/RenameNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/rename/{new_name}")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2144,18 +1722,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_RenameNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListNodes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListNodes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListNodes", runtime.WithHTTPPathPattern("/api/v1/node")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListNodes", runtime.WithHTTPPathPattern("/api/v1/node")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2166,150 +1739,30 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListNodes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_MoveNode_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_BackfillNodeIPs_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/MoveNode", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/user")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/BackfillNodeIPs", runtime.WithHTTPPathPattern("/api/v1/node/backfillips")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return } - resp, md, err := request_HeadscaleService_MoveNode_0(annotatedContext, inboundMarshaler, client, req, pathParams) + resp, md, err := request_HeadscaleService_BackfillNodeIPs_0(annotatedContext, inboundMarshaler, client, req, pathParams) annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) if err != nil { runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - - forward_HeadscaleService_MoveNode_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - + forward_HeadscaleService_BackfillNodeIPs_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) - - mux.Handle("GET", pattern_HeadscaleService_GetRoutes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetRoutes", runtime.WithHTTPPathPattern("/api/v1/routes")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_HeadscaleService_GetRoutes_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_GetRoutes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_EnableRoute_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/EnableRoute", runtime.WithHTTPPathPattern("/api/v1/routes/{route_id}/enable")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_HeadscaleService_EnableRoute_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_EnableRoute_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_DisableRoute_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DisableRoute", runtime.WithHTTPPathPattern("/api/v1/routes/{route_id}/disable")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_HeadscaleService_DisableRoute_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_DisableRoute_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("GET", pattern_HeadscaleService_GetNodeRoutes_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetNodeRoutes", runtime.WithHTTPPathPattern("/api/v1/node/{node_id}/routes")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_HeadscaleService_GetNodeRoutes_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_GetNodeRoutes_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("DELETE", pattern_HeadscaleService_DeleteRoute_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteRoute", runtime.WithHTTPPathPattern("/api/v1/routes/{route_id}")) - if err != nil { - runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) - return - } - resp, md, err := request_HeadscaleService_DeleteRoute_0(annotatedContext, inboundMarshaler, client, req, pathParams) - annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) - if err != nil { - runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) - return - } - - forward_HeadscaleService_DeleteRoute_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - - }) - - mux.Handle("POST", pattern_HeadscaleService_CreateApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { - ctx, cancel := context.WithCancel(req.Context()) - defer cancel() - inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/CreateApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2320,18 +1773,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_CreateApiKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("POST", pattern_HeadscaleService_ExpireApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodPost, pattern_HeadscaleService_ExpireApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey/expire")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ExpireApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey/expire")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2342,18 +1790,13 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ExpireApiKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - - mux.Handle("GET", pattern_HeadscaleService_ListApiKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + mux.Handle(http.MethodGet, pattern_HeadscaleService_ListApiKeys_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) - var err error - var annotatedContext context.Context - annotatedContext, err = runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListApiKeys", runtime.WithHTTPPathPattern("/api/v1/apikey")) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/ListApiKeys", runtime.WithHTTPPathPattern("/api/v1/apikey")) if err != nil { runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) return @@ -2364,114 +1807,131 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) return } - forward_HeadscaleService_ListApiKeys_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) - }) - + mux.Handle(http.MethodDelete, pattern_HeadscaleService_DeleteApiKey_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/DeleteApiKey", runtime.WithHTTPPathPattern("/api/v1/apikey/{prefix}")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_DeleteApiKey_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_DeleteApiKey_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_HeadscaleService_GetPolicy_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/GetPolicy", runtime.WithHTTPPathPattern("/api/v1/policy")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_GetPolicy_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_GetPolicy_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodPut, pattern_HeadscaleService_SetPolicy_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/SetPolicy", runtime.WithHTTPPathPattern("/api/v1/policy")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_SetPolicy_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_SetPolicy_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) + mux.Handle(http.MethodGet, pattern_HeadscaleService_Health_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/Health", runtime.WithHTTPPathPattern("/api/v1/health")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_Health_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_Health_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) return nil } var ( - pattern_HeadscaleService_GetUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "user", "name"}, "")) - - pattern_HeadscaleService_CreateUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "user"}, "")) - - pattern_HeadscaleService_RenameUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "user", "old_name", "rename", "new_name"}, "")) - - pattern_HeadscaleService_DeleteUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "user", "name"}, "")) - - pattern_HeadscaleService_ListUsers_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "user"}, "")) - - pattern_HeadscaleService_CreatePreAuthKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "preauthkey"}, "")) - - pattern_HeadscaleService_ExpirePreAuthKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "preauthkey", "expire"}, "")) - - pattern_HeadscaleService_ListPreAuthKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "preauthkey"}, "")) - - pattern_HeadscaleService_DebugCreateNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "debug", "node"}, "")) - - pattern_HeadscaleService_GetNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "node", "node_id"}, "")) - - pattern_HeadscaleService_SetTags_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "tags"}, "")) - - pattern_HeadscaleService_RegisterNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "register"}, "")) - - pattern_HeadscaleService_DeleteNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "node", "node_id"}, "")) - - pattern_HeadscaleService_ExpireNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "expire"}, "")) - - pattern_HeadscaleService_RenameNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "node", "node_id", "rename", "new_name"}, "")) - - pattern_HeadscaleService_ListNodes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "node"}, "")) - - pattern_HeadscaleService_MoveNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "user"}, "")) - - pattern_HeadscaleService_GetRoutes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "routes"}, "")) - - pattern_HeadscaleService_EnableRoute_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "routes", "route_id", "enable"}, "")) - - pattern_HeadscaleService_DisableRoute_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "routes", "route_id", "disable"}, "")) - - pattern_HeadscaleService_GetNodeRoutes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "routes"}, "")) - - pattern_HeadscaleService_DeleteRoute_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "routes", "route_id"}, "")) - - pattern_HeadscaleService_CreateApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) - - pattern_HeadscaleService_ExpireApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "apikey", "expire"}, "")) - - pattern_HeadscaleService_ListApiKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) + pattern_HeadscaleService_CreateUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "user"}, "")) + pattern_HeadscaleService_RenameUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "user", "old_id", "rename", "new_name"}, "")) + pattern_HeadscaleService_DeleteUser_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "user", "id"}, "")) + pattern_HeadscaleService_ListUsers_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "user"}, "")) + pattern_HeadscaleService_CreatePreAuthKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "preauthkey"}, "")) + pattern_HeadscaleService_ExpirePreAuthKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "preauthkey", "expire"}, "")) + pattern_HeadscaleService_DeletePreAuthKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "preauthkey"}, "")) + pattern_HeadscaleService_ListPreAuthKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "preauthkey"}, "")) + pattern_HeadscaleService_DebugCreateNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "debug", "node"}, "")) + pattern_HeadscaleService_GetNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "node", "node_id"}, "")) + pattern_HeadscaleService_SetTags_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "tags"}, "")) + pattern_HeadscaleService_SetApprovedRoutes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "approve_routes"}, "")) + pattern_HeadscaleService_RegisterNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "register"}, "")) + pattern_HeadscaleService_DeleteNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "node", "node_id"}, "")) + pattern_HeadscaleService_ExpireNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4}, []string{"api", "v1", "node", "node_id", "expire"}, "")) + pattern_HeadscaleService_RenameNode_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3, 2, 4, 1, 0, 4, 1, 5, 5}, []string{"api", "v1", "node", "node_id", "rename", "new_name"}, "")) + pattern_HeadscaleService_ListNodes_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "node"}, "")) + pattern_HeadscaleService_BackfillNodeIPs_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "node", "backfillips"}, "")) + pattern_HeadscaleService_CreateApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) + pattern_HeadscaleService_ExpireApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "apikey", "expire"}, "")) + pattern_HeadscaleService_ListApiKeys_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "apikey"}, "")) + pattern_HeadscaleService_DeleteApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "apikey", "prefix"}, "")) + pattern_HeadscaleService_GetPolicy_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "policy"}, "")) + pattern_HeadscaleService_SetPolicy_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "policy"}, "")) + pattern_HeadscaleService_Health_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "health"}, "")) ) var ( - forward_HeadscaleService_GetUser_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_CreateUser_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_RenameUser_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_DeleteUser_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ListUsers_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_CreatePreAuthKey_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ExpirePreAuthKey_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ListPreAuthKeys_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_DebugCreateNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_GetNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_SetTags_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_RegisterNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_DeleteNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ExpireNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_RenameNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ListNodes_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_MoveNode_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_GetRoutes_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_EnableRoute_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_DisableRoute_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_GetNodeRoutes_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_DeleteRoute_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_CreateApiKey_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ExpireApiKey_0 = runtime.ForwardResponseMessage - - forward_HeadscaleService_ListApiKeys_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_CreateUser_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_RenameUser_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_DeleteUser_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ListUsers_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_CreatePreAuthKey_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ExpirePreAuthKey_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_DeletePreAuthKey_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ListPreAuthKeys_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_DebugCreateNode_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_GetNode_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_SetTags_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_SetApprovedRoutes_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_RegisterNode_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_DeleteNode_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ExpireNode_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_RenameNode_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ListNodes_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_BackfillNodeIPs_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_CreateApiKey_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ExpireApiKey_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_ListApiKeys_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_DeleteApiKey_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_GetPolicy_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_SetPolicy_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_Health_0 = runtime.ForwardResponseMessage ) diff --git a/gen/go/headscale/v1/headscale_grpc.pb.go b/gen/go/headscale/v1/headscale_grpc.pb.go index fab8f522..a3963935 100644 --- a/gen/go/headscale/v1/headscale_grpc.pb.go +++ b/gen/go/headscale/v1/headscale_grpc.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.3.0 +// - protoc-gen-go-grpc v1.6.0 // - protoc (unknown) // source: headscale/v1/headscale.proto @@ -15,35 +15,35 @@ import ( // This is a compile-time assertion to ensure that this generated file // is compatible with the grpc package it is being compiled against. -// Requires gRPC-Go v1.32.0 or later. -const _ = grpc.SupportPackageIsVersion7 +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 const ( - HeadscaleService_GetUser_FullMethodName = "/headscale.v1.HeadscaleService/GetUser" - HeadscaleService_CreateUser_FullMethodName = "/headscale.v1.HeadscaleService/CreateUser" - HeadscaleService_RenameUser_FullMethodName = "/headscale.v1.HeadscaleService/RenameUser" - HeadscaleService_DeleteUser_FullMethodName = "/headscale.v1.HeadscaleService/DeleteUser" - HeadscaleService_ListUsers_FullMethodName = "/headscale.v1.HeadscaleService/ListUsers" - HeadscaleService_CreatePreAuthKey_FullMethodName = "/headscale.v1.HeadscaleService/CreatePreAuthKey" - HeadscaleService_ExpirePreAuthKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpirePreAuthKey" - HeadscaleService_ListPreAuthKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListPreAuthKeys" - HeadscaleService_DebugCreateNode_FullMethodName = "/headscale.v1.HeadscaleService/DebugCreateNode" - HeadscaleService_GetNode_FullMethodName = "/headscale.v1.HeadscaleService/GetNode" - HeadscaleService_SetTags_FullMethodName = "/headscale.v1.HeadscaleService/SetTags" - HeadscaleService_RegisterNode_FullMethodName = "/headscale.v1.HeadscaleService/RegisterNode" - HeadscaleService_DeleteNode_FullMethodName = "/headscale.v1.HeadscaleService/DeleteNode" - HeadscaleService_ExpireNode_FullMethodName = "/headscale.v1.HeadscaleService/ExpireNode" - HeadscaleService_RenameNode_FullMethodName = "/headscale.v1.HeadscaleService/RenameNode" - HeadscaleService_ListNodes_FullMethodName = "/headscale.v1.HeadscaleService/ListNodes" - HeadscaleService_MoveNode_FullMethodName = "/headscale.v1.HeadscaleService/MoveNode" - HeadscaleService_GetRoutes_FullMethodName = "/headscale.v1.HeadscaleService/GetRoutes" - HeadscaleService_EnableRoute_FullMethodName = "/headscale.v1.HeadscaleService/EnableRoute" - HeadscaleService_DisableRoute_FullMethodName = "/headscale.v1.HeadscaleService/DisableRoute" - HeadscaleService_GetNodeRoutes_FullMethodName = "/headscale.v1.HeadscaleService/GetNodeRoutes" - HeadscaleService_DeleteRoute_FullMethodName = "/headscale.v1.HeadscaleService/DeleteRoute" - HeadscaleService_CreateApiKey_FullMethodName = "/headscale.v1.HeadscaleService/CreateApiKey" - HeadscaleService_ExpireApiKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpireApiKey" - HeadscaleService_ListApiKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListApiKeys" + HeadscaleService_CreateUser_FullMethodName = "/headscale.v1.HeadscaleService/CreateUser" + HeadscaleService_RenameUser_FullMethodName = "/headscale.v1.HeadscaleService/RenameUser" + HeadscaleService_DeleteUser_FullMethodName = "/headscale.v1.HeadscaleService/DeleteUser" + HeadscaleService_ListUsers_FullMethodName = "/headscale.v1.HeadscaleService/ListUsers" + HeadscaleService_CreatePreAuthKey_FullMethodName = "/headscale.v1.HeadscaleService/CreatePreAuthKey" + HeadscaleService_ExpirePreAuthKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpirePreAuthKey" + HeadscaleService_DeletePreAuthKey_FullMethodName = "/headscale.v1.HeadscaleService/DeletePreAuthKey" + HeadscaleService_ListPreAuthKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListPreAuthKeys" + HeadscaleService_DebugCreateNode_FullMethodName = "/headscale.v1.HeadscaleService/DebugCreateNode" + HeadscaleService_GetNode_FullMethodName = "/headscale.v1.HeadscaleService/GetNode" + HeadscaleService_SetTags_FullMethodName = "/headscale.v1.HeadscaleService/SetTags" + HeadscaleService_SetApprovedRoutes_FullMethodName = "/headscale.v1.HeadscaleService/SetApprovedRoutes" + HeadscaleService_RegisterNode_FullMethodName = "/headscale.v1.HeadscaleService/RegisterNode" + HeadscaleService_DeleteNode_FullMethodName = "/headscale.v1.HeadscaleService/DeleteNode" + HeadscaleService_ExpireNode_FullMethodName = "/headscale.v1.HeadscaleService/ExpireNode" + HeadscaleService_RenameNode_FullMethodName = "/headscale.v1.HeadscaleService/RenameNode" + HeadscaleService_ListNodes_FullMethodName = "/headscale.v1.HeadscaleService/ListNodes" + HeadscaleService_BackfillNodeIPs_FullMethodName = "/headscale.v1.HeadscaleService/BackfillNodeIPs" + HeadscaleService_CreateApiKey_FullMethodName = "/headscale.v1.HeadscaleService/CreateApiKey" + HeadscaleService_ExpireApiKey_FullMethodName = "/headscale.v1.HeadscaleService/ExpireApiKey" + HeadscaleService_ListApiKeys_FullMethodName = "/headscale.v1.HeadscaleService/ListApiKeys" + HeadscaleService_DeleteApiKey_FullMethodName = "/headscale.v1.HeadscaleService/DeleteApiKey" + HeadscaleService_GetPolicy_FullMethodName = "/headscale.v1.HeadscaleService/GetPolicy" + HeadscaleService_SetPolicy_FullMethodName = "/headscale.v1.HeadscaleService/SetPolicy" + HeadscaleService_Health_FullMethodName = "/headscale.v1.HeadscaleService/Health" ) // HeadscaleServiceClient is the client API for HeadscaleService service. @@ -51,7 +51,6 @@ const ( // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type HeadscaleServiceClient interface { // --- User start --- - GetUser(ctx context.Context, in *GetUserRequest, opts ...grpc.CallOption) (*GetUserResponse, error) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) RenameUser(ctx context.Context, in *RenameUserRequest, opts ...grpc.CallOption) (*RenameUserResponse, error) DeleteUser(ctx context.Context, in *DeleteUserRequest, opts ...grpc.CallOption) (*DeleteUserResponse, error) @@ -59,27 +58,29 @@ type HeadscaleServiceClient interface { // --- PreAuthKeys start --- CreatePreAuthKey(ctx context.Context, in *CreatePreAuthKeyRequest, opts ...grpc.CallOption) (*CreatePreAuthKeyResponse, error) ExpirePreAuthKey(ctx context.Context, in *ExpirePreAuthKeyRequest, opts ...grpc.CallOption) (*ExpirePreAuthKeyResponse, error) + DeletePreAuthKey(ctx context.Context, in *DeletePreAuthKeyRequest, opts ...grpc.CallOption) (*DeletePreAuthKeyResponse, error) ListPreAuthKeys(ctx context.Context, in *ListPreAuthKeysRequest, opts ...grpc.CallOption) (*ListPreAuthKeysResponse, error) // --- Node start --- DebugCreateNode(ctx context.Context, in *DebugCreateNodeRequest, opts ...grpc.CallOption) (*DebugCreateNodeResponse, error) GetNode(ctx context.Context, in *GetNodeRequest, opts ...grpc.CallOption) (*GetNodeResponse, error) SetTags(ctx context.Context, in *SetTagsRequest, opts ...grpc.CallOption) (*SetTagsResponse, error) + SetApprovedRoutes(ctx context.Context, in *SetApprovedRoutesRequest, opts ...grpc.CallOption) (*SetApprovedRoutesResponse, error) RegisterNode(ctx context.Context, in *RegisterNodeRequest, opts ...grpc.CallOption) (*RegisterNodeResponse, error) DeleteNode(ctx context.Context, in *DeleteNodeRequest, opts ...grpc.CallOption) (*DeleteNodeResponse, error) ExpireNode(ctx context.Context, in *ExpireNodeRequest, opts ...grpc.CallOption) (*ExpireNodeResponse, error) RenameNode(ctx context.Context, in *RenameNodeRequest, opts ...grpc.CallOption) (*RenameNodeResponse, error) ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error) - MoveNode(ctx context.Context, in *MoveNodeRequest, opts ...grpc.CallOption) (*MoveNodeResponse, error) - // --- Route start --- - GetRoutes(ctx context.Context, in *GetRoutesRequest, opts ...grpc.CallOption) (*GetRoutesResponse, error) - EnableRoute(ctx context.Context, in *EnableRouteRequest, opts ...grpc.CallOption) (*EnableRouteResponse, error) - DisableRoute(ctx context.Context, in *DisableRouteRequest, opts ...grpc.CallOption) (*DisableRouteResponse, error) - GetNodeRoutes(ctx context.Context, in *GetNodeRoutesRequest, opts ...grpc.CallOption) (*GetNodeRoutesResponse, error) - DeleteRoute(ctx context.Context, in *DeleteRouteRequest, opts ...grpc.CallOption) (*DeleteRouteResponse, error) + BackfillNodeIPs(ctx context.Context, in *BackfillNodeIPsRequest, opts ...grpc.CallOption) (*BackfillNodeIPsResponse, error) // --- ApiKeys start --- CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) ExpireApiKey(ctx context.Context, in *ExpireApiKeyRequest, opts ...grpc.CallOption) (*ExpireApiKeyResponse, error) ListApiKeys(ctx context.Context, in *ListApiKeysRequest, opts ...grpc.CallOption) (*ListApiKeysResponse, error) + DeleteApiKey(ctx context.Context, in *DeleteApiKeyRequest, opts ...grpc.CallOption) (*DeleteApiKeyResponse, error) + // --- Policy start --- + GetPolicy(ctx context.Context, in *GetPolicyRequest, opts ...grpc.CallOption) (*GetPolicyResponse, error) + SetPolicy(ctx context.Context, in *SetPolicyRequest, opts ...grpc.CallOption) (*SetPolicyResponse, error) + // --- Health start --- + Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) } type headscaleServiceClient struct { @@ -90,18 +91,10 @@ func NewHeadscaleServiceClient(cc grpc.ClientConnInterface) HeadscaleServiceClie return &headscaleServiceClient{cc} } -func (c *headscaleServiceClient) GetUser(ctx context.Context, in *GetUserRequest, opts ...grpc.CallOption) (*GetUserResponse, error) { - out := new(GetUserResponse) - err := c.cc.Invoke(ctx, HeadscaleService_GetUser_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - func (c *headscaleServiceClient) CreateUser(ctx context.Context, in *CreateUserRequest, opts ...grpc.CallOption) (*CreateUserResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CreateUserResponse) - err := c.cc.Invoke(ctx, HeadscaleService_CreateUser_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_CreateUser_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -109,8 +102,9 @@ func (c *headscaleServiceClient) CreateUser(ctx context.Context, in *CreateUserR } func (c *headscaleServiceClient) RenameUser(ctx context.Context, in *RenameUserRequest, opts ...grpc.CallOption) (*RenameUserResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RenameUserResponse) - err := c.cc.Invoke(ctx, HeadscaleService_RenameUser_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_RenameUser_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -118,8 +112,9 @@ func (c *headscaleServiceClient) RenameUser(ctx context.Context, in *RenameUserR } func (c *headscaleServiceClient) DeleteUser(ctx context.Context, in *DeleteUserRequest, opts ...grpc.CallOption) (*DeleteUserResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteUserResponse) - err := c.cc.Invoke(ctx, HeadscaleService_DeleteUser_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_DeleteUser_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -127,8 +122,9 @@ func (c *headscaleServiceClient) DeleteUser(ctx context.Context, in *DeleteUserR } func (c *headscaleServiceClient) ListUsers(ctx context.Context, in *ListUsersRequest, opts ...grpc.CallOption) (*ListUsersResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListUsersResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ListUsers_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ListUsers_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -136,8 +132,9 @@ func (c *headscaleServiceClient) ListUsers(ctx context.Context, in *ListUsersReq } func (c *headscaleServiceClient) CreatePreAuthKey(ctx context.Context, in *CreatePreAuthKeyRequest, opts ...grpc.CallOption) (*CreatePreAuthKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CreatePreAuthKeyResponse) - err := c.cc.Invoke(ctx, HeadscaleService_CreatePreAuthKey_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_CreatePreAuthKey_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -145,8 +142,19 @@ func (c *headscaleServiceClient) CreatePreAuthKey(ctx context.Context, in *Creat } func (c *headscaleServiceClient) ExpirePreAuthKey(ctx context.Context, in *ExpirePreAuthKeyRequest, opts ...grpc.CallOption) (*ExpirePreAuthKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ExpirePreAuthKeyResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ExpirePreAuthKey_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ExpirePreAuthKey_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) DeletePreAuthKey(ctx context.Context, in *DeletePreAuthKeyRequest, opts ...grpc.CallOption) (*DeletePreAuthKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DeletePreAuthKeyResponse) + err := c.cc.Invoke(ctx, HeadscaleService_DeletePreAuthKey_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -154,8 +162,9 @@ func (c *headscaleServiceClient) ExpirePreAuthKey(ctx context.Context, in *Expir } func (c *headscaleServiceClient) ListPreAuthKeys(ctx context.Context, in *ListPreAuthKeysRequest, opts ...grpc.CallOption) (*ListPreAuthKeysResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListPreAuthKeysResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ListPreAuthKeys_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ListPreAuthKeys_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -163,8 +172,9 @@ func (c *headscaleServiceClient) ListPreAuthKeys(ctx context.Context, in *ListPr } func (c *headscaleServiceClient) DebugCreateNode(ctx context.Context, in *DebugCreateNodeRequest, opts ...grpc.CallOption) (*DebugCreateNodeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DebugCreateNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_DebugCreateNode_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_DebugCreateNode_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -172,8 +182,9 @@ func (c *headscaleServiceClient) DebugCreateNode(ctx context.Context, in *DebugC } func (c *headscaleServiceClient) GetNode(ctx context.Context, in *GetNodeRequest, opts ...grpc.CallOption) (*GetNodeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(GetNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_GetNode_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_GetNode_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -181,8 +192,19 @@ func (c *headscaleServiceClient) GetNode(ctx context.Context, in *GetNodeRequest } func (c *headscaleServiceClient) SetTags(ctx context.Context, in *SetTagsRequest, opts ...grpc.CallOption) (*SetTagsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(SetTagsResponse) - err := c.cc.Invoke(ctx, HeadscaleService_SetTags_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_SetTags_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) SetApprovedRoutes(ctx context.Context, in *SetApprovedRoutesRequest, opts ...grpc.CallOption) (*SetApprovedRoutesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(SetApprovedRoutesResponse) + err := c.cc.Invoke(ctx, HeadscaleService_SetApprovedRoutes_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -190,8 +212,9 @@ func (c *headscaleServiceClient) SetTags(ctx context.Context, in *SetTagsRequest } func (c *headscaleServiceClient) RegisterNode(ctx context.Context, in *RegisterNodeRequest, opts ...grpc.CallOption) (*RegisterNodeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RegisterNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_RegisterNode_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_RegisterNode_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -199,8 +222,9 @@ func (c *headscaleServiceClient) RegisterNode(ctx context.Context, in *RegisterN } func (c *headscaleServiceClient) DeleteNode(ctx context.Context, in *DeleteNodeRequest, opts ...grpc.CallOption) (*DeleteNodeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(DeleteNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_DeleteNode_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_DeleteNode_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -208,8 +232,9 @@ func (c *headscaleServiceClient) DeleteNode(ctx context.Context, in *DeleteNodeR } func (c *headscaleServiceClient) ExpireNode(ctx context.Context, in *ExpireNodeRequest, opts ...grpc.CallOption) (*ExpireNodeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ExpireNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ExpireNode_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ExpireNode_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -217,8 +242,9 @@ func (c *headscaleServiceClient) ExpireNode(ctx context.Context, in *ExpireNodeR } func (c *headscaleServiceClient) RenameNode(ctx context.Context, in *RenameNodeRequest, opts ...grpc.CallOption) (*RenameNodeResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(RenameNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_RenameNode_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_RenameNode_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -226,62 +252,19 @@ func (c *headscaleServiceClient) RenameNode(ctx context.Context, in *RenameNodeR } func (c *headscaleServiceClient) ListNodes(ctx context.Context, in *ListNodesRequest, opts ...grpc.CallOption) (*ListNodesResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListNodesResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ListNodes_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ListNodes_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } return out, nil } -func (c *headscaleServiceClient) MoveNode(ctx context.Context, in *MoveNodeRequest, opts ...grpc.CallOption) (*MoveNodeResponse, error) { - out := new(MoveNodeResponse) - err := c.cc.Invoke(ctx, HeadscaleService_MoveNode_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *headscaleServiceClient) GetRoutes(ctx context.Context, in *GetRoutesRequest, opts ...grpc.CallOption) (*GetRoutesResponse, error) { - out := new(GetRoutesResponse) - err := c.cc.Invoke(ctx, HeadscaleService_GetRoutes_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *headscaleServiceClient) EnableRoute(ctx context.Context, in *EnableRouteRequest, opts ...grpc.CallOption) (*EnableRouteResponse, error) { - out := new(EnableRouteResponse) - err := c.cc.Invoke(ctx, HeadscaleService_EnableRoute_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *headscaleServiceClient) DisableRoute(ctx context.Context, in *DisableRouteRequest, opts ...grpc.CallOption) (*DisableRouteResponse, error) { - out := new(DisableRouteResponse) - err := c.cc.Invoke(ctx, HeadscaleService_DisableRoute_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *headscaleServiceClient) GetNodeRoutes(ctx context.Context, in *GetNodeRoutesRequest, opts ...grpc.CallOption) (*GetNodeRoutesResponse, error) { - out := new(GetNodeRoutesResponse) - err := c.cc.Invoke(ctx, HeadscaleService_GetNodeRoutes_FullMethodName, in, out, opts...) - if err != nil { - return nil, err - } - return out, nil -} - -func (c *headscaleServiceClient) DeleteRoute(ctx context.Context, in *DeleteRouteRequest, opts ...grpc.CallOption) (*DeleteRouteResponse, error) { - out := new(DeleteRouteResponse) - err := c.cc.Invoke(ctx, HeadscaleService_DeleteRoute_FullMethodName, in, out, opts...) +func (c *headscaleServiceClient) BackfillNodeIPs(ctx context.Context, in *BackfillNodeIPsRequest, opts ...grpc.CallOption) (*BackfillNodeIPsResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(BackfillNodeIPsResponse) + err := c.cc.Invoke(ctx, HeadscaleService_BackfillNodeIPs_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -289,8 +272,9 @@ func (c *headscaleServiceClient) DeleteRoute(ctx context.Context, in *DeleteRout } func (c *headscaleServiceClient) CreateApiKey(ctx context.Context, in *CreateApiKeyRequest, opts ...grpc.CallOption) (*CreateApiKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(CreateApiKeyResponse) - err := c.cc.Invoke(ctx, HeadscaleService_CreateApiKey_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_CreateApiKey_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -298,8 +282,9 @@ func (c *headscaleServiceClient) CreateApiKey(ctx context.Context, in *CreateApi } func (c *headscaleServiceClient) ExpireApiKey(ctx context.Context, in *ExpireApiKeyRequest, opts ...grpc.CallOption) (*ExpireApiKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ExpireApiKeyResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ExpireApiKey_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ExpireApiKey_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -307,8 +292,49 @@ func (c *headscaleServiceClient) ExpireApiKey(ctx context.Context, in *ExpireApi } func (c *headscaleServiceClient) ListApiKeys(ctx context.Context, in *ListApiKeysRequest, opts ...grpc.CallOption) (*ListApiKeysResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(ListApiKeysResponse) - err := c.cc.Invoke(ctx, HeadscaleService_ListApiKeys_FullMethodName, in, out, opts...) + err := c.cc.Invoke(ctx, HeadscaleService_ListApiKeys_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) DeleteApiKey(ctx context.Context, in *DeleteApiKeyRequest, opts ...grpc.CallOption) (*DeleteApiKeyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DeleteApiKeyResponse) + err := c.cc.Invoke(ctx, HeadscaleService_DeleteApiKey_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) GetPolicy(ctx context.Context, in *GetPolicyRequest, opts ...grpc.CallOption) (*GetPolicyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(GetPolicyResponse) + err := c.cc.Invoke(ctx, HeadscaleService_GetPolicy_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) SetPolicy(ctx context.Context, in *SetPolicyRequest, opts ...grpc.CallOption) (*SetPolicyResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(SetPolicyResponse) + err := c.cc.Invoke(ctx, HeadscaleService_SetPolicy_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +func (c *headscaleServiceClient) Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(HealthResponse) + err := c.cc.Invoke(ctx, HeadscaleService_Health_FullMethodName, in, out, cOpts...) if err != nil { return nil, err } @@ -317,10 +343,9 @@ func (c *headscaleServiceClient) ListApiKeys(ctx context.Context, in *ListApiKey // HeadscaleServiceServer is the server API for HeadscaleService service. // All implementations must embed UnimplementedHeadscaleServiceServer -// for forward compatibility +// for forward compatibility. type HeadscaleServiceServer interface { // --- User start --- - GetUser(context.Context, *GetUserRequest) (*GetUserResponse, error) CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) RenameUser(context.Context, *RenameUserRequest) (*RenameUserResponse, error) DeleteUser(context.Context, *DeleteUserRequest) (*DeleteUserResponse, error) @@ -328,110 +353,116 @@ type HeadscaleServiceServer interface { // --- PreAuthKeys start --- CreatePreAuthKey(context.Context, *CreatePreAuthKeyRequest) (*CreatePreAuthKeyResponse, error) ExpirePreAuthKey(context.Context, *ExpirePreAuthKeyRequest) (*ExpirePreAuthKeyResponse, error) + DeletePreAuthKey(context.Context, *DeletePreAuthKeyRequest) (*DeletePreAuthKeyResponse, error) ListPreAuthKeys(context.Context, *ListPreAuthKeysRequest) (*ListPreAuthKeysResponse, error) // --- Node start --- DebugCreateNode(context.Context, *DebugCreateNodeRequest) (*DebugCreateNodeResponse, error) GetNode(context.Context, *GetNodeRequest) (*GetNodeResponse, error) SetTags(context.Context, *SetTagsRequest) (*SetTagsResponse, error) + SetApprovedRoutes(context.Context, *SetApprovedRoutesRequest) (*SetApprovedRoutesResponse, error) RegisterNode(context.Context, *RegisterNodeRequest) (*RegisterNodeResponse, error) DeleteNode(context.Context, *DeleteNodeRequest) (*DeleteNodeResponse, error) ExpireNode(context.Context, *ExpireNodeRequest) (*ExpireNodeResponse, error) RenameNode(context.Context, *RenameNodeRequest) (*RenameNodeResponse, error) ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error) - MoveNode(context.Context, *MoveNodeRequest) (*MoveNodeResponse, error) - // --- Route start --- - GetRoutes(context.Context, *GetRoutesRequest) (*GetRoutesResponse, error) - EnableRoute(context.Context, *EnableRouteRequest) (*EnableRouteResponse, error) - DisableRoute(context.Context, *DisableRouteRequest) (*DisableRouteResponse, error) - GetNodeRoutes(context.Context, *GetNodeRoutesRequest) (*GetNodeRoutesResponse, error) - DeleteRoute(context.Context, *DeleteRouteRequest) (*DeleteRouteResponse, error) + BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) // --- ApiKeys start --- CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) ExpireApiKey(context.Context, *ExpireApiKeyRequest) (*ExpireApiKeyResponse, error) ListApiKeys(context.Context, *ListApiKeysRequest) (*ListApiKeysResponse, error) + DeleteApiKey(context.Context, *DeleteApiKeyRequest) (*DeleteApiKeyResponse, error) + // --- Policy start --- + GetPolicy(context.Context, *GetPolicyRequest) (*GetPolicyResponse, error) + SetPolicy(context.Context, *SetPolicyRequest) (*SetPolicyResponse, error) + // --- Health start --- + Health(context.Context, *HealthRequest) (*HealthResponse, error) mustEmbedUnimplementedHeadscaleServiceServer() } -// UnimplementedHeadscaleServiceServer must be embedded to have forward compatible implementations. -type UnimplementedHeadscaleServiceServer struct { -} +// UnimplementedHeadscaleServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedHeadscaleServiceServer struct{} -func (UnimplementedHeadscaleServiceServer) GetUser(context.Context, *GetUserRequest) (*GetUserResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetUser not implemented") -} func (UnimplementedHeadscaleServiceServer) CreateUser(context.Context, *CreateUserRequest) (*CreateUserResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CreateUser not implemented") + return nil, status.Error(codes.Unimplemented, "method CreateUser not implemented") } func (UnimplementedHeadscaleServiceServer) RenameUser(context.Context, *RenameUserRequest) (*RenameUserResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RenameUser not implemented") + return nil, status.Error(codes.Unimplemented, "method RenameUser not implemented") } func (UnimplementedHeadscaleServiceServer) DeleteUser(context.Context, *DeleteUserRequest) (*DeleteUserResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeleteUser not implemented") + return nil, status.Error(codes.Unimplemented, "method DeleteUser not implemented") } func (UnimplementedHeadscaleServiceServer) ListUsers(context.Context, *ListUsersRequest) (*ListUsersResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListUsers not implemented") + return nil, status.Error(codes.Unimplemented, "method ListUsers not implemented") } func (UnimplementedHeadscaleServiceServer) CreatePreAuthKey(context.Context, *CreatePreAuthKeyRequest) (*CreatePreAuthKeyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CreatePreAuthKey not implemented") + return nil, status.Error(codes.Unimplemented, "method CreatePreAuthKey not implemented") } func (UnimplementedHeadscaleServiceServer) ExpirePreAuthKey(context.Context, *ExpirePreAuthKeyRequest) (*ExpirePreAuthKeyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ExpirePreAuthKey not implemented") + return nil, status.Error(codes.Unimplemented, "method ExpirePreAuthKey not implemented") +} +func (UnimplementedHeadscaleServiceServer) DeletePreAuthKey(context.Context, *DeletePreAuthKeyRequest) (*DeletePreAuthKeyResponse, error) { + return nil, status.Error(codes.Unimplemented, "method DeletePreAuthKey not implemented") } func (UnimplementedHeadscaleServiceServer) ListPreAuthKeys(context.Context, *ListPreAuthKeysRequest) (*ListPreAuthKeysResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListPreAuthKeys not implemented") + return nil, status.Error(codes.Unimplemented, "method ListPreAuthKeys not implemented") } func (UnimplementedHeadscaleServiceServer) DebugCreateNode(context.Context, *DebugCreateNodeRequest) (*DebugCreateNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DebugCreateNode not implemented") + return nil, status.Error(codes.Unimplemented, "method DebugCreateNode not implemented") } func (UnimplementedHeadscaleServiceServer) GetNode(context.Context, *GetNodeRequest) (*GetNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetNode not implemented") + return nil, status.Error(codes.Unimplemented, "method GetNode not implemented") } func (UnimplementedHeadscaleServiceServer) SetTags(context.Context, *SetTagsRequest) (*SetTagsResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method SetTags not implemented") + return nil, status.Error(codes.Unimplemented, "method SetTags not implemented") +} +func (UnimplementedHeadscaleServiceServer) SetApprovedRoutes(context.Context, *SetApprovedRoutesRequest) (*SetApprovedRoutesResponse, error) { + return nil, status.Error(codes.Unimplemented, "method SetApprovedRoutes not implemented") } func (UnimplementedHeadscaleServiceServer) RegisterNode(context.Context, *RegisterNodeRequest) (*RegisterNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RegisterNode not implemented") + return nil, status.Error(codes.Unimplemented, "method RegisterNode not implemented") } func (UnimplementedHeadscaleServiceServer) DeleteNode(context.Context, *DeleteNodeRequest) (*DeleteNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeleteNode not implemented") + return nil, status.Error(codes.Unimplemented, "method DeleteNode not implemented") } func (UnimplementedHeadscaleServiceServer) ExpireNode(context.Context, *ExpireNodeRequest) (*ExpireNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ExpireNode not implemented") + return nil, status.Error(codes.Unimplemented, "method ExpireNode not implemented") } func (UnimplementedHeadscaleServiceServer) RenameNode(context.Context, *RenameNodeRequest) (*RenameNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method RenameNode not implemented") + return nil, status.Error(codes.Unimplemented, "method RenameNode not implemented") } func (UnimplementedHeadscaleServiceServer) ListNodes(context.Context, *ListNodesRequest) (*ListNodesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListNodes not implemented") + return nil, status.Error(codes.Unimplemented, "method ListNodes not implemented") } -func (UnimplementedHeadscaleServiceServer) MoveNode(context.Context, *MoveNodeRequest) (*MoveNodeResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method MoveNode not implemented") -} -func (UnimplementedHeadscaleServiceServer) GetRoutes(context.Context, *GetRoutesRequest) (*GetRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetRoutes not implemented") -} -func (UnimplementedHeadscaleServiceServer) EnableRoute(context.Context, *EnableRouteRequest) (*EnableRouteResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method EnableRoute not implemented") -} -func (UnimplementedHeadscaleServiceServer) DisableRoute(context.Context, *DisableRouteRequest) (*DisableRouteResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DisableRoute not implemented") -} -func (UnimplementedHeadscaleServiceServer) GetNodeRoutes(context.Context, *GetNodeRoutesRequest) (*GetNodeRoutesResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method GetNodeRoutes not implemented") -} -func (UnimplementedHeadscaleServiceServer) DeleteRoute(context.Context, *DeleteRouteRequest) (*DeleteRouteResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method DeleteRoute not implemented") +func (UnimplementedHeadscaleServiceServer) BackfillNodeIPs(context.Context, *BackfillNodeIPsRequest) (*BackfillNodeIPsResponse, error) { + return nil, status.Error(codes.Unimplemented, "method BackfillNodeIPs not implemented") } func (UnimplementedHeadscaleServiceServer) CreateApiKey(context.Context, *CreateApiKeyRequest) (*CreateApiKeyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method CreateApiKey not implemented") + return nil, status.Error(codes.Unimplemented, "method CreateApiKey not implemented") } func (UnimplementedHeadscaleServiceServer) ExpireApiKey(context.Context, *ExpireApiKeyRequest) (*ExpireApiKeyResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ExpireApiKey not implemented") + return nil, status.Error(codes.Unimplemented, "method ExpireApiKey not implemented") } func (UnimplementedHeadscaleServiceServer) ListApiKeys(context.Context, *ListApiKeysRequest) (*ListApiKeysResponse, error) { - return nil, status.Errorf(codes.Unimplemented, "method ListApiKeys not implemented") + return nil, status.Error(codes.Unimplemented, "method ListApiKeys not implemented") +} +func (UnimplementedHeadscaleServiceServer) DeleteApiKey(context.Context, *DeleteApiKeyRequest) (*DeleteApiKeyResponse, error) { + return nil, status.Error(codes.Unimplemented, "method DeleteApiKey not implemented") +} +func (UnimplementedHeadscaleServiceServer) GetPolicy(context.Context, *GetPolicyRequest) (*GetPolicyResponse, error) { + return nil, status.Error(codes.Unimplemented, "method GetPolicy not implemented") +} +func (UnimplementedHeadscaleServiceServer) SetPolicy(context.Context, *SetPolicyRequest) (*SetPolicyResponse, error) { + return nil, status.Error(codes.Unimplemented, "method SetPolicy not implemented") +} +func (UnimplementedHeadscaleServiceServer) Health(context.Context, *HealthRequest) (*HealthResponse, error) { + return nil, status.Error(codes.Unimplemented, "method Health not implemented") } func (UnimplementedHeadscaleServiceServer) mustEmbedUnimplementedHeadscaleServiceServer() {} +func (UnimplementedHeadscaleServiceServer) testEmbeddedByValue() {} // UnsafeHeadscaleServiceServer may be embedded to opt out of forward compatibility for this service. // Use of this interface is not recommended, as added methods to HeadscaleServiceServer will @@ -441,27 +472,16 @@ type UnsafeHeadscaleServiceServer interface { } func RegisterHeadscaleServiceServer(s grpc.ServiceRegistrar, srv HeadscaleServiceServer) { + // If the following call panics, it indicates UnimplementedHeadscaleServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } s.RegisterService(&HeadscaleService_ServiceDesc, srv) } -func _HeadscaleService_GetUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetUserRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(HeadscaleServiceServer).GetUser(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: HeadscaleService_GetUser_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).GetUser(ctx, req.(*GetUserRequest)) - } - return interceptor(ctx, in, info, handler) -} - func _HeadscaleService_CreateUser_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(CreateUserRequest) if err := dec(in); err != nil { @@ -570,6 +590,24 @@ func _HeadscaleService_ExpirePreAuthKey_Handler(srv interface{}, ctx context.Con return interceptor(ctx, in, info, handler) } +func _HeadscaleService_DeletePreAuthKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeletePreAuthKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).DeletePreAuthKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_DeletePreAuthKey_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).DeletePreAuthKey(ctx, req.(*DeletePreAuthKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _HeadscaleService_ListPreAuthKeys_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(ListPreAuthKeysRequest) if err := dec(in); err != nil { @@ -642,6 +680,24 @@ func _HeadscaleService_SetTags_Handler(srv interface{}, ctx context.Context, dec return interceptor(ctx, in, info, handler) } +func _HeadscaleService_SetApprovedRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetApprovedRoutesRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).SetApprovedRoutes(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_SetApprovedRoutes_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).SetApprovedRoutes(ctx, req.(*SetApprovedRoutesRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _HeadscaleService_RegisterNode_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(RegisterNodeRequest) if err := dec(in); err != nil { @@ -732,110 +788,20 @@ func _HeadscaleService_ListNodes_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } -func _HeadscaleService_MoveNode_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(MoveNodeRequest) +func _HeadscaleService_BackfillNodeIPs_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(BackfillNodeIPsRequest) if err := dec(in); err != nil { return nil, err } if interceptor == nil { - return srv.(HeadscaleServiceServer).MoveNode(ctx, in) + return srv.(HeadscaleServiceServer).BackfillNodeIPs(ctx, in) } info := &grpc.UnaryServerInfo{ Server: srv, - FullMethod: HeadscaleService_MoveNode_FullMethodName, + FullMethod: HeadscaleService_BackfillNodeIPs_FullMethodName, } handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).MoveNode(ctx, req.(*MoveNodeRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _HeadscaleService_GetRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetRoutesRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(HeadscaleServiceServer).GetRoutes(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: HeadscaleService_GetRoutes_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).GetRoutes(ctx, req.(*GetRoutesRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _HeadscaleService_EnableRoute_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(EnableRouteRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(HeadscaleServiceServer).EnableRoute(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: HeadscaleService_EnableRoute_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).EnableRoute(ctx, req.(*EnableRouteRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _HeadscaleService_DisableRoute_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(DisableRouteRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(HeadscaleServiceServer).DisableRoute(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: HeadscaleService_DisableRoute_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).DisableRoute(ctx, req.(*DisableRouteRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _HeadscaleService_GetNodeRoutes_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(GetNodeRoutesRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(HeadscaleServiceServer).GetNodeRoutes(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: HeadscaleService_GetNodeRoutes_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).GetNodeRoutes(ctx, req.(*GetNodeRoutesRequest)) - } - return interceptor(ctx, in, info, handler) -} - -func _HeadscaleService_DeleteRoute_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { - in := new(DeleteRouteRequest) - if err := dec(in); err != nil { - return nil, err - } - if interceptor == nil { - return srv.(HeadscaleServiceServer).DeleteRoute(ctx, in) - } - info := &grpc.UnaryServerInfo{ - Server: srv, - FullMethod: HeadscaleService_DeleteRoute_FullMethodName, - } - handler := func(ctx context.Context, req interface{}) (interface{}, error) { - return srv.(HeadscaleServiceServer).DeleteRoute(ctx, req.(*DeleteRouteRequest)) + return srv.(HeadscaleServiceServer).BackfillNodeIPs(ctx, req.(*BackfillNodeIPsRequest)) } return interceptor(ctx, in, info, handler) } @@ -894,6 +860,78 @@ func _HeadscaleService_ListApiKeys_Handler(srv interface{}, ctx context.Context, return interceptor(ctx, in, info, handler) } +func _HeadscaleService_DeleteApiKey_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DeleteApiKeyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).DeleteApiKey(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_DeleteApiKey_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).DeleteApiKey(ctx, req.(*DeleteApiKeyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _HeadscaleService_GetPolicy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(GetPolicyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).GetPolicy(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_GetPolicy_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).GetPolicy(ctx, req.(*GetPolicyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _HeadscaleService_SetPolicy_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetPolicyRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).SetPolicy(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_SetPolicy_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).SetPolicy(ctx, req.(*SetPolicyRequest)) + } + return interceptor(ctx, in, info, handler) +} + +func _HeadscaleService_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(HealthRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).Health(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_Health_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).Health(ctx, req.(*HealthRequest)) + } + return interceptor(ctx, in, info, handler) +} + // HeadscaleService_ServiceDesc is the grpc.ServiceDesc for HeadscaleService service. // It's only intended for direct use with grpc.RegisterService, // and not to be introspected or modified (even as a copy) @@ -901,10 +939,6 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ ServiceName: "headscale.v1.HeadscaleService", HandlerType: (*HeadscaleServiceServer)(nil), Methods: []grpc.MethodDesc{ - { - MethodName: "GetUser", - Handler: _HeadscaleService_GetUser_Handler, - }, { MethodName: "CreateUser", Handler: _HeadscaleService_CreateUser_Handler, @@ -929,6 +963,10 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ExpirePreAuthKey", Handler: _HeadscaleService_ExpirePreAuthKey_Handler, }, + { + MethodName: "DeletePreAuthKey", + Handler: _HeadscaleService_DeletePreAuthKey_Handler, + }, { MethodName: "ListPreAuthKeys", Handler: _HeadscaleService_ListPreAuthKeys_Handler, @@ -945,6 +983,10 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetTags", Handler: _HeadscaleService_SetTags_Handler, }, + { + MethodName: "SetApprovedRoutes", + Handler: _HeadscaleService_SetApprovedRoutes_Handler, + }, { MethodName: "RegisterNode", Handler: _HeadscaleService_RegisterNode_Handler, @@ -966,28 +1008,8 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ Handler: _HeadscaleService_ListNodes_Handler, }, { - MethodName: "MoveNode", - Handler: _HeadscaleService_MoveNode_Handler, - }, - { - MethodName: "GetRoutes", - Handler: _HeadscaleService_GetRoutes_Handler, - }, - { - MethodName: "EnableRoute", - Handler: _HeadscaleService_EnableRoute_Handler, - }, - { - MethodName: "DisableRoute", - Handler: _HeadscaleService_DisableRoute_Handler, - }, - { - MethodName: "GetNodeRoutes", - Handler: _HeadscaleService_GetNodeRoutes_Handler, - }, - { - MethodName: "DeleteRoute", - Handler: _HeadscaleService_DeleteRoute_Handler, + MethodName: "BackfillNodeIPs", + Handler: _HeadscaleService_BackfillNodeIPs_Handler, }, { MethodName: "CreateApiKey", @@ -1001,6 +1023,22 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ MethodName: "ListApiKeys", Handler: _HeadscaleService_ListApiKeys_Handler, }, + { + MethodName: "DeleteApiKey", + Handler: _HeadscaleService_DeleteApiKey_Handler, + }, + { + MethodName: "GetPolicy", + Handler: _HeadscaleService_GetPolicy_Handler, + }, + { + MethodName: "SetPolicy", + Handler: _HeadscaleService_SetPolicy_Handler, + }, + { + MethodName: "Health", + Handler: _HeadscaleService_Health_Handler, + }, }, Streams: []grpc.StreamDesc{}, Metadata: "headscale/v1/headscale.proto", diff --git a/gen/go/headscale/v1/node.pb.go b/gen/go/headscale/v1/node.pb.go index e567d3ca..b4b7e8f6 100644 --- a/gen/go/headscale/v1/node.pb.go +++ b/gen/go/headscale/v1/node.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: headscale/v1/node.proto @@ -12,6 +12,7 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -74,37 +75,38 @@ func (RegisterMethod) EnumDescriptor() ([]byte, []int) { } type Node struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` - MachineKey string `protobuf:"bytes,2,opt,name=machine_key,json=machineKey,proto3" json:"machine_key,omitempty"` - NodeKey string `protobuf:"bytes,3,opt,name=node_key,json=nodeKey,proto3" json:"node_key,omitempty"` - DiscoKey string `protobuf:"bytes,4,opt,name=disco_key,json=discoKey,proto3" json:"disco_key,omitempty"` - IpAddresses []string `protobuf:"bytes,5,rep,name=ip_addresses,json=ipAddresses,proto3" json:"ip_addresses,omitempty"` - Name string `protobuf:"bytes,6,opt,name=name,proto3" json:"name,omitempty"` - User *User `protobuf:"bytes,7,opt,name=user,proto3" json:"user,omitempty"` - LastSeen *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` - LastSuccessfulUpdate *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=last_successful_update,json=lastSuccessfulUpdate,proto3" json:"last_successful_update,omitempty"` - Expiry *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=expiry,proto3" json:"expiry,omitempty"` - PreAuthKey *PreAuthKey `protobuf:"bytes,11,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` - RegisterMethod RegisterMethod `protobuf:"varint,13,opt,name=register_method,json=registerMethod,proto3,enum=headscale.v1.RegisterMethod" json:"register_method,omitempty"` - ForcedTags []string `protobuf:"bytes,18,rep,name=forced_tags,json=forcedTags,proto3" json:"forced_tags,omitempty"` - InvalidTags []string `protobuf:"bytes,19,rep,name=invalid_tags,json=invalidTags,proto3" json:"invalid_tags,omitempty"` - ValidTags []string `protobuf:"bytes,20,rep,name=valid_tags,json=validTags,proto3" json:"valid_tags,omitempty"` - GivenName string `protobuf:"bytes,21,opt,name=given_name,json=givenName,proto3" json:"given_name,omitempty"` - Online bool `protobuf:"varint,22,opt,name=online,proto3" json:"online,omitempty"` + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + MachineKey string `protobuf:"bytes,2,opt,name=machine_key,json=machineKey,proto3" json:"machine_key,omitempty"` + NodeKey string `protobuf:"bytes,3,opt,name=node_key,json=nodeKey,proto3" json:"node_key,omitempty"` + DiscoKey string `protobuf:"bytes,4,opt,name=disco_key,json=discoKey,proto3" json:"disco_key,omitempty"` + IpAddresses []string `protobuf:"bytes,5,rep,name=ip_addresses,json=ipAddresses,proto3" json:"ip_addresses,omitempty"` + Name string `protobuf:"bytes,6,opt,name=name,proto3" json:"name,omitempty"` + User *User `protobuf:"bytes,7,opt,name=user,proto3" json:"user,omitempty"` + LastSeen *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=last_seen,json=lastSeen,proto3" json:"last_seen,omitempty"` + Expiry *timestamppb.Timestamp `protobuf:"bytes,10,opt,name=expiry,proto3" json:"expiry,omitempty"` + PreAuthKey *PreAuthKey `protobuf:"bytes,11,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` + CreatedAt *timestamppb.Timestamp `protobuf:"bytes,12,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + RegisterMethod RegisterMethod `protobuf:"varint,13,opt,name=register_method,json=registerMethod,proto3,enum=headscale.v1.RegisterMethod" json:"register_method,omitempty"` + // Deprecated + // repeated string forced_tags = 18; + // repeated string invalid_tags = 19; + // repeated string valid_tags = 20; + GivenName string `protobuf:"bytes,21,opt,name=given_name,json=givenName,proto3" json:"given_name,omitempty"` + Online bool `protobuf:"varint,22,opt,name=online,proto3" json:"online,omitempty"` + ApprovedRoutes []string `protobuf:"bytes,23,rep,name=approved_routes,json=approvedRoutes,proto3" json:"approved_routes,omitempty"` + AvailableRoutes []string `protobuf:"bytes,24,rep,name=available_routes,json=availableRoutes,proto3" json:"available_routes,omitempty"` + SubnetRoutes []string `protobuf:"bytes,25,rep,name=subnet_routes,json=subnetRoutes,proto3" json:"subnet_routes,omitempty"` + Tags []string `protobuf:"bytes,26,rep,name=tags,proto3" json:"tags,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *Node) Reset() { *x = Node{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *Node) String() string { @@ -115,7 +117,7 @@ func (*Node) ProtoMessage() {} func (x *Node) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -186,13 +188,6 @@ func (x *Node) GetLastSeen() *timestamppb.Timestamp { return nil } -func (x *Node) GetLastSuccessfulUpdate() *timestamppb.Timestamp { - if x != nil { - return x.LastSuccessfulUpdate - } - return nil -} - func (x *Node) GetExpiry() *timestamppb.Timestamp { if x != nil { return x.Expiry @@ -221,27 +216,6 @@ func (x *Node) GetRegisterMethod() RegisterMethod { return RegisterMethod_REGISTER_METHOD_UNSPECIFIED } -func (x *Node) GetForcedTags() []string { - if x != nil { - return x.ForcedTags - } - return nil -} - -func (x *Node) GetInvalidTags() []string { - if x != nil { - return x.InvalidTags - } - return nil -} - -func (x *Node) GetValidTags() []string { - if x != nil { - return x.ValidTags - } - return nil -} - func (x *Node) GetGivenName() string { if x != nil { return x.GivenName @@ -256,22 +230,47 @@ func (x *Node) GetOnline() bool { return false } -type RegisterNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields +func (x *Node) GetApprovedRoutes() []string { + if x != nil { + return x.ApprovedRoutes + } + return nil +} - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` - Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` +func (x *Node) GetAvailableRoutes() []string { + if x != nil { + return x.AvailableRoutes + } + return nil +} + +func (x *Node) GetSubnetRoutes() []string { + if x != nil { + return x.SubnetRoutes + } + return nil +} + +func (x *Node) GetTags() []string { + if x != nil { + return x.Tags + } + return nil +} + +type RegisterNodeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *RegisterNodeRequest) Reset() { *x = RegisterNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *RegisterNodeRequest) String() string { @@ -282,7 +281,7 @@ func (*RegisterNodeRequest) ProtoMessage() {} func (x *RegisterNodeRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -312,20 +311,17 @@ func (x *RegisterNodeRequest) GetKey() string { } type RegisterNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` unknownFields protoimpl.UnknownFields - - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + sizeCache protoimpl.SizeCache } func (x *RegisterNodeResponse) Reset() { *x = RegisterNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *RegisterNodeResponse) String() string { @@ -336,7 +332,7 @@ func (*RegisterNodeResponse) ProtoMessage() {} func (x *RegisterNodeResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -359,20 +355,17 @@ func (x *RegisterNodeResponse) GetNode() *Node { } type GetNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` unknownFields protoimpl.UnknownFields - - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetNodeRequest) Reset() { *x = GetNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetNodeRequest) String() string { @@ -383,7 +376,7 @@ func (*GetNodeRequest) ProtoMessage() {} func (x *GetNodeRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -406,20 +399,17 @@ func (x *GetNodeRequest) GetNodeId() uint64 { } type GetNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` unknownFields protoimpl.UnknownFields - - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + sizeCache protoimpl.SizeCache } func (x *GetNodeResponse) Reset() { *x = GetNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *GetNodeResponse) String() string { @@ -430,7 +420,7 @@ func (*GetNodeResponse) ProtoMessage() {} func (x *GetNodeResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -453,21 +443,18 @@ func (x *GetNodeResponse) GetNode() *Node { } type SetTagsRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + Tags []string `protobuf:"bytes,2,rep,name=tags,proto3" json:"tags,omitempty"` unknownFields protoimpl.UnknownFields - - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` - Tags []string `protobuf:"bytes,2,rep,name=tags,proto3" json:"tags,omitempty"` + sizeCache protoimpl.SizeCache } func (x *SetTagsRequest) Reset() { *x = SetTagsRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *SetTagsRequest) String() string { @@ -478,7 +465,7 @@ func (*SetTagsRequest) ProtoMessage() {} func (x *SetTagsRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -508,20 +495,17 @@ func (x *SetTagsRequest) GetTags() []string { } type SetTagsResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` unknownFields protoimpl.UnknownFields - - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + sizeCache protoimpl.SizeCache } func (x *SetTagsResponse) Reset() { *x = SetTagsResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *SetTagsResponse) String() string { @@ -532,7 +516,7 @@ func (*SetTagsResponse) ProtoMessage() {} func (x *SetTagsResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -554,21 +538,114 @@ func (x *SetTagsResponse) GetNode() *Node { return nil } -type DeleteNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +type SetApprovedRoutesRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + Routes []string `protobuf:"bytes,2,rep,name=routes,proto3" json:"routes,omitempty"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` +func (x *SetApprovedRoutesRequest) Reset() { + *x = SetApprovedRoutesRequest{} + mi := &file_headscale_v1_node_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetApprovedRoutesRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetApprovedRoutesRequest) ProtoMessage() {} + +func (x *SetApprovedRoutesRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_node_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetApprovedRoutesRequest.ProtoReflect.Descriptor instead. +func (*SetApprovedRoutesRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_node_proto_rawDescGZIP(), []int{7} +} + +func (x *SetApprovedRoutesRequest) GetNodeId() uint64 { + if x != nil { + return x.NodeId + } + return 0 +} + +func (x *SetApprovedRoutesRequest) GetRoutes() []string { + if x != nil { + return x.Routes + } + return nil +} + +type SetApprovedRoutesResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetApprovedRoutesResponse) Reset() { + *x = SetApprovedRoutesResponse{} + mi := &file_headscale_v1_node_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetApprovedRoutesResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetApprovedRoutesResponse) ProtoMessage() {} + +func (x *SetApprovedRoutesResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_node_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetApprovedRoutesResponse.ProtoReflect.Descriptor instead. +func (*SetApprovedRoutesResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_node_proto_rawDescGZIP(), []int{8} +} + +func (x *SetApprovedRoutesResponse) GetNode() *Node { + if x != nil { + return x.Node + } + return nil +} + +type DeleteNodeRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *DeleteNodeRequest) Reset() { *x = DeleteNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DeleteNodeRequest) String() string { @@ -578,8 +655,8 @@ func (x *DeleteNodeRequest) String() string { func (*DeleteNodeRequest) ProtoMessage() {} func (x *DeleteNodeRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[7] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[9] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -591,7 +668,7 @@ func (x *DeleteNodeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteNodeRequest.ProtoReflect.Descriptor instead. func (*DeleteNodeRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{7} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{9} } func (x *DeleteNodeRequest) GetNodeId() uint64 { @@ -602,18 +679,16 @@ func (x *DeleteNodeRequest) GetNodeId() uint64 { } type DeleteNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *DeleteNodeResponse) Reset() { *x = DeleteNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[10] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DeleteNodeResponse) String() string { @@ -623,8 +698,8 @@ func (x *DeleteNodeResponse) String() string { func (*DeleteNodeResponse) ProtoMessage() {} func (x *DeleteNodeResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[8] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[10] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -636,24 +711,22 @@ func (x *DeleteNodeResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteNodeResponse.ProtoReflect.Descriptor instead. func (*DeleteNodeResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{8} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{10} } type ExpireNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + Expiry *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=expiry,proto3" json:"expiry,omitempty"` unknownFields protoimpl.UnknownFields - - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ExpireNodeRequest) Reset() { *x = ExpireNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[11] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExpireNodeRequest) String() string { @@ -663,8 +736,8 @@ func (x *ExpireNodeRequest) String() string { func (*ExpireNodeRequest) ProtoMessage() {} func (x *ExpireNodeRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[9] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[11] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -676,7 +749,7 @@ func (x *ExpireNodeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ExpireNodeRequest.ProtoReflect.Descriptor instead. func (*ExpireNodeRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{9} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{11} } func (x *ExpireNodeRequest) GetNodeId() uint64 { @@ -686,21 +759,25 @@ func (x *ExpireNodeRequest) GetNodeId() uint64 { return 0 } -type ExpireNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields +func (x *ExpireNodeRequest) GetExpiry() *timestamppb.Timestamp { + if x != nil { + return x.Expiry + } + return nil +} - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` +type ExpireNodeResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExpireNodeResponse) Reset() { *x = ExpireNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[12] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExpireNodeResponse) String() string { @@ -710,8 +787,8 @@ func (x *ExpireNodeResponse) String() string { func (*ExpireNodeResponse) ProtoMessage() {} func (x *ExpireNodeResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[10] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[12] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -723,7 +800,7 @@ func (x *ExpireNodeResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ExpireNodeResponse.ProtoReflect.Descriptor instead. func (*ExpireNodeResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{10} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{12} } func (x *ExpireNodeResponse) GetNode() *Node { @@ -734,21 +811,18 @@ func (x *ExpireNodeResponse) GetNode() *Node { } type RenameNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` + NewName string `protobuf:"bytes,2,opt,name=new_name,json=newName,proto3" json:"new_name,omitempty"` unknownFields protoimpl.UnknownFields - - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` - NewName string `protobuf:"bytes,2,opt,name=new_name,json=newName,proto3" json:"new_name,omitempty"` + sizeCache protoimpl.SizeCache } func (x *RenameNodeRequest) Reset() { *x = RenameNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[11] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[13] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *RenameNodeRequest) String() string { @@ -758,8 +832,8 @@ func (x *RenameNodeRequest) String() string { func (*RenameNodeRequest) ProtoMessage() {} func (x *RenameNodeRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[11] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[13] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -771,7 +845,7 @@ func (x *RenameNodeRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RenameNodeRequest.ProtoReflect.Descriptor instead. func (*RenameNodeRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{11} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{13} } func (x *RenameNodeRequest) GetNodeId() uint64 { @@ -789,20 +863,17 @@ func (x *RenameNodeRequest) GetNewName() string { } type RenameNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` unknownFields protoimpl.UnknownFields - - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + sizeCache protoimpl.SizeCache } func (x *RenameNodeResponse) Reset() { *x = RenameNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[12] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[14] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *RenameNodeResponse) String() string { @@ -812,8 +883,8 @@ func (x *RenameNodeResponse) String() string { func (*RenameNodeResponse) ProtoMessage() {} func (x *RenameNodeResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[12] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[14] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -825,7 +896,7 @@ func (x *RenameNodeResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RenameNodeResponse.ProtoReflect.Descriptor instead. func (*RenameNodeResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{12} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{14} } func (x *RenameNodeResponse) GetNode() *Node { @@ -836,20 +907,17 @@ func (x *RenameNodeResponse) GetNode() *Node { } type ListNodesRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` unknownFields protoimpl.UnknownFields - - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ListNodesRequest) Reset() { *x = ListNodesRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[13] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[15] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListNodesRequest) String() string { @@ -859,8 +927,8 @@ func (x *ListNodesRequest) String() string { func (*ListNodesRequest) ProtoMessage() {} func (x *ListNodesRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[13] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[15] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -872,7 +940,7 @@ func (x *ListNodesRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNodesRequest.ProtoReflect.Descriptor instead. func (*ListNodesRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{13} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{15} } func (x *ListNodesRequest) GetUser() string { @@ -883,20 +951,17 @@ func (x *ListNodesRequest) GetUser() string { } type ListNodesResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Nodes []*Node `protobuf:"bytes,1,rep,name=nodes,proto3" json:"nodes,omitempty"` unknownFields protoimpl.UnknownFields - - Nodes []*Node `protobuf:"bytes,1,rep,name=nodes,proto3" json:"nodes,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ListNodesResponse) Reset() { *x = ListNodesResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[14] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[16] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListNodesResponse) String() string { @@ -906,8 +971,8 @@ func (x *ListNodesResponse) String() string { func (*ListNodesResponse) ProtoMessage() {} func (x *ListNodesResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[14] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_node_proto_msgTypes[16] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -919,7 +984,7 @@ func (x *ListNodesResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListNodesResponse.ProtoReflect.Descriptor instead. func (*ListNodesResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{14} + return file_headscale_v1_node_proto_rawDescGZIP(), []int{16} } func (x *ListNodesResponse) GetNodes() []*Node { @@ -929,126 +994,21 @@ func (x *ListNodesResponse) GetNodes() []*Node { return nil } -type MoveNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` - User string `protobuf:"bytes,2,opt,name=user,proto3" json:"user,omitempty"` -} - -func (x *MoveNodeRequest) Reset() { - *x = MoveNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[15] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MoveNodeRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MoveNodeRequest) ProtoMessage() {} - -func (x *MoveNodeRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[15] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MoveNodeRequest.ProtoReflect.Descriptor instead. -func (*MoveNodeRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{15} -} - -func (x *MoveNodeRequest) GetNodeId() uint64 { - if x != nil { - return x.NodeId - } - return 0 -} - -func (x *MoveNodeRequest) GetUser() string { - if x != nil { - return x.User - } - return "" -} - -type MoveNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` -} - -func (x *MoveNodeResponse) Reset() { - *x = MoveNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[16] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *MoveNodeResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*MoveNodeResponse) ProtoMessage() {} - -func (x *MoveNodeResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_node_proto_msgTypes[16] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use MoveNodeResponse.ProtoReflect.Descriptor instead. -func (*MoveNodeResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_node_proto_rawDescGZIP(), []int{16} -} - -func (x *MoveNodeResponse) GetNode() *Node { - if x != nil { - return x.Node - } - return nil -} - type DebugCreateNodeRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` + Routes []string `protobuf:"bytes,4,rep,name=routes,proto3" json:"routes,omitempty"` unknownFields protoimpl.UnknownFields - - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` - Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` - Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` - Routes []string `protobuf:"bytes,4,rep,name=routes,proto3" json:"routes,omitempty"` + sizeCache protoimpl.SizeCache } func (x *DebugCreateNodeRequest) Reset() { *x = DebugCreateNodeRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[17] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[17] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DebugCreateNodeRequest) String() string { @@ -1059,7 +1019,7 @@ func (*DebugCreateNodeRequest) ProtoMessage() {} func (x *DebugCreateNodeRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[17] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1103,20 +1063,17 @@ func (x *DebugCreateNodeRequest) GetRoutes() []string { } type DebugCreateNodeResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` unknownFields protoimpl.UnknownFields - - Node *Node `protobuf:"bytes,1,opt,name=node,proto3" json:"node,omitempty"` + sizeCache protoimpl.SizeCache } func (x *DebugCreateNodeResponse) Reset() { *x = DebugCreateNodeResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_node_proto_msgTypes[18] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_node_proto_msgTypes[18] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DebugCreateNodeResponse) String() string { @@ -1127,7 +1084,7 @@ func (*DebugCreateNodeResponse) ProtoMessage() {} func (x *DebugCreateNodeResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_node_proto_msgTypes[18] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -1149,198 +1106,233 @@ func (x *DebugCreateNodeResponse) GetNode() *Node { return nil } +type BackfillNodeIPsRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Confirmed bool `protobuf:"varint,1,opt,name=confirmed,proto3" json:"confirmed,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BackfillNodeIPsRequest) Reset() { + *x = BackfillNodeIPsRequest{} + mi := &file_headscale_v1_node_proto_msgTypes[19] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BackfillNodeIPsRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BackfillNodeIPsRequest) ProtoMessage() {} + +func (x *BackfillNodeIPsRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_node_proto_msgTypes[19] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BackfillNodeIPsRequest.ProtoReflect.Descriptor instead. +func (*BackfillNodeIPsRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_node_proto_rawDescGZIP(), []int{19} +} + +func (x *BackfillNodeIPsRequest) GetConfirmed() bool { + if x != nil { + return x.Confirmed + } + return false +} + +type BackfillNodeIPsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Changes []string `protobuf:"bytes,1,rep,name=changes,proto3" json:"changes,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *BackfillNodeIPsResponse) Reset() { + *x = BackfillNodeIPsResponse{} + mi := &file_headscale_v1_node_proto_msgTypes[20] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *BackfillNodeIPsResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*BackfillNodeIPsResponse) ProtoMessage() {} + +func (x *BackfillNodeIPsResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_node_proto_msgTypes[20] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use BackfillNodeIPsResponse.ProtoReflect.Descriptor instead. +func (*BackfillNodeIPsResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_node_proto_rawDescGZIP(), []int{20} +} + +func (x *BackfillNodeIPsResponse) GetChanges() []string { + if x != nil { + return x.Changes + } + return nil +} + var File_headscale_v1_node_proto protoreflect.FileDescriptor -var file_headscale_v1_node_proto_rawDesc = []byte{ - 0x0a, 0x17, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x6e, - 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, - 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x75, 0x73, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, - 0x6f, 0x1a, 0x1d, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, - 0x70, 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x22, 0xeb, 0x05, 0x0a, 0x04, 0x4e, 0x6f, 0x64, 0x65, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x1f, 0x0a, 0x0b, 0x6d, 0x61, 0x63, - 0x68, 0x69, 0x6e, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, - 0x6d, 0x61, 0x63, 0x68, 0x69, 0x6e, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x19, 0x0a, 0x08, 0x6e, 0x6f, - 0x64, 0x65, 0x5f, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6e, 0x6f, - 0x64, 0x65, 0x4b, 0x65, 0x79, 0x12, 0x1b, 0x0a, 0x09, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x5f, 0x6b, - 0x65, 0x79, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x64, 0x69, 0x73, 0x63, 0x6f, 0x4b, - 0x65, 0x79, 0x12, 0x21, 0x0a, 0x0c, 0x69, 0x70, 0x5f, 0x61, 0x64, 0x64, 0x72, 0x65, 0x73, 0x73, - 0x65, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x69, 0x70, 0x41, 0x64, 0x64, 0x72, - 0x65, 0x73, 0x73, 0x65, 0x73, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x06, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x52, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x12, 0x37, 0x0a, 0x09, 0x6c, 0x61, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x65, 0x6e, 0x18, 0x08, - 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, - 0x52, 0x08, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x65, 0x65, 0x6e, 0x12, 0x50, 0x0a, 0x16, 0x6c, 0x61, - 0x73, 0x74, 0x5f, 0x73, 0x75, 0x63, 0x63, 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x5f, 0x75, 0x70, - 0x64, 0x61, 0x74, 0x65, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x14, 0x6c, 0x61, 0x73, 0x74, 0x53, 0x75, 0x63, 0x63, - 0x65, 0x73, 0x73, 0x66, 0x75, 0x6c, 0x55, 0x70, 0x64, 0x61, 0x74, 0x65, 0x12, 0x32, 0x0a, 0x06, - 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, 0x18, 0x0a, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, - 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, - 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x06, 0x65, 0x78, 0x70, 0x69, 0x72, 0x79, - 0x12, 0x3a, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, 0x65, 0x79, - 0x18, 0x0b, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, - 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, - 0x52, 0x0a, 0x70, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x12, 0x39, 0x0a, 0x0a, - 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x0c, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x45, 0x0a, 0x0f, 0x72, 0x65, 0x67, 0x69, 0x73, - 0x74, 0x65, 0x72, 0x5f, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x18, 0x0d, 0x20, 0x01, 0x28, 0x0e, - 0x32, 0x1c, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, - 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x52, 0x0e, - 0x72, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x1f, - 0x0a, 0x0b, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x64, 0x5f, 0x74, 0x61, 0x67, 0x73, 0x18, 0x12, 0x20, - 0x03, 0x28, 0x09, 0x52, 0x0a, 0x66, 0x6f, 0x72, 0x63, 0x65, 0x64, 0x54, 0x61, 0x67, 0x73, 0x12, - 0x21, 0x0a, 0x0c, 0x69, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x74, 0x61, 0x67, 0x73, 0x18, - 0x13, 0x20, 0x03, 0x28, 0x09, 0x52, 0x0b, 0x69, 0x6e, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x54, 0x61, - 0x67, 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x5f, 0x74, 0x61, 0x67, 0x73, - 0x18, 0x14, 0x20, 0x03, 0x28, 0x09, 0x52, 0x09, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x54, 0x61, 0x67, - 0x73, 0x12, 0x1d, 0x0a, 0x0a, 0x67, 0x69, 0x76, 0x65, 0x6e, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, - 0x15, 0x20, 0x01, 0x28, 0x09, 0x52, 0x09, 0x67, 0x69, 0x76, 0x65, 0x6e, 0x4e, 0x61, 0x6d, 0x65, - 0x12, 0x16, 0x0a, 0x06, 0x6f, 0x6e, 0x6c, 0x69, 0x6e, 0x65, 0x18, 0x16, 0x20, 0x01, 0x28, 0x08, - 0x52, 0x06, 0x6f, 0x6e, 0x6c, 0x69, 0x6e, 0x65, 0x4a, 0x04, 0x08, 0x0e, 0x10, 0x12, 0x22, 0x3b, - 0x0a, 0x13, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, - 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x22, 0x3e, 0x0a, 0x14, 0x52, - 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, 0x72, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x22, 0x29, 0x0a, 0x0e, 0x47, - 0x65, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, - 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, - 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x22, 0x39, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x4e, 0x6f, 0x64, - 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x6e, 0x6f, 0x64, - 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6e, 0x6f, 0x64, - 0x65, 0x22, 0x3d, 0x0a, 0x0e, 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, - 0x74, 0x61, 0x67, 0x73, 0x18, 0x02, 0x20, 0x03, 0x28, 0x09, 0x52, 0x04, 0x74, 0x61, 0x67, 0x73, - 0x22, 0x39, 0x0a, 0x0f, 0x53, 0x65, 0x74, 0x54, 0x61, 0x67, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x22, 0x2c, 0x0a, 0x11, 0x44, - 0x65, 0x6c, 0x65, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x17, 0x0a, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x04, 0x52, 0x06, 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x22, 0x14, 0x0a, 0x12, 0x44, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x2c, 0x0a, 0x11, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, - 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x22, 0x3c, 0x0a, - 0x12, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, - 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x22, 0x47, 0x0a, 0x11, 0x52, - 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, - 0x12, 0x17, 0x0a, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x04, 0x52, 0x06, 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, 0x12, 0x19, 0x0a, 0x08, 0x6e, 0x65, 0x77, - 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x77, - 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x3c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x4e, 0x6f, - 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x6e, 0x6f, - 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, 0x6e, 0x6f, - 0x64, 0x65, 0x22, 0x26, 0x0a, 0x10, 0x4c, 0x69, 0x73, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x22, 0x3d, 0x0a, 0x11, 0x4c, 0x69, - 0x73, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x28, 0x0a, 0x05, 0x6e, 0x6f, 0x64, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, - 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, - 0x64, 0x65, 0x52, 0x05, 0x6e, 0x6f, 0x64, 0x65, 0x73, 0x22, 0x3e, 0x0a, 0x0f, 0x4d, 0x6f, 0x76, - 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, - 0x6e, 0x6f, 0x64, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x6e, - 0x6f, 0x64, 0x65, 0x49, 0x64, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x02, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x22, 0x3a, 0x0a, 0x10, 0x4d, 0x6f, 0x76, - 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, - 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, - 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x22, 0x6a, 0x0a, 0x16, 0x44, 0x65, 0x62, 0x75, 0x67, 0x43, 0x72, - 0x65, 0x61, 0x74, 0x65, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, - 0x73, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, - 0x52, 0x03, 0x6b, 0x65, 0x79, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x03, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x18, 0x04, 0x20, 0x03, 0x28, 0x09, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x22, 0x41, 0x0a, 0x17, 0x44, 0x65, 0x62, 0x75, 0x67, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, - 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x04, - 0x6e, 0x6f, 0x64, 0x65, 0x2a, 0x82, 0x01, 0x0a, 0x0e, 0x52, 0x65, 0x67, 0x69, 0x73, 0x74, 0x65, - 0x72, 0x4d, 0x65, 0x74, 0x68, 0x6f, 0x64, 0x12, 0x1f, 0x0a, 0x1b, 0x52, 0x45, 0x47, 0x49, 0x53, - 0x54, 0x45, 0x52, 0x5f, 0x4d, 0x45, 0x54, 0x48, 0x4f, 0x44, 0x5f, 0x55, 0x4e, 0x53, 0x50, 0x45, - 0x43, 0x49, 0x46, 0x49, 0x45, 0x44, 0x10, 0x00, 0x12, 0x1c, 0x0a, 0x18, 0x52, 0x45, 0x47, 0x49, - 0x53, 0x54, 0x45, 0x52, 0x5f, 0x4d, 0x45, 0x54, 0x48, 0x4f, 0x44, 0x5f, 0x41, 0x55, 0x54, 0x48, - 0x5f, 0x4b, 0x45, 0x59, 0x10, 0x01, 0x12, 0x17, 0x0a, 0x13, 0x52, 0x45, 0x47, 0x49, 0x53, 0x54, - 0x45, 0x52, 0x5f, 0x4d, 0x45, 0x54, 0x48, 0x4f, 0x44, 0x5f, 0x43, 0x4c, 0x49, 0x10, 0x02, 0x12, - 0x18, 0x0a, 0x14, 0x52, 0x45, 0x47, 0x49, 0x53, 0x54, 0x45, 0x52, 0x5f, 0x4d, 0x45, 0x54, 0x48, - 0x4f, 0x44, 0x5f, 0x4f, 0x49, 0x44, 0x43, 0x10, 0x03, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, - 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, 0x75, 0x61, 0x6e, 0x66, 0x6f, 0x6e, 0x74, - 0x2f, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x67, - 0x6f, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +const file_headscale_v1_node_proto_rawDesc = "" + + "\n" + + "\x17headscale/v1/node.proto\x12\fheadscale.v1\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/user.proto\"\xc9\x05\n" + + "\x04Node\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x12\x1f\n" + + "\vmachine_key\x18\x02 \x01(\tR\n" + + "machineKey\x12\x19\n" + + "\bnode_key\x18\x03 \x01(\tR\anodeKey\x12\x1b\n" + + "\tdisco_key\x18\x04 \x01(\tR\bdiscoKey\x12!\n" + + "\fip_addresses\x18\x05 \x03(\tR\vipAddresses\x12\x12\n" + + "\x04name\x18\x06 \x01(\tR\x04name\x12&\n" + + "\x04user\x18\a \x01(\v2\x12.headscale.v1.UserR\x04user\x127\n" + + "\tlast_seen\x18\b \x01(\v2\x1a.google.protobuf.TimestampR\blastSeen\x122\n" + + "\x06expiry\x18\n" + + " \x01(\v2\x1a.google.protobuf.TimestampR\x06expiry\x12:\n" + + "\fpre_auth_key\x18\v \x01(\v2\x18.headscale.v1.PreAuthKeyR\n" + + "preAuthKey\x129\n" + + "\n" + + "created_at\x18\f \x01(\v2\x1a.google.protobuf.TimestampR\tcreatedAt\x12E\n" + + "\x0fregister_method\x18\r \x01(\x0e2\x1c.headscale.v1.RegisterMethodR\x0eregisterMethod\x12\x1d\n" + + "\n" + + "given_name\x18\x15 \x01(\tR\tgivenName\x12\x16\n" + + "\x06online\x18\x16 \x01(\bR\x06online\x12'\n" + + "\x0fapproved_routes\x18\x17 \x03(\tR\x0eapprovedRoutes\x12)\n" + + "\x10available_routes\x18\x18 \x03(\tR\x0favailableRoutes\x12#\n" + + "\rsubnet_routes\x18\x19 \x03(\tR\fsubnetRoutes\x12\x12\n" + + "\x04tags\x18\x1a \x03(\tR\x04tagsJ\x04\b\t\x10\n" + + "J\x04\b\x0e\x10\x15\";\n" + + "\x13RegisterNodeRequest\x12\x12\n" + + "\x04user\x18\x01 \x01(\tR\x04user\x12\x10\n" + + "\x03key\x18\x02 \x01(\tR\x03key\">\n" + + "\x14RegisterNodeResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\")\n" + + "\x0eGetNodeRequest\x12\x17\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\"9\n" + + "\x0fGetNodeResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"=\n" + + "\x0eSetTagsRequest\x12\x17\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\x12\x12\n" + + "\x04tags\x18\x02 \x03(\tR\x04tags\"9\n" + + "\x0fSetTagsResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"K\n" + + "\x18SetApprovedRoutesRequest\x12\x17\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\x12\x16\n" + + "\x06routes\x18\x02 \x03(\tR\x06routes\"C\n" + + "\x19SetApprovedRoutesResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\",\n" + + "\x11DeleteNodeRequest\x12\x17\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\"\x14\n" + + "\x12DeleteNodeResponse\"`\n" + + "\x11ExpireNodeRequest\x12\x17\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\x122\n" + + "\x06expiry\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\x06expiry\"<\n" + + "\x12ExpireNodeResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"G\n" + + "\x11RenameNodeRequest\x12\x17\n" + + "\anode_id\x18\x01 \x01(\x04R\x06nodeId\x12\x19\n" + + "\bnew_name\x18\x02 \x01(\tR\anewName\"<\n" + + "\x12RenameNodeResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"&\n" + + "\x10ListNodesRequest\x12\x12\n" + + "\x04user\x18\x01 \x01(\tR\x04user\"=\n" + + "\x11ListNodesResponse\x12(\n" + + "\x05nodes\x18\x01 \x03(\v2\x12.headscale.v1.NodeR\x05nodes\"j\n" + + "\x16DebugCreateNodeRequest\x12\x12\n" + + "\x04user\x18\x01 \x01(\tR\x04user\x12\x10\n" + + "\x03key\x18\x02 \x01(\tR\x03key\x12\x12\n" + + "\x04name\x18\x03 \x01(\tR\x04name\x12\x16\n" + + "\x06routes\x18\x04 \x03(\tR\x06routes\"A\n" + + "\x17DebugCreateNodeResponse\x12&\n" + + "\x04node\x18\x01 \x01(\v2\x12.headscale.v1.NodeR\x04node\"6\n" + + "\x16BackfillNodeIPsRequest\x12\x1c\n" + + "\tconfirmed\x18\x01 \x01(\bR\tconfirmed\"3\n" + + "\x17BackfillNodeIPsResponse\x12\x18\n" + + "\achanges\x18\x01 \x03(\tR\achanges*\x82\x01\n" + + "\x0eRegisterMethod\x12\x1f\n" + + "\x1bREGISTER_METHOD_UNSPECIFIED\x10\x00\x12\x1c\n" + + "\x18REGISTER_METHOD_AUTH_KEY\x10\x01\x12\x17\n" + + "\x13REGISTER_METHOD_CLI\x10\x02\x12\x18\n" + + "\x14REGISTER_METHOD_OIDC\x10\x03B)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( file_headscale_v1_node_proto_rawDescOnce sync.Once - file_headscale_v1_node_proto_rawDescData = file_headscale_v1_node_proto_rawDesc + file_headscale_v1_node_proto_rawDescData []byte ) func file_headscale_v1_node_proto_rawDescGZIP() []byte { file_headscale_v1_node_proto_rawDescOnce.Do(func() { - file_headscale_v1_node_proto_rawDescData = protoimpl.X.CompressGZIP(file_headscale_v1_node_proto_rawDescData) + file_headscale_v1_node_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_node_proto_rawDesc), len(file_headscale_v1_node_proto_rawDesc))) }) return file_headscale_v1_node_proto_rawDescData } var file_headscale_v1_node_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_headscale_v1_node_proto_msgTypes = make([]protoimpl.MessageInfo, 19) -var file_headscale_v1_node_proto_goTypes = []interface{}{ - (RegisterMethod)(0), // 0: headscale.v1.RegisterMethod - (*Node)(nil), // 1: headscale.v1.Node - (*RegisterNodeRequest)(nil), // 2: headscale.v1.RegisterNodeRequest - (*RegisterNodeResponse)(nil), // 3: headscale.v1.RegisterNodeResponse - (*GetNodeRequest)(nil), // 4: headscale.v1.GetNodeRequest - (*GetNodeResponse)(nil), // 5: headscale.v1.GetNodeResponse - (*SetTagsRequest)(nil), // 6: headscale.v1.SetTagsRequest - (*SetTagsResponse)(nil), // 7: headscale.v1.SetTagsResponse - (*DeleteNodeRequest)(nil), // 8: headscale.v1.DeleteNodeRequest - (*DeleteNodeResponse)(nil), // 9: headscale.v1.DeleteNodeResponse - (*ExpireNodeRequest)(nil), // 10: headscale.v1.ExpireNodeRequest - (*ExpireNodeResponse)(nil), // 11: headscale.v1.ExpireNodeResponse - (*RenameNodeRequest)(nil), // 12: headscale.v1.RenameNodeRequest - (*RenameNodeResponse)(nil), // 13: headscale.v1.RenameNodeResponse - (*ListNodesRequest)(nil), // 14: headscale.v1.ListNodesRequest - (*ListNodesResponse)(nil), // 15: headscale.v1.ListNodesResponse - (*MoveNodeRequest)(nil), // 16: headscale.v1.MoveNodeRequest - (*MoveNodeResponse)(nil), // 17: headscale.v1.MoveNodeResponse - (*DebugCreateNodeRequest)(nil), // 18: headscale.v1.DebugCreateNodeRequest - (*DebugCreateNodeResponse)(nil), // 19: headscale.v1.DebugCreateNodeResponse - (*User)(nil), // 20: headscale.v1.User - (*timestamppb.Timestamp)(nil), // 21: google.protobuf.Timestamp - (*PreAuthKey)(nil), // 22: headscale.v1.PreAuthKey +var file_headscale_v1_node_proto_msgTypes = make([]protoimpl.MessageInfo, 21) +var file_headscale_v1_node_proto_goTypes = []any{ + (RegisterMethod)(0), // 0: headscale.v1.RegisterMethod + (*Node)(nil), // 1: headscale.v1.Node + (*RegisterNodeRequest)(nil), // 2: headscale.v1.RegisterNodeRequest + (*RegisterNodeResponse)(nil), // 3: headscale.v1.RegisterNodeResponse + (*GetNodeRequest)(nil), // 4: headscale.v1.GetNodeRequest + (*GetNodeResponse)(nil), // 5: headscale.v1.GetNodeResponse + (*SetTagsRequest)(nil), // 6: headscale.v1.SetTagsRequest + (*SetTagsResponse)(nil), // 7: headscale.v1.SetTagsResponse + (*SetApprovedRoutesRequest)(nil), // 8: headscale.v1.SetApprovedRoutesRequest + (*SetApprovedRoutesResponse)(nil), // 9: headscale.v1.SetApprovedRoutesResponse + (*DeleteNodeRequest)(nil), // 10: headscale.v1.DeleteNodeRequest + (*DeleteNodeResponse)(nil), // 11: headscale.v1.DeleteNodeResponse + (*ExpireNodeRequest)(nil), // 12: headscale.v1.ExpireNodeRequest + (*ExpireNodeResponse)(nil), // 13: headscale.v1.ExpireNodeResponse + (*RenameNodeRequest)(nil), // 14: headscale.v1.RenameNodeRequest + (*RenameNodeResponse)(nil), // 15: headscale.v1.RenameNodeResponse + (*ListNodesRequest)(nil), // 16: headscale.v1.ListNodesRequest + (*ListNodesResponse)(nil), // 17: headscale.v1.ListNodesResponse + (*DebugCreateNodeRequest)(nil), // 18: headscale.v1.DebugCreateNodeRequest + (*DebugCreateNodeResponse)(nil), // 19: headscale.v1.DebugCreateNodeResponse + (*BackfillNodeIPsRequest)(nil), // 20: headscale.v1.BackfillNodeIPsRequest + (*BackfillNodeIPsResponse)(nil), // 21: headscale.v1.BackfillNodeIPsResponse + (*User)(nil), // 22: headscale.v1.User + (*timestamppb.Timestamp)(nil), // 23: google.protobuf.Timestamp + (*PreAuthKey)(nil), // 24: headscale.v1.PreAuthKey } var file_headscale_v1_node_proto_depIdxs = []int32{ - 20, // 0: headscale.v1.Node.user:type_name -> headscale.v1.User - 21, // 1: headscale.v1.Node.last_seen:type_name -> google.protobuf.Timestamp - 21, // 2: headscale.v1.Node.last_successful_update:type_name -> google.protobuf.Timestamp - 21, // 3: headscale.v1.Node.expiry:type_name -> google.protobuf.Timestamp - 22, // 4: headscale.v1.Node.pre_auth_key:type_name -> headscale.v1.PreAuthKey - 21, // 5: headscale.v1.Node.created_at:type_name -> google.protobuf.Timestamp - 0, // 6: headscale.v1.Node.register_method:type_name -> headscale.v1.RegisterMethod - 1, // 7: headscale.v1.RegisterNodeResponse.node:type_name -> headscale.v1.Node - 1, // 8: headscale.v1.GetNodeResponse.node:type_name -> headscale.v1.Node - 1, // 9: headscale.v1.SetTagsResponse.node:type_name -> headscale.v1.Node - 1, // 10: headscale.v1.ExpireNodeResponse.node:type_name -> headscale.v1.Node - 1, // 11: headscale.v1.RenameNodeResponse.node:type_name -> headscale.v1.Node - 1, // 12: headscale.v1.ListNodesResponse.nodes:type_name -> headscale.v1.Node - 1, // 13: headscale.v1.MoveNodeResponse.node:type_name -> headscale.v1.Node + 22, // 0: headscale.v1.Node.user:type_name -> headscale.v1.User + 23, // 1: headscale.v1.Node.last_seen:type_name -> google.protobuf.Timestamp + 23, // 2: headscale.v1.Node.expiry:type_name -> google.protobuf.Timestamp + 24, // 3: headscale.v1.Node.pre_auth_key:type_name -> headscale.v1.PreAuthKey + 23, // 4: headscale.v1.Node.created_at:type_name -> google.protobuf.Timestamp + 0, // 5: headscale.v1.Node.register_method:type_name -> headscale.v1.RegisterMethod + 1, // 6: headscale.v1.RegisterNodeResponse.node:type_name -> headscale.v1.Node + 1, // 7: headscale.v1.GetNodeResponse.node:type_name -> headscale.v1.Node + 1, // 8: headscale.v1.SetTagsResponse.node:type_name -> headscale.v1.Node + 1, // 9: headscale.v1.SetApprovedRoutesResponse.node:type_name -> headscale.v1.Node + 23, // 10: headscale.v1.ExpireNodeRequest.expiry:type_name -> google.protobuf.Timestamp + 1, // 11: headscale.v1.ExpireNodeResponse.node:type_name -> headscale.v1.Node + 1, // 12: headscale.v1.RenameNodeResponse.node:type_name -> headscale.v1.Node + 1, // 13: headscale.v1.ListNodesResponse.nodes:type_name -> headscale.v1.Node 1, // 14: headscale.v1.DebugCreateNodeResponse.node:type_name -> headscale.v1.Node 15, // [15:15] is the sub-list for method output_type 15, // [15:15] is the sub-list for method input_type @@ -1354,245 +1346,15 @@ func file_headscale_v1_node_proto_init() { if File_headscale_v1_node_proto != nil { return } - file_headscale_v1_user_proto_init() file_headscale_v1_preauthkey_proto_init() - if !protoimpl.UnsafeEnabled { - file_headscale_v1_node_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Node); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RegisterNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RegisterNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SetTagsRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*SetTagsResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExpireNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExpireNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[11].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RenameNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[12].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RenameNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[13].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListNodesRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[14].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListNodesResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[15].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MoveNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[16].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*MoveNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[17].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DebugCreateNodeRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_node_proto_msgTypes[18].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DebugCreateNodeResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } + file_headscale_v1_user_proto_init() type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_node_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_node_proto_rawDesc), len(file_headscale_v1_node_proto_rawDesc)), NumEnums: 1, - NumMessages: 19, + NumMessages: 21, NumExtensions: 0, NumServices: 0, }, @@ -1602,7 +1364,6 @@ func file_headscale_v1_node_proto_init() { MessageInfos: file_headscale_v1_node_proto_msgTypes, }.Build() File_headscale_v1_node_proto = out.File - file_headscale_v1_node_proto_rawDesc = nil file_headscale_v1_node_proto_goTypes = nil file_headscale_v1_node_proto_depIdxs = nil } diff --git a/gen/go/headscale/v1/policy.pb.go b/gen/go/headscale/v1/policy.pb.go new file mode 100644 index 00000000..faa3fc40 --- /dev/null +++ b/gen/go/headscale/v1/policy.pb.go @@ -0,0 +1,278 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc (unknown) +// source: headscale/v1/policy.proto + +package v1 + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + timestamppb "google.golang.org/protobuf/types/known/timestamppb" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type SetPolicyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Policy string `protobuf:"bytes,1,opt,name=policy,proto3" json:"policy,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetPolicyRequest) Reset() { + *x = SetPolicyRequest{} + mi := &file_headscale_v1_policy_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetPolicyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetPolicyRequest) ProtoMessage() {} + +func (x *SetPolicyRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetPolicyRequest.ProtoReflect.Descriptor instead. +func (*SetPolicyRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{0} +} + +func (x *SetPolicyRequest) GetPolicy() string { + if x != nil { + return x.Policy + } + return "" +} + +type SetPolicyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Policy string `protobuf:"bytes,1,opt,name=policy,proto3" json:"policy,omitempty"` + UpdatedAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=updated_at,json=updatedAt,proto3" json:"updated_at,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetPolicyResponse) Reset() { + *x = SetPolicyResponse{} + mi := &file_headscale_v1_policy_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetPolicyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetPolicyResponse) ProtoMessage() {} + +func (x *SetPolicyResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetPolicyResponse.ProtoReflect.Descriptor instead. +func (*SetPolicyResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{1} +} + +func (x *SetPolicyResponse) GetPolicy() string { + if x != nil { + return x.Policy + } + return "" +} + +func (x *SetPolicyResponse) GetUpdatedAt() *timestamppb.Timestamp { + if x != nil { + return x.UpdatedAt + } + return nil +} + +type GetPolicyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPolicyRequest) Reset() { + *x = GetPolicyRequest{} + mi := &file_headscale_v1_policy_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPolicyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPolicyRequest) ProtoMessage() {} + +func (x *GetPolicyRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPolicyRequest.ProtoReflect.Descriptor instead. +func (*GetPolicyRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{2} +} + +type GetPolicyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Policy string `protobuf:"bytes,1,opt,name=policy,proto3" json:"policy,omitempty"` + UpdatedAt *timestamppb.Timestamp `protobuf:"bytes,2,opt,name=updated_at,json=updatedAt,proto3" json:"updated_at,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *GetPolicyResponse) Reset() { + *x = GetPolicyResponse{} + mi := &file_headscale_v1_policy_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *GetPolicyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*GetPolicyResponse) ProtoMessage() {} + +func (x *GetPolicyResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use GetPolicyResponse.ProtoReflect.Descriptor instead. +func (*GetPolicyResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{3} +} + +func (x *GetPolicyResponse) GetPolicy() string { + if x != nil { + return x.Policy + } + return "" +} + +func (x *GetPolicyResponse) GetUpdatedAt() *timestamppb.Timestamp { + if x != nil { + return x.UpdatedAt + } + return nil +} + +var File_headscale_v1_policy_proto protoreflect.FileDescriptor + +const file_headscale_v1_policy_proto_rawDesc = "" + + "\n" + + "\x19headscale/v1/policy.proto\x12\fheadscale.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"*\n" + + "\x10SetPolicyRequest\x12\x16\n" + + "\x06policy\x18\x01 \x01(\tR\x06policy\"f\n" + + "\x11SetPolicyResponse\x12\x16\n" + + "\x06policy\x18\x01 \x01(\tR\x06policy\x129\n" + + "\n" + + "updated_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\tupdatedAt\"\x12\n" + + "\x10GetPolicyRequest\"f\n" + + "\x11GetPolicyResponse\x12\x16\n" + + "\x06policy\x18\x01 \x01(\tR\x06policy\x129\n" + + "\n" + + "updated_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\tupdatedAtB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" + +var ( + file_headscale_v1_policy_proto_rawDescOnce sync.Once + file_headscale_v1_policy_proto_rawDescData []byte +) + +func file_headscale_v1_policy_proto_rawDescGZIP() []byte { + file_headscale_v1_policy_proto_rawDescOnce.Do(func() { + file_headscale_v1_policy_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_policy_proto_rawDesc), len(file_headscale_v1_policy_proto_rawDesc))) + }) + return file_headscale_v1_policy_proto_rawDescData +} + +var file_headscale_v1_policy_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_headscale_v1_policy_proto_goTypes = []any{ + (*SetPolicyRequest)(nil), // 0: headscale.v1.SetPolicyRequest + (*SetPolicyResponse)(nil), // 1: headscale.v1.SetPolicyResponse + (*GetPolicyRequest)(nil), // 2: headscale.v1.GetPolicyRequest + (*GetPolicyResponse)(nil), // 3: headscale.v1.GetPolicyResponse + (*timestamppb.Timestamp)(nil), // 4: google.protobuf.Timestamp +} +var file_headscale_v1_policy_proto_depIdxs = []int32{ + 4, // 0: headscale.v1.SetPolicyResponse.updated_at:type_name -> google.protobuf.Timestamp + 4, // 1: headscale.v1.GetPolicyResponse.updated_at:type_name -> google.protobuf.Timestamp + 2, // [2:2] is the sub-list for method output_type + 2, // [2:2] is the sub-list for method input_type + 2, // [2:2] is the sub-list for extension type_name + 2, // [2:2] is the sub-list for extension extendee + 0, // [0:2] is the sub-list for field type_name +} + +func init() { file_headscale_v1_policy_proto_init() } +func file_headscale_v1_policy_proto_init() { + if File_headscale_v1_policy_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_policy_proto_rawDesc), len(file_headscale_v1_policy_proto_rawDesc)), + NumEnums: 0, + NumMessages: 4, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_headscale_v1_policy_proto_goTypes, + DependencyIndexes: file_headscale_v1_policy_proto_depIdxs, + MessageInfos: file_headscale_v1_policy_proto_msgTypes, + }.Build() + File_headscale_v1_policy_proto = out.File + file_headscale_v1_policy_proto_goTypes = nil + file_headscale_v1_policy_proto_depIdxs = nil +} diff --git a/gen/go/headscale/v1/preauthkey.pb.go b/gen/go/headscale/v1/preauthkey.pb.go index 856377f2..ff902d45 100644 --- a/gen/go/headscale/v1/preauthkey.pb.go +++ b/gen/go/headscale/v1/preauthkey.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: headscale/v1/preauthkey.proto @@ -12,6 +12,7 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -22,28 +23,25 @@ const ( ) type PreAuthKey struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + Id uint64 `protobuf:"varint,2,opt,name=id,proto3" json:"id,omitempty"` + Key string `protobuf:"bytes,3,opt,name=key,proto3" json:"key,omitempty"` + Reusable bool `protobuf:"varint,4,opt,name=reusable,proto3" json:"reusable,omitempty"` + Ephemeral bool `protobuf:"varint,5,opt,name=ephemeral,proto3" json:"ephemeral,omitempty"` + Used bool `protobuf:"varint,6,opt,name=used,proto3" json:"used,omitempty"` + Expiration *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=expiration,proto3" json:"expiration,omitempty"` + CreatedAt *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + AclTags []string `protobuf:"bytes,9,rep,name=acl_tags,json=aclTags,proto3" json:"acl_tags,omitempty"` unknownFields protoimpl.UnknownFields - - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` - Id string `protobuf:"bytes,2,opt,name=id,proto3" json:"id,omitempty"` - Key string `protobuf:"bytes,3,opt,name=key,proto3" json:"key,omitempty"` - Reusable bool `protobuf:"varint,4,opt,name=reusable,proto3" json:"reusable,omitempty"` - Ephemeral bool `protobuf:"varint,5,opt,name=ephemeral,proto3" json:"ephemeral,omitempty"` - Used bool `protobuf:"varint,6,opt,name=used,proto3" json:"used,omitempty"` - Expiration *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=expiration,proto3" json:"expiration,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` - AclTags []string `protobuf:"bytes,9,rep,name=acl_tags,json=aclTags,proto3" json:"acl_tags,omitempty"` + sizeCache protoimpl.SizeCache } func (x *PreAuthKey) Reset() { *x = PreAuthKey{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *PreAuthKey) String() string { @@ -54,7 +52,7 @@ func (*PreAuthKey) ProtoMessage() {} func (x *PreAuthKey) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_preauthkey_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -69,18 +67,18 @@ func (*PreAuthKey) Descriptor() ([]byte, []int) { return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{0} } -func (x *PreAuthKey) GetUser() string { +func (x *PreAuthKey) GetUser() *User { if x != nil { return x.User } - return "" + return nil } -func (x *PreAuthKey) GetId() string { +func (x *PreAuthKey) GetId() uint64 { if x != nil { return x.Id } - return "" + return 0 } func (x *PreAuthKey) GetKey() string { @@ -133,24 +131,21 @@ func (x *PreAuthKey) GetAclTags() []string { } type CreatePreAuthKeyRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + User uint64 `protobuf:"varint,1,opt,name=user,proto3" json:"user,omitempty"` + Reusable bool `protobuf:"varint,2,opt,name=reusable,proto3" json:"reusable,omitempty"` + Ephemeral bool `protobuf:"varint,3,opt,name=ephemeral,proto3" json:"ephemeral,omitempty"` + Expiration *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=expiration,proto3" json:"expiration,omitempty"` + AclTags []string `protobuf:"bytes,5,rep,name=acl_tags,json=aclTags,proto3" json:"acl_tags,omitempty"` unknownFields protoimpl.UnknownFields - - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` - Reusable bool `protobuf:"varint,2,opt,name=reusable,proto3" json:"reusable,omitempty"` - Ephemeral bool `protobuf:"varint,3,opt,name=ephemeral,proto3" json:"ephemeral,omitempty"` - Expiration *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=expiration,proto3" json:"expiration,omitempty"` - AclTags []string `protobuf:"bytes,5,rep,name=acl_tags,json=aclTags,proto3" json:"acl_tags,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CreatePreAuthKeyRequest) Reset() { *x = CreatePreAuthKeyRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreatePreAuthKeyRequest) String() string { @@ -161,7 +156,7 @@ func (*CreatePreAuthKeyRequest) ProtoMessage() {} func (x *CreatePreAuthKeyRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_preauthkey_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -176,11 +171,11 @@ func (*CreatePreAuthKeyRequest) Descriptor() ([]byte, []int) { return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{1} } -func (x *CreatePreAuthKeyRequest) GetUser() string { +func (x *CreatePreAuthKeyRequest) GetUser() uint64 { if x != nil { return x.User } - return "" + return 0 } func (x *CreatePreAuthKeyRequest) GetReusable() bool { @@ -212,20 +207,17 @@ func (x *CreatePreAuthKeyRequest) GetAclTags() []string { } type CreatePreAuthKeyResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + PreAuthKey *PreAuthKey `protobuf:"bytes,1,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` unknownFields protoimpl.UnknownFields - - PreAuthKey *PreAuthKey `protobuf:"bytes,1,opt,name=pre_auth_key,json=preAuthKey,proto3" json:"pre_auth_key,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CreatePreAuthKeyResponse) Reset() { *x = CreatePreAuthKeyResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreatePreAuthKeyResponse) String() string { @@ -236,7 +228,7 @@ func (*CreatePreAuthKeyResponse) ProtoMessage() {} func (x *CreatePreAuthKeyResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_preauthkey_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -259,21 +251,17 @@ func (x *CreatePreAuthKeyResponse) GetPreAuthKey() *PreAuthKey { } type ExpirePreAuthKeyRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields - - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` - Key string `protobuf:"bytes,2,opt,name=key,proto3" json:"key,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ExpirePreAuthKeyRequest) Reset() { *x = ExpirePreAuthKeyRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExpirePreAuthKeyRequest) String() string { @@ -284,7 +272,7 @@ func (*ExpirePreAuthKeyRequest) ProtoMessage() {} func (x *ExpirePreAuthKeyRequest) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_preauthkey_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -299,33 +287,24 @@ func (*ExpirePreAuthKeyRequest) Descriptor() ([]byte, []int) { return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{3} } -func (x *ExpirePreAuthKeyRequest) GetUser() string { +func (x *ExpirePreAuthKeyRequest) GetId() uint64 { if x != nil { - return x.User + return x.Id } - return "" -} - -func (x *ExpirePreAuthKeyRequest) GetKey() string { - if x != nil { - return x.Key - } - return "" + return 0 } type ExpirePreAuthKeyResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ExpirePreAuthKeyResponse) Reset() { *x = ExpirePreAuthKeyResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ExpirePreAuthKeyResponse) String() string { @@ -336,7 +315,7 @@ func (*ExpirePreAuthKeyResponse) ProtoMessage() {} func (x *ExpirePreAuthKeyResponse) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_preauthkey_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -351,21 +330,97 @@ func (*ExpirePreAuthKeyResponse) Descriptor() ([]byte, []int) { return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{4} } -type ListPreAuthKeysRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache +type DeletePreAuthKeyRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} - User string `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` +func (x *DeletePreAuthKeyRequest) Reset() { + *x = DeletePreAuthKeyRequest{} + mi := &file_headscale_v1_preauthkey_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeletePreAuthKeyRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeletePreAuthKeyRequest) ProtoMessage() {} + +func (x *DeletePreAuthKeyRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_preauthkey_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeletePreAuthKeyRequest.ProtoReflect.Descriptor instead. +func (*DeletePreAuthKeyRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{5} +} + +func (x *DeletePreAuthKeyRequest) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +type DeletePreAuthKeyResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DeletePreAuthKeyResponse) Reset() { + *x = DeletePreAuthKeyResponse{} + mi := &file_headscale_v1_preauthkey_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DeletePreAuthKeyResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DeletePreAuthKeyResponse) ProtoMessage() {} + +func (x *DeletePreAuthKeyResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_preauthkey_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DeletePreAuthKeyResponse.ProtoReflect.Descriptor instead. +func (*DeletePreAuthKeyResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{6} +} + +type ListPreAuthKeysRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ListPreAuthKeysRequest) Reset() { *x = ListPreAuthKeysRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListPreAuthKeysRequest) String() string { @@ -375,8 +430,8 @@ func (x *ListPreAuthKeysRequest) String() string { func (*ListPreAuthKeysRequest) ProtoMessage() {} func (x *ListPreAuthKeysRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_preauthkey_proto_msgTypes[7] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -388,31 +443,21 @@ func (x *ListPreAuthKeysRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListPreAuthKeysRequest.ProtoReflect.Descriptor instead. func (*ListPreAuthKeysRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{5} -} - -func (x *ListPreAuthKeysRequest) GetUser() string { - if x != nil { - return x.User - } - return "" + return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{7} } type ListPreAuthKeysResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + PreAuthKeys []*PreAuthKey `protobuf:"bytes,1,rep,name=pre_auth_keys,json=preAuthKeys,proto3" json:"pre_auth_keys,omitempty"` unknownFields protoimpl.UnknownFields - - PreAuthKeys []*PreAuthKey `protobuf:"bytes,1,rep,name=pre_auth_keys,json=preAuthKeys,proto3" json:"pre_auth_keys,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ListPreAuthKeysResponse) Reset() { *x = ListPreAuthKeysResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_preauthkey_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListPreAuthKeysResponse) String() string { @@ -422,8 +467,8 @@ func (x *ListPreAuthKeysResponse) String() string { func (*ListPreAuthKeysResponse) ProtoMessage() {} func (x *ListPreAuthKeysResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_preauthkey_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_preauthkey_proto_msgTypes[8] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -435,7 +480,7 @@ func (x *ListPreAuthKeysResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListPreAuthKeysResponse.ProtoReflect.Descriptor instead. func (*ListPreAuthKeysResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{6} + return file_headscale_v1_preauthkey_proto_rawDescGZIP(), []int{8} } func (x *ListPreAuthKeysResponse) GetPreAuthKeys() []*PreAuthKey { @@ -447,102 +492,82 @@ func (x *ListPreAuthKeysResponse) GetPreAuthKeys() []*PreAuthKey { var File_headscale_v1_preauthkey_proto protoreflect.FileDescriptor -var file_headscale_v1_preauthkey_proto_rawDesc = []byte{ - 0x0a, 0x1d, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x70, - 0x72, 0x65, 0x61, 0x75, 0x74, 0x68, 0x6b, 0x65, 0x79, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, - 0x0c, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, - 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, - 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa2, - 0x02, 0x0a, 0x0a, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x12, 0x12, 0x0a, - 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, - 0x72, 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, - 0x64, 0x12, 0x10, 0x0a, 0x03, 0x6b, 0x65, 0x79, 0x18, 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, - 0x6b, 0x65, 0x79, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x75, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x18, - 0x04, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x72, 0x65, 0x75, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x12, - 0x1c, 0x0a, 0x09, 0x65, 0x70, 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x18, 0x05, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x65, 0x70, 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x12, 0x12, 0x0a, - 0x04, 0x75, 0x73, 0x65, 0x64, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x04, 0x75, 0x73, 0x65, - 0x64, 0x12, 0x3a, 0x0a, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, - 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x39, 0x0a, - 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x08, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, - 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x61, 0x63, 0x6c, 0x5f, - 0x74, 0x61, 0x67, 0x73, 0x18, 0x09, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x61, 0x63, 0x6c, 0x54, - 0x61, 0x67, 0x73, 0x22, 0xbe, 0x01, 0x0a, 0x17, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, - 0x73, 0x65, 0x72, 0x12, 0x1a, 0x0a, 0x08, 0x72, 0x65, 0x75, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x18, - 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x08, 0x72, 0x65, 0x75, 0x73, 0x61, 0x62, 0x6c, 0x65, 0x12, - 0x1c, 0x0a, 0x09, 0x65, 0x70, 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x18, 0x03, 0x20, 0x01, - 0x28, 0x08, 0x52, 0x09, 0x65, 0x70, 0x68, 0x65, 0x6d, 0x65, 0x72, 0x61, 0x6c, 0x12, 0x3a, 0x0a, - 0x0a, 0x65, 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x0a, 0x65, - 0x78, 0x70, 0x69, 0x72, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x12, 0x19, 0x0a, 0x08, 0x61, 0x63, 0x6c, - 0x5f, 0x74, 0x61, 0x67, 0x73, 0x18, 0x05, 0x20, 0x03, 0x28, 0x09, 0x52, 0x07, 0x61, 0x63, 0x6c, - 0x54, 0x61, 0x67, 0x73, 0x22, 0x56, 0x0a, 0x18, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x50, 0x72, - 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x3a, 0x0a, 0x0c, 0x70, 0x72, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, 0x65, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, - 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, - 0x52, 0x0a, 0x70, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x22, 0x3f, 0x0a, 0x17, - 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x12, 0x10, 0x0a, 0x03, 0x6b, - 0x65, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x03, 0x6b, 0x65, 0x79, 0x22, 0x1a, 0x0a, - 0x18, 0x45, 0x78, 0x70, 0x69, 0x72, 0x65, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, - 0x79, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x2c, 0x0a, 0x16, 0x4c, 0x69, 0x73, - 0x74, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x22, 0x57, 0x0a, 0x17, 0x4c, 0x69, 0x73, 0x74, 0x50, - 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, - 0x73, 0x65, 0x12, 0x3c, 0x0a, 0x0d, 0x70, 0x72, 0x65, 0x5f, 0x61, 0x75, 0x74, 0x68, 0x5f, 0x6b, - 0x65, 0x79, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x18, 0x2e, 0x68, 0x65, 0x61, 0x64, - 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x50, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, - 0x4b, 0x65, 0x79, 0x52, 0x0b, 0x70, 0x72, 0x65, 0x41, 0x75, 0x74, 0x68, 0x4b, 0x65, 0x79, 0x73, - 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, - 0x75, 0x61, 0x6e, 0x66, 0x6f, 0x6e, 0x74, 0x2f, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, - 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x67, 0x6f, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, - 0x74, 0x6f, 0x33, -} +const file_headscale_v1_preauthkey_proto_rawDesc = "" + + "\n" + + "\x1dheadscale/v1/preauthkey.proto\x12\fheadscale.v1\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x17headscale/v1/user.proto\"\xb6\x02\n" + + "\n" + + "PreAuthKey\x12&\n" + + "\x04user\x18\x01 \x01(\v2\x12.headscale.v1.UserR\x04user\x12\x0e\n" + + "\x02id\x18\x02 \x01(\x04R\x02id\x12\x10\n" + + "\x03key\x18\x03 \x01(\tR\x03key\x12\x1a\n" + + "\breusable\x18\x04 \x01(\bR\breusable\x12\x1c\n" + + "\tephemeral\x18\x05 \x01(\bR\tephemeral\x12\x12\n" + + "\x04used\x18\x06 \x01(\bR\x04used\x12:\n" + + "\n" + + "expiration\x18\a \x01(\v2\x1a.google.protobuf.TimestampR\n" + + "expiration\x129\n" + + "\n" + + "created_at\x18\b \x01(\v2\x1a.google.protobuf.TimestampR\tcreatedAt\x12\x19\n" + + "\bacl_tags\x18\t \x03(\tR\aaclTags\"\xbe\x01\n" + + "\x17CreatePreAuthKeyRequest\x12\x12\n" + + "\x04user\x18\x01 \x01(\x04R\x04user\x12\x1a\n" + + "\breusable\x18\x02 \x01(\bR\breusable\x12\x1c\n" + + "\tephemeral\x18\x03 \x01(\bR\tephemeral\x12:\n" + + "\n" + + "expiration\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\n" + + "expiration\x12\x19\n" + + "\bacl_tags\x18\x05 \x03(\tR\aaclTags\"V\n" + + "\x18CreatePreAuthKeyResponse\x12:\n" + + "\fpre_auth_key\x18\x01 \x01(\v2\x18.headscale.v1.PreAuthKeyR\n" + + "preAuthKey\")\n" + + "\x17ExpirePreAuthKeyRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\"\x1a\n" + + "\x18ExpirePreAuthKeyResponse\")\n" + + "\x17DeletePreAuthKeyRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\"\x1a\n" + + "\x18DeletePreAuthKeyResponse\"\x18\n" + + "\x16ListPreAuthKeysRequest\"W\n" + + "\x17ListPreAuthKeysResponse\x12<\n" + + "\rpre_auth_keys\x18\x01 \x03(\v2\x18.headscale.v1.PreAuthKeyR\vpreAuthKeysB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( file_headscale_v1_preauthkey_proto_rawDescOnce sync.Once - file_headscale_v1_preauthkey_proto_rawDescData = file_headscale_v1_preauthkey_proto_rawDesc + file_headscale_v1_preauthkey_proto_rawDescData []byte ) func file_headscale_v1_preauthkey_proto_rawDescGZIP() []byte { file_headscale_v1_preauthkey_proto_rawDescOnce.Do(func() { - file_headscale_v1_preauthkey_proto_rawDescData = protoimpl.X.CompressGZIP(file_headscale_v1_preauthkey_proto_rawDescData) + file_headscale_v1_preauthkey_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_preauthkey_proto_rawDesc), len(file_headscale_v1_preauthkey_proto_rawDesc))) }) return file_headscale_v1_preauthkey_proto_rawDescData } -var file_headscale_v1_preauthkey_proto_msgTypes = make([]protoimpl.MessageInfo, 7) -var file_headscale_v1_preauthkey_proto_goTypes = []interface{}{ +var file_headscale_v1_preauthkey_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_headscale_v1_preauthkey_proto_goTypes = []any{ (*PreAuthKey)(nil), // 0: headscale.v1.PreAuthKey (*CreatePreAuthKeyRequest)(nil), // 1: headscale.v1.CreatePreAuthKeyRequest (*CreatePreAuthKeyResponse)(nil), // 2: headscale.v1.CreatePreAuthKeyResponse (*ExpirePreAuthKeyRequest)(nil), // 3: headscale.v1.ExpirePreAuthKeyRequest (*ExpirePreAuthKeyResponse)(nil), // 4: headscale.v1.ExpirePreAuthKeyResponse - (*ListPreAuthKeysRequest)(nil), // 5: headscale.v1.ListPreAuthKeysRequest - (*ListPreAuthKeysResponse)(nil), // 6: headscale.v1.ListPreAuthKeysResponse - (*timestamppb.Timestamp)(nil), // 7: google.protobuf.Timestamp + (*DeletePreAuthKeyRequest)(nil), // 5: headscale.v1.DeletePreAuthKeyRequest + (*DeletePreAuthKeyResponse)(nil), // 6: headscale.v1.DeletePreAuthKeyResponse + (*ListPreAuthKeysRequest)(nil), // 7: headscale.v1.ListPreAuthKeysRequest + (*ListPreAuthKeysResponse)(nil), // 8: headscale.v1.ListPreAuthKeysResponse + (*User)(nil), // 9: headscale.v1.User + (*timestamppb.Timestamp)(nil), // 10: google.protobuf.Timestamp } var file_headscale_v1_preauthkey_proto_depIdxs = []int32{ - 7, // 0: headscale.v1.PreAuthKey.expiration:type_name -> google.protobuf.Timestamp - 7, // 1: headscale.v1.PreAuthKey.created_at:type_name -> google.protobuf.Timestamp - 7, // 2: headscale.v1.CreatePreAuthKeyRequest.expiration:type_name -> google.protobuf.Timestamp - 0, // 3: headscale.v1.CreatePreAuthKeyResponse.pre_auth_key:type_name -> headscale.v1.PreAuthKey - 0, // 4: headscale.v1.ListPreAuthKeysResponse.pre_auth_keys:type_name -> headscale.v1.PreAuthKey - 5, // [5:5] is the sub-list for method output_type - 5, // [5:5] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name + 9, // 0: headscale.v1.PreAuthKey.user:type_name -> headscale.v1.User + 10, // 1: headscale.v1.PreAuthKey.expiration:type_name -> google.protobuf.Timestamp + 10, // 2: headscale.v1.PreAuthKey.created_at:type_name -> google.protobuf.Timestamp + 10, // 3: headscale.v1.CreatePreAuthKeyRequest.expiration:type_name -> google.protobuf.Timestamp + 0, // 4: headscale.v1.CreatePreAuthKeyResponse.pre_auth_key:type_name -> headscale.v1.PreAuthKey + 0, // 5: headscale.v1.ListPreAuthKeysResponse.pre_auth_keys:type_name -> headscale.v1.PreAuthKey + 6, // [6:6] is the sub-list for method output_type + 6, // [6:6] is the sub-list for method input_type + 6, // [6:6] is the sub-list for extension type_name + 6, // [6:6] is the sub-list for extension extendee + 0, // [0:6] is the sub-list for field type_name } func init() { file_headscale_v1_preauthkey_proto_init() } @@ -550,99 +575,14 @@ func file_headscale_v1_preauthkey_proto_init() { if File_headscale_v1_preauthkey_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_headscale_v1_preauthkey_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*PreAuthKey); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_preauthkey_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreatePreAuthKeyRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_preauthkey_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreatePreAuthKeyResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_preauthkey_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExpirePreAuthKeyRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_preauthkey_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ExpirePreAuthKeyResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_preauthkey_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListPreAuthKeysRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_preauthkey_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListPreAuthKeysResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } + file_headscale_v1_user_proto_init() type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_preauthkey_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_preauthkey_proto_rawDesc), len(file_headscale_v1_preauthkey_proto_rawDesc)), NumEnums: 0, - NumMessages: 7, + NumMessages: 9, NumExtensions: 0, NumServices: 0, }, @@ -651,7 +591,6 @@ func file_headscale_v1_preauthkey_proto_init() { MessageInfos: file_headscale_v1_preauthkey_proto_msgTypes, }.Build() File_headscale_v1_preauthkey_proto = out.File - file_headscale_v1_preauthkey_proto_rawDesc = nil file_headscale_v1_preauthkey_proto_goTypes = nil file_headscale_v1_preauthkey_proto_depIdxs = nil } diff --git a/gen/go/headscale/v1/routes.pb.go b/gen/go/headscale/v1/routes.pb.go deleted file mode 100644 index 051e92e9..00000000 --- a/gen/go/headscale/v1/routes.pb.go +++ /dev/null @@ -1,833 +0,0 @@ -// Code generated by protoc-gen-go. DO NOT EDIT. -// versions: -// protoc-gen-go v1.31.0 -// protoc (unknown) -// source: headscale/v1/routes.proto - -package v1 - -import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" - timestamppb "google.golang.org/protobuf/types/known/timestamppb" - reflect "reflect" - sync "sync" -) - -const ( - // Verify that this generated code is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) - // Verify that runtime/protoimpl is sufficiently up-to-date. - _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) -) - -type Route struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` - Node *Node `protobuf:"bytes,2,opt,name=node,proto3" json:"node,omitempty"` - Prefix string `protobuf:"bytes,3,opt,name=prefix,proto3" json:"prefix,omitempty"` - Advertised bool `protobuf:"varint,4,opt,name=advertised,proto3" json:"advertised,omitempty"` - Enabled bool `protobuf:"varint,5,opt,name=enabled,proto3" json:"enabled,omitempty"` - IsPrimary bool `protobuf:"varint,6,opt,name=is_primary,json=isPrimary,proto3" json:"is_primary,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` - UpdatedAt *timestamppb.Timestamp `protobuf:"bytes,8,opt,name=updated_at,json=updatedAt,proto3" json:"updated_at,omitempty"` - DeletedAt *timestamppb.Timestamp `protobuf:"bytes,9,opt,name=deleted_at,json=deletedAt,proto3" json:"deleted_at,omitempty"` -} - -func (x *Route) Reset() { - *x = Route{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *Route) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*Route) ProtoMessage() {} - -func (x *Route) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use Route.ProtoReflect.Descriptor instead. -func (*Route) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{0} -} - -func (x *Route) GetId() uint64 { - if x != nil { - return x.Id - } - return 0 -} - -func (x *Route) GetNode() *Node { - if x != nil { - return x.Node - } - return nil -} - -func (x *Route) GetPrefix() string { - if x != nil { - return x.Prefix - } - return "" -} - -func (x *Route) GetAdvertised() bool { - if x != nil { - return x.Advertised - } - return false -} - -func (x *Route) GetEnabled() bool { - if x != nil { - return x.Enabled - } - return false -} - -func (x *Route) GetIsPrimary() bool { - if x != nil { - return x.IsPrimary - } - return false -} - -func (x *Route) GetCreatedAt() *timestamppb.Timestamp { - if x != nil { - return x.CreatedAt - } - return nil -} - -func (x *Route) GetUpdatedAt() *timestamppb.Timestamp { - if x != nil { - return x.UpdatedAt - } - return nil -} - -func (x *Route) GetDeletedAt() *timestamppb.Timestamp { - if x != nil { - return x.DeletedAt - } - return nil -} - -type GetRoutesRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *GetRoutesRequest) Reset() { - *x = GetRoutesRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetRoutesRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetRoutesRequest) ProtoMessage() {} - -func (x *GetRoutesRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetRoutesRequest.ProtoReflect.Descriptor instead. -func (*GetRoutesRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{1} -} - -type GetRoutesResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` -} - -func (x *GetRoutesResponse) Reset() { - *x = GetRoutesResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetRoutesResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetRoutesResponse) ProtoMessage() {} - -func (x *GetRoutesResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetRoutesResponse.ProtoReflect.Descriptor instead. -func (*GetRoutesResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{2} -} - -func (x *GetRoutesResponse) GetRoutes() []*Route { - if x != nil { - return x.Routes - } - return nil -} - -type EnableRouteRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - RouteId uint64 `protobuf:"varint,1,opt,name=route_id,json=routeId,proto3" json:"route_id,omitempty"` -} - -func (x *EnableRouteRequest) Reset() { - *x = EnableRouteRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *EnableRouteRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*EnableRouteRequest) ProtoMessage() {} - -func (x *EnableRouteRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use EnableRouteRequest.ProtoReflect.Descriptor instead. -func (*EnableRouteRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{3} -} - -func (x *EnableRouteRequest) GetRouteId() uint64 { - if x != nil { - return x.RouteId - } - return 0 -} - -type EnableRouteResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *EnableRouteResponse) Reset() { - *x = EnableRouteResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *EnableRouteResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*EnableRouteResponse) ProtoMessage() {} - -func (x *EnableRouteResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use EnableRouteResponse.ProtoReflect.Descriptor instead. -func (*EnableRouteResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{4} -} - -type DisableRouteRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - RouteId uint64 `protobuf:"varint,1,opt,name=route_id,json=routeId,proto3" json:"route_id,omitempty"` -} - -func (x *DisableRouteRequest) Reset() { - *x = DisableRouteRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *DisableRouteRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DisableRouteRequest) ProtoMessage() {} - -func (x *DisableRouteRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DisableRouteRequest.ProtoReflect.Descriptor instead. -func (*DisableRouteRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{5} -} - -func (x *DisableRouteRequest) GetRouteId() uint64 { - if x != nil { - return x.RouteId - } - return 0 -} - -type DisableRouteResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *DisableRouteResponse) Reset() { - *x = DisableRouteResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *DisableRouteResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DisableRouteResponse) ProtoMessage() {} - -func (x *DisableRouteResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DisableRouteResponse.ProtoReflect.Descriptor instead. -func (*DisableRouteResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{6} -} - -type GetNodeRoutesRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - NodeId uint64 `protobuf:"varint,1,opt,name=node_id,json=nodeId,proto3" json:"node_id,omitempty"` -} - -func (x *GetNodeRoutesRequest) Reset() { - *x = GetNodeRoutesRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetNodeRoutesRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetNodeRoutesRequest) ProtoMessage() {} - -func (x *GetNodeRoutesRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[7] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetNodeRoutesRequest.ProtoReflect.Descriptor instead. -func (*GetNodeRoutesRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{7} -} - -func (x *GetNodeRoutesRequest) GetNodeId() uint64 { - if x != nil { - return x.NodeId - } - return 0 -} - -type GetNodeRoutesResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Routes []*Route `protobuf:"bytes,1,rep,name=routes,proto3" json:"routes,omitempty"` -} - -func (x *GetNodeRoutesResponse) Reset() { - *x = GetNodeRoutesResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetNodeRoutesResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetNodeRoutesResponse) ProtoMessage() {} - -func (x *GetNodeRoutesResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[8] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetNodeRoutesResponse.ProtoReflect.Descriptor instead. -func (*GetNodeRoutesResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{8} -} - -func (x *GetNodeRoutesResponse) GetRoutes() []*Route { - if x != nil { - return x.Routes - } - return nil -} - -type DeleteRouteRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - RouteId uint64 `protobuf:"varint,1,opt,name=route_id,json=routeId,proto3" json:"route_id,omitempty"` -} - -func (x *DeleteRouteRequest) Reset() { - *x = DeleteRouteRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *DeleteRouteRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DeleteRouteRequest) ProtoMessage() {} - -func (x *DeleteRouteRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[9] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DeleteRouteRequest.ProtoReflect.Descriptor instead. -func (*DeleteRouteRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{9} -} - -func (x *DeleteRouteRequest) GetRouteId() uint64 { - if x != nil { - return x.RouteId - } - return 0 -} - -type DeleteRouteResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields -} - -func (x *DeleteRouteResponse) Reset() { - *x = DeleteRouteResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_routes_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *DeleteRouteResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*DeleteRouteResponse) ProtoMessage() {} - -func (x *DeleteRouteResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_routes_proto_msgTypes[10] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use DeleteRouteResponse.ProtoReflect.Descriptor instead. -func (*DeleteRouteResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_routes_proto_rawDescGZIP(), []int{10} -} - -var File_headscale_v1_routes_proto protoreflect.FileDescriptor - -var file_headscale_v1_routes_proto_rawDesc = []byte{ - 0x0a, 0x19, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x72, - 0x6f, 0x75, 0x74, 0x65, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x68, 0x65, 0x61, - 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x1a, 0x17, 0x68, 0x65, 0x61, 0x64, - 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x6e, 0x6f, 0x64, 0x65, 0x2e, 0x70, 0x72, - 0x6f, 0x74, 0x6f, 0x22, 0xe1, 0x02, 0x0a, 0x05, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x12, 0x0e, 0x0a, - 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x02, 0x69, 0x64, 0x12, 0x26, 0x0a, - 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, - 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4e, 0x6f, 0x64, 0x65, 0x52, - 0x04, 0x6e, 0x6f, 0x64, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x18, - 0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x06, 0x70, 0x72, 0x65, 0x66, 0x69, 0x78, 0x12, 0x1e, 0x0a, - 0x0a, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, 0x64, 0x18, 0x04, 0x20, 0x01, 0x28, - 0x08, 0x52, 0x0a, 0x61, 0x64, 0x76, 0x65, 0x72, 0x74, 0x69, 0x73, 0x65, 0x64, 0x12, 0x18, 0x0a, - 0x07, 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x18, 0x05, 0x20, 0x01, 0x28, 0x08, 0x52, 0x07, - 0x65, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x64, 0x12, 0x1d, 0x0a, 0x0a, 0x69, 0x73, 0x5f, 0x70, 0x72, - 0x69, 0x6d, 0x61, 0x72, 0x79, 0x18, 0x06, 0x20, 0x01, 0x28, 0x08, 0x52, 0x09, 0x69, 0x73, 0x50, - 0x72, 0x69, 0x6d, 0x61, 0x72, 0x79, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x64, 0x5f, 0x61, 0x74, 0x18, 0x07, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, - 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, - 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, - 0x74, 0x12, 0x39, 0x0a, 0x0a, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, - 0x08, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, - 0x70, 0x52, 0x09, 0x75, 0x70, 0x64, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x12, 0x39, 0x0a, 0x0a, - 0x64, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x64, 0x5f, 0x61, 0x74, 0x18, 0x09, 0x20, 0x01, 0x28, 0x0b, - 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, - 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x64, 0x65, - 0x6c, 0x65, 0x74, 0x65, 0x64, 0x41, 0x74, 0x22, 0x12, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x52, 0x6f, - 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x22, 0x40, 0x0a, 0x11, 0x47, - 0x65, 0x74, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, - 0x12, 0x2b, 0x0a, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, - 0x32, 0x13, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x2f, 0x0a, - 0x12, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x5f, 0x69, 0x64, 0x18, - 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x64, 0x22, 0x15, - 0x0a, 0x13, 0x45, 0x6e, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, 0x30, 0x0a, 0x13, 0x44, 0x69, 0x73, 0x61, 0x62, 0x6c, 0x65, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x64, 0x22, 0x16, 0x0a, 0x14, 0x44, 0x69, 0x73, 0x61, 0x62, - 0x6c, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x2f, 0x0a, 0x14, 0x47, 0x65, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x73, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x17, 0x0a, 0x07, 0x6e, 0x6f, 0x64, 0x65, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x06, 0x6e, 0x6f, 0x64, 0x65, 0x49, 0x64, - 0x22, 0x44, 0x0a, 0x15, 0x47, 0x65, 0x74, 0x4e, 0x6f, 0x64, 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, - 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x2b, 0x0a, 0x06, 0x72, 0x6f, 0x75, - 0x74, 0x65, 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x13, 0x2e, 0x68, 0x65, 0x61, 0x64, - 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x06, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x73, 0x22, 0x2f, 0x0a, 0x12, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, - 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x04, 0x52, 0x07, - 0x72, 0x6f, 0x75, 0x74, 0x65, 0x49, 0x64, 0x22, 0x15, 0x0a, 0x13, 0x44, 0x65, 0x6c, 0x65, 0x74, - 0x65, 0x52, 0x6f, 0x75, 0x74, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x29, - 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x6a, 0x75, 0x61, - 0x6e, 0x66, 0x6f, 0x6e, 0x74, 0x2f, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, - 0x67, 0x65, 0x6e, 0x2f, 0x67, 0x6f, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, - 0x33, -} - -var ( - file_headscale_v1_routes_proto_rawDescOnce sync.Once - file_headscale_v1_routes_proto_rawDescData = file_headscale_v1_routes_proto_rawDesc -) - -func file_headscale_v1_routes_proto_rawDescGZIP() []byte { - file_headscale_v1_routes_proto_rawDescOnce.Do(func() { - file_headscale_v1_routes_proto_rawDescData = protoimpl.X.CompressGZIP(file_headscale_v1_routes_proto_rawDescData) - }) - return file_headscale_v1_routes_proto_rawDescData -} - -var file_headscale_v1_routes_proto_msgTypes = make([]protoimpl.MessageInfo, 11) -var file_headscale_v1_routes_proto_goTypes = []interface{}{ - (*Route)(nil), // 0: headscale.v1.Route - (*GetRoutesRequest)(nil), // 1: headscale.v1.GetRoutesRequest - (*GetRoutesResponse)(nil), // 2: headscale.v1.GetRoutesResponse - (*EnableRouteRequest)(nil), // 3: headscale.v1.EnableRouteRequest - (*EnableRouteResponse)(nil), // 4: headscale.v1.EnableRouteResponse - (*DisableRouteRequest)(nil), // 5: headscale.v1.DisableRouteRequest - (*DisableRouteResponse)(nil), // 6: headscale.v1.DisableRouteResponse - (*GetNodeRoutesRequest)(nil), // 7: headscale.v1.GetNodeRoutesRequest - (*GetNodeRoutesResponse)(nil), // 8: headscale.v1.GetNodeRoutesResponse - (*DeleteRouteRequest)(nil), // 9: headscale.v1.DeleteRouteRequest - (*DeleteRouteResponse)(nil), // 10: headscale.v1.DeleteRouteResponse - (*Node)(nil), // 11: headscale.v1.Node - (*timestamppb.Timestamp)(nil), // 12: google.protobuf.Timestamp -} -var file_headscale_v1_routes_proto_depIdxs = []int32{ - 11, // 0: headscale.v1.Route.node:type_name -> headscale.v1.Node - 12, // 1: headscale.v1.Route.created_at:type_name -> google.protobuf.Timestamp - 12, // 2: headscale.v1.Route.updated_at:type_name -> google.protobuf.Timestamp - 12, // 3: headscale.v1.Route.deleted_at:type_name -> google.protobuf.Timestamp - 0, // 4: headscale.v1.GetRoutesResponse.routes:type_name -> headscale.v1.Route - 0, // 5: headscale.v1.GetNodeRoutesResponse.routes:type_name -> headscale.v1.Route - 6, // [6:6] is the sub-list for method output_type - 6, // [6:6] is the sub-list for method input_type - 6, // [6:6] is the sub-list for extension type_name - 6, // [6:6] is the sub-list for extension extendee - 0, // [0:6] is the sub-list for field type_name -} - -func init() { file_headscale_v1_routes_proto_init() } -func file_headscale_v1_routes_proto_init() { - if File_headscale_v1_routes_proto != nil { - return - } - file_headscale_v1_node_proto_init() - if !protoimpl.UnsafeEnabled { - file_headscale_v1_routes_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*Route); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetRoutesRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetRoutesResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EnableRouteRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*EnableRouteResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DisableRouteRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DisableRouteResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetNodeRoutesRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetNodeRoutesResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteRouteRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_routes_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteRouteResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } - type x struct{} - out := protoimpl.TypeBuilder{ - File: protoimpl.DescBuilder{ - GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_routes_proto_rawDesc, - NumEnums: 0, - NumMessages: 11, - NumExtensions: 0, - NumServices: 0, - }, - GoTypes: file_headscale_v1_routes_proto_goTypes, - DependencyIndexes: file_headscale_v1_routes_proto_depIdxs, - MessageInfos: file_headscale_v1_routes_proto_msgTypes, - }.Build() - File_headscale_v1_routes_proto = out.File - file_headscale_v1_routes_proto_rawDesc = nil - file_headscale_v1_routes_proto_goTypes = nil - file_headscale_v1_routes_proto_depIdxs = nil -} diff --git a/gen/go/headscale/v1/user.pb.go b/gen/go/headscale/v1/user.pb.go index 3fbe969d..5f05d084 100644 --- a/gen/go/headscale/v1/user.pb.go +++ b/gen/go/headscale/v1/user.pb.go @@ -1,6 +1,6 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.31.0 +// protoc-gen-go v1.36.11 // protoc (unknown) // source: headscale/v1/user.proto @@ -12,6 +12,7 @@ import ( timestamppb "google.golang.org/protobuf/types/known/timestamppb" reflect "reflect" sync "sync" + unsafe "unsafe" ) const ( @@ -22,22 +23,24 @@ const ( ) type User struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + CreatedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + DisplayName string `protobuf:"bytes,4,opt,name=display_name,json=displayName,proto3" json:"display_name,omitempty"` + Email string `protobuf:"bytes,5,opt,name=email,proto3" json:"email,omitempty"` + ProviderId string `protobuf:"bytes,6,opt,name=provider_id,json=providerId,proto3" json:"provider_id,omitempty"` + Provider string `protobuf:"bytes,7,opt,name=provider,proto3" json:"provider,omitempty"` + ProfilePicUrl string `protobuf:"bytes,8,opt,name=profile_pic_url,json=profilePicUrl,proto3" json:"profile_pic_url,omitempty"` unknownFields protoimpl.UnknownFields - - Id string `protobuf:"bytes,1,opt,name=id,proto3" json:"id,omitempty"` - Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + sizeCache protoimpl.SizeCache } func (x *User) Reset() { *x = User{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[0] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *User) String() string { @@ -48,7 +51,7 @@ func (*User) ProtoMessage() {} func (x *User) ProtoReflect() protoreflect.Message { mi := &file_headscale_v1_user_proto_msgTypes[0] - if protoimpl.UnsafeEnabled && x != nil { + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -63,11 +66,11 @@ func (*User) Descriptor() ([]byte, []int) { return file_headscale_v1_user_proto_rawDescGZIP(), []int{0} } -func (x *User) GetId() string { +func (x *User) GetId() uint64 { if x != nil { return x.Id } - return "" + return 0 } func (x *User) GetName() string { @@ -84,115 +87,56 @@ func (x *User) GetCreatedAt() *timestamppb.Timestamp { return nil } -type GetUserRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` -} - -func (x *GetUserRequest) Reset() { - *x = GetUserRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[1] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetUserRequest) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetUserRequest) ProtoMessage() {} - -func (x *GetUserRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[1] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetUserRequest.ProtoReflect.Descriptor instead. -func (*GetUserRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{1} -} - -func (x *GetUserRequest) GetName() string { +func (x *User) GetDisplayName() string { if x != nil { - return x.Name + return x.DisplayName } return "" } -type GetUserResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields - - User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` -} - -func (x *GetUserResponse) Reset() { - *x = GetUserResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[2] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } -} - -func (x *GetUserResponse) String() string { - return protoimpl.X.MessageStringOf(x) -} - -func (*GetUserResponse) ProtoMessage() {} - -func (x *GetUserResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[2] - if protoimpl.UnsafeEnabled && x != nil { - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - if ms.LoadMessageInfo() == nil { - ms.StoreMessageInfo(mi) - } - return ms - } - return mi.MessageOf(x) -} - -// Deprecated: Use GetUserResponse.ProtoReflect.Descriptor instead. -func (*GetUserResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{2} -} - -func (x *GetUserResponse) GetUser() *User { +func (x *User) GetEmail() string { if x != nil { - return x.User + return x.Email } - return nil + return "" +} + +func (x *User) GetProviderId() string { + if x != nil { + return x.ProviderId + } + return "" +} + +func (x *User) GetProvider() string { + if x != nil { + return x.Provider + } + return "" +} + +func (x *User) GetProfilePicUrl() string { + if x != nil { + return x.ProfilePicUrl + } + return "" } type CreateUserRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + DisplayName string `protobuf:"bytes,2,opt,name=display_name,json=displayName,proto3" json:"display_name,omitempty"` + Email string `protobuf:"bytes,3,opt,name=email,proto3" json:"email,omitempty"` + PictureUrl string `protobuf:"bytes,4,opt,name=picture_url,json=pictureUrl,proto3" json:"picture_url,omitempty"` unknownFields protoimpl.UnknownFields - - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + sizeCache protoimpl.SizeCache } func (x *CreateUserRequest) Reset() { *x = CreateUserRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[3] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreateUserRequest) String() string { @@ -202,8 +146,8 @@ func (x *CreateUserRequest) String() string { func (*CreateUserRequest) ProtoMessage() {} func (x *CreateUserRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[3] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[1] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -215,7 +159,7 @@ func (x *CreateUserRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateUserRequest.ProtoReflect.Descriptor instead. func (*CreateUserRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{3} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{1} } func (x *CreateUserRequest) GetName() string { @@ -225,21 +169,39 @@ func (x *CreateUserRequest) GetName() string { return "" } -type CreateUserResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache - unknownFields protoimpl.UnknownFields +func (x *CreateUserRequest) GetDisplayName() string { + if x != nil { + return x.DisplayName + } + return "" +} - User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` +func (x *CreateUserRequest) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + +func (x *CreateUserRequest) GetPictureUrl() string { + if x != nil { + return x.PictureUrl + } + return "" +} + +type CreateUserResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *CreateUserResponse) Reset() { *x = CreateUserResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[4] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *CreateUserResponse) String() string { @@ -249,8 +211,8 @@ func (x *CreateUserResponse) String() string { func (*CreateUserResponse) ProtoMessage() {} func (x *CreateUserResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[4] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[2] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -262,7 +224,7 @@ func (x *CreateUserResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use CreateUserResponse.ProtoReflect.Descriptor instead. func (*CreateUserResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{4} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{2} } func (x *CreateUserResponse) GetUser() *User { @@ -273,21 +235,18 @@ func (x *CreateUserResponse) GetUser() *User { } type RenameUserRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + OldId uint64 `protobuf:"varint,1,opt,name=old_id,json=oldId,proto3" json:"old_id,omitempty"` + NewName string `protobuf:"bytes,2,opt,name=new_name,json=newName,proto3" json:"new_name,omitempty"` unknownFields protoimpl.UnknownFields - - OldName string `protobuf:"bytes,1,opt,name=old_name,json=oldName,proto3" json:"old_name,omitempty"` - NewName string `protobuf:"bytes,2,opt,name=new_name,json=newName,proto3" json:"new_name,omitempty"` + sizeCache protoimpl.SizeCache } func (x *RenameUserRequest) Reset() { *x = RenameUserRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[5] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *RenameUserRequest) String() string { @@ -297,8 +256,8 @@ func (x *RenameUserRequest) String() string { func (*RenameUserRequest) ProtoMessage() {} func (x *RenameUserRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[5] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[3] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -310,14 +269,14 @@ func (x *RenameUserRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use RenameUserRequest.ProtoReflect.Descriptor instead. func (*RenameUserRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{5} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{3} } -func (x *RenameUserRequest) GetOldName() string { +func (x *RenameUserRequest) GetOldId() uint64 { if x != nil { - return x.OldName + return x.OldId } - return "" + return 0 } func (x *RenameUserRequest) GetNewName() string { @@ -328,20 +287,17 @@ func (x *RenameUserRequest) GetNewName() string { } type RenameUserResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` unknownFields protoimpl.UnknownFields - - User *User `protobuf:"bytes,1,opt,name=user,proto3" json:"user,omitempty"` + sizeCache protoimpl.SizeCache } func (x *RenameUserResponse) Reset() { *x = RenameUserResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[6] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *RenameUserResponse) String() string { @@ -351,8 +307,8 @@ func (x *RenameUserResponse) String() string { func (*RenameUserResponse) ProtoMessage() {} func (x *RenameUserResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[6] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[4] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -364,7 +320,7 @@ func (x *RenameUserResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use RenameUserResponse.ProtoReflect.Descriptor instead. func (*RenameUserResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{6} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{4} } func (x *RenameUserResponse) GetUser() *User { @@ -375,20 +331,17 @@ func (x *RenameUserResponse) GetUser() *User { } type DeleteUserRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` unknownFields protoimpl.UnknownFields - - Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + sizeCache protoimpl.SizeCache } func (x *DeleteUserRequest) Reset() { *x = DeleteUserRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[7] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DeleteUserRequest) String() string { @@ -398,8 +351,8 @@ func (x *DeleteUserRequest) String() string { func (*DeleteUserRequest) ProtoMessage() {} func (x *DeleteUserRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[7] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[5] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -411,29 +364,27 @@ func (x *DeleteUserRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteUserRequest.ProtoReflect.Descriptor instead. func (*DeleteUserRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{7} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{5} } -func (x *DeleteUserRequest) GetName() string { +func (x *DeleteUserRequest) GetId() uint64 { if x != nil { - return x.Name + return x.Id } - return "" + return 0 } type DeleteUserResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *DeleteUserResponse) Reset() { *x = DeleteUserResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[8] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *DeleteUserResponse) String() string { @@ -443,8 +394,8 @@ func (x *DeleteUserResponse) String() string { func (*DeleteUserResponse) ProtoMessage() {} func (x *DeleteUserResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[8] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[6] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -456,22 +407,23 @@ func (x *DeleteUserResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use DeleteUserResponse.ProtoReflect.Descriptor instead. func (*DeleteUserResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{8} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{6} } type ListUsersRequest struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Id uint64 `protobuf:"varint,1,opt,name=id,proto3" json:"id,omitempty"` + Name string `protobuf:"bytes,2,opt,name=name,proto3" json:"name,omitempty"` + Email string `protobuf:"bytes,3,opt,name=email,proto3" json:"email,omitempty"` unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *ListUsersRequest) Reset() { *x = ListUsersRequest{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[9] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListUsersRequest) String() string { @@ -481,8 +433,8 @@ func (x *ListUsersRequest) String() string { func (*ListUsersRequest) ProtoMessage() {} func (x *ListUsersRequest) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[9] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[7] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -494,24 +446,42 @@ func (x *ListUsersRequest) ProtoReflect() protoreflect.Message { // Deprecated: Use ListUsersRequest.ProtoReflect.Descriptor instead. func (*ListUsersRequest) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{9} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{7} +} + +func (x *ListUsersRequest) GetId() uint64 { + if x != nil { + return x.Id + } + return 0 +} + +func (x *ListUsersRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *ListUsersRequest) GetEmail() string { + if x != nil { + return x.Email + } + return "" } type ListUsersResponse struct { - state protoimpl.MessageState - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + Users []*User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` unknownFields protoimpl.UnknownFields - - Users []*User `protobuf:"bytes,1,rep,name=users,proto3" json:"users,omitempty"` + sizeCache protoimpl.SizeCache } func (x *ListUsersResponse) Reset() { *x = ListUsersResponse{} - if protoimpl.UnsafeEnabled { - mi := &file_headscale_v1_user_proto_msgTypes[10] - ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) - ms.StoreMessageInfo(mi) - } + mi := &file_headscale_v1_user_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) } func (x *ListUsersResponse) String() string { @@ -521,8 +491,8 @@ func (x *ListUsersResponse) String() string { func (*ListUsersResponse) ProtoMessage() {} func (x *ListUsersResponse) ProtoReflect() protoreflect.Message { - mi := &file_headscale_v1_user_proto_msgTypes[10] - if protoimpl.UnsafeEnabled && x != nil { + mi := &file_headscale_v1_user_proto_msgTypes[8] + if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { ms.StoreMessageInfo(mi) @@ -534,7 +504,7 @@ func (x *ListUsersResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use ListUsersResponse.ProtoReflect.Descriptor instead. func (*ListUsersResponse) Descriptor() ([]byte, []int) { - return file_headscale_v1_user_proto_rawDescGZIP(), []int{10} + return file_headscale_v1_user_proto_rawDescGZIP(), []int{8} } func (x *ListUsersResponse) GetUsers() []*User { @@ -546,92 +516,78 @@ func (x *ListUsersResponse) GetUsers() []*User { var File_headscale_v1_user_proto protoreflect.FileDescriptor -var file_headscale_v1_user_proto_rawDesc = []byte{ - 0x0a, 0x17, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x2f, 0x75, - 0x73, 0x65, 0x72, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x0c, 0x68, 0x65, 0x61, 0x64, 0x73, - 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x1a, 0x1f, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x2f, - 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2f, 0x74, 0x69, 0x6d, 0x65, 0x73, 0x74, 0x61, - 0x6d, 0x70, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x65, 0x0a, 0x04, 0x55, 0x73, 0x65, 0x72, - 0x12, 0x0e, 0x0a, 0x02, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x02, 0x69, 0x64, - 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, - 0x6e, 0x61, 0x6d, 0x65, 0x12, 0x39, 0x0a, 0x0a, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x5f, - 0x61, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x67, 0x6f, 0x6f, 0x67, 0x6c, - 0x65, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x62, 0x75, 0x66, 0x2e, 0x54, 0x69, 0x6d, 0x65, 0x73, - 0x74, 0x61, 0x6d, 0x70, 0x52, 0x09, 0x63, 0x72, 0x65, 0x61, 0x74, 0x65, 0x64, 0x41, 0x74, 0x22, - 0x24, 0x0a, 0x0e, 0x47, 0x65, 0x74, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, - 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, - 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x39, 0x0a, 0x0f, 0x47, 0x65, 0x74, 0x55, 0x73, 0x65, 0x72, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, - 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, - 0x22, 0x27, 0x0a, 0x11, 0x43, 0x72, 0x65, 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x3c, 0x0a, 0x12, 0x43, 0x72, 0x65, - 0x61, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x26, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, - 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x73, 0x65, - 0x72, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, 0x22, 0x49, 0x0a, 0x11, 0x52, 0x65, 0x6e, 0x61, 0x6d, - 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, - 0x6f, 0x6c, 0x64, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, - 0x6f, 0x6c, 0x64, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x19, 0x0a, 0x08, 0x6e, 0x65, 0x77, 0x5f, 0x6e, - 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x6e, 0x65, 0x77, 0x4e, 0x61, - 0x6d, 0x65, 0x22, 0x3c, 0x0a, 0x12, 0x52, 0x65, 0x6e, 0x61, 0x6d, 0x65, 0x55, 0x73, 0x65, 0x72, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x26, 0x0a, 0x04, 0x75, 0x73, 0x65, 0x72, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, 0x61, - 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x52, 0x04, 0x75, 0x73, 0x65, 0x72, - 0x22, 0x27, 0x0a, 0x11, 0x44, 0x65, 0x6c, 0x65, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, - 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, - 0x01, 0x28, 0x09, 0x52, 0x04, 0x6e, 0x61, 0x6d, 0x65, 0x22, 0x14, 0x0a, 0x12, 0x44, 0x65, 0x6c, - 0x65, 0x74, 0x65, 0x55, 0x73, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x22, - 0x12, 0x0a, 0x10, 0x4c, 0x69, 0x73, 0x74, 0x55, 0x73, 0x65, 0x72, 0x73, 0x52, 0x65, 0x71, 0x75, - 0x65, 0x73, 0x74, 0x22, 0x3d, 0x0a, 0x11, 0x4c, 0x69, 0x73, 0x74, 0x55, 0x73, 0x65, 0x72, 0x73, - 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x28, 0x0a, 0x05, 0x75, 0x73, 0x65, 0x72, - 0x73, 0x18, 0x01, 0x20, 0x03, 0x28, 0x0b, 0x32, 0x12, 0x2e, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x55, 0x73, 0x65, 0x72, 0x52, 0x05, 0x75, 0x73, 0x65, - 0x72, 0x73, 0x42, 0x29, 0x5a, 0x27, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, - 0x2f, 0x6a, 0x75, 0x61, 0x6e, 0x66, 0x6f, 0x6e, 0x74, 0x2f, 0x68, 0x65, 0x61, 0x64, 0x73, 0x63, - 0x61, 0x6c, 0x65, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x67, 0x6f, 0x2f, 0x76, 0x31, 0x62, 0x06, 0x70, - 0x72, 0x6f, 0x74, 0x6f, 0x33, -} +const file_headscale_v1_user_proto_rawDesc = "" + + "\n" + + "\x17headscale/v1/user.proto\x12\fheadscale.v1\x1a\x1fgoogle/protobuf/timestamp.proto\"\x83\x02\n" + + "\x04User\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x129\n" + + "\n" + + "created_at\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\tcreatedAt\x12!\n" + + "\fdisplay_name\x18\x04 \x01(\tR\vdisplayName\x12\x14\n" + + "\x05email\x18\x05 \x01(\tR\x05email\x12\x1f\n" + + "\vprovider_id\x18\x06 \x01(\tR\n" + + "providerId\x12\x1a\n" + + "\bprovider\x18\a \x01(\tR\bprovider\x12&\n" + + "\x0fprofile_pic_url\x18\b \x01(\tR\rprofilePicUrl\"\x81\x01\n" + + "\x11CreateUserRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12!\n" + + "\fdisplay_name\x18\x02 \x01(\tR\vdisplayName\x12\x14\n" + + "\x05email\x18\x03 \x01(\tR\x05email\x12\x1f\n" + + "\vpicture_url\x18\x04 \x01(\tR\n" + + "pictureUrl\"<\n" + + "\x12CreateUserResponse\x12&\n" + + "\x04user\x18\x01 \x01(\v2\x12.headscale.v1.UserR\x04user\"E\n" + + "\x11RenameUserRequest\x12\x15\n" + + "\x06old_id\x18\x01 \x01(\x04R\x05oldId\x12\x19\n" + + "\bnew_name\x18\x02 \x01(\tR\anewName\"<\n" + + "\x12RenameUserResponse\x12&\n" + + "\x04user\x18\x01 \x01(\v2\x12.headscale.v1.UserR\x04user\"#\n" + + "\x11DeleteUserRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\"\x14\n" + + "\x12DeleteUserResponse\"L\n" + + "\x10ListUsersRequest\x12\x0e\n" + + "\x02id\x18\x01 \x01(\x04R\x02id\x12\x12\n" + + "\x04name\x18\x02 \x01(\tR\x04name\x12\x14\n" + + "\x05email\x18\x03 \x01(\tR\x05email\"=\n" + + "\x11ListUsersResponse\x12(\n" + + "\x05users\x18\x01 \x03(\v2\x12.headscale.v1.UserR\x05usersB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( file_headscale_v1_user_proto_rawDescOnce sync.Once - file_headscale_v1_user_proto_rawDescData = file_headscale_v1_user_proto_rawDesc + file_headscale_v1_user_proto_rawDescData []byte ) func file_headscale_v1_user_proto_rawDescGZIP() []byte { file_headscale_v1_user_proto_rawDescOnce.Do(func() { - file_headscale_v1_user_proto_rawDescData = protoimpl.X.CompressGZIP(file_headscale_v1_user_proto_rawDescData) + file_headscale_v1_user_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_headscale_v1_user_proto_rawDesc), len(file_headscale_v1_user_proto_rawDesc))) }) return file_headscale_v1_user_proto_rawDescData } -var file_headscale_v1_user_proto_msgTypes = make([]protoimpl.MessageInfo, 11) -var file_headscale_v1_user_proto_goTypes = []interface{}{ +var file_headscale_v1_user_proto_msgTypes = make([]protoimpl.MessageInfo, 9) +var file_headscale_v1_user_proto_goTypes = []any{ (*User)(nil), // 0: headscale.v1.User - (*GetUserRequest)(nil), // 1: headscale.v1.GetUserRequest - (*GetUserResponse)(nil), // 2: headscale.v1.GetUserResponse - (*CreateUserRequest)(nil), // 3: headscale.v1.CreateUserRequest - (*CreateUserResponse)(nil), // 4: headscale.v1.CreateUserResponse - (*RenameUserRequest)(nil), // 5: headscale.v1.RenameUserRequest - (*RenameUserResponse)(nil), // 6: headscale.v1.RenameUserResponse - (*DeleteUserRequest)(nil), // 7: headscale.v1.DeleteUserRequest - (*DeleteUserResponse)(nil), // 8: headscale.v1.DeleteUserResponse - (*ListUsersRequest)(nil), // 9: headscale.v1.ListUsersRequest - (*ListUsersResponse)(nil), // 10: headscale.v1.ListUsersResponse - (*timestamppb.Timestamp)(nil), // 11: google.protobuf.Timestamp + (*CreateUserRequest)(nil), // 1: headscale.v1.CreateUserRequest + (*CreateUserResponse)(nil), // 2: headscale.v1.CreateUserResponse + (*RenameUserRequest)(nil), // 3: headscale.v1.RenameUserRequest + (*RenameUserResponse)(nil), // 4: headscale.v1.RenameUserResponse + (*DeleteUserRequest)(nil), // 5: headscale.v1.DeleteUserRequest + (*DeleteUserResponse)(nil), // 6: headscale.v1.DeleteUserResponse + (*ListUsersRequest)(nil), // 7: headscale.v1.ListUsersRequest + (*ListUsersResponse)(nil), // 8: headscale.v1.ListUsersResponse + (*timestamppb.Timestamp)(nil), // 9: google.protobuf.Timestamp } var file_headscale_v1_user_proto_depIdxs = []int32{ - 11, // 0: headscale.v1.User.created_at:type_name -> google.protobuf.Timestamp - 0, // 1: headscale.v1.GetUserResponse.user:type_name -> headscale.v1.User - 0, // 2: headscale.v1.CreateUserResponse.user:type_name -> headscale.v1.User - 0, // 3: headscale.v1.RenameUserResponse.user:type_name -> headscale.v1.User - 0, // 4: headscale.v1.ListUsersResponse.users:type_name -> headscale.v1.User - 5, // [5:5] is the sub-list for method output_type - 5, // [5:5] is the sub-list for method input_type - 5, // [5:5] is the sub-list for extension type_name - 5, // [5:5] is the sub-list for extension extendee - 0, // [0:5] is the sub-list for field type_name + 9, // 0: headscale.v1.User.created_at:type_name -> google.protobuf.Timestamp + 0, // 1: headscale.v1.CreateUserResponse.user:type_name -> headscale.v1.User + 0, // 2: headscale.v1.RenameUserResponse.user:type_name -> headscale.v1.User + 0, // 3: headscale.v1.ListUsersResponse.users:type_name -> headscale.v1.User + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_headscale_v1_user_proto_init() } @@ -639,147 +595,13 @@ func file_headscale_v1_user_proto_init() { if File_headscale_v1_user_proto != nil { return } - if !protoimpl.UnsafeEnabled { - file_headscale_v1_user_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*User); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetUserRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[2].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*GetUserResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[3].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateUserRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[4].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*CreateUserResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[5].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RenameUserRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[6].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*RenameUserResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[7].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteUserRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[8].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*DeleteUserResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[9].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListUsersRequest); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - file_headscale_v1_user_proto_msgTypes[10].Exporter = func(v interface{}, i int) interface{} { - switch v := v.(*ListUsersResponse); i { - case 0: - return &v.state - case 1: - return &v.sizeCache - case 2: - return &v.unknownFields - default: - return nil - } - } - } type x struct{} out := protoimpl.TypeBuilder{ File: protoimpl.DescBuilder{ GoPackagePath: reflect.TypeOf(x{}).PkgPath(), - RawDescriptor: file_headscale_v1_user_proto_rawDesc, + RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_user_proto_rawDesc), len(file_headscale_v1_user_proto_rawDesc)), NumEnums: 0, - NumMessages: 11, + NumMessages: 9, NumExtensions: 0, NumServices: 0, }, @@ -788,7 +610,6 @@ func file_headscale_v1_user_proto_init() { MessageInfos: file_headscale_v1_user_proto_msgTypes, }.Build() File_headscale_v1_user_proto = out.File - file_headscale_v1_user_proto_rawDesc = nil file_headscale_v1_user_proto_goTypes = nil file_headscale_v1_user_proto_depIdxs = nil } diff --git a/gen/openapiv2/headscale/v1/apikey.swagger.json b/gen/openapiv2/headscale/v1/apikey.swagger.json index 0d4ebbe9..8c8596a9 100644 --- a/gen/openapiv2/headscale/v1/apikey.swagger.json +++ b/gen/openapiv2/headscale/v1/apikey.swagger.json @@ -34,6 +34,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } diff --git a/gen/openapiv2/headscale/v1/device.swagger.json b/gen/openapiv2/headscale/v1/device.swagger.json index 5360527a..99d20deb 100644 --- a/gen/openapiv2/headscale/v1/device.swagger.json +++ b/gen/openapiv2/headscale/v1/device.swagger.json @@ -34,6 +34,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } diff --git a/gen/openapiv2/headscale/v1/headscale.swagger.json b/gen/openapiv2/headscale/v1/headscale.swagger.json index bf48b143..1db1db94 100644 --- a/gen/openapiv2/headscale/v1/headscale.swagger.json +++ b/gen/openapiv2/headscale/v1/headscale.swagger.json @@ -101,6 +101,43 @@ ] } }, + "/api/v1/apikey/{prefix}": { + "delete": { + "operationId": "HeadscaleService_DeleteApiKey", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1DeleteApiKeyResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "prefix", + "in": "path", + "required": true, + "type": "string" + }, + { + "name": "id", + "in": "query", + "required": false, + "type": "string", + "format": "uint64" + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/debug/node": { "post": { "summary": "--- Node start ---", @@ -134,6 +171,29 @@ ] } }, + "/api/v1/health": { + "get": { + "summary": "--- Health start ---", + "operationId": "HeadscaleService_Health", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1HealthResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/node": { "get": { "operationId": "HeadscaleService_ListNodes", @@ -164,6 +224,36 @@ ] } }, + "/api/v1/node/backfillips": { + "post": { + "operationId": "HeadscaleService_BackfillNodeIPs", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1BackfillNodeIPsResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "confirmed", + "in": "query", + "required": false, + "type": "boolean" + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/node/register": { "post": { "operationId": "HeadscaleService_RegisterNode", @@ -260,6 +350,45 @@ ] } }, + "/api/v1/node/{nodeId}/approve_routes": { + "post": { + "operationId": "HeadscaleService_SetApprovedRoutes", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1SetApprovedRoutesResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "nodeId", + "in": "path", + "required": true, + "type": "string", + "format": "uint64" + }, + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/HeadscaleServiceSetApprovedRoutesBody" + } + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/node/{nodeId}/expire": { "post": { "operationId": "HeadscaleService_ExpireNode", @@ -284,6 +413,13 @@ "required": true, "type": "string", "format": "uint64" + }, + { + "name": "expiry", + "in": "query", + "required": false, + "type": "string", + "format": "date-time" } ], "tags": [ @@ -328,37 +464,6 @@ ] } }, - "/api/v1/node/{nodeId}/routes": { - "get": { - "operationId": "HeadscaleService_GetNodeRoutes", - "responses": { - "200": { - "description": "A successful response.", - "schema": { - "$ref": "#/definitions/v1GetNodeRoutesResponse" - } - }, - "default": { - "description": "An unexpected error response.", - "schema": { - "$ref": "#/definitions/rpcStatus" - } - } - }, - "parameters": [ - { - "name": "nodeId", - "in": "path", - "required": true, - "type": "string", - "format": "uint64" - } - ], - "tags": [ - "HeadscaleService" - ] - } - }, "/api/v1/node/{nodeId}/tags": { "post": { "operationId": "HeadscaleService_SetTags", @@ -389,15 +494,7 @@ "in": "body", "required": true, "schema": { - "type": "object", - "properties": { - "tags": { - "type": "array", - "items": { - "type": "string" - } - } - } + "$ref": "#/definitions/HeadscaleServiceSetTagsBody" } } ], @@ -406,14 +503,35 @@ ] } }, - "/api/v1/node/{nodeId}/user": { - "post": { - "operationId": "HeadscaleService_MoveNode", + "/api/v1/policy": { + "get": { + "summary": "--- Policy start ---", + "operationId": "HeadscaleService_GetPolicy", "responses": { "200": { "description": "A successful response.", "schema": { - "$ref": "#/definitions/v1MoveNodeResponse" + "$ref": "#/definitions/v1GetPolicyResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "tags": [ + "HeadscaleService" + ] + }, + "put": { + "operationId": "HeadscaleService_SetPolicy", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1SetPolicyResponse" } }, "default": { @@ -425,17 +543,12 @@ }, "parameters": [ { - "name": "nodeId", - "in": "path", + "name": "body", + "in": "body", "required": true, - "type": "string", - "format": "uint64" - }, - { - "name": "user", - "in": "query", - "required": false, - "type": "string" + "schema": { + "$ref": "#/definitions/v1SetPolicyRequest" + } } ], "tags": [ @@ -460,12 +573,33 @@ } } }, + "tags": [ + "HeadscaleService" + ] + }, + "delete": { + "operationId": "HeadscaleService_DeletePreAuthKey", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1DeletePreAuthKeyResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, "parameters": [ { - "name": "user", + "name": "id", "in": "query", "required": false, - "type": "string" + "type": "string", + "format": "uint64" } ], "tags": [ @@ -536,122 +670,6 @@ ] } }, - "/api/v1/routes": { - "get": { - "summary": "--- Route start ---", - "operationId": "HeadscaleService_GetRoutes", - "responses": { - "200": { - "description": "A successful response.", - "schema": { - "$ref": "#/definitions/v1GetRoutesResponse" - } - }, - "default": { - "description": "An unexpected error response.", - "schema": { - "$ref": "#/definitions/rpcStatus" - } - } - }, - "tags": [ - "HeadscaleService" - ] - } - }, - "/api/v1/routes/{routeId}": { - "delete": { - "operationId": "HeadscaleService_DeleteRoute", - "responses": { - "200": { - "description": "A successful response.", - "schema": { - "$ref": "#/definitions/v1DeleteRouteResponse" - } - }, - "default": { - "description": "An unexpected error response.", - "schema": { - "$ref": "#/definitions/rpcStatus" - } - } - }, - "parameters": [ - { - "name": "routeId", - "in": "path", - "required": true, - "type": "string", - "format": "uint64" - } - ], - "tags": [ - "HeadscaleService" - ] - } - }, - "/api/v1/routes/{routeId}/disable": { - "post": { - "operationId": "HeadscaleService_DisableRoute", - "responses": { - "200": { - "description": "A successful response.", - "schema": { - "$ref": "#/definitions/v1DisableRouteResponse" - } - }, - "default": { - "description": "An unexpected error response.", - "schema": { - "$ref": "#/definitions/rpcStatus" - } - } - }, - "parameters": [ - { - "name": "routeId", - "in": "path", - "required": true, - "type": "string", - "format": "uint64" - } - ], - "tags": [ - "HeadscaleService" - ] - } - }, - "/api/v1/routes/{routeId}/enable": { - "post": { - "operationId": "HeadscaleService_EnableRoute", - "responses": { - "200": { - "description": "A successful response.", - "schema": { - "$ref": "#/definitions/v1EnableRouteResponse" - } - }, - "default": { - "description": "An unexpected error response.", - "schema": { - "$ref": "#/definitions/rpcStatus" - } - } - }, - "parameters": [ - { - "name": "routeId", - "in": "path", - "required": true, - "type": "string", - "format": "uint64" - } - ], - "tags": [ - "HeadscaleService" - ] - } - }, "/api/v1/user": { "get": { "operationId": "HeadscaleService_ListUsers", @@ -669,11 +687,33 @@ } } }, + "parameters": [ + { + "name": "id", + "in": "query", + "required": false, + "type": "string", + "format": "uint64" + }, + { + "name": "name", + "in": "query", + "required": false, + "type": "string" + }, + { + "name": "email", + "in": "query", + "required": false, + "type": "string" + } + ], "tags": [ "HeadscaleService" ] }, "post": { + "summary": "--- User start ---", "operationId": "HeadscaleService_CreateUser", "responses": { "200": { @@ -704,36 +744,7 @@ ] } }, - "/api/v1/user/{name}": { - "get": { - "summary": "--- User start ---", - "operationId": "HeadscaleService_GetUser", - "responses": { - "200": { - "description": "A successful response.", - "schema": { - "$ref": "#/definitions/v1GetUserResponse" - } - }, - "default": { - "description": "An unexpected error response.", - "schema": { - "$ref": "#/definitions/rpcStatus" - } - } - }, - "parameters": [ - { - "name": "name", - "in": "path", - "required": true, - "type": "string" - } - ], - "tags": [ - "HeadscaleService" - ] - }, + "/api/v1/user/{id}": { "delete": { "operationId": "HeadscaleService_DeleteUser", "responses": { @@ -752,10 +763,11 @@ }, "parameters": [ { - "name": "name", + "name": "id", "in": "path", "required": true, - "type": "string" + "type": "string", + "format": "uint64" } ], "tags": [ @@ -763,7 +775,7 @@ ] } }, - "/api/v1/user/{oldName}/rename/{newName}": { + "/api/v1/user/{oldId}/rename/{newName}": { "post": { "operationId": "HeadscaleService_RenameUser", "responses": { @@ -782,10 +794,11 @@ }, "parameters": [ { - "name": "oldName", + "name": "oldId", "in": "path", "required": true, - "type": "string" + "type": "string", + "format": "uint64" }, { "name": "newName", @@ -801,6 +814,28 @@ } }, "definitions": { + "HeadscaleServiceSetApprovedRoutesBody": { + "type": "object", + "properties": { + "routes": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, + "HeadscaleServiceSetTagsBody": { + "type": "object", + "properties": { + "tags": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, "protobufAny": { "type": "object", "properties": { @@ -823,6 +858,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } @@ -852,6 +888,17 @@ } } }, + "v1BackfillNodeIPsResponse": { + "type": "object", + "properties": { + "changes": { + "type": "array", + "items": { + "type": "string" + } + } + } + }, "v1CreateApiKeyRequest": { "type": "object", "properties": { @@ -873,7 +920,8 @@ "type": "object", "properties": { "user": { - "type": "string" + "type": "string", + "format": "uint64" }, "reusable": { "type": "boolean" @@ -906,6 +954,15 @@ "properties": { "name": { "type": "string" + }, + "displayName": { + "type": "string" + }, + "email": { + "type": "string" + }, + "pictureUrl": { + "type": "string" } } }, @@ -945,26 +1002,27 @@ } } }, + "v1DeleteApiKeyResponse": { + "type": "object" + }, "v1DeleteNodeResponse": { "type": "object" }, - "v1DeleteRouteResponse": { + "v1DeletePreAuthKeyResponse": { "type": "object" }, "v1DeleteUserResponse": { "type": "object" }, - "v1DisableRouteResponse": { - "type": "object" - }, - "v1EnableRouteResponse": { - "type": "object" - }, "v1ExpireApiKeyRequest": { "type": "object", "properties": { "prefix": { "type": "string" + }, + "id": { + "type": "string", + "format": "uint64" } } }, @@ -982,11 +1040,9 @@ "v1ExpirePreAuthKeyRequest": { "type": "object", "properties": { - "user": { - "type": "string" - }, - "key": { - "type": "string" + "id": { + "type": "string", + "format": "uint64" } } }, @@ -1001,33 +1057,23 @@ } } }, - "v1GetNodeRoutesResponse": { + "v1GetPolicyResponse": { "type": "object", "properties": { - "routes": { - "type": "array", - "items": { - "$ref": "#/definitions/v1Route" - } + "policy": { + "type": "string" + }, + "updatedAt": { + "type": "string", + "format": "date-time" } } }, - "v1GetRoutesResponse": { + "v1HealthResponse": { "type": "object", "properties": { - "routes": { - "type": "array", - "items": { - "$ref": "#/definitions/v1Route" - } - } - } - }, - "v1GetUserResponse": { - "type": "object", - "properties": { - "user": { - "$ref": "#/definitions/v1User" + "databaseConnectivity": { + "type": "boolean" } } }, @@ -1037,6 +1083,7 @@ "apiKeys": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/v1ApiKey" } } @@ -1048,6 +1095,7 @@ "nodes": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/v1Node" } } @@ -1059,6 +1107,7 @@ "preAuthKeys": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/v1PreAuthKey" } } @@ -1070,19 +1119,12 @@ "users": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/v1User" } } } }, - "v1MoveNodeResponse": { - "type": "object", - "properties": { - "node": { - "$ref": "#/definitions/v1Node" - } - } - }, "v1Node": { "type": "object", "properties": { @@ -1115,10 +1157,6 @@ "type": "string", "format": "date-time" }, - "lastSuccessfulUpdate": { - "type": "string", - "format": "date-time" - }, "expiry": { "type": "string", "format": "date-time" @@ -1133,29 +1171,36 @@ "registerMethod": { "$ref": "#/definitions/v1RegisterMethod" }, - "forcedTags": { - "type": "array", - "items": { - "type": "string" - } - }, - "invalidTags": { - "type": "array", - "items": { - "type": "string" - } - }, - "validTags": { - "type": "array", - "items": { - "type": "string" - } - }, "givenName": { - "type": "string" + "type": "string", + "title": "Deprecated\nrepeated string forced_tags = 18;\nrepeated string invalid_tags = 19;\nrepeated string valid_tags = 20;" }, "online": { "type": "boolean" + }, + "approvedRoutes": { + "type": "array", + "items": { + "type": "string" + } + }, + "availableRoutes": { + "type": "array", + "items": { + "type": "string" + } + }, + "subnetRoutes": { + "type": "array", + "items": { + "type": "string" + } + }, + "tags": { + "type": "array", + "items": { + "type": "string" + } } } }, @@ -1163,10 +1208,11 @@ "type": "object", "properties": { "user": { - "type": "string" + "$ref": "#/definitions/v1User" }, "id": { - "type": "string" + "type": "string", + "format": "uint64" }, "key": { "type": "string" @@ -1230,39 +1276,31 @@ } } }, - "v1Route": { + "v1SetApprovedRoutesResponse": { "type": "object", "properties": { - "id": { - "type": "string", - "format": "uint64" - }, "node": { "$ref": "#/definitions/v1Node" - }, - "prefix": { + } + } + }, + "v1SetPolicyRequest": { + "type": "object", + "properties": { + "policy": { + "type": "string" + } + } + }, + "v1SetPolicyResponse": { + "type": "object", + "properties": { + "policy": { "type": "string" }, - "advertised": { - "type": "boolean" - }, - "enabled": { - "type": "boolean" - }, - "isPrimary": { - "type": "boolean" - }, - "createdAt": { - "type": "string", - "format": "date-time" - }, "updatedAt": { "type": "string", "format": "date-time" - }, - "deletedAt": { - "type": "string", - "format": "date-time" } } }, @@ -1278,7 +1316,8 @@ "type": "object", "properties": { "id": { - "type": "string" + "type": "string", + "format": "uint64" }, "name": { "type": "string" @@ -1286,6 +1325,21 @@ "createdAt": { "type": "string", "format": "date-time" + }, + "displayName": { + "type": "string" + }, + "email": { + "type": "string" + }, + "providerId": { + "type": "string" + }, + "provider": { + "type": "string" + }, + "profilePicUrl": { + "type": "string" } } } diff --git a/gen/openapiv2/headscale/v1/node.swagger.json b/gen/openapiv2/headscale/v1/node.swagger.json index 8271250e..16321347 100644 --- a/gen/openapiv2/headscale/v1/node.swagger.json +++ b/gen/openapiv2/headscale/v1/node.swagger.json @@ -34,6 +34,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } diff --git a/gen/openapiv2/headscale/v1/routes.swagger.json b/gen/openapiv2/headscale/v1/policy.swagger.json similarity index 91% rename from gen/openapiv2/headscale/v1/routes.swagger.json rename to gen/openapiv2/headscale/v1/policy.swagger.json index 34eda676..63057ed0 100644 --- a/gen/openapiv2/headscale/v1/routes.swagger.json +++ b/gen/openapiv2/headscale/v1/policy.swagger.json @@ -1,7 +1,7 @@ { "swagger": "2.0", "info": { - "title": "headscale/v1/routes.proto", + "title": "headscale/v1/policy.proto", "version": "version not set" }, "consumes": [ @@ -34,6 +34,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } diff --git a/gen/openapiv2/headscale/v1/preauthkey.swagger.json b/gen/openapiv2/headscale/v1/preauthkey.swagger.json index ef16319c..17a2be1a 100644 --- a/gen/openapiv2/headscale/v1/preauthkey.swagger.json +++ b/gen/openapiv2/headscale/v1/preauthkey.swagger.json @@ -34,6 +34,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } diff --git a/gen/openapiv2/headscale/v1/user.swagger.json b/gen/openapiv2/headscale/v1/user.swagger.json index 1355a9cc..008ca3e8 100644 --- a/gen/openapiv2/headscale/v1/user.swagger.json +++ b/gen/openapiv2/headscale/v1/user.swagger.json @@ -34,6 +34,7 @@ "details": { "type": "array", "items": { + "type": "object", "$ref": "#/definitions/protobufAny" } } diff --git a/go.mod b/go.mod index 46086c06..5cc9a7dd 100644 --- a/go.mod +++ b/go.mod @@ -1,170 +1,228 @@ module github.com/juanfont/headscale -go 1.21 - -toolchain go1.21.1 +go 1.25.5 require ( - github.com/AlecAivazis/survey/v2 v2.3.7 - github.com/coreos/go-oidc/v3 v3.8.0 + github.com/arl/statsviz v0.8.0 + github.com/cenkalti/backoff/v5 v5.0.3 + github.com/chasefleming/elem-go v0.31.0 + github.com/coder/websocket v1.8.14 + github.com/coreos/go-oidc/v3 v3.16.0 + github.com/creachadair/command v0.2.0 + github.com/creachadair/flax v0.0.5 github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc - github.com/deckarep/golang-set/v2 v2.4.0 - github.com/efekarakus/termcolor v1.0.1 - github.com/glebarez/sqlite v1.10.0 - github.com/gofrs/uuid/v5 v5.0.0 - github.com/google/go-cmp v0.6.0 + github.com/docker/docker v28.5.2+incompatible + github.com/fsnotify/fsnotify v1.9.0 + github.com/glebarez/sqlite v1.11.0 + github.com/go-gormigrate/gormigrate/v2 v2.1.5 + github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced + github.com/gofrs/uuid/v5 v5.4.0 + github.com/google/go-cmp v0.7.0 github.com/gorilla/mux v1.8.1 - github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 - github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 - github.com/klauspost/compress v1.17.3 - github.com/oauth2-proxy/mockoidc v0.0.0-20220308204021-b9169deeb282 - github.com/ory/dockertest/v3 v3.10.0 - github.com/patrickmn/go-cache v2.1.0+incompatible + github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.4 + github.com/jagottsicher/termcolor v1.0.2 + github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 + github.com/ory/dockertest/v3 v3.12.0 github.com/philip-bui/grpc-zerolog v1.0.1 github.com/pkg/profile v1.7.0 - github.com/prometheus/client_golang v1.17.0 - github.com/prometheus/common v0.45.0 - github.com/pterm/pterm v0.12.71 - github.com/puzpuzpuz/xsync/v3 v3.0.2 - github.com/rs/zerolog v1.31.0 - github.com/samber/lo v1.38.1 - github.com/spf13/cobra v1.8.0 - github.com/spf13/viper v1.17.0 - github.com/stretchr/testify v1.8.4 - github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a + github.com/prometheus/client_golang v1.23.2 + github.com/prometheus/common v0.67.5 + github.com/pterm/pterm v0.12.82 + github.com/puzpuzpuz/xsync/v4 v4.3.0 + github.com/rs/zerolog v1.34.0 + github.com/samber/lo v1.52.0 + github.com/sasha-s/go-deadlock v0.3.6 + github.com/spf13/cobra v1.10.2 + github.com/spf13/viper v1.21.0 + github.com/stretchr/testify v1.11.1 + github.com/tailscale/hujson v0.0.0-20250605163823-992244df8c5a + github.com/tailscale/squibble v0.0.0-20251104223530-a961feffb67f + github.com/tailscale/tailsql v0.0.0-20260105194658-001575c3ca09 github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e - go4.org/netipx v0.0.0-20230824141953-6213f710f925 - golang.org/x/crypto v0.16.0 - golang.org/x/exp v0.0.0-20231127185646-65229373498e - golang.org/x/net v0.19.0 - golang.org/x/oauth2 v0.15.0 - golang.org/x/sync v0.5.0 - google.golang.org/genproto/googleapis/api v0.0.0-20231127180814-3a041ad873d4 - google.golang.org/grpc v1.59.0 - google.golang.org/protobuf v1.31.0 - gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c + go4.org/netipx v0.0.0-20231129151722-fdeea329fbba + golang.org/x/crypto v0.46.0 + golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 + golang.org/x/net v0.48.0 + golang.org/x/oauth2 v0.34.0 + golang.org/x/sync v0.19.0 + google.golang.org/genproto/googleapis/api v0.0.0-20251222181119-0a764e51fe1b + google.golang.org/grpc v1.78.0 + google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 - gorm.io/driver/postgres v1.5.4 - gorm.io/gorm v1.25.5 - tailscale.com v1.54.0 + gorm.io/driver/postgres v1.6.0 + gorm.io/gorm v1.31.1 + tailscale.com v1.94.0 + zgo.at/zcache/v2 v2.4.1 + zombiezen.com/go/postgrestest v1.0.1 +) + +// NOTE: modernc sqlite has a fragile dependency +// chain and it is important that they are updated +// in lockstep to ensure that they do not break +// some architectures and similar at runtime: +// https://github.com/juanfont/headscale/issues/2188 +// +// Fragile libc dependency: +// https://pkg.go.dev/modernc.org/sqlite#hdr-Fragile_modernc_org_libc_dependency +// https://gitlab.com/cznic/sqlite/-/issues/177 +// +// To upgrade, determine the new SQLite version to +// be used, and consult the `go.mod` file: +// https://gitlab.com/cznic/sqlite/-/blob/master/go.mod +// to find +// the appropriate `libc` version, then upgrade them +// together, e.g: +// go get modernc.org/libc@v1.55.3 modernc.org/sqlite@v1.33.1 +require ( + modernc.org/libc v1.67.6 // indirect + modernc.org/mathutil v1.7.1 // indirect + modernc.org/memory v1.11.0 // indirect + modernc.org/sqlite v1.44.3 ) require ( atomicgo.dev/cursor v0.2.0 // indirect atomicgo.dev/keyboard v0.2.9 // indirect atomicgo.dev/schedule v0.1.0 // indirect - filippo.io/edwards25519 v1.0.0 // indirect - github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 // indirect - github.com/Microsoft/go-winio v0.6.1 // indirect + dario.cat/mergo v1.0.2 // indirect + filippo.io/edwards25519 v1.1.0 // indirect + github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c // indirect + github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 // indirect github.com/akutz/memconn v0.1.0 // indirect github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa // indirect + github.com/aws/aws-sdk-go-v2 v1.41.0 // indirect + github.com/aws/aws-sdk-go-v2/config v1.29.5 // indirect + github.com/aws/aws-sdk-go-v2/credentials v1.17.58 // indirect + github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 // indirect + github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 // indirect + github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 // indirect + github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 // indirect + github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 // indirect + github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 // indirect + github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 // indirect + github.com/aws/smithy-go v1.24.0 // indirect + github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02 // indirect github.com/beorn7/perks v1.0.1 // indirect - github.com/cenkalti/backoff/v4 v4.2.1 // indirect - github.com/cespare/xxhash/v2 v2.2.0 // indirect - github.com/containerd/console v1.0.3 // indirect - github.com/containerd/continuity v0.4.3 // indirect - github.com/coreos/go-iptables v0.7.0 // indirect - github.com/dblohm7/wingoes v0.0.0-20231025182615-65d8b4b5428f // indirect - github.com/docker/cli v24.0.7+incompatible // indirect - github.com/docker/docker v24.0.7+incompatible // indirect - github.com/docker/go-connections v0.4.0 // indirect + github.com/cenkalti/backoff/v4 v4.3.0 // indirect + github.com/cespare/xxhash/v2 v2.3.0 // indirect + github.com/clipperhouse/uax29/v2 v2.2.0 // indirect + github.com/containerd/console v1.0.5 // indirect + github.com/containerd/continuity v0.4.5 // indirect + github.com/containerd/errdefs v1.0.0 // indirect + github.com/containerd/errdefs/pkg v0.3.0 // indirect + github.com/creachadair/mds v0.25.10 // indirect + github.com/creachadair/msync v0.7.1 // indirect + github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0 // indirect + github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc // indirect + github.com/distribution/reference v0.6.0 // indirect + github.com/docker/cli v28.5.1+incompatible // indirect + github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-units v0.5.0 // indirect github.com/dustin/go-humanize v1.0.1 // indirect - github.com/felixge/fgprof v0.9.3 // indirect - github.com/fsnotify/fsnotify v1.7.0 // indirect - github.com/fxamacker/cbor/v2 v2.5.0 // indirect - github.com/glebarez/go-sqlite v1.21.2 // indirect - github.com/go-gormigrate/gormigrate/v2 v2.1.1 // indirect - github.com/go-jose/go-jose/v3 v3.0.1 // indirect - github.com/gogo/protobuf v1.3.2 // indirect - github.com/golang-jwt/jwt v3.2.2+incompatible // indirect - github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect - github.com/golang/protobuf v1.5.3 // indirect + github.com/felixge/fgprof v0.9.5 // indirect + github.com/felixge/httpsnoop v1.0.4 // indirect + github.com/fxamacker/cbor/v2 v2.9.0 // indirect + github.com/gaissmai/bart v0.18.0 // indirect + github.com/glebarez/go-sqlite v1.22.0 // indirect + github.com/go-jose/go-jose/v3 v3.0.4 // indirect + github.com/go-jose/go-jose/v4 v4.1.3 // indirect + github.com/go-logr/logr v1.4.3 // indirect + github.com/go-logr/stdr v1.2.2 // indirect + github.com/go-viper/mapstructure/v2 v2.4.0 // indirect + github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 // indirect + github.com/golang-jwt/jwt/v5 v5.3.0 // indirect + github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 // indirect + github.com/golang/protobuf v1.5.4 // indirect + github.com/google/btree v1.1.3 // indirect github.com/google/go-github v17.0.0+incompatible // indirect github.com/google/go-querystring v1.1.0 // indirect - github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c // indirect - github.com/google/pprof v0.0.0-20231127191134-f3a68a39ae15 // indirect + github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d // indirect github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 // indirect - github.com/google/uuid v1.4.0 // indirect - github.com/gookit/color v1.5.4 // indirect - github.com/hashicorp/go-version v1.6.0 // indirect - github.com/hashicorp/hcl v1.0.0 // indirect - github.com/hdevalence/ed25519consensus v0.1.0 // indirect - github.com/imdario/mergo v0.3.16 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/gookit/color v1.6.0 // indirect + github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 // indirect + github.com/hashicorp/go-version v1.7.0 // indirect + github.com/hdevalence/ed25519consensus v0.2.0 // indirect + github.com/huin/goupnp v1.3.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect - github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect - github.com/jackc/pgx/v5 v5.5.0 // indirect - github.com/jackc/puddle/v2 v2.2.1 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect + github.com/jackc/pgx/v5 v5.7.6 // indirect + github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect - github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 // indirect - github.com/jsimonetti/rtnetlink v1.4.0 // indirect - github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect - github.com/kr/pretty v0.3.1 // indirect - github.com/kr/text v0.2.0 // indirect - github.com/lib/pq v1.10.7 // indirect + github.com/jsimonetti/rtnetlink v1.4.1 // indirect + github.com/klauspost/compress v1.18.2 // indirect + github.com/lib/pq v1.10.9 // indirect github.com/lithammer/fuzzysearch v1.1.8 // indirect - github.com/magiconair/properties v1.8.7 // indirect - github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-isatty v0.0.20 // indirect - github.com/mattn/go-runewidth v0.0.15 // indirect - github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect - github.com/mdlayher/netlink v1.7.2 // indirect + github.com/mattn/go-runewidth v0.0.19 // indirect + github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 // indirect github.com/mdlayher/socket v0.5.0 // indirect - github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d // indirect - github.com/miekg/dns v1.1.57 // indirect github.com/mitchellh/go-ps v1.0.0 // indirect - github.com/mitchellh/mapstructure v1.5.0 // indirect - github.com/moby/term v0.5.0 // indirect + github.com/moby/docker-image-spec v1.3.1 // indirect + github.com/moby/sys/atomicwriter v0.1.0 // indirect + github.com/moby/sys/user v0.4.0 // indirect + github.com/moby/term v0.5.2 // indirect + github.com/morikuni/aec v1.0.0 // indirect + github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect + github.com/ncruces/go-strftime v1.0.0 // indirect github.com/opencontainers/go-digest v1.0.0 // indirect - github.com/opencontainers/image-spec v1.1.0-rc5 // indirect - github.com/opencontainers/runc v1.1.10 // indirect - github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/opencontainers/image-spec v1.1.1 // indirect + github.com/opencontainers/runc v1.3.2 // indirect + github.com/pelletier/go-toml/v2 v2.2.4 // indirect + github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 // indirect + github.com/pires/go-proxyproto v0.8.1 // indirect github.com/pkg/errors v0.9.1 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect - github.com/prometheus/client_model v0.5.0 // indirect - github.com/prometheus/procfs v0.12.0 // indirect + github.com/prometheus-community/pro-bing v0.4.0 // indirect + github.com/prometheus/client_model v0.6.2 // indirect + github.com/prometheus/procfs v0.16.1 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect - github.com/rivo/uniseg v0.4.4 // indirect - github.com/rogpeppe/go-internal v1.11.0 // indirect - github.com/sagikazarmark/locafero v0.4.0 // indirect - github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/safchain/ethtool v0.3.0 // indirect + github.com/sagikazarmark/locafero v0.12.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect - github.com/sourcegraph/conc v0.3.0 // indirect - github.com/spf13/afero v1.11.0 // indirect - github.com/spf13/cast v1.6.0 // indirect - github.com/spf13/pflag v1.0.5 // indirect + github.com/spf13/afero v1.15.0 // indirect + github.com/spf13/cast v1.10.0 // indirect + github.com/spf13/pflag v1.0.10 // indirect github.com/subosito/gotenv v1.6.0 // indirect + github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e // indirect github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 // indirect - github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 // indirect - github.com/vishvananda/netlink v1.2.1-beta.2 // indirect - github.com/vishvananda/netns v0.0.4 // indirect + github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc // indirect + github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a // indirect + github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 // indirect + github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da // indirect github.com/x448/float16 v0.8.4 // indirect github.com/xeipuuv/gojsonpointer v0.0.0-20190905194746-02993c407bfb // indirect github.com/xeipuuv/gojsonreference v0.0.0-20180127040603-bd5ef7bd5415 // indirect github.com/xeipuuv/gojsonschema v1.2.0 // indirect github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect - go.uber.org/multierr v1.11.0 // indirect - go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect - golang.org/x/mod v0.14.0 // indirect - golang.org/x/sys v0.15.0 // indirect - golang.org/x/term v0.15.0 // indirect - golang.org/x/text v0.14.0 // indirect - golang.org/x/time v0.5.0 // indirect - golang.org/x/tools v0.16.0 // indirect + go.opentelemetry.io/auto/sdk v1.2.1 // indirect + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 // indirect + go.opentelemetry.io/otel v1.39.0 // indirect + go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 // indirect + go.opentelemetry.io/otel/metric v1.39.0 // indirect + go.opentelemetry.io/otel/trace v1.39.0 // indirect + go.yaml.in/yaml/v2 v2.4.3 // indirect + go.yaml.in/yaml/v3 v3.0.4 // indirect + go4.org/mem v0.0.0-20240501181205-ae6ca9944745 // indirect + golang.org/x/mod v0.30.0 // indirect + golang.org/x/sys v0.40.0 // indirect + golang.org/x/term v0.38.0 // indirect + golang.org/x/text v0.32.0 // indirect + golang.org/x/time v0.12.0 // indirect + golang.org/x/tools v0.39.0 // indirect + golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wireguard/windows v0.5.3 // indirect - google.golang.org/appengine v1.6.8 // indirect - google.golang.org/genproto v0.0.0-20231127180814-3a041ad873d4 // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4 // indirect - gopkg.in/ini.v1 v1.67.0 // indirect - gopkg.in/square/go-jose.v2 v2.6.0 // indirect - gopkg.in/yaml.v2 v2.4.0 // indirect - gotest.tools/v3 v3.4.0 // indirect - modernc.org/libc v1.34.11 // indirect - modernc.org/mathutil v1.6.0 // indirect - modernc.org/memory v1.7.2 // indirect - modernc.org/sqlite v1.27.0 // indirect - nhooyr.io/websocket v1.8.10 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b // indirect + gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 // indirect +) + +tool ( + golang.org/x/tools/cmd/stringer + tailscale.com/cmd/viewer ) diff --git a/go.sum b/go.sum index aa0a650b..1021d749 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f h1:1C7nZuxUMNz7eiQALRfiqNOm04+m3edWlRff/BYHf0Q= +9fans.net/go v0.0.8-0.20250307142834-96bdba94b63f/go.mod h1:hHyrZRryGqVdqrknjq5OWDLGCTJ2NeEvtrpR96mjraM= atomicgo.dev/assert v0.0.2 h1:FiKeMiZSgRrZsPo9qn/7vmr7mCsh5SZyXY4YGYiYwrg= atomicgo.dev/assert v0.0.2/go.mod h1:ut4NcI3QDdJtlmAxQULOmA13Gz6e2DWbSAS8RUOmNYQ= atomicgo.dev/cursor v0.2.0 h1:H6XN5alUJ52FZZUkI7AlJbUc1aW38GWZalpYRPpoPOw= @@ -6,16 +8,16 @@ atomicgo.dev/keyboard v0.2.9 h1:tOsIid3nlPLZ3lwgG8KZMp/SFmr7P0ssEN5JUsm78K8= atomicgo.dev/keyboard v0.2.9/go.mod h1:BC4w9g00XkxH/f1HXhW2sXmJFOCWbKn9xrOunSFtExQ= atomicgo.dev/schedule v0.1.0 h1:nTthAbhZS5YZmgYbb2+DH8uQIZcTlIrd4eYr3UQxEjs= atomicgo.dev/schedule v0.1.0/go.mod h1:xeUa3oAkiuHYh8bKiQBRojqAMq3PXXbJujjb0hw8pEU= -cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= -filippo.io/edwards25519 v1.0.0 h1:0wAIcmJUqRdI8IJ/3eGi5/HwXZWPujYXXlkrQogz0Ek= -filippo.io/edwards25519 v1.0.0/go.mod h1:N1IkdkCkiLB6tki+MYJoSx2JTY9NUlxZE7eHn5EwJns= +dario.cat/mergo v1.0.2 h1:85+piFYR1tMbRrLcDwR18y4UKJ3aH1Tbzi24VRW1TK8= +dario.cat/mergo v1.0.2/go.mod h1:E/hbnu0NxMFBjpMIE34DRGLWqDy0g5FuKDhCb31ngxA= +filippo.io/edwards25519 v1.1.0 h1:FNf4tywRC1HmFuKW5xopWpigGjJKiJSV0Cqo0cJWDaA= +filippo.io/edwards25519 v1.1.0/go.mod h1:BxyFTGdWcka3PhytdK4V28tE5sGfRvvvRV7EaN4VDT4= filippo.io/mkcert v1.4.4 h1:8eVbbwfVlaqUM7OwuftKc2nuYOoTDQWqsoXmzoXZdbc= filippo.io/mkcert v1.4.4/go.mod h1:VyvOchVuAye3BoUsPUOOofKygVwLV2KQMVFJNRq+1dA= -github.com/AlecAivazis/survey/v2 v2.3.7 h1:6I/u8FvytdGsgonrYsVn2t8t4QiRnh6QSTqkkhIiSjQ= -github.com/AlecAivazis/survey/v2 v2.3.7/go.mod h1:xUTIdE4KCOIjsBAE1JYsUPoCqYdZ1reCfTwbto0Fduo= -github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161 h1:L/gRVlceqvL25UVaW/CKtUDjefjrs0SPonmDGUVOYP0= -github.com/Azure/go-ansiterm v0.0.0-20230124172434-306776ec8161/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= -github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c h1:udKWzYgxTojEKWjV8V+WSxDXJ4NFATAsZjh8iIbsQIg= +github.com/Azure/go-ansiterm v0.0.0-20250102033503-faa5f7b0171c/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/BurntSushi/toml v1.5.0 h1:W5quZX/G/csjUnuI8SUYlsHs9M38FC7znL0lIO+DvMg= +github.com/BurntSushi/toml v1.5.0/go.mod h1:ukJfTF/6rtPPRCnwkur4qwRxa8vTRFBF0uk2lLoLwho= github.com/MarvinJWendt/testza v0.1.0/go.mod h1:7AxNvlfeHP7Z/hDQ5JtE3OKYT3XFUeLCDE2DQninSqs= github.com/MarvinJWendt/testza v0.2.1/go.mod h1:God7bhG8n6uQxwdScay+gjm9/LnO4D3kkcZX4hv9Rp8= github.com/MarvinJWendt/testza v0.2.8/go.mod h1:nwIcjmr0Zz+Rcwfh3/4UhBp7ePKVhuBExvZqnKYWlII= @@ -25,258 +27,371 @@ github.com/MarvinJWendt/testza v0.3.0/go.mod h1:eFcL4I0idjtIx8P9C6KkAuLgATNKpX4/ github.com/MarvinJWendt/testza v0.4.2/go.mod h1:mSdhXiKH8sg/gQehJ63bINcCKp7RtYewEjXsvsVUPbE= github.com/MarvinJWendt/testza v0.5.2 h1:53KDo64C1z/h/d/stCYCPY69bt/OSwjq5KpFNwi+zB4= github.com/MarvinJWendt/testza v0.5.2/go.mod h1:xu53QFE5sCdjtMCKk8YMQ2MnymimEctc4n3EjyIYvEY= -github.com/Microsoft/go-winio v0.6.1 h1:9/kr64B9VUZrLm5YYwbGtUJnMgqWVOdUAXu6Migciow= -github.com/Microsoft/go-winio v0.6.1/go.mod h1:LRdKpFKfdobln8UmuiYcKPot9D2v6svN5+sAH+4kjUM= -github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2 h1:+vx7roKuyA63nhn5WAunQHLTznkw5W8b1Xc0dNjp83s= -github.com/Netflix/go-expect v0.0.0-20220104043353-73e0943537d2/go.mod h1:HBCaDeC1lPdgDeDbhX8XFpy1jqjK0IBG8W5K+xYqA0w= +github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= +github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5 h1:TngWCqHvy9oXAN6lEVMRuU21PR1EtLVZJmdB18Gu3Rw= github.com/Nvveen/Gotty v0.0.0-20120604004816-cd527374f1e5/go.mod h1:lmUJ/7eu/Q8D7ML55dXQrVaamCz2vxCfdQBasLZfHKk= github.com/akutz/memconn v0.1.0 h1:NawI0TORU4hcOMsMr11g7vwlCdkYeLKXBcxWu2W/P8A= github.com/akutz/memconn v0.1.0/go.mod h1:Jo8rI7m0NieZyLI5e2CDlRdRqRRB4S7Xp77ukDjH+Fw= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI= github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/arl/statsviz v0.8.0 h1:O6GjjVxEDxcByAucOSl29HaGYLXsuwA3ujJw8H9E7/U= +github.com/arl/statsviz v0.8.0/go.mod h1:XlrbiT7xYT03xaW9JMMfD8KFUhBOESJwfyNJu83PbB0= github.com/atomicgo/cursor v0.0.1/go.mod h1:cBON2QmmrysudxNBFthvMtN32r3jxVRIvzkUiF/RuIk= -github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= +github.com/aws/aws-sdk-go-v2 v1.41.0 h1:tNvqh1s+v0vFYdA1xq0aOJH+Y5cRyZ5upu6roPgPKd4= +github.com/aws/aws-sdk-go-v2 v1.41.0/go.mod h1:MayyLB8y+buD9hZqkCW3kX1AKq07Y5pXxtgB+rRFhz0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8 h1:zAxi9p3wsZMIaVCdoiQp2uZ9k1LsZvmAnoTBeZPXom0= +github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.8/go.mod h1:3XkePX5dSaxveLAYY7nsbsZZrKxCyEuE5pM4ziFxyGg= +github.com/aws/aws-sdk-go-v2/config v1.29.5 h1:4lS2IB+wwkj5J43Tq/AwvnscBerBJtQQ6YS7puzCI1k= +github.com/aws/aws-sdk-go-v2/config v1.29.5/go.mod h1:SNzldMlDVbN6nWxM7XsUiNXPSa1LWlqiXtvh/1PrJGg= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58 h1:/d7FUpAPU8Lf2KUdjniQvfNdlMID0Sd9pS23FJ3SS9Y= +github.com/aws/aws-sdk-go-v2/credentials v1.17.58/go.mod h1:aVYW33Ow10CyMQGFgC0ptMRIqJWvJ4nxZb0sUiuQT/A= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27 h1:7lOW8NUwE9UZekS1DYoiPdVAqZ6A+LheHWb+mHbNOq8= +github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.27/go.mod h1:w1BASFIPOPUae7AgaH4SbjNbfdkxuggLyGfNFTn8ITY= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16 h1:rgGwPzb82iBYSvHMHXc8h9mRoOUBZIGFgKb9qniaZZc= +github.com/aws/aws-sdk-go-v2/internal/configsources v1.4.16/go.mod h1:L/UxsGeKpGoIj6DxfhOWHWQ/kGKcd4I1VncE4++IyKA= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16 h1:1jtGzuV7c82xnqOVfx2F0xmJcOw5374L7N6juGW6x6U= +github.com/aws/aws-sdk-go-v2/internal/endpoints/v2 v2.7.16/go.mod h1:M2E5OQf+XLe+SZGmmpaI2yy+J326aFf6/+54PoxSANc= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2 h1:Pg9URiobXy85kgFev3og2CuOZ8JZUBENF+dcgWBaYNk= +github.com/aws/aws-sdk-go-v2/internal/ini v1.8.2/go.mod h1:FbtygfRFze9usAadmnGJNc8KsP346kEe+y2/oyhGAGc= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31 h1:8IwBjuLdqIO1dGB+dZ9zJEl8wzY3bVYxcs0Xyu/Lsc0= +github.com/aws/aws-sdk-go-v2/internal/v4a v1.3.31/go.mod h1:8tMBcuVjL4kP/ECEIWTCWtwV2kj6+ouEKl4cqR4iWLw= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4 h1:0ryTNEdJbzUCEWkVXEXoqlXV72J5keC1GvILMOuD00E= +github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.13.4/go.mod h1:HQ4qwNZh32C3CBeO6iJLQlgtMzqeG17ziAA/3KDJFow= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5 h1:siiQ+jummya9OLPDEyHVb2dLW4aOMe22FGDd0sAfuSw= +github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.5.5/go.mod h1:iHVx2J9pWzITdP5MJY6qWfG34TfD9EA+Qi3eV6qQCXw= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16 h1:oHjJHeUy0ImIV0bsrX0X91GkV5nJAyv1l1CC9lnO0TI= +github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.13.16/go.mod h1:iRSNGgOYmiYwSCXxXaKb9HfOEj40+oTKn8pTxMlYkRM= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12 h1:tkVNm99nkJnFo1H9IIQb5QkCiPcvCDn3Pos+IeTbGRA= +github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.18.12/go.mod h1:dIVlquSPUMqEJtx2/W17SM2SuESRaVEhEV9alcMqxjw= +github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3 h1:JBod0SnNqcWQ0+uAyzeRFG1zCHotW8DukumYYyNy0zo= +github.com/aws/aws-sdk-go-v2/service/s3 v1.75.3/go.mod h1:FHSHmyEUkzRbaFFqqm6bkLAOQHgqhsLmfCahvCBMiyA= +github.com/aws/aws-sdk-go-v2/service/ssm v1.45.0 h1:IOdss+igJDFdic9w3WKwxGCmHqUxydvIhJOm9LJ32Dk= +github.com/aws/aws-sdk-go-v2/service/ssm v1.45.0/go.mod h1:Q7XIWsMo0JcMpI/6TGD6XXcXcV1DbTj6e9BKNntIMIM= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14 h1:c5WJ3iHz7rLIgArznb3JCSQT3uUMiz9DLZhIX+1G8ok= +github.com/aws/aws-sdk-go-v2/service/sso v1.24.14/go.mod h1:+JJQTxB6N4niArC14YNtxcQtwEqzS3o9Z32n7q33Rfs= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13 h1:f1L/JtUkVODD+k1+IiSJUUv8A++2qVr+Xvb3xWXETMU= +github.com/aws/aws-sdk-go-v2/service/ssooidc v1.28.13/go.mod h1:tvqlFoja8/s0o+UruA1Nrezo/df0PzdunMDDurUfg6U= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5 h1:SciGFVNZ4mHdm7gpD1dgZYnCuVdX1s+lFTg4+4DOy70= +github.com/aws/aws-sdk-go-v2/service/sts v1.41.5/go.mod h1:iW40X4QBmUxdP+fZNOpfmkdMZqsovezbAeO+Ubiv2pk= +github.com/aws/smithy-go v1.24.0 h1:LpilSUItNPFr1eY85RYgTIg5eIEPtvFbskaFcmmIUnk= +github.com/aws/smithy-go v1.24.0/go.mod h1:LEj2LM3rBRQJxPZTB4KuzZkaZYnZPnvgIhb4pu07mx0= +github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02 h1:bXAPYSbdYbS5VTy92NIUbeDI1qyggi+JYh5op9IFlcQ= +github.com/axiomhq/hyperloglog v0.0.0-20240319100328-84253e514e02/go.mod h1:k08r+Yj1PRAmuayFiRK6MYuR5Ve4IuZtTfxErMIh0+c= github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= -github.com/cenkalti/backoff/v4 v4.2.1 h1:y4OZtCnogmCPw98Zjyt5a6+QwPLGkiQsYW5oUqylYbM= -github.com/cenkalti/backoff/v4 v4.2.1/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= -github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= -github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= -github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/cenkalti/backoff/v4 v4.3.0 h1:MyRJ/UdXutAwSAT+s3wNd7MfTIcy71VQueUuFK343L8= +github.com/cenkalti/backoff/v4 v4.3.0/go.mod h1:Y3VNntkOUPxTVeUxJ/G5vcM//AlwfmyYozVcomhLiZE= +github.com/cenkalti/backoff/v5 v5.0.3 h1:ZN+IMa753KfX5hd8vVaMixjnqRZ3y8CuJKRKj1xcsSM= +github.com/cenkalti/backoff/v5 v5.0.3/go.mod h1:rkhZdG3JZukswDf7f0cwqPNk4K0sa+F97BxZthm/crw= +github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= +github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chasefleming/elem-go v0.31.0 h1:vZsuKmKdv6idnUbu3awMruxTiFqZ/ertFJFAyBCkVhI= +github.com/chasefleming/elem-go v0.31.0/go.mod h1:UBmmZfso2LkXA0HZInbcwsmhE/LXFClEcBPNCGeARtA= +github.com/chromedp/cdproto v0.0.0-20230802225258-3cf4e6d46a89/go.mod h1:GKljq0VrfU4D5yc+2qA6OVr8pmO/MBbPEWqWQ/oqGEs= +github.com/chromedp/chromedp v0.9.2/go.mod h1:LkSXJKONWTCHAfQasKFUZI+mxqS4tZqhmtGzzhLsnLs= +github.com/chromedp/sysutil v1.0.0/go.mod h1:kgWmDdq8fTzXYcKIBqIYvRRTnYb9aNS9moAV0xufSww= github.com/chzyer/logex v1.1.10/go.mod h1:+Ywpsq7O8HXn0nuIou7OrIPyXbp3wmkHB+jjWRnGsAI= +github.com/chzyer/logex v1.2.1/go.mod h1:JLbx6lG2kDbNRFnfkgvh4eRJRPX1QCoOIWomwysCBrQ= github.com/chzyer/readline v0.0.0-20180603132655-2972be24d48e/go.mod h1:nSuG5e5PlCu98SY8svDHJxuZscDgtXS6KTTbou5AhLI= +github.com/chzyer/readline v1.5.1/go.mod h1:Eh+b79XXUwfKfcPLepksvw2tcLE/Ct21YObkaSkeBlk= github.com/chzyer/test v0.0.0-20180213035817-a1ea475d72b1/go.mod h1:Q3SI9o4m/ZMnBNeIyt5eFwwo7qiLfzFZmjNmxjkiQlU= -github.com/cilium/ebpf v0.12.3 h1:8ht6F9MquybnY97at+VDZb3eQQr8ev79RueWeVaEcG4= -github.com/cilium/ebpf v0.12.3/go.mod h1:TctK1ivibvI3znr66ljgi4hqOT8EYQjz1KWBfb1UVgM= -github.com/client9/misspell v0.3.4/go.mod h1:qj6jICC3Q7zFZvVWo7KLAzC3yx5G7kyvSDkc90ppPyw= -github.com/cncf/udpa/go v0.0.0-20191209042840-269d4d468f6f/go.mod h1:M8M6+tZqaGXZJjfX53e64911xZQV5JYwmTeXPW+k8Sc= -github.com/containerd/console v1.0.3 h1:lIr7SlA5PxZyMV30bDW0MGbiOPXwc63yRuCP0ARubLw= +github.com/chzyer/test v1.0.0/go.mod h1:2JlltgoNkt4TW/z9V/IzDdFaMTM2JPIi26O1pF38GC8= +github.com/cilium/ebpf v0.17.3 h1:FnP4r16PWYSE4ux6zN+//jMcW4nMVRvuTLVTvCjyyjg= +github.com/cilium/ebpf v0.17.3/go.mod h1:G5EDHij8yiLzaqn0WjyfJHvRa+3aDlReIaLVRMvOyJk= +github.com/clipperhouse/uax29/v2 v2.2.0 h1:ChwIKnQN3kcZteTXMgb1wztSgaU+ZemkgWdohwgs8tY= +github.com/clipperhouse/uax29/v2 v2.2.0/go.mod h1:EFJ2TJMRUaplDxHKj1qAEhCtQPW2tJSwu5BF98AuoVM= +github.com/coder/websocket v1.8.14 h1:9L0p0iKiNOibykf283eHkKUHHrpG7f65OE3BhhO7v9g= +github.com/coder/websocket v1.8.14/go.mod h1:NX3SzP+inril6yawo5CQXx8+fk145lPDC6pumgx0mVg= github.com/containerd/console v1.0.3/go.mod h1:7LqA/THxQ86k76b8c/EMSiaJ3h1eZkMkXar0TQ1gf3U= -github.com/containerd/continuity v0.4.3 h1:6HVkalIp+2u1ZLH1J/pYX2oBVXlJZvh1X1A7bEZ9Su8= -github.com/containerd/continuity v0.4.3/go.mod h1:F6PTNCKepoxEaXLQp3wDAjygEnImnZ/7o4JzpodfroQ= -github.com/coreos/go-iptables v0.7.0 h1:XWM3V+MPRr5/q51NuWSgU0fqMad64Zyxs8ZUoMsamr8= -github.com/coreos/go-iptables v0.7.0/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= -github.com/coreos/go-oidc/v3 v3.8.0 h1:s3e30r6VEl3/M7DTSCEuImmrfu1/1WBgA0cXkdzkrAY= -github.com/coreos/go-oidc/v3 v3.8.0/go.mod h1:yQzSCqBnK3e6Fs5l+f5i0F8Kwf0zpH9bPEsbY00KanM= +github.com/containerd/console v1.0.5 h1:R0ymNeydRqH2DmakFNdmjR2k0t7UPuiOV/N/27/qqsc= +github.com/containerd/console v1.0.5/go.mod h1:YynlIjWYF8myEu6sdkwKIvGQq+cOckRm6So2avqoYAk= +github.com/containerd/continuity v0.4.5 h1:ZRoN1sXq9u7V6QoHMcVWGhOwDFqZ4B9i5H6un1Wh0x4= +github.com/containerd/continuity v0.4.5/go.mod h1:/lNJvtJKUQStBzpVQ1+rasXO1LAWtUQssk28EZvJ3nE= +github.com/containerd/errdefs v1.0.0 h1:tg5yIfIlQIrxYtu9ajqY42W3lpS19XqdxRQeEwYG8PI= +github.com/containerd/errdefs v1.0.0/go.mod h1:+YBYIdtsnF4Iw6nWZhJcqGSg/dwvV7tyJ/kCkyJ2k+M= +github.com/containerd/errdefs/pkg v0.3.0 h1:9IKJ06FvyNlexW690DXuQNx2KA2cUJXx151Xdx3ZPPE= +github.com/containerd/errdefs/pkg v0.3.0/go.mod h1:NJw6s9HwNuRhnjJhM7pylWwMyAkmCQvQ4GpJHEqRLVk= +github.com/containerd/log v0.1.0 h1:TCJt7ioM2cr/tfR8GPbGf9/VRAX8D2B4PjzCpfX540I= +github.com/containerd/log v0.1.0/go.mod h1:VRRf09a7mHDIRezVKTRCrOq78v577GXq3bSa3EhrzVo= +github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6 h1:8h5+bWd7R6AYUslN6c6iuZWTKsKxUFDlpnmilO6R2n0= +github.com/coreos/go-iptables v0.7.1-0.20240112124308-65c67c9f46e6/go.mod h1:Qe8Bv2Xik5FyTXwgIbLAnv2sWSBmvWdFETJConOQ//Q= +github.com/coreos/go-oidc/v3 v3.16.0 h1:qRQUCFstKpXwmEjDQTIbyY/5jF00+asXzSkmkoa/mow= +github.com/coreos/go-oidc/v3 v3.16.0/go.mod h1:wqPbKFrVnE90vty060SB40FCJ8fTHTxSwyXJqZH+sI8= github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= -github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= -github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= -github.com/creack/pty v1.1.17/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= -github.com/creack/pty v1.1.18 h1:n56/Zwd5o6whRC5PMGretI4IdRLlmBXYNjScPaBgsbY= -github.com/creack/pty v1.1.18/go.mod h1:MOBLtS5ELjhRRrroQr9kyvTxUAFNvYEK993ew/Vr4O4= +github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= +github.com/creachadair/command v0.2.0 h1:qTA9cMMhZePAxFoNdnk6F6nn94s1qPndIg9hJbqI9cA= +github.com/creachadair/command v0.2.0/go.mod h1:j+Ar+uYnFsHpkMeV9kGj6lJ45y9u2xqtg8FYy6cm+0o= +github.com/creachadair/flax v0.0.5 h1:zt+CRuXQASxwQ68e9GHAOnEgAU29nF0zYMHOCrL5wzE= +github.com/creachadair/flax v0.0.5/go.mod h1:F1PML0JZLXSNDMNiRGK2yjm5f+L9QCHchyHBldFymj8= +github.com/creachadair/mds v0.25.10 h1:9k9JB35D1xhOCFl0liBhagBBp8fWWkKZrA7UXsfoHtA= +github.com/creachadair/mds v0.25.10/go.mod h1:4hatI3hRM+qhzuAmqPRFvaBM8mONkS7nsLxkcuTYUIs= +github.com/creachadair/msync v0.7.1 h1:SeZmuEBXQPe5GqV/C94ER7QIZPwtvFbeQiykzt/7uho= +github.com/creachadair/msync v0.7.1/go.mod h1:8CcFlLsSujfHE5wWm19uUBLHIPDAUr6LXDwneVMO008= +github.com/creachadair/taskgroup v0.13.2 h1:3KyqakBuFsm3KkXi/9XIb0QcA8tEzLHLgaoidf0MdVc= +github.com/creachadair/taskgroup v0.13.2/go.mod h1:i3V1Zx7H8RjwljUEeUWYT30Lmb9poewSb2XI1yTwD0g= +github.com/creack/pty v1.1.23 h1:4M6+isWdcStXEf15G/RbrMPOQj1dZ7HPZCGwE4kOeP0= +github.com/creack/pty v1.1.23/go.mod h1:08sCNb52WyoAwi2QDyzUCTgcvVFhUzewun7wtTfvcwE= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/dblohm7/wingoes v0.0.0-20231025182615-65d8b4b5428f h1:c5mkOIXbHZVKGQaSEZZyLW9ORD+h4PT2TPF8IQPwyOs= -github.com/dblohm7/wingoes v0.0.0-20231025182615-65d8b4b5428f/go.mod h1:6NCrWM5jRefaG7iN0iMShPalLsljHWBh9v1zxM2f8Xs= -github.com/deckarep/golang-set/v2 v2.4.0 h1:DnfgWKdhvHM8Kihdw9fKWXd08EdsPiyoHsk5bfsmkNI= -github.com/deckarep/golang-set/v2 v2.4.0/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= -github.com/docker/cli v24.0.7+incompatible h1:wa/nIwYFW7BVTGa7SWPVyyXU9lgORqUb1xfI36MSkFg= -github.com/docker/cli v24.0.7+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= -github.com/docker/docker v24.0.7+incompatible h1:Wo6l37AuwP3JaMnZa226lzVXGA3F9Ig1seQen0cKYlM= -github.com/docker/docker v24.0.7+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= -github.com/docker/go-connections v0.4.0 h1:El9xVISelRB7BuFusrZozjnkIM5YnzCViNKohAFqRJQ= -github.com/docker/go-connections v0.4.0/go.mod h1:Gbd7IOopHjR8Iph03tsViu4nIes5XhDvyHbTtUxmeec= +github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0 h1:vrC07UZcgPzu/OjWsmQKMGg3LoPSz9jh/pQXIrHjUj4= +github.com/dblohm7/wingoes v0.0.0-20240123200102-b75a8a7d7eb0/go.mod h1:Nx87SkVqTKd8UtT+xu7sM/l+LgXs6c0aHrlKusR+2EQ= +github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc h1:8WFBn63wegobsYAX0YjD+8suexZDga5CctH4CCTx2+8= +github.com/dgryski/go-metro v0.0.0-20180109044635-280f6062b5bc/go.mod h1:c9O8+fpSOX1DM8cPNSkX/qsBWdkD4yd2dpciOWQjpBw= +github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e h1:vUmf0yezR0y7jJ5pceLHthLaYf4bA5T14B6q39S4q2Q= +github.com/digitalocean/go-smbios v0.0.0-20180907143718-390a4f403a8e/go.mod h1:YTIHhz/QFSYnu/EhlF2SpU2Uk+32abacUYA5ZPljz1A= +github.com/distribution/reference v0.6.0 h1:0IXCQ5g4/QMHHkarYzh5l+u8T3t73zM5QvfrDyIgxBk= +github.com/distribution/reference v0.6.0/go.mod h1:BbU0aIcezP1/5jX/8MP0YiH4SdvB5Y4f/wlDRiLyi3E= +github.com/djherbis/times v1.6.0 h1:w2ctJ92J8fBvWPxugmXIv7Nz7Q3iDMKNx9v5ocVH20c= +github.com/djherbis/times v1.6.0/go.mod h1:gOHeRAz2h+VJNZ5Gmc/o7iD9k4wW7NMVqieYCY99oc0= +github.com/docker/cli v28.5.1+incompatible h1:ESutzBALAD6qyCLqbQSEf1a/U8Ybms5agw59yGVc+yY= +github.com/docker/cli v28.5.1+incompatible/go.mod h1:JLrzqnKDaYBop7H2jaqPtU4hHvMKP+vjCwu2uszcLI8= +github.com/docker/docker v28.5.2+incompatible h1:DBX0Y0zAjZbSrm1uzOkdr1onVghKaftjlSWt4AFexzM= +github.com/docker/docker v28.5.2+incompatible/go.mod h1:eEKB0N0r5NX/I1kEveEz05bcu8tLC/8azJZsviup8Sk= +github.com/docker/go-connections v0.6.0 h1:LlMG9azAe1TqfR7sO+NJttz1gy6KO7VJBh+pMmjSD94= +github.com/docker/go-connections v0.6.0/go.mod h1:AahvXYshr6JgfUJGdDCs2b5EZG/vmaMAntpSFH5BFKE= github.com/docker/go-units v0.5.0 h1:69rxXcBk27SvSaaxTtLh/8llcHD8vYHT7WSdRZ/jvr4= github.com/docker/go-units v0.5.0/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= -github.com/efekarakus/termcolor v1.0.1 h1:YAKFO3bnLrqZGTWyNLcYoSIAQFKVOmbqmDnwsU/znzg= -github.com/efekarakus/termcolor v1.0.1/go.mod h1:AitrZNrE4nPO538fRsqf+p0WgLdAsGN5pUNrHEPsEMM= -github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= -github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= -github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/felixge/fgprof v0.9.3 h1:VvyZxILNuCiUCSXtPtYmmtGvb65nqXh2QFWc0Wpf2/g= github.com/felixge/fgprof v0.9.3/go.mod h1:RdbpDgzqYVh/T9fPELJyV7EYJuHB55UTEULNun8eiPw= +github.com/felixge/fgprof v0.9.5 h1:8+vR6yu2vvSKn08urWyEuxx75NWPEvybbkBirEpsbVY= +github.com/felixge/fgprof v0.9.5/go.mod h1:yKl+ERSa++RYOs32d8K6WEXCB4uXdLls4ZaZPpayhMM= +github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= +github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= -github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= -github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= -github.com/fxamacker/cbor/v2 v2.5.0 h1:oHsG0V/Q6E/wqTS2O1Cozzsy69nqCiguo5Q1a1ADivE= -github.com/fxamacker/cbor/v2 v2.5.0/go.mod h1:TA1xS00nchWmaBnEIxPSE5oHLuJBAVvqrtAnWBwBCVo= -github.com/glebarez/go-sqlite v1.21.2 h1:3a6LFC4sKahUunAmynQKLZceZCOzUthkRkEAl9gAXWo= -github.com/glebarez/go-sqlite v1.21.2/go.mod h1:sfxdZyhQjTM2Wry3gVYWaW072Ri1WMdWJi0k6+3382k= -github.com/glebarez/sqlite v1.10.0 h1:u4gt8y7OND/cCei/NMHmfbLxF6xP2wgKcT/BJf2pYkc= -github.com/glebarez/sqlite v1.10.0/go.mod h1:IJ+lfSOmiekhQsFTJRx/lHtGYmCdtAiTaf5wI9u5uHA= -github.com/go-gormigrate/gormigrate/v2 v2.1.1 h1:eGS0WTFRV30r103lU8JNXY27KbviRnqqIDobW3EV3iY= -github.com/go-gormigrate/gormigrate/v2 v2.1.1/go.mod h1:L7nJ620PFDKei9QOhJzqA8kRCk+E3UbV2f5gv+1ndLc= -github.com/go-jose/go-jose/v3 v3.0.1 h1:pWmKFVtt+Jl0vBZTIpz/eAKwsm6LkIxDVVbFHKkchhA= -github.com/go-jose/go-jose/v3 v3.0.1/go.mod h1:RNkWWRld676jZEYoV3+XK8L2ZnNSvIsxFMht0mSX+u8= -github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= -github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= -github.com/go-sql-driver/mysql v1.6.0 h1:BCTh4TKNUYmOmMUcQ3IipzF5prigylS7XXjEkfCHuOE= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= -github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/fsnotify/fsnotify v1.9.0 h1:2Ml+OJNzbYCTzsxtv8vKSFD9PbJjmhYF14k/jKC7S9k= +github.com/fsnotify/fsnotify v1.9.0/go.mod h1:8jBTzvmWwFyi3Pb8djgCCO5IBqzKJ/Jwo8TRcHyHii0= +github.com/fxamacker/cbor/v2 v2.9.0 h1:NpKPmjDBgUfBms6tr6JZkTHtfFGcMKsw3eGcmD/sapM= +github.com/fxamacker/cbor/v2 v2.9.0/go.mod h1:vM4b+DJCtHn+zz7h3FFp/hDAI9WNWCsZj23V5ytsSxQ= +github.com/gaissmai/bart v0.18.0 h1:jQLBT/RduJu0pv/tLwXE+xKPgtWJejbxuXAR+wLJafo= +github.com/gaissmai/bart v0.18.0/go.mod h1:JJzMAhNF5Rjo4SF4jWBrANuJfqY+FvsFhW7t1UZJ+XY= +github.com/github/fakeca v0.1.0 h1:Km/MVOFvclqxPM9dZBC4+QE564nU4gz4iZ0D9pMw28I= +github.com/github/fakeca v0.1.0/go.mod h1:+bormgoGMMuamOscx7N91aOuUST7wdaJ2rNjeohylyo= +github.com/glebarez/go-sqlite v1.22.0 h1:uAcMJhaA6r3LHMTFgP0SifzgXg46yJkgxqyuyec+ruQ= +github.com/glebarez/go-sqlite v1.22.0/go.mod h1:PlBIdHe0+aUEFn+r2/uthrWq4FxbzugL0L8Li6yQJbc= +github.com/glebarez/sqlite v1.11.0 h1:wSG0irqzP6VurnMEpFGer5Li19RpIRi2qvQz++w0GMw= +github.com/glebarez/sqlite v1.11.0/go.mod h1:h8/o8j5wiAsqSPoWELDUdJXhjAhsVliSn7bWZjOhrgQ= +github.com/go-gormigrate/gormigrate/v2 v2.1.5 h1:1OyorA5LtdQw12cyJDEHuTrEV3GiXiIhS4/QTTa/SM8= +github.com/go-gormigrate/gormigrate/v2 v2.1.5/go.mod h1:mj9ekk/7CPF3VjopaFvWKN2v7fN3D9d3eEOAXRhi/+M= +github.com/go-jose/go-jose/v3 v3.0.4 h1:Wp5HA7bLQcKnf6YYao/4kpRpVMp/yf6+pJKV8WFSaNY= +github.com/go-jose/go-jose/v3 v3.0.4/go.mod h1:5b+7YgP7ZICgJDBdfjZaIt+H/9L9T/YQrVfLAMboGkQ= +github.com/go-jose/go-jose/v4 v4.1.3 h1:CVLmWDhDVRa6Mi/IgCgaopNosCaHz7zrMeF9MlZRkrs= +github.com/go-jose/go-jose/v4 v4.1.3/go.mod h1:x4oUasVrzR7071A4TnHLGSPpNOm2a21K9Kf04k1rs08= +github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced h1:Q311OHjMh/u5E2TITc++WlTP5We0xNseRMkHDyvhW7I= +github.com/go-json-experiment/json v0.0.0-20250813024750-ebf49471dced/go.mod h1:TiCD2a1pcmjd7YnhGH0f/zKNcCD06B029pHhzV23c2M= +github.com/go-logr/logr v1.2.2/go.mod h1:jdQByPbusPIv2/zmleS9BjJVeZ6kBagPoEUsqbVz/1A= +github.com/go-logr/logr v1.4.3 h1:CjnDlHq8ikf6E492q6eKboGOC0T8CDaOvkHCIg8idEI= +github.com/go-logr/logr v1.4.3/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/go-ole/go-ole v1.3.0 h1:Dt6ye7+vXGIKZ7Xtk4s6/xVdGDQynvom7xCFEdWr6uE= +github.com/go-ole/go-ole v1.3.0/go.mod h1:5LS6F96DhAwUc7C+1HLexzMXY1xGRSryjyPPKW6zv78= +github.com/go-sql-driver/mysql v1.8.1 h1:LedoTUt/eveggdHS9qUFC1EFSa8bU2+1pZjSRpvNJ1Y= +github.com/go-sql-driver/mysql v1.8.1/go.mod h1:wEBSXgmK//2ZFJyE+qWnIsVGmvmEKlqwuVSjsCm7DZg= +github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= +github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737 h1:cf60tHxREO3g1nroKr2osU3JWZsJzkfi7rEg+oAB0Lo= +github.com/go4org/plan9netshell v0.0.0-20250324183649-788daa080737/go.mod h1:MIS0jDzbU/vuM9MC4YnBITCv+RYuTRq8dJzmCrFsK9g= +github.com/gobwas/httphead v0.1.0/go.mod h1:O/RXo79gxV8G+RqlR/otEwx4Q36zl9rqC5u12GKvMCM= +github.com/gobwas/pool v0.2.1/go.mod h1:q8bcK0KcYlCgd9e7WYLm9LpyS+YeLd8JVDW6WezmKEw= +github.com/gobwas/ws v1.2.1/go.mod h1:hRKAFb8wOxFROYNsT1bqfWnhX+b5MFeJM9r2ZSwg/KY= github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= -github.com/gofrs/uuid/v5 v5.0.0 h1:p544++a97kEL+svbcFbCQVM9KFu0Yo25UoISXGNNH9M= -github.com/gofrs/uuid/v5 v5.0.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= -github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q= -github.com/gogo/protobuf v1.3.2/go.mod h1:P1XiOD3dCwIKUDQYPy72D8LYyHL2YPYrpS2s69NZV8Q= -github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= -github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= -github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da h1:oI5xCqsCo564l8iNU+DwB5epxmsaqB+rhGL0m5jtYqE= -github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= -github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= -github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= -github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= -github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= -github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= -github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= -github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466 h1:sQspH8M4niEijh3PFscJRLDnkL547IeP7kpPe3uUhEg= +github.com/godbus/dbus/v5 v5.1.1-0.20230522191255-76236955d466/go.mod h1:ZiQxhyQ+bbbfxUKVvjfO498oPYvtYhZzycal3G/NHmU= +github.com/gofrs/uuid/v5 v5.4.0 h1:EfbpCTjqMuGyq5ZJwxqzn3Cbr2d0rUZU7v5ycAk/e/0= +github.com/gofrs/uuid/v5 v5.4.0/go.mod h1:CDOjlDMVAtN56jqyRUZh58JT31Tiw7/oQyEXZV+9bD8= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8 h1:f+oWsMOmNPc8JmEHVZIycC7hBoQxHH9pNKQORJNozsQ= +github.com/golang/groupcache v0.0.0-20241129210726-2c02b8208cf8/go.mod h1:wcDNUvekVysuuOpQKo3191zZyTpiI6se1N1ULghS0sw= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= +github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= +github.com/google/btree v1.1.3/go.mod h1:qOPhT0dTNdNzV6Z/lhRX0YXUafgPLFUh+gZMl761Gm4= github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.4/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= -github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= +github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/go-github v17.0.0+incompatible h1:N0LgJ1j65A7kfXrZnUDaYCs/Sf4rEjNlfyDHW9dolSY= github.com/google/go-github v17.0.0+incompatible/go.mod h1:zLgOLi98H3fifZn+44m+umXrS52loVEgC2AApnigrVQ= github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8= github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU= -github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c h1:06RMfw+TMMHtRuUOroMeatRCCgSMWXCJQeABvHU69YQ= -github.com/google/nftables v0.1.1-0.20230115205135-9aa6fdf5a28c/go.mod h1:BVIYo3cdnT4qSylnYqcd5YtmXhr51cJPGtnLBe/uLBU= +github.com/google/go-tpm v0.9.4 h1:awZRf9FwOeTunQmHoDYSHJps3ie6f1UlhS1fOdPEt1I= +github.com/google/go-tpm v0.9.4/go.mod h1:h9jEsEECg7gtLis0upRBQU+GhYVH6jMjrFxI8u6bVUY= +github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806 h1:wG8RYIyctLhdFk6Vl1yPGtSRtwGpVkWyZww1OCil2MI= +github.com/google/nftables v0.2.1-0.20240414091927-5e242ec57806/go.mod h1:Beg6V6zZ3oEn0JuiUQ4wqwuyqqzasOltcoXPtgLbFp4= github.com/google/pprof v0.0.0-20211214055906-6f57359322fd/go.mod h1:KgnwoLYCZ8IQu3XUZ8Nc/bM9CCZFOyjUNOSygVozoDg= -github.com/google/pprof v0.0.0-20231127191134-f3a68a39ae15 h1:t2sLhFuGXwoomaKLTuoxFfFqqlG1Gp2DpsupXq3UvZ0= -github.com/google/pprof v0.0.0-20231127191134-f3a68a39ae15/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/google/pprof v0.0.0-20240227163752-401108e1b7e7/go.mod h1:czg5+yv1E0ZGTi6S6vVK1mke0fV+FaUhNGcd6VRS9Ik= +github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d h1:KJIErDwbSHjnp/SGzE5ed8Aol7JsKiI5X7yWKAtzhM0= +github.com/google/pprof v0.0.0-20251007162407-5df77e3f7d1d/go.mod h1:I6V7YzU0XDpsHqbsyrghnFZLO1gwK6NPTNvmetQIk9U= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 h1:El6M4kTTCOh6aBiKaUGG7oYTSPP8MxqL4YI3kZKwcP4= github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510/go.mod h1:pupxD2MaaD3pAXIBCelhxNneeOaAeabZDe5s4K6zSpQ= -github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= -github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gookit/assert v0.1.1 h1:lh3GcawXe/p+cU7ESTZ5Ui3Sm/x8JWpIis4/1aF0mY0= +github.com/gookit/assert v0.1.1/go.mod h1:jS5bmIVQZTIwk42uXl4lyj4iaaxx32tqH16CFj0VX2E= github.com/gookit/color v1.4.2/go.mod h1:fqRyamkC1W8uxl+lxCQxOT09l/vYfZ+QeiX3rKQHCoQ= github.com/gookit/color v1.5.0/go.mod h1:43aQb+Zerm/BWh2GnrgOQm7ffz7tvQXEKV6BFMl7wAo= -github.com/gookit/color v1.5.4 h1:FZmqs7XOyGgCAxmWyPslpiok1k05wmY3SJTytgvYFs0= -github.com/gookit/color v1.5.4/go.mod h1:pZJOeOS8DM43rXbp4AZo1n9zCU2qjpcRko0b6/QJi9w= +github.com/gookit/color v1.6.0 h1:JjJXBTk1ETNyqyilJhkTXJYYigHG24TM9Xa2M1xAhRA= +github.com/gookit/color v1.6.0/go.mod h1:9ACFc7/1IpHGBW8RwuDm/0YEnhg3dwwXpoMsmtyHfjs= github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= -github.com/grpc-ecosystem/go-grpc-middleware v1.4.0 h1:UH//fgunKIs4JdUbpDl1VZCDaL56wXCB/5+wF6uHfaI= -github.com/grpc-ecosystem/go-grpc-middleware v1.4.0/go.mod h1:g5qyo/la0ALbONm6Vbp88Yd8NsDy6rZz+RcrMPxvld8= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 h1:6UKoz5ujsI55KNpsJH3UwCq3T8kKbZwNZBNPuTTje8U= -github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1/go.mod h1:YvJ2f6MplWDhfxiUC3KpyTy76kYUZA4W3pTv/wdKQ9Y= -github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek= -github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= -github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= -github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= -github.com/hdevalence/ed25519consensus v0.1.0 h1:jtBwzzcHuTmFrQN6xQZn6CQEO/V9f7HsjsjeEZ6auqU= -github.com/hdevalence/ed25519consensus v0.1.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= -github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec h1:qv2VnGeEQHchGaZ/u7lxST/RaJw+cv273q79D81Xbog= -github.com/hinshun/vt10x v0.0.0-20220119200601-820417d04eec/go.mod h1:Q48J4R4DvxnHolD5P8pOtXigYlRuPLGl6moFx3ulM68= +github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674 h1:JeSE6pjso5THxAzdVpqr6/geYxZytqFMBCOtn/ujyeo= +github.com/gorilla/websocket v1.5.4-0.20250319132907-e064f32e3674/go.mod h1:r4w70xmWCQKmi1ONH4KIaBptdivuRPyosB9RmPlGEwA= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.4 h1:kEISI/Gx67NzH3nJxAmY/dGac80kKZgZt134u7Y/k1s= +github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.4/go.mod h1:6Nz966r3vQYCqIzWsuEl9d7cf7mRhtDmm++sOxlnfxI= +github.com/hashicorp/go-version v1.7.0 h1:5tqGy27NaOTB8yJKUZELlFAS/LTKJkrmONwQKeRZfjY= +github.com/hashicorp/go-version v1.7.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/hashicorp/golang-lru v0.6.0 h1:uL2shRDx7RTrOrTCUZEGP/wJUFiUI8QT6E7z5o8jga4= +github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= +github.com/hashicorp/golang-lru/v2 v2.0.7/go.mod h1:QeFd9opnmA6QUJc5vARoKUSoFhyfM2/ZepoAG6RGpeM= +github.com/hdevalence/ed25519consensus v0.2.0 h1:37ICyZqdyj0lAZ8P4D1d1id3HqbbG1N3iBb1Tb4rdcU= +github.com/hdevalence/ed25519consensus v0.2.0/go.mod h1:w3BHWjwJbFU29IRHL1Iqkw3sus+7FctEyM4RqDxYNzo= +github.com/huin/goupnp v1.3.0 h1:UvLUlWDNpoUdYzb2TCn+MuTWtcjXKSza2n6CBdQ0xXc= +github.com/huin/goupnp v1.3.0/go.mod h1:gnGPsThkYa7bFi/KWmEysQRf48l2dvR5bxr2OFckNX8= github.com/ianlancetaylor/demangle v0.0.0-20210905161508-09a460cdf81d/go.mod h1:aYm2/VgdVmcIU8iMfdMvDMsRAQjcfZSKFby6HOFvi/w= -github.com/imdario/mergo v0.3.16 h1:wwQJbIsHYGMUyLSPrEq1CT16AhnhNJQ51+4fdHUnCl4= -github.com/imdario/mergo v0.3.16/go.mod h1:WBLT9ZmE3lPoWsEzCh9LPo3TiwVN+ZKEjmz+hD27ysY= +github.com/ianlancetaylor/demangle v0.0.0-20230524184225-eabc099b10ab/go.mod h1:gx7rwoVhcfuVKG5uya9Hs3Sxj7EIvldVofAWIUtGouw= +github.com/illarion/gonotify/v3 v3.0.2 h1:O7S6vcopHexutmpObkeWsnzMJt/r1hONIEogeVNmJMk= +github.com/illarion/gonotify/v3 v3.0.2/go.mod h1:HWGPdPe817GfvY3w7cx6zkbzNZfi3QjcBm/wgVvEL1U= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/insomniacslk/dhcp v0.0.0-20240129002554-15c9b8791914 h1:kD8PseueGeYiid/Mmcv17Q0Qqicc4F46jcX22L/e/Hs= +github.com/insomniacslk/dhcp v0.0.0-20240129002554-15c9b8791914/go.mod h1:3A9PQ1cunSDF/1rbTq99Ts4pVnycWg+vlPkfeD2NLFI= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= -github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= -github.com/jackc/pgx/v5 v5.5.0 h1:NxstgwndsTRy7eq9/kqYc/BZh5w2hHJV86wjvO+1xPw= -github.com/jackc/pgx/v5 v5.5.0/go.mod h1:Ig06C2Vu0t5qXC60W8sqIthScaEnFvojjj9dSljmHRA= -github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= -github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.7.6 h1:rWQc5FwZSPX58r1OQmkuaNicxdmExaEz5A2DO2hUuTk= +github.com/jackc/pgx/v5 v5.7.6/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jagottsicher/termcolor v1.0.2 h1:fo0c51pQSuLBN1+yVX2ZE+hE+P7ULb/TY8eRowJnrsM= +github.com/jagottsicher/termcolor v1.0.2/go.mod h1:RcH8uFwF/0wbEdQmi83rjmlJ+QOKdMSE9Rc1BEB7zFo= +github.com/jellydator/ttlcache/v3 v3.1.0 h1:0gPFG0IHHP6xyUyXq+JaD8fwkDCqgqwohXNJBcYE71g= +github.com/jellydator/ttlcache/v3 v3.1.0/go.mod h1:hi7MGFdMAwZna5n2tuvh63DvFLzVKySzCVW6+0gA2n4= github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= -github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86 h1:elKwZS1OcdQ0WwEDBeqxKwb7WB62QX8bvZ/FJnVXIfk= -github.com/josharian/native v1.1.1-0.20230202152459-5c7d0dd6ab86/go.mod h1:aFAMtuldEgx/4q7iSGazk22+IcgvtiC+HIimFO9XlS8= -github.com/jsimonetti/rtnetlink v1.4.0 h1:Z1BF0fRgcETPEa0Kt0MRk3yV5+kF1FWTni6KUFKrq2I= -github.com/jsimonetti/rtnetlink v1.4.0/go.mod h1:5W1jDvWdnthFJ7fxYX1GMK07BUpI4oskfOqvPteYS6E= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 h1:Z9n2FFNUXsshfwJMBgNA0RU6/i7WVaAegv3PtuIHPMs= -github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51/go.mod h1:CzGEWj7cYgsdH8dAjBGEr58BoE7ScuLd+fwFZ44+/x8= -github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= -github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= -github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA= -github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= +github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= +github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo= +github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y= +github.com/jsimonetti/rtnetlink v1.4.1 h1:JfD4jthWBqZMEffc5RjgmlzpYttAVw1sdnmiNaPO3hE= +github.com/jsimonetti/rtnetlink v1.4.1/go.mod h1:xJjT7t59UIZ62GLZbv6PLLo8VFrostJMPBAheR6OM8w= +github.com/klauspost/compress v1.18.2 h1:iiPHWW0YrcFgpBYhsA6D1+fqHssJscY/Tm/y2Uqnapk= +github.com/klauspost/compress v1.18.2/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/klauspost/cpuid/v2 v2.0.10/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.0.12/go.mod h1:g2LTdtYhdyuGPqyWyv7qRAmj1WBqxuObKfj5c0PQa7c= github.com/klauspost/cpuid/v2 v2.2.3 h1:sxCkb+qR91z4vsqw4vGGZlDgPz3G7gjaLyK3V8y70BU= github.com/klauspost/cpuid/v2 v2.2.3/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= -github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a h1:+RR6SqnTkDLWyICxS1xpjCi/3dhyV+TgZwA6Ww3KncQ= +github.com/kortschak/wol v0.0.0-20200729010619-da482cc4850a/go.mod h1:YTtCCM3ryyfiu4F7t8HQ1mxvp1UBdWM2r6Xa+nGWvDk= +github.com/kr/fs v0.1.0 h1:Jskdu9ieNAYnjxsi0LbQp1ulIKZV1LAFgK1tWhpZgl8= +github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/lib/pq v1.10.7 h1:p7ZhMD+KsSRozJr34udlUrhboJwWAgCg34+/ZZNvZZw= -github.com/lib/pq v1.10.7/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= +github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= +github.com/ledongthuc/pdf v0.0.0-20220302134840-0c2507a12d80/go.mod h1:imJHygn/1yfhB7XSJJKlFZKl/J+dCPAknuiaGOshXAs= +github.com/lib/pq v1.8.0/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lithammer/fuzzysearch v1.1.8 h1:/HIuJnjHuXS8bKaiTMeeDlW2/AyIWk2brx1V8LFgLN4= github.com/lithammer/fuzzysearch v1.1.8/go.mod h1:IdqeyBClc3FFqSzYq/MXESsS4S0FsZ5ajtkr5xPLts4= -github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= -github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= -github.com/mattn/go-colorable v0.1.2/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= -github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= +github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc= github.com/mattn/go-colorable v0.1.13/go.mod h1:7S9/ev0klgBDR4GtXTXX8a3vIGJpMovkB8vQcUbaXHg= -github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= -github.com/mattn/go-isatty v0.0.10/go.mod h1:qgIWMr58cqv1PHHyhnkY9lrL7etaEgOFcMEpPG5Rm84= +github.com/mattn/go-colorable v0.1.14 h1:9A9LHSqF/7dyVVX6g0U9cwm9pG3kP9gSzcuIPHPsaIE= +github.com/mattn/go-colorable v0.1.14/go.mod h1:6LmQG8QLFO4G5z1gPvYEzlUgJ2wF+stgPZH1UqBm1s8= github.com/mattn/go-isatty v0.0.16/go.mod h1:kYGgaQfpe5nmfYZH+SKPsOc2e4SrIfOl2e/yFXSvRLM= github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/mattn/go-runewidth v0.0.15 h1:UNAjwbU9l54TA3KzvqLGxwWjHmMgBUVhBiTjelZgg3U= -github.com/mattn/go-runewidth v0.0.15/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= -github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= -github.com/mdlayher/netlink v1.7.2 h1:/UtM3ofJap7Vl4QWCPDGXY8d3GIY2UGSDbK+QWmY8/g= -github.com/mdlayher/netlink v1.7.2/go.mod h1:xraEF7uJbxLhc5fpHL4cPe221LI2bdttWlU+ZGLfQSw= +github.com/mattn/go-runewidth v0.0.19 h1:v++JhqYnZuu5jSKrk9RbgF5v4CGUjqRfBm05byFGLdw= +github.com/mattn/go-runewidth v0.0.19/go.mod h1:XBkDxAl56ILZc9knddidhrOlY5R/pDhgLpndooCuJAs= +github.com/mdlayher/genetlink v1.3.2 h1:KdrNKe+CTu+IbZnm/GVUMXSqBBLqcGpRDa0xkQy56gw= +github.com/mdlayher/genetlink v1.3.2/go.mod h1:tcC3pkCrPUGIKKsCsp0B3AdaaKuHtaxoJRz3cc+528o= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0tpKMbdCXCIfBclan4oHk1Jb+Hrejirg= +github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= +github.com/mdlayher/sdnotify v1.0.0 h1:Ma9XeLVN/l0qpyx1tNeMSeTjCPH6NtuD6/N9XdTlQ3c= +github.com/mdlayher/sdnotify v1.0.0/go.mod h1:HQUmpM4XgYkhDLtd+Uad8ZFK1T9D5+pNxnXQjCeJlGE= github.com/mdlayher/socket v0.5.0 h1:ilICZmJcQz70vrWVes1MFera4jGiWNocSkykwwoy3XI= github.com/mdlayher/socket v0.5.0/go.mod h1:WkcBFfvyG8QENs5+hfQPl1X6Jpd2yeLIYgrGFmJiJxI= -github.com/mgutz/ansi v0.0.0-20170206155736-9520e82c474b/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= -github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d h1:5PJl274Y63IEHC+7izoQE9x6ikvDFZS2mDVS3drnohI= -github.com/mgutz/ansi v0.0.0-20200706080929-d51e80ef957d/go.mod h1:01TrycV0kFyexm33Z7vhZRXopbI8J3TDReVlkTgMUxE= -github.com/miekg/dns v1.1.57 h1:Jzi7ApEIzwEPLHWRcafCN9LZSBbqQpxjt/wpgvg7wcM= -github.com/miekg/dns v1.1.57/go.mod h1:uqRjCRUuEAA6qsOiJvDd+CFo/vW+y5WR6SNmHE55hZk= +github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= +github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= github.com/mitchellh/go-ps v1.0.0 h1:i6ampVEEF4wQFF+bkYfwYgY+F/uYJDktmvLPf7qIgjc= github.com/mitchellh/go-ps v1.0.0/go.mod h1:J4lOc8z8yJs6vUwklHw2XEIiT4z4C40KtWVN3nvg8Pg= -github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= -github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= -github.com/moby/term v0.5.0 h1:xt8Q1nalod/v7BqbG21f8mQPqH+xAaC9C3N3wfWbVP0= -github.com/moby/term v0.5.0/go.mod h1:8FzsFHVUBGZdbDsJw/ot+X+d5HLUbvklYLJ9uGfcI3Y= -github.com/oauth2-proxy/mockoidc v0.0.0-20220308204021-b9169deeb282 h1:TQMyrpijtkFyXpNI3rY5hsZQZw+paiH+BfAlsb81HBY= -github.com/oauth2-proxy/mockoidc v0.0.0-20220308204021-b9169deeb282/go.mod h1:rW25Kyd08Wdn3UVn0YBsDTSvReu0jqpmJKzxITPSjks= +github.com/moby/docker-image-spec v1.3.1 h1:jMKff3w6PgbfSa69GfNg+zN/XLhfXJGnEx3Nl2EsFP0= +github.com/moby/docker-image-spec v1.3.1/go.mod h1:eKmb5VW8vQEh/BAr2yvVNvuiJuY6UIocYsFu/DxxRpo= +github.com/moby/sys/atomicwriter v0.1.0 h1:kw5D/EqkBwsBFi0ss9v1VG3wIkVhzGvLklJ+w3A14Sw= +github.com/moby/sys/atomicwriter v0.1.0/go.mod h1:Ul8oqv2ZMNHOceF643P6FKPXeCmYtlQMvpizfsSoaWs= +github.com/moby/sys/sequential v0.6.0 h1:qrx7XFUd/5DxtqcoH1h438hF5TmOvzC/lspjy7zgvCU= +github.com/moby/sys/sequential v0.6.0/go.mod h1:uyv8EUTrca5PnDsdMGXhZe6CCe8U/UiTWd+lL+7b/Ko= +github.com/moby/sys/user v0.4.0 h1:jhcMKit7SA80hivmFJcbB1vqmw//wU61Zdui2eQXuMs= +github.com/moby/sys/user v0.4.0/go.mod h1:bG+tYYYJgaMtRKgEmuueC0hJEAZWwtIbZTB+85uoHjs= +github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ= +github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc= +github.com/morikuni/aec v1.0.0 h1:nP9CBfwrvYnBRgY6qfDQkygYDmYwOilePFkwzv4dU8A= +github.com/morikuni/aec v1.0.0/go.mod h1:BbKIizmSmc5MMPqRYbxO4ZU0S0+P200+tUnFx7PXmsc= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq1c1nUAm88MOHcQC9l5mIlSMApZMrHA= +github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ= +github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w= +github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ= +github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8= +github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25 h1:9bCMuD3TcnjeqjPT2gSlha4asp8NvgcFRYExCaikCxk= +github.com/oauth2-proxy/mockoidc v0.0.0-20240214162133-caebfff84d25/go.mod h1:eDjgYHYDJbPLBLsyZ6qRaugP0mX8vePOhZ5id1fdzJw= github.com/opencontainers/go-digest v1.0.0 h1:apOUWs51W5PlhuyGyz9FCeeBIOUDA/6nW8Oi/yOhh5U= github.com/opencontainers/go-digest v1.0.0/go.mod h1:0JzlMkj0TRzQZfJkVvzbP0HBR3IKzErnv2BNG4W4MAM= -github.com/opencontainers/image-spec v1.1.0-rc5 h1:Ygwkfw9bpDvs+c9E34SdgGOj41dX/cbdlwvlWt0pnFI= -github.com/opencontainers/image-spec v1.1.0-rc5/go.mod h1:X4pATf0uXsnn3g5aiGIsVnJBR4mxhKzfwmvK/B2NTm8= -github.com/opencontainers/runc v1.1.10 h1:EaL5WeO9lv9wmS6SASjszOeQdSctvpbu0DdBQBizE40= -github.com/opencontainers/runc v1.1.10/go.mod h1:+/R6+KmDlh+hOO8NkjmgkG9Qzvypzk0yXxAPYYR65+M= -github.com/opentracing/opentracing-go v1.1.0/go.mod h1:UkNAQd3GIcIGf0SeVgPpRdFStlNbqXla1AfSYxPUl2o= -github.com/ory/dockertest/v3 v3.10.0 h1:4K3z2VMe8Woe++invjaTB7VRyQXQy5UY+loujO4aNE4= -github.com/ory/dockertest/v3 v3.10.0/go.mod h1:nr57ZbRWMqfsdGdFNLHz5jjNdDb7VVFnzAeW1n5N1Lg= -github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc= -github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ= -github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= -github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/opencontainers/image-spec v1.1.1 h1:y0fUlFfIZhPF1W537XOLg0/fcx6zcHCJwooC2xJA040= +github.com/opencontainers/image-spec v1.1.1/go.mod h1:qpqAh3Dmcf36wStyyWU+kCeDgrGnAve2nCC8+7h8Q0M= +github.com/opencontainers/runc v1.3.2 h1:GUwgo0Fx9M/pl2utaSYlJfdBcXAB/CZXDxe322lvJ3Y= +github.com/opencontainers/runc v1.3.2/go.mod h1:F7UQQEsxcjUNnFpT1qPLHZBKYP7yWwk6hq8suLy9cl0= +github.com/orisano/pixelmatch v0.0.0-20220722002657-fb0b55479cde/go.mod h1:nZgzbfBr3hhjoZnS66nKrHmduYNpc34ny7RK4z5/HM0= +github.com/ory/dockertest/v3 v3.12.0 h1:3oV9d0sDzlSQfHtIaB5k6ghUCVMVLpAY8hwrqoCyRCw= +github.com/ory/dockertest/v3 v3.12.0/go.mod h1:aKNDTva3cp8dwOWwb9cWuX84aH5akkxXRvO7KCwWVjE= +github.com/pelletier/go-toml/v2 v2.2.4 h1:mye9XuhQ6gvn5h28+VilKrrPoQVanw5PMw/TB0t5Ec4= +github.com/pelletier/go-toml/v2 v2.2.4/go.mod h1:2gIqNv+qfxSVS7cM2xJQKtLSTLUE9V8t9Stt+h56mCY= +github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= +github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490 h1:QTvNkZ5ylY0PGgA+Lih+GdboMLY/G9SEGLMEGVjTVA4= +github.com/petermattis/goid v0.0.0-20250904145737-900bdf8bb490/go.mod h1:pxMtw7cyUw6B2bRH0ZBANSPg+AoSud1I1iyJHI69jH4= github.com/philip-bui/grpc-zerolog v1.0.1 h1:EMacvLRUd2O1K0eWod27ZP5CY1iTNkhBDLSN+Q4JEvA= github.com/philip-bui/grpc-zerolog v1.0.1/go.mod h1:qXbiq/2X4ZUMMshsqlWyTHOcw7ns+GZmlqZZN05ZHcQ= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= -github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pierrec/lz4/v4 v4.1.21 h1:yOVMLb6qSIDP67pl/5F7RepeKYu/VmTyEXvuMI5d9mQ= +github.com/pierrec/lz4/v4 v4.1.21/go.mod h1:gZWDp/Ze/IJXGXf23ltt2EXimqmTUXEy0GFuRQyBid4= +github.com/pires/go-proxyproto v0.8.1 h1:9KEixbdJfhrbtjpz/ZwCdWDD2Xem0NZ38qMYaASJgp0= +github.com/pires/go-proxyproto v0.8.1/go.mod h1:ZKAAyp3cgy5Y5Mo4n9AlScrkCZwUy0g3Jf+slqQVcuU= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/profile v1.7.0 h1:hnbDkaNWPCLMO9wGLdBFTIZvzDrDfBM2072E1S9gJkA= github.com/pkg/profile v1.7.0/go.mod h1:8Uer0jas47ZQMJ7VD+OHknK4YDY07LPUC6dEvqDjvNo= +github.com/pkg/sftp v1.13.6 h1:JFZT4XbOU7l77xGSpOdW+pwIMqP044IyjXX6FGyEKFo= +github.com/pkg/sftp v1.13.6/go.mod h1:tz1ryNURKu77RL+GuCzmoJYxQczL3wLNNpPWagdg4Qk= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q= -github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY= -github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= -github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw= -github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI= -github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM= -github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY= -github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo= -github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo= +github.com/prometheus-community/pro-bing v0.4.0 h1:YMbv+i08gQz97OZZBwLyvmmQEEzyfyrrjEaAchdy3R4= +github.com/prometheus-community/pro-bing v0.4.0/go.mod h1:b7wRYZtCcPmt4Sz319BykUU241rWLe1VFXyiyWK/dH4= +github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= +github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= +github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= +github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= +github.com/prometheus/common v0.67.5 h1:pIgK94WWlQt1WLwAC5j2ynLaBRDiinoAb86HZHTUGI4= +github.com/prometheus/common v0.67.5/go.mod h1:SjE/0MzDEEAyrdr5Gqc6G+sXI67maCxzaT3A2+HqjUw= +github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= +github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= github.com/pterm/pterm v0.12.27/go.mod h1:PhQ89w4i95rhgE+xedAoqous6K9X+r6aSOI2eFF7DZI= github.com/pterm/pterm v0.12.29/go.mod h1:WI3qxgvoQFFGKGjGnJR849gU0TsEOvKn5Q8LlY1U7lg= github.com/pterm/pterm v0.12.30/go.mod h1:MOqLIyMOgmTDz9yorcYbcw+HsgoZo3BQfg2wtl3HEFE= @@ -284,75 +399,95 @@ github.com/pterm/pterm v0.12.31/go.mod h1:32ZAWZVXD7ZfG0s8qqHXePte42kdz8ECtRyEej github.com/pterm/pterm v0.12.33/go.mod h1:x+h2uL+n7CP/rel9+bImHD5lF3nM9vJj80k9ybiiTTE= github.com/pterm/pterm v0.12.36/go.mod h1:NjiL09hFhT/vWjQHSj1athJpx6H8cjpHXNAK5bUw8T8= github.com/pterm/pterm v0.12.40/go.mod h1:ffwPLwlbXxP+rxT0GsgDTzS3y3rmpAO1NMjUkGTYf8s= -github.com/pterm/pterm v0.12.71 h1:KcEJ98EiVCbzDkFbktJ2gMlr4pn8IzyGb9bwK6ffkuA= -github.com/pterm/pterm v0.12.71/go.mod h1:SUAcoZjRt+yjPWlWba+/Fd8zJJ2lSXBQWf0Z0HbFiIQ= -github.com/puzpuzpuz/xsync/v3 v3.0.2 h1:3yESHrRFYr6xzkz61LLkvNiPFXxJEAABanTQpKbAaew= -github.com/puzpuzpuz/xsync/v3 v3.0.2/go.mod h1:VjzYrABPabuM4KyBh1Ftq6u8nhwY5tBPKP9jpmh0nnA= +github.com/pterm/pterm v0.12.82 h1:+D9wYhCaeaK0FIQoZtqbNQuNpe2lB2tajKKsTd5paVQ= +github.com/pterm/pterm v0.12.82/go.mod h1:TyuyrPjnxfwP+ccJdBTeWHtd/e0ybQHkOS/TakajZCw= +github.com/puzpuzpuz/xsync/v4 v4.3.0 h1:w/bWkEJdYuRNYhHn5eXnIT8LzDM1O629X1I9MJSkD7Q= +github.com/puzpuzpuz/xsync/v4 v4.3.0/go.mod h1:VJDmTCJMBt8igNxnkQd86r+8KUeN1quSfNKu5bLYFQo= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= -github.com/rivo/uniseg v0.4.4 h1:8TfxU8dW6PdqD27gjM8MVNuicgxIjxpm4K7x4jp8sis= -github.com/rivo/uniseg v0.4.4/go.mod h1:FN3SvrM+Zdj16jyLfmOkMNblXMcoc8DfTHruCPUcx88= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= -github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M= -github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA= -github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= -github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= -github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= +github.com/rogpeppe/go-internal v1.14.1 h1:UQB4HGPB6osV0SQTLymcB4TgvyWu6ZyliaW0tI/otEQ= +github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7so1lCWt35ZSgc= +github.com/rs/xid v1.6.0/go.mod h1:7XoLgs4eV+QndskICGsho+ADou8ySMSjJKDIan90Nz0= +github.com/rs/zerolog v1.34.0 h1:k43nTLIwcTVQAncfCw4KZ2VY6ukYoZaBPNOE8txlOeY= +github.com/rs/zerolog v1.34.0/go.mod h1:bJsvje4Z08ROH4Nhs5iH600c3IkWhwp44iRc54W6wYQ= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= -github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= -github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= -github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= -github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= -github.com/samber/lo v1.38.1 h1:j2XEAqXKb09Am4ebOg31SpvzUTTs6EN3VfgeLUhPdXM= -github.com/samber/lo v1.38.1/go.mod h1:+m/ZKRl6ClXCE2Lgf3MsQlWfh4bn1bz6CXEOxnEXnEA= +github.com/safchain/ethtool v0.3.0 h1:gimQJpsI6sc1yIqP/y8GYgiXn/NjgvpM0RNoWLVVmP0= +github.com/safchain/ethtool v0.3.0/go.mod h1:SA9BwrgyAqNo7M+uaL6IYbxpm5wk3L7Mm6ocLW+CJUs= +github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4= +github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI= +github.com/samber/lo v1.52.0 h1:Rvi+3BFHES3A8meP33VPAxiBZX/Aws5RxrschYGjomw= +github.com/samber/lo v1.52.0/go.mod h1:4+MXEGsJzbKGaUEQFKBq2xtfuznW9oz/WrgyzMzRoM0= +github.com/sasha-s/go-deadlock v0.3.6 h1:TR7sfOnZ7x00tWPfD397Peodt57KzMDo+9Ae9rMiUmw= +github.com/sasha-s/go-deadlock v0.3.6/go.mod h1:CUqNyyvMxTyjFqDT7MRg9mb4Dv/btmGTqSR+rky/UXo= github.com/sergi/go-diff v1.2.0/go.mod h1:STckp+ISIX8hZLjrqAeVduY0gWCT9IjLuqbuNXdaHfM= -github.com/sergi/go-diff v1.3.1 h1:xkr+Oxo4BOQKmkn/B9eMK0g5Kg/983T9DqqPHwYqD+8= -github.com/sergi/go-diff v1.3.1/go.mod h1:aMJSSKb2lpPvRNec0+w3fl7LP9IOFzdc9Pa4NFbPK1I= -github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3 h1:n661drycOFuPLCN3Uc8sB6B/s6Z4t2xvBgU1htSHuq8= +github.com/sergi/go-diff v1.3.2-0.20230802210424-5b0b94c5c0d3/go.mod h1:A0bzQcvG0E7Rwjx0REVgAGH58e96+X0MeOfepqsbeW4= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= -github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= -github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= -github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= -github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= -github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= -github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= -github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= -github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= -github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= -github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -github.com/spf13/viper v1.17.0 h1:I5txKw7MJasPL/BrfkbA0Jyo/oELqVmux4pR/UxOMfI= -github.com/spf13/viper v1.17.0/go.mod h1:BmMMMLQXSbcHK6KAOiFLz0l5JHrU89OdIRHvsk0+yVI= +github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= +github.com/spf13/afero v1.15.0/go.mod h1:NC2ByUVxtQs4b3sIUphxK0NioZnmxgyCrfzeuq8lxMg= +github.com/spf13/cast v1.10.0 h1:h2x0u2shc1QuLHfxi+cTJvs30+ZAHOGRic8uyGTDWxY= +github.com/spf13/cast v1.10.0/go.mod h1:jNfB8QC9IA6ZuY2ZjDp0KtFO2LZZlg4S/7bzP6qqeHo= +github.com/spf13/cobra v1.10.2 h1:DMTTonx5m65Ic0GOoRY2c16WCbHxOOw6xxezuLaBpcU= +github.com/spf13/cobra v1.10.2/go.mod h1:7C1pvHqHw5A4vrJfjNwvOdzYu0Gml16OCs2GRiTUUS4= +github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/pflag v1.0.10 h1:4EBh2KAYBwaONj6b2Ye1GiHfwjqyROoF4RwYO+vPwFk= +github.com/spf13/pflag v1.0.10/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/spf13/viper v1.21.0 h1:x5S+0EU27Lbphp4UKm1C+1oQO+rKx36vfCoaVebLFSU= +github.com/spf13/viper v1.21.0/go.mod h1:P0lhsswPGWD/1lZJ9ny3fYnVqxiegrlNrEmgLjbTCAY= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= -github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= -github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= -github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= -github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= -github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= +github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e h1:PtWT87weP5LWHEY//SWsYkSO3RWRZo4OSWagh3YD2vQ= +github.com/tailscale/certstore v0.1.1-0.20231202035212-d3fa0460f47e/go.mod h1:XrBNfAFN+pwoWuksbFS9Ccxnopa15zJGgXRFN90l3K4= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55 h1:Gzfnfk2TWrk8Jj4P4c1a3CtQyMaTVCznlkLZI++hok4= github.com/tailscale/go-winio v0.0.0-20231025203758-c4f33415bf55/go.mod h1:4k4QO+dQ3R5FofL+SanAUZe+/QfeK0+OIuwDIRu2vSg= -github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a h1:SJy1Pu0eH1C29XwJucQo73FrleVK6t4kYz4NVhp34Yw= -github.com/tailscale/hujson v0.0.0-20221223112325-20486734a56a/go.mod h1:DFSS3NAGHthKo1gTlmEcSBiZrRJXi28rLNd/1udP1c8= -github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85 h1:zrsUcqrG2uQSPhaUPjUQwozcRdDdSxxqhNgNZ3drZFk= -github.com/tailscale/netlink v1.1.1-0.20211101221916-cabfb018fe85/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= +github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869 h1:SRL6irQkKGQKKLzvQP/ke/2ZuB7Py5+XuqtOgSj+iMM= +github.com/tailscale/golang-x-crypto v0.0.0-20250404221719-a5573b049869/go.mod h1:ikbF+YT089eInTp9f2vmvy4+ZVnW5hzX1q2WknxSprQ= +github.com/tailscale/hujson v0.0.0-20250605163823-992244df8c5a h1:a6TNDN9CgG+cYjaeN8l2mc4kSz2iMiCDQxPEyltUV/I= +github.com/tailscale/hujson v0.0.0-20250605163823-992244df8c5a/go.mod h1:EbW0wDK/qEUYI0A5bqq0C2kF8JTQwWONmGDBbzsxxHo= +github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7 h1:uFsXVBE9Qr4ZoF094vE6iYTLDl0qCiKzYXlL6UeWObU= +github.com/tailscale/netlink v1.1.1-0.20240822203006-4d49adab4de7/go.mod h1:NzVQi3Mleb+qzq8VmcWpSkcSYxXIg0DkI6XDzpVkhJ0= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc h1:24heQPtnFR+yfntqhI3oAu9i27nEojcQ4NuBQOo5ZFA= +github.com/tailscale/peercred v0.0.0-20250107143737-35a0c7bd7edc/go.mod h1:f93CXfllFsO9ZQVq+Zocb1Gp4G5Fz0b0rXHLOzt/Djc= +github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a h1:TApskGPim53XY5WRt5hX4DnO8V6CmVoimSklryIoGMM= +github.com/tailscale/setec v0.0.0-20251203133219-2ab774e4129a/go.mod h1:+6WyG6kub5/5uPsMdYQuSti8i6F5WuKpFWLQnZt/Mms= +github.com/tailscale/squibble v0.0.0-20251104223530-a961feffb67f h1:CL6gu95Y1o2ko4XiWPvWkJka0QmQWcUyPywWVWDPQbQ= +github.com/tailscale/squibble v0.0.0-20251104223530-a961feffb67f/go.mod h1:xJkMmR3t+thnUQhA3Q4m2VSlS5pcOq+CIjmU/xfKKx4= +github.com/tailscale/tailsql v0.0.0-20260105194658-001575c3ca09 h1:Fc9lE2cDYJbBLpCqnVmoLdf7McPqoHZiDxDPPpkJM04= +github.com/tailscale/tailsql v0.0.0-20260105194658-001575c3ca09/go.mod h1:QMNhC4XGFiXKngHVLXE+ERDmQoH0s5fD7AUxupykocQ= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976 h1:UBPHPtv8+nEAy2PD8RyAhOYvau1ek0HDJqLS/Pysi14= +github.com/tailscale/web-client-prebuilt v0.0.0-20250124233751-d4cd19a26976/go.mod h1:agQPE6y6ldqCOui2gkIh7ZMztTkIQKH049tv8siLuNQ= +github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6 h1:l10Gi6w9jxvinoiq15g8OToDdASBni4CyJOdHY1Hr8M= +github.com/tailscale/wf v0.0.0-20240214030419-6fbb0a674ee6/go.mod h1:ZXRML051h7o4OcI0d3AaILDIad/Xw0IkXaHM17dic1Y= +github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da h1:jVRUZPRs9sqyKlYHHzHjAqKN+6e/Vog6NpHYeNPJqOw= +github.com/tailscale/wireguard-go v0.0.0-20250716170648-1d0488a3d7da/go.mod h1:BOm5fXUBFM+m9woLNBoxI9TaBXXhGNP50LX/TGIvGb4= +github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e h1:zOGKqN5D5hHhiYUp091JqK7DPCqSARyUfduhGUY8Bek= +github.com/tailscale/xnet v0.0.0-20240729143630-8497ac4dab2e/go.mod h1:orPd6JZXXRyuDusYilywte7k094d7dycXXU5YnWsrwg= +github.com/tc-hib/winres v0.2.1 h1:YDE0FiP0VmtRaDn7+aaChp1KiF4owBiJa5l964l5ujA= +github.com/tc-hib/winres v0.2.1/go.mod h1:C/JaNhH3KBvhNKVbvdlDWkbMDO9H4fKKDaN7/07SSuk= github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e h1:IWllFTiDjjLIf2oeKxpIUmtiDV5sn71VgeQgg6vcE7k= github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e/go.mod h1:d7u6HkTYKSv5m6MCKkOQlHwaShTMl3HjqSGW3XtVhXM= -github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= -github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= -github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae/go.mod h1:DD4vA1DwXk04H54A1oHXtwZmA0grkVMdPxx/VGLCah0= -github.com/vishvananda/netns v0.0.4 h1:Oeaw1EM2JMxD51g9uhtC0D7erkIjgmj8+JZc26m1YX8= -github.com/vishvananda/netns v0.0.4/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= +github.com/tink-crypto/tink-go/v2 v2.1.0 h1:QXFBguwMwTIaU17EgZpEJWsUSc60b1BAGTzBIoMdmok= +github.com/tink-crypto/tink-go/v2 v2.1.0/go.mod h1:y1TnYFt1i2eZVfx4OGc+C+EMp4CoKWAw2VSEuoicHHI= +github.com/u-root/u-root v0.14.0 h1:Ka4T10EEML7dQ5XDvO9c3MBN8z4nuSnGjcd1jmU2ivg= +github.com/u-root/u-root v0.14.0/go.mod h1:hAyZorapJe4qzbLWlAkmSVCJGbfoU9Pu4jpJ1WMluqE= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701 h1:pyC9PaHYZFgEKFdlp3G8RaCKgVpHZnecvArXvPXcFkM= +github.com/u-root/uio v0.0.0-20240224005618-d2acac8f3701/go.mod h1:P3a5rG4X7tI17Nn3aOIAYr5HbIMukwXG0urG0WuL8OA= +github.com/vishvananda/netns v0.0.5 h1:DfiHV+j8bA32MFM7bfEunvT8IAqQ/NzSJHtcmW5zdEY= +github.com/vishvananda/netns v0.0.5/go.mod h1:SpkAiCQRtJ6TvvxPnOSyH3BMl6unz3xZlaprSwhNNJM= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/xeipuuv/gojsonpointer v0.0.0-20180127040702-4e3ac2762d5f/go.mod h1:N2zxlSyiKSe5eX1tZViRH5QA0qijqEDrYZiPEAiq3wU= @@ -365,199 +500,186 @@ github.com/xeipuuv/gojsonschema v1.2.0/go.mod h1:anYRn/JVcOK2ZgGU+IjEV4nwlhoK5sQ github.com/xo/terminfo v0.0.0-20210125001918-ca9a967f8778/go.mod h1:2MuV+tbUrU1zIOPMxZ5EncGwgmMJsa+9ucAQZXxsObs= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e h1:JVG44RsyaB9T2KIHavMF/ppJZNG9ZpyihvCd0w101no= github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e/go.mod h1:RbqR21r5mrJuqunuUZ/Dhy/avygyECGrLceyNeo4LiM= -github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= -github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= -go.uber.org/goleak v1.1.10/go.mod h1:8a7PlsEVH3e/a/GLqe5IIrQx6GzcnRmZEufDUTk4A7A= -go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= -go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= -go.uber.org/multierr v1.11.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= -go.uber.org/zap v1.18.1/go.mod h1:xg/QME4nWcxGxrpdeYfq7UvYrLh66cuVKdrbD1XF/NI= -go4.org/mem v0.0.0-20220726221520-4f986261bf13 h1:CbZeCBZ0aZj8EfVgnqQcYZgf0lpZ3H9rmp5nkDTAst8= -go4.org/mem v0.0.0-20220726221520-4f986261bf13/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= -go4.org/netipx v0.0.0-20230824141953-6213f710f925 h1:eeQDDVKFkx0g4Hyy8pHgmZaK0EqB4SD6rvKbUdN3ziQ= -go4.org/netipx v0.0.0-20230824141953-6213f710f925/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= +go.opentelemetry.io/auto/sdk v1.2.1 h1:jXsnJ4Lmnqd11kwkBV2LgLoFMZKizbCi5fNZ/ipaZ64= +go.opentelemetry.io/auto/sdk v1.2.1/go.mod h1:KRTj+aOaElaLi+wW1kO/DZRXwkF4C5xPbEe3ZiIhN7Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0 h1:ssfIgGNANqpVFCndZvcuyKbl0g+UAVcbBcqGkG28H0Y= +go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.64.0/go.mod h1:GQ/474YrbE4Jx8gZ4q5I4hrhUzM6UPzyrqJYV2AqPoQ= +go.opentelemetry.io/otel v1.39.0 h1:8yPrr/S0ND9QEfTfdP9V+SiwT4E0G7Y5MO7p85nis48= +go.opentelemetry.io/otel v1.39.0/go.mod h1:kLlFTywNWrFyEdH0oj2xK0bFYZtHRYUdv1NklR/tgc8= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0 h1:dNzwXjZKpMpE2JhmO+9HsPl42NIXFIFSUSSs0fiqra0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace v1.36.0/go.mod h1:90PoxvaEB5n6AOdZvi+yWJQoE95U8Dhhw2bSyRqnTD0= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0 h1:nRVXXvf78e00EwY6Wp0YII8ww2JVWshZ20HfTlE11AM= +go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp v1.36.0/go.mod h1:r49hO7CgrxY9Voaj3Xe8pANWtr0Oq916d0XAmOoCZAQ= +go.opentelemetry.io/otel/metric v1.39.0 h1:d1UzonvEZriVfpNKEVmHXbdf909uGTOQjA0HF0Ls5Q0= +go.opentelemetry.io/otel/metric v1.39.0/go.mod h1:jrZSWL33sD7bBxg1xjrqyDjnuzTUB0x1nBERXd7Ftcs= +go.opentelemetry.io/otel/sdk v1.39.0 h1:nMLYcjVsvdui1B/4FRkwjzoRVsMK8uL/cj0OyhKzt18= +go.opentelemetry.io/otel/sdk v1.39.0/go.mod h1:vDojkC4/jsTJsE+kh+LXYQlbL8CgrEcwmt1ENZszdJE= +go.opentelemetry.io/otel/sdk/metric v1.39.0 h1:cXMVVFVgsIf2YL6QkRF4Urbr/aMInf+2WKg+sEJTtB8= +go.opentelemetry.io/otel/sdk/metric v1.39.0/go.mod h1:xq9HEVH7qeX69/JnwEfp6fVq5wosJsY1mt4lLfYdVew= +go.opentelemetry.io/otel/trace v1.39.0 h1:2d2vfpEDmCJ5zVYz7ijaJdOF59xLomrvj7bjt6/qCJI= +go.opentelemetry.io/otel/trace v1.39.0/go.mod h1:88w4/PnZSazkGzz/w84VHpQafiU4EtqqlVdxWy+rNOA= +go.opentelemetry.io/proto/otlp v1.6.0 h1:jQjP+AQyTf+Fe7OKj/MfkDrmK4MNVtw2NpXsf9fefDI= +go.opentelemetry.io/proto/otlp v1.6.0/go.mod h1:cicgGehlFuNdgZkcALOCh3VE6K/u2tAjzlRhDwmVpZc= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.yaml.in/yaml/v2 v2.4.3 h1:6gvOSjQoTB3vt1l+CU+tSyi/HOjfOjRLJ4YwYZGwRO0= +go.yaml.in/yaml/v2 v2.4.3/go.mod h1:zSxWcmIDjOzPXpjlTTbAsKokqkDNAVtZO0WOMiT90s8= +go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc= +go.yaml.in/yaml/v3 v3.0.4/go.mod h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745 h1:Tl++JLUCe4sxGu8cTpDzRLd3tN7US4hOxG5YpKCzkek= +go4.org/mem v0.0.0-20240501181205-ae6ca9944745/go.mod h1:reUoABIJ9ikfM5sgtSF3Wushcza7+WeD01VB9Lirh3g= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba h1:0b9z3AuHCjxk0x/opv64kcgZLBseWJUpBw5I82+2U4M= +go4.org/netipx v0.0.0-20231129151722-fdeea329fbba/go.mod h1:PLyyIXexvUFg3Owu6p/WfdlivPbZJsZdgWZlrGope/Y= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20190911031432-227b76d455e7/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= -golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.0.0-20220214200702-86341886e292/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= -golang.org/x/crypto v0.16.0 h1:mMMrFzRSCF0GvB7Ne27XVtVAaXLrPmgPC7/v0tkwHaY= -golang.org/x/crypto v0.16.0/go.mod h1:gCAAfMLgwOJRpTjQ2zCCt2OcSfYMTeZVSRtQlPC7Nq4= -golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= -golang.org/x/exp v0.0.0-20231127185646-65229373498e h1:Gvh4YaCaXNs6dKTlfgismwWZKyjVZXwOPfIyUaqU3No= -golang.org/x/exp v0.0.0-20231127185646-65229373498e/go.mod h1:iRJReGqOEeBhDZGkGbynYwcHlctCvnjTYIamk7uXpHI= -golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= -golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= -golang.org/x/lint v0.0.0-20190313153728-d0100b6bd8b3/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= -golang.org/x/mod v0.2.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= -golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= +golang.org/x/crypto v0.46.0 h1:cKRW/pmt1pKAfetfu+RCEvjvZkA9RimPbh7bhFjGVBU= +golang.org/x/crypto v0.46.0/go.mod h1:Evb/oLKmMraqjZ2iQTwDwvCtJkczlDuTmdJXoZVzqU0= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546 h1:mgKeJMpvi0yx/sU5GsxQ7p6s2wtOnGAHZWCHUM4KGzY= +golang.org/x/exp v0.0.0-20251023183803-a4bb9ffd2546/go.mod h1:j/pmGrbnkbPtQfxEe5D0VQhZC6qKbfKifgD0oM7sR70= +golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f h1:phY1HzDcf18Aq9A8KkmRtY9WvOFIxN8wgfvy6Zm1DV8= +golang.org/x/exp/typeparams v0.0.0-20240314144324-c7f7c6466f7f/go.mod h1:AbB0pIl9nAr9wVwH+Z2ZpaocVmF5I4GyWCDIsVjR0bk= +golang.org/x/image v0.27.0 h1:C8gA4oWU/tKkdCfYT6T2u4faJu3MeNS5O8UPWlPF61w= +golang.org/x/image v0.27.0/go.mod h1:xbdrClrAUway1MUTEZDq9mz/UpRwYAkFFNUslZtcB+g= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= -golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= -golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= -golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190213061140-3a22650c66bd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/mod v0.30.0 h1:fDEXFVZ/fmCKProc/yAXXUijritrDzahmwwefnjoPFk= +golang.org/x/mod v0.30.0/go.mod h1:lAsf5O2EvJeSFMiBxXDki7sCgAxEUcZHXoXMKT4GJKc= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= -golang.org/x/net v0.19.0 h1:zTwKpTd2XuCqf8huc7Fo2iSy+4RHPd10s4KzeTnVr1c= -golang.org/x/net v0.19.0/go.mod h1:CfAk/cbD4CthTvqiEl8NpboMuiuOYsAr/7NOjZJtv1U= -golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.15.0 h1:s8pnnxNVzjWyrvYdFUQq5llS1PX2zhPXmccZv99h7uQ= -golang.org/x/oauth2 v0.15.0/go.mod h1:q48ptWNTY5XWf+JNten23lcvHpLJ0ZSxF5ttTHKVCAM= -golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= +golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.5.0 h1:60k92dhOjHxJkrqnwsfl8KuaHbn/5dl0lUPUklKo3qE= -golang.org/x/sync v0.5.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sync v0.19.0 h1:vV+1eWNmZ5geRlYjzm2adRgW2/mcpevXNg50YZtPCE4= +golang.org/x/sync v0.19.0/go.mod h1:9KTHXmSnoGruLpwFjVSX0lNNA75CykiMECbovNTZqGI= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20191113165036-4c7a9d0fe056/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200217220822-9197077df867/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200728102440-3e129f6d46b1/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210119212857-b64e53b001e4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210124154548-22da62e12c0c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210616094352-59db8d763f22/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211007075335-d3039528d8ac/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20211013075003-97ac67df715c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20211025201205-69cdffdb9359/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220310020820-b874c991c1a5/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220319134239-a9b59b0215f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.1-0.20230131160137-e7d7f63158de/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220817070843-5a390386f1f2/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= -golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/sys v0.40.0 h1:DBZZqJ2Rkml6QMQsZywtnjnnGvHza6BTfYFWY9kjEWQ= +golang.org/x/sys v0.40.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210220032956-6a3ed077a48d/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210615171337-6886f2dfbf5b/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= -golang.org/x/term v0.15.0 h1:y/Oo/a/q3IXu26lQgl04j/gjuBDOBlx7X6Om1j2CPW4= -golang.org/x/term v0.15.0/go.mod h1:BDl952bC7+uMoWR75FIrCDx79TPU9oHkTZ9yRbYOrX0= +golang.org/x/term v0.8.0/go.mod h1:xPskH00ivmX89bAKVGSKKtLOWNx2+17Eiy94tnKShWo= +golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= +golang.org/x/term v0.38.0 h1:PQ5pkm/rLO6HnxFR7N2lJHOZX6Kez5Y1gDSJla6jo7Q= +golang.org/x/term v0.38.0/go.mod h1:bSEAKrOT1W+VSu9TSCMtoGEOUcKxOKgl3LE5QEF/xVg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= -golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= -golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= -golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= +golang.org/x/text v0.32.0 h1:ZD01bjUt1FQ9WJ0ClOL5vxgxOI/sVCNgX1YtKwcY0mU= +golang.org/x/text v0.32.0/go.mod h1:o/rUWzghvpD5TXrTIBuJU77MTaN0ljMWE47kxGJQ7jY= +golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= +golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= -golang.org/x/tools v0.0.0-20190226205152-f727befe758c/go.mod h1:9Yl7xja0Znq3iFh3HoIrodX9oNMXvdceNzlUR8zjMvY= -golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20190524140312-2c0ae7006135/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= -golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= -golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= -golang.org/x/tools v0.1.0/go.mod h1:xkSsbof2nBLbhDlRMhhhyNLN/zl3eTqcnHD5viDpcZ0= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= -golang.org/x/tools v0.16.0 h1:GO788SKMRunPIBCXiQyo2AaexLstOrVhuAL5YwsckQM= -golang.org/x/tools v0.16.0/go.mod h1:kYVVN6I1mBNoB1OX+noeBjbRk4IUEPa7JJ+TJMEooJ0= +golang.org/x/tools v0.39.0 h1:ik4ho21kwuQln40uelmciQPp9SipgNDdrafrYA4TmQQ= +golang.org/x/tools v0.39.0/go.mod h1:JnefbkDPyD8UU2kI5fuf8ZX4/yUeh9W877ZeBONxUqQ= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 h1:B82qJJgjvYKsXS9jeunTOisW56dUokqW/FOteYJJ/yg= +golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2/go.mod h1:deeaetjYA+DHMHg+sMSMI58GrEteJUUzzw7en6TJQcI= golang.zx2c4.com/wireguard/windows v0.5.3 h1:On6j2Rpn3OEMXqBq00QEDC7bWSZrPIHKIus8eIuExIE= golang.zx2c4.com/wireguard/windows v0.5.3/go.mod h1:9TEe8TJmtwyQebdFwAkEWOPr3prrtqm+REGFifP60hI= -google.golang.org/appengine v1.1.0/go.mod h1:EbEs0AVv82hx2wNQdGPgUI5lhzA/G0D9YwlJXL52JkM= -google.golang.org/appengine v1.4.0/go.mod h1:xpcJRLb0r/rnEns0DIKYYv+WjYCduHsrkT7/EB5XEv4= -google.golang.org/appengine v1.6.8 h1:IhEN5q69dyKagZPYMSdIjS2HqprW324FRQZJcGqPAsM= -google.golang.org/appengine v1.6.8/go.mod h1:1jJ3jBArFh5pcgW8gCtRJnepW8FzD1V44FJffLiz/Ds= -google.golang.org/genproto v0.0.0-20180817151627-c66870c02cf8/go.mod h1:JiN7NxoALGmiZfu7CAH4rXhgtRTLTxftemlI0sWmxmc= -google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98Agz4BDEuKkezgsaosCRResVns1a3J2ZsMNc= -google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= -google.golang.org/genproto v0.0.0-20231127180814-3a041ad873d4 h1:W12Pwm4urIbRdGhMEg2NM9O3TWKjNcxQhs46V0ypf/k= -google.golang.org/genproto v0.0.0-20231127180814-3a041ad873d4/go.mod h1:5RBcpGRxr25RbDzY5w+dmaqpSEvl8Gwl1x2CICf60ic= -google.golang.org/genproto/googleapis/api v0.0.0-20231127180814-3a041ad873d4 h1:ZcOkrmX74HbKFYnpPY8Qsw93fC29TbJXspYKaBkSXDQ= -google.golang.org/genproto/googleapis/api v0.0.0-20231127180814-3a041ad873d4/go.mod h1:k2dtGpRrbsSyKcNPKKI5sstZkrNCZwpU/ns96JoHbGg= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4 h1:DC7wcm+i+P1rN3Ff07vL+OndGg5OhNddHyTA+ocPqYE= -google.golang.org/genproto/googleapis/rpc v0.0.0-20231127180814-3a041ad873d4/go.mod h1:eJVxU6o+4G1PSczBr85xmyvSNYAKvAYgkub40YGomFM= -google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= -google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= -google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQciAY= -google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= -google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= -google.golang.org/grpc v1.59.0 h1:Z5Iec2pjwb+LEOqzpB2MR12/eKFhDPhuqW91O+4bwUk= -google.golang.org/grpc v1.59.0/go.mod h1:aUPDwccQo6OTjy7Hct4AfBPD1GptF4fyUjIkQ9YtF98= -google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= -google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= -google.golang.org/protobuf v1.31.0 h1:g0LDEJHgrBl9N9r17Ru3sqWhkIx2NB67okBHPwC7hs8= -google.golang.org/protobuf v1.31.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gonum.org/v1/gonum v0.16.0 h1:5+ul4Swaf3ESvrOnidPp4GZbzf0mxVQpDCYUQE7OJfk= +gonum.org/v1/gonum v0.16.0/go.mod h1:fef3am4MQ93R2HHpKnLk4/Tbh/s0+wqD5nfa6Pnwy4E= +google.golang.org/genproto/googleapis/api v0.0.0-20251222181119-0a764e51fe1b h1:uA40e2M6fYRBf0+8uN5mLlqUtV192iiksiICIBkYJ1E= +google.golang.org/genproto/googleapis/api v0.0.0-20251222181119-0a764e51fe1b/go.mod h1:Xa7le7qx2vmqB/SzWUBa7KdMjpdpAHlh5QCSnjessQk= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b h1:Mv8VFug0MP9e5vUxfBcE3vUkV6CImK3cMNMIDFjmzxU= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251222181119-0a764e51fe1b/go.mod h1:j9x/tPzZkyxcgEFkiKEEGxfvyumM01BEtsW8xzOahRQ= +google.golang.org/grpc v1.78.0 h1:K1XZG/yGDJnzMdd/uZHAkVqJE+xIDOcmdSFZkBUicNc= +google.golang.org/grpc v1.78.0/go.mod h1:I47qjTo4OKbMkjA/aOOwxDIiPSBofUtQUI5EfpWvW7U= +google.golang.org/protobuf v1.36.11 h1:fV6ZwhNocDyBLK0dj+fg8ektcVegBBuEolpbTQyBNVE= +google.golang.org/protobuf v1.36.11/go.mod h1:HTf+CrKn2C3g5S8VImy6tdcUvCska2kB7j23XfzDpco= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= -gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= -gopkg.in/square/go-jose.v2 v2.6.0 h1:NGk74WTnPKBNUhNzQX7PYcTLUjoq7mzKk2OKbvwk2iI= -gopkg.in/square/go-jose.v2 v2.6.0/go.mod h1:M9dMgbHiYLoDGQrXy7OpJDJWiKiU//h+vD76mk0e1AI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gorm.io/driver/postgres v1.5.4 h1:Iyrp9Meh3GmbSuyIAGyjkN+n9K+GHX9b9MqsTL4EJCo= -gorm.io/driver/postgres v1.5.4/go.mod h1:Bgo89+h0CRcdA33Y6frlaHHVuTdOf87pmyzwW9C/BH0= -gorm.io/gorm v1.25.5 h1:zR9lOiiYf09VNh5Q1gphfyia1JpiClIWG9hQaxB/mls= -gorm.io/gorm v1.25.5/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= -gotest.tools/v3 v3.4.0 h1:ZazjZUfuVeZGLAmlKKuyv3IKP5orXcwtOwDQH6YVr6o= -gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= -honnef.co/go/tools v0.0.0-20190102054323-c2f93a96b099/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= -honnef.co/go/tools v0.0.0-20190523083050-ea95bdfd59fc/go.mod h1:rf3lG4BRIbNafJWhAfAdb/ePZxsR/4RtNHQocxwk9r4= +gorm.io/driver/postgres v1.6.0 h1:2dxzU8xJ+ivvqTRph34QX+WrRaJlmfyPqXmoGVjMBa4= +gorm.io/driver/postgres v1.6.0/go.mod h1:vUw0mrGgrTK+uPHEhAdV4sfFELrByKVGnaVRkXDhtWo= +gorm.io/gorm v1.31.1 h1:7CA8FTFz/gRfgqgpeKIBcervUn3xSyPUmr6B2WXJ7kg= +gorm.io/gorm v1.31.1/go.mod h1:XyQVbO2k6YkOis7C2437jSit3SsDK72s7n7rsSHd+Gs= +gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= +gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633 h1:2gap+Kh/3F47cO6hAu3idFvsJ0ue6TRcEi2IUkv/F8k= +gvisor.dev/gvisor v0.0.0-20250205023644-9414b50a5633/go.mod h1:5DMfjtclAbTIjbXqO1qCe2K5GKKxWz2JHvCChuTcJEM= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0 h1:5SXjd4ET5dYijLaf0O3aOenC0Z4ZafIWSpjUzsQaNho= +honnef.co/go/tools v0.7.0-0.dev.0.20251022135355-8273271481d0/go.mod h1:EPDDhEZqVHhWuPI5zPAsjU0U7v9xNIWjoOVyZ5ZcniQ= howett.net/plist v1.0.0 h1:7CrbWYbPPO/PyNy38b2EB/+gYbjCe2DXBxgtOOZbSQM= howett.net/plist v1.0.0/go.mod h1:lqaXoTrLY4hg8tnEzNru53gicrbv7rrk+2xJA/7hw9g= -modernc.org/libc v1.34.11 h1:hQDcIUlSG4QAOkXCIQKkaAOV5ptXvkOx4ddbXzgW2JU= -modernc.org/libc v1.34.11/go.mod h1:YAXkAZ8ktnkCKaN9sw/UDeUVkGYJ/YquGO4FTi5nmHE= -modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= -modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= -modernc.org/memory v1.7.2 h1:Klh90S215mmH8c9gO98QxQFsY+W451E8AnzjoE2ee1E= -modernc.org/memory v1.7.2/go.mod h1:NO4NVCQy0N7ln+T9ngWqOQfi7ley4vpwvARR+Hjw95E= -modernc.org/sqlite v1.27.0 h1:MpKAHoyYB7xqcwnUwkuD+npwEa0fojF0B5QRbN+auJ8= -modernc.org/sqlite v1.27.0/go.mod h1:Qxpazz0zH8Z1xCFyi5GSL3FzbtZ3fvbjmywNogldEW0= -nhooyr.io/websocket v1.8.10 h1:mv4p+MnGrLDcPlBoWsvPP7XCzTYMXP9F9eIGoKbgx7Q= -nhooyr.io/websocket v1.8.10/go.mod h1:rN9OFWIUwuxg4fR5tELlYC04bXYowCP9GX47ivo2l+c= -software.sslmate.com/src/go-pkcs12 v0.2.1 h1:tbT1jjaeFOF230tzOIRJ6U5S1jNqpsSyNjzDd58H3J8= -software.sslmate.com/src/go-pkcs12 v0.2.1/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= -tailscale.com v1.54.0 h1:Dri5BTKkHYpl+/t8ofY+tyvoTDbH/FpP7iB4B0cAQOY= -tailscale.com v1.54.0/go.mod h1:MnLFoCRwzFWr3qtkSW2nZdQpK7wQRZEk1KtcEGAuZYw= +modernc.org/cc/v4 v4.27.1 h1:9W30zRlYrefrDV2JE2O8VDtJ1yPGownxciz5rrbQZis= +modernc.org/cc/v4 v4.27.1/go.mod h1:uVtb5OGqUKpoLWhqwNQo/8LwvoiEBLvZXIQ/SmO6mL0= +modernc.org/ccgo/v4 v4.30.1 h1:4r4U1J6Fhj98NKfSjnPUN7Ze2c6MnAdL0hWw6+LrJpc= +modernc.org/ccgo/v4 v4.30.1/go.mod h1:bIOeI1JL54Utlxn+LwrFyjCx2n2RDiYEaJVSrgdrRfM= +modernc.org/fileutil v1.3.40 h1:ZGMswMNc9JOCrcrakF1HrvmergNLAmxOPjizirpfqBA= +modernc.org/fileutil v1.3.40/go.mod h1:HxmghZSZVAz/LXcMNwZPA/DRrQZEVP9VX0V4LQGQFOc= +modernc.org/gc/v2 v2.6.5 h1:nyqdV8q46KvTpZlsw66kWqwXRHdjIlJOhG6kxiV/9xI= +modernc.org/gc/v2 v2.6.5/go.mod h1:YgIahr1ypgfe7chRuJi2gD7DBQiKSLMPgBQe9oIiito= +modernc.org/gc/v3 v3.1.1 h1:k8T3gkXWY9sEiytKhcgyiZ2L0DTyCQ/nvX+LoCljoRE= +modernc.org/gc/v3 v3.1.1/go.mod h1:HFK/6AGESC7Ex+EZJhJ2Gni6cTaYpSMmU/cT9RmlfYY= +modernc.org/goabi0 v0.2.0 h1:HvEowk7LxcPd0eq6mVOAEMai46V+i7Jrj13t4AzuNks= +modernc.org/goabi0 v0.2.0/go.mod h1:CEFRnnJhKvWT1c1JTI3Avm+tgOWbkOu5oPA8eH8LnMI= +modernc.org/libc v1.67.6 h1:eVOQvpModVLKOdT+LvBPjdQqfrZq+pC39BygcT+E7OI= +modernc.org/libc v1.67.6/go.mod h1:JAhxUVlolfYDErnwiqaLvUqc8nfb2r6S6slAgZOnaiE= +modernc.org/mathutil v1.7.1 h1:GCZVGXdaN8gTqB1Mf/usp1Y/hSqgI2vAGGP4jZMCxOU= +modernc.org/mathutil v1.7.1/go.mod h1:4p5IwJITfppl0G4sUEDtCr4DthTaT47/N3aT6MhfgJg= +modernc.org/memory v1.11.0 h1:o4QC8aMQzmcwCK3t3Ux/ZHmwFPzE6hf2Y5LbkRs+hbI= +modernc.org/memory v1.11.0/go.mod h1:/JP4VbVC+K5sU2wZi9bHoq2MAkCnrt2r98UGeSK7Mjw= +modernc.org/opt v0.1.4 h1:2kNGMRiUjrp4LcaPuLY2PzUfqM/w9N23quVwhKt5Qm8= +modernc.org/opt v0.1.4/go.mod h1:03fq9lsNfvkYSfxrfUhZCWPk1lm4cq4N+Bh//bEtgns= +modernc.org/sortutil v1.2.1 h1:+xyoGf15mM3NMlPDnFqrteY07klSFxLElE2PVuWIJ7w= +modernc.org/sortutil v1.2.1/go.mod h1:7ZI3a3REbai7gzCLcotuw9AC4VZVpYMjDzETGsSMqJE= +modernc.org/sqlite v1.44.3 h1:+39JvV/HWMcYslAwRxHb8067w+2zowvFOUrOWIy9PjY= +modernc.org/sqlite v1.44.3/go.mod h1:CzbrU2lSB1DKUusvwGz7rqEKIq+NUd8GWuBBZDs9/nA= +modernc.org/strutil v1.2.1 h1:UneZBkQA+DX2Rp35KcM69cSsNES9ly8mQWD71HKlOA0= +modernc.org/strutil v1.2.1/go.mod h1:EHkiggD70koQxjVdSBM3JKM7k6L0FbGE5eymy9i3B9A= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= +software.sslmate.com/src/go-pkcs12 v0.4.0 h1:H2g08FrTvSFKUj+D309j1DPfk5APnIdAQAB8aEykJ5k= +software.sslmate.com/src/go-pkcs12 v0.4.0/go.mod h1:Qiz0EyvDRJjjxGyUQa2cCNZn/wMyzrRJ/qcDXOQazLI= +tailscale.com v1.94.0 h1:5oW3SF35aU9ekHDhP2J4CHewnA2NxE7SRilDB2pVjaA= +tailscale.com v1.94.0/go.mod h1:gLnVrEOP32GWvroaAHHGhjSGMPJ1i4DvqNwEg+Yuov4= +zgo.at/zcache/v2 v2.4.1 h1:Dfjoi8yI0Uq7NCc4lo2kaQJJmp9Mijo21gef+oJstbY= +zgo.at/zcache/v2 v2.4.1/go.mod h1:gyCeoLVo01QjDZynjime8xUGHHMbsLiPyUTBpDGd4Gk= +zombiezen.com/go/postgrestest v1.0.1 h1:aXoADQAJmZDU3+xilYVut0pHhgc0sF8ZspPW9gFNwP4= +zombiezen.com/go/postgrestest v1.0.1/go.mod h1:marlZezr+k2oSJrvXHnZUs1olHqpE9czlz8ZYkVxliQ= diff --git a/hscontrol/app.go b/hscontrol/app.go index 3ad32788..aa011503 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -8,37 +8,39 @@ import ( "io" "net" "net/http" - _ "net/http/pprof" //nolint + _ "net/http/pprof" // nolint "os" "os/signal" + "path/filepath" "runtime" - "strconv" "strings" "sync" "syscall" "time" - "github.com/coreos/go-oidc/v3/oidc" + "github.com/cenkalti/backoff/v5" + "github.com/davecgh/go-spew/spew" "github.com/gorilla/mux" - grpcMiddleware "github.com/grpc-ecosystem/go-grpc-middleware" grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" "github.com/juanfont/headscale" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/derp" derpServer "github.com/juanfont/headscale/hscontrol/derp/server" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/dns" + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" - "github.com/patrickmn/go-cache" zerolog "github.com/philip-bui/grpc-zerolog" - "github.com/prometheus/client_golang/prometheus/promhttp" + "github.com/pkg/profile" zl "github.com/rs/zerolog" "github.com/rs/zerolog/log" + "github.com/sasha-s/go-deadlock" "golang.org/x/crypto/acme" "golang.org/x/crypto/acme/autocert" - "golang.org/x/oauth2" "golang.org/x/sync/errgroup" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -48,56 +50,72 @@ import ( "google.golang.org/grpc/peer" "google.golang.org/grpc/reflection" "google.golang.org/grpc/status" + "tailscale.com/envknob" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" "tailscale.com/types/key" + "tailscale.com/util/dnsname" ) var ( errSTUNAddressNotSet = errors.New("STUN address not set") - errUnsupportedDatabase = errors.New("unsupported DB") errUnsupportedLetsEncryptChallengeType = errors.New( "unknown value for Lets Encrypt challenge type", ) - errEmptyInitialDERPMap = errors.New("initial DERPMap is empty, Headscale requries at least one entry") + errEmptyInitialDERPMap = errors.New( + "initial DERPMap is empty, Headscale requires at least one entry", + ) ) +var ( + debugDeadlock = envknob.Bool("HEADSCALE_DEBUG_DEADLOCK") + debugDeadlockTimeout = envknob.RegisterDuration("HEADSCALE_DEBUG_DEADLOCK_TIMEOUT") +) + +func init() { + deadlock.Opts.Disable = !debugDeadlock + if debugDeadlock { + deadlock.Opts.DeadlockTimeout = debugDeadlockTimeout() + deadlock.Opts.PrintAllCurrentGoroutines = true + } +} + const ( AuthPrefix = "Bearer " - updateInterval = 5000 + updateInterval = 5 * time.Second privateKeyFileMode = 0o600 - - registerCacheExpiration = time.Minute * 15 - registerCacheCleanup = time.Minute * 20 + headscaleDirPerm = 0o700 ) // Headscale represents the base app of the service. type Headscale struct { cfg *types.Config - db *db.HSDatabase - dbString string - dbType string - dbDebug bool + state *state.State noisePrivateKey *key.MachinePrivate + ephemeralGC *db.EphemeralGarbageCollector - DERPMap *tailcfg.DERPMap DERPServer *derpServer.DERPServer - ACLPolicy *policy.ACLPolicy + // Things that generate changes + extraRecordMan *dns.ExtraRecordsMan + authProvider AuthProvider + mapBatcher mapper.Batcher - nodeNotifier *notifier.Notifier - - oidcProvider *oidc.Provider - oauth2Config *oauth2.Config - - registrationCache *cache.Cache - - shutdownChan chan struct{} - pollNetMapStreamWG sync.WaitGroup + clientStreamsOpen sync.WaitGroup } +var ( + profilingEnabled = envknob.Bool("HEADSCALE_DEBUG_PROFILING_ENABLED") + profilingPath = envknob.String("HEADSCALE_DEBUG_PROFILING_PATH") + tailsqlEnabled = envknob.Bool("HEADSCALE_DEBUG_TAILSQL_ENABLED") + tailsqlStateDir = envknob.String("HEADSCALE_DEBUG_TAILSQL_STATE_DIR") + tailsqlTSKey = envknob.String("TS_AUTHKEY") + dumpConfig = envknob.Bool("HEADSCALE_DEBUG_DUMP_CONFIG") +) + func NewHeadscale(cfg *types.Config) (*Headscale, error) { - if _, enableProfile := os.LookupEnv("HEADSCALE_PROFILING_ENABLED"); enableProfile { + var err error + if profilingEnabled { runtime.SetBlockProfileRate(1) } @@ -106,84 +124,82 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { return nil, fmt.Errorf("failed to read or create Noise protocol private key: %w", err) } - var dbString string - switch cfg.DBtype { - case db.Postgres: - dbString = fmt.Sprintf( - "host=%s dbname=%s user=%s", - cfg.DBhost, - cfg.DBname, - cfg.DBuser, - ) - - if sslEnabled, err := strconv.ParseBool(cfg.DBssl); err == nil { - if !sslEnabled { - dbString += " sslmode=disable" - } - } else { - dbString += fmt.Sprintf(" sslmode=%s", cfg.DBssl) - } - - if cfg.DBport != 0 { - dbString += fmt.Sprintf(" port=%d", cfg.DBport) - } - - if cfg.DBpass != "" { - dbString += fmt.Sprintf(" password=%s", cfg.DBpass) - } - case db.Sqlite: - dbString = cfg.DBpath - default: - return nil, errUnsupportedDatabase + s, err := state.NewState(cfg) + if err != nil { + return nil, fmt.Errorf("init state: %w", err) } - registrationCache := cache.New( - registerCacheExpiration, - registerCacheCleanup, - ) - app := Headscale{ - cfg: cfg, - dbType: cfg.DBtype, - dbString: dbString, - noisePrivateKey: noisePrivateKey, - registrationCache: registrationCache, - pollNetMapStreamWG: sync.WaitGroup{}, - nodeNotifier: notifier.NewNotifier(), + cfg: cfg, + noisePrivateKey: noisePrivateKey, + clientStreamsOpen: sync.WaitGroup{}, + state: s, } - database, err := db.NewHeadscaleDatabase( - cfg.DBtype, - dbString, - app.dbDebug, - app.nodeNotifier, - cfg.IPPrefixes, - cfg.BaseDomain) - if err != nil { - return nil, err - } + // Initialize ephemeral garbage collector + ephemeralGC := db.NewEphemeralGarbageCollector(func(ni types.NodeID) { + node, ok := app.state.GetNodeByID(ni) + if !ok { + log.Error().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed") + log.Debug().Caller().Uint64("node.id", ni.Uint64()).Msg("Ephemeral node deletion failed because node not found in NodeStore") + return + } - app.db = database + policyChanged, err := app.state.DeleteNode(node) + if err != nil { + log.Error().Err(err).Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deletion failed") + return + } + app.Change(policyChanged) + log.Debug().Caller().Uint64("node.id", ni.Uint64()).Str("node.name", node.Hostname()).Msg("Ephemeral node deleted because garbage collection timeout reached") + }) + app.ephemeralGC = ephemeralGC + + var authProvider AuthProvider + authProvider = NewAuthProviderWeb(cfg.ServerURL) if cfg.OIDC.Issuer != "" { - err = app.initOIDC() + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + oidcProvider, err := NewAuthProviderOIDC( + ctx, + &app, + cfg.ServerURL, + &cfg.OIDC, + ) if err != nil { if cfg.OIDC.OnlyStartIfOIDCIsAvailable { return nil, err } else { log.Warn().Err(err).Msg("failed to set up OIDC provider, falling back to CLI based authentication") } + } else { + authProvider = oidcProvider } } + app.authProvider = authProvider + + if app.cfg.TailcfgDNSConfig != nil && app.cfg.TailcfgDNSConfig.Proxied { // if MagicDNS + // TODO(kradalby): revisit why this takes a list. + + var magicDNSDomains []dnsname.FQDN + if cfg.PrefixV4 != nil { + magicDNSDomains = append( + magicDNSDomains, + util.GenerateIPv4DNSRootDomain(*cfg.PrefixV4)...) + } + if cfg.PrefixV6 != nil { + magicDNSDomains = append( + magicDNSDomains, + util.GenerateIPv6DNSRootDomain(*cfg.PrefixV6)...) + } - if app.cfg.DNSConfig != nil && app.cfg.DNSConfig.Proxied { // if MagicDNS - magicDNSDomains := util.GenerateMagicDNSRootDomains(app.cfg.IPPrefixes) // we might have routes already from Split DNS - if app.cfg.DNSConfig.Routes == nil { - app.cfg.DNSConfig.Routes = make(map[string][]*dnstype.Resolver) + if app.cfg.TailcfgDNSConfig.Routes == nil { + app.cfg.TailcfgDNSConfig.Routes = make(map[string][]*dnstype.Resolver) } for _, d := range magicDNSDomains { - app.cfg.DNSConfig.Routes[d.WithoutTrailingDot()] = nil + app.cfg.TailcfgDNSConfig.Routes[d.WithoutTrailingDot()] = nil } } @@ -200,6 +216,14 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { ) } + if cfg.DERP.ServerVerifyClients { + t := http.DefaultTransport.(*http.Transport) //nolint:forcetypeassert + t.RegisterProtocol( + derpServer.DerpVerifyScheme, + derpServer.NewDERPVerifyTransport(app.handleVerifyRequest), + ) + } + embeddedDERPServer, err := derpServer.NewDERPServer( cfg.ServerURL, key.NodePrivate(*derpServerKey), @@ -220,69 +244,89 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) { http.Redirect(w, req, target, http.StatusFound) } -// expireEphemeralNodes deletes ephemeral node records that have not been -// seen for longer than h.cfg.EphemeralNodeInactivityTimeout. -func (h *Headscale) expireEphemeralNodes(milliSeconds int64) { - ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) - for range ticker.C { - h.db.ExpireEphemeralNodes(h.cfg.EphemeralNodeInactivityTimeout) +func (h *Headscale) scheduledTasks(ctx context.Context) { + expireTicker := time.NewTicker(updateInterval) + defer expireTicker.Stop() + + lastExpiryCheck := time.Unix(0, 0) + + derpTickerChan := make(<-chan time.Time) + if h.cfg.DERP.AutoUpdate && h.cfg.DERP.UpdateFrequency != 0 { + derpTicker := time.NewTicker(h.cfg.DERP.UpdateFrequency) + defer derpTicker.Stop() + derpTickerChan = derpTicker.C } -} -// expireExpiredMachines expires nodes that have an explicit expiry set -// after that expiry time has passed. -func (h *Headscale) expireExpiredMachines(intervalMs int64) { - interval := time.Duration(intervalMs) * time.Millisecond - ticker := time.NewTicker(interval) - - lastCheck := time.Unix(0, 0) - - for range ticker.C { - lastCheck = h.db.ExpireExpiredNodes(lastCheck) + var extraRecordsUpdate <-chan []tailcfg.DNSRecord + if h.extraRecordMan != nil { + extraRecordsUpdate = h.extraRecordMan.UpdateCh() + } else { + extraRecordsUpdate = make(chan []tailcfg.DNSRecord) } -} - -// scheduledDERPMapUpdateWorker refreshes the DERPMap stored on the global object -// at a set interval. -func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { - log.Info(). - Dur("frequency", h.cfg.DERP.UpdateFrequency). - Msg("Setting up a DERPMap update worker") - ticker := time.NewTicker(h.cfg.DERP.UpdateFrequency) for { select { - case <-cancelChan: + case <-ctx.Done(): + log.Info().Caller().Msg("scheduled task worker is shutting down.") return - case <-ticker.C: - log.Info().Msg("Fetching DERPMap updates") - h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - if h.cfg.DERP.ServerEnabled { - region, _ := h.DERPServer.GenerateRegion() - h.DERPMap.Regions[region.RegionID] = ®ion + case <-expireTicker.C: + var expiredNodeChanges []change.Change + var changed bool + + lastExpiryCheck, expiredNodeChanges, changed = h.state.ExpireExpiredNodes(lastExpiryCheck) + + if changed { + log.Trace().Interface("changes", expiredNodeChanges).Msgf("expiring nodes") + + // Send the changes directly since they're already in the new format + for _, nodeChange := range expiredNodeChanges { + h.Change(nodeChange) + } } - stateUpdate := types.StateUpdate{ - Type: types.StateDERPUpdated, - DERPMap: h.DERPMap, + case <-derpTickerChan: + log.Info().Msg("Fetching DERPMap updates") + derpMap, err := backoff.Retry(ctx, func() (*tailcfg.DERPMap, error) { + derpMap, err := derp.GetDERPMap(h.cfg.DERP) + if err != nil { + return nil, err + } + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { + region, _ := h.DERPServer.GenerateRegion() + derpMap.Regions[region.RegionID] = ®ion + } + + return derpMap, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff())) + if err != nil { + log.Error().Err(err).Msg("failed to build new DERPMap, retrying later") + continue } - if stateUpdate.Valid() { - h.nodeNotifier.NotifyAll(stateUpdate) + h.state.SetDERPMap(derpMap) + + h.Change(change.DERPMap()) + + case records, ok := <-extraRecordsUpdate: + if !ok { + continue } + h.cfg.TailcfgDNSConfig.ExtraRecords = records + + h.Change(change.ExtraRecords()) } } } func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, - req interface{}, + req any, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, -) (interface{}, error) { +) (any, error) { // Check if the request is coming from the on-server client. // This is not secure, but it is to maintain maintainability // with the "legacy" database-based client - // It is also neede for grpc-gateway to be able to connect to + // It is also needed for grpc-gateway to be able to connect to // the server client, _ := peer.FromContext(ctx) @@ -293,11 +337,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, meta, ok := metadata.FromIncomingContext(ctx) if !ok { - log.Error(). - Caller(). - Str("client_address", client.Addr.String()). - Msg("Retrieving metadata is failed") - return ctx, status.Errorf( codes.InvalidArgument, "Retrieving metadata is failed", @@ -306,11 +345,6 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, authHeader, ok := meta["authorization"] if !ok { - log.Error(). - Caller(). - Str("client_address", client.Addr.String()). - Msg("Authorization token is not supplied") - return ctx, status.Errorf( codes.Unauthenticated, "Authorization token is not supplied", @@ -320,25 +354,14 @@ func (h *Headscale) grpcAuthenticationInterceptor(ctx context.Context, token := authHeader[0] if !strings.HasPrefix(token, AuthPrefix) { - log.Error(). - Caller(). - Str("client_address", client.Addr.String()). - Msg(`missing "Bearer " prefix in "Authorization" header`) - return ctx, status.Error( codes.Unauthenticated, `missing "Bearer " prefix in "Authorization" header`, ) } - valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) + valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(token, AuthPrefix)) if err != nil { - log.Error(). - Caller(). - Err(err). - Str("client_address", client.Addr.String()). - Msg("failed to validate token") - return ctx, status.Error(codes.Internal, "failed to validate token") } @@ -363,42 +386,32 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler Str("client_address", req.RemoteAddr). Msg("HTTP authentication invoked") - authHeader := req.Header.Get("authorization") + authHeader := req.Header.Get("Authorization") + + writeUnauthorized := func(statusCode int) { + writer.WriteHeader(statusCode) + if _, err := writer.Write([]byte("Unauthorized")); err != nil { + log.Error().Err(err).Msg("writing HTTP response failed") + } + } if !strings.HasPrefix(authHeader, AuthPrefix) { log.Error(). Caller(). Str("client_address", req.RemoteAddr). Msg(`missing "Bearer " prefix in "Authorization" header`) - writer.WriteHeader(http.StatusUnauthorized) - _, err := writer.Write([]byte("Unauthorized")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + writeUnauthorized(http.StatusUnauthorized) return } - valid, err := h.db.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) + valid, err := h.state.ValidateAPIKey(strings.TrimPrefix(authHeader, AuthPrefix)) if err != nil { - log.Error(). + log.Info(). Caller(). Err(err). Str("client_address", req.RemoteAddr). Msg("failed to validate token") - - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Unauthorized")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + writeUnauthorized(http.StatusUnauthorized) return } @@ -406,16 +419,7 @@ func (h *Headscale) httpAuthenticationMiddleware(next http.Handler) http.Handler log.Info(). Str("client_address", req.RemoteAddr). Msg("invalid token") - - writer.WriteHeader(http.StatusUnauthorized) - _, err := writer.Write([]byte("Unauthorized")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + writeUnauthorized(http.StatusUnauthorized) return } @@ -436,49 +440,80 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router := mux.NewRouter() - router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux) + router.Use(prometheusMiddleware) - router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost) + router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler). + Methods(http.MethodPost, http.MethodGet) + router.HandleFunc("/robots.txt", h.RobotsHandler).Methods(http.MethodGet) router.HandleFunc("/health", h.HealthHandler).Methods(http.MethodGet) + router.HandleFunc("/version", h.VersionHandler).Methods(http.MethodGet) router.HandleFunc("/key", h.KeyHandler).Methods(http.MethodGet) - router.HandleFunc("/register/{mkey}", h.RegisterWebAPI).Methods(http.MethodGet) + router.HandleFunc("/register/{registration_id}", h.authProvider.RegisterHandler). + Methods(http.MethodGet) - router.HandleFunc("/oidc/register/{mkey}", h.RegisterOIDC).Methods(http.MethodGet) - router.HandleFunc("/oidc/callback", h.OIDCCallback).Methods(http.MethodGet) + if provider, ok := h.authProvider.(*AuthProviderOIDC); ok { + router.HandleFunc("/oidc/callback", provider.OIDCCallbackHandler).Methods(http.MethodGet) + } router.HandleFunc("/apple", h.AppleConfigMessage).Methods(http.MethodGet) router.HandleFunc("/apple/{platform}", h.ApplePlatformConfig). Methods(http.MethodGet) router.HandleFunc("/windows", h.WindowsConfigMessage).Methods(http.MethodGet) - router.HandleFunc("/windows/tailscale.reg", h.WindowsRegConfig). - Methods(http.MethodGet) // TODO(kristoffer): move swagger into a package router.HandleFunc("/swagger", headscale.SwaggerUI).Methods(http.MethodGet) router.HandleFunc("/swagger/v1/openapiv2.json", headscale.SwaggerAPIv1). Methods(http.MethodGet) + router.HandleFunc("/verify", h.VerifyHandler).Methods(http.MethodPost) + if h.cfg.DERP.ServerEnabled { router.HandleFunc("/derp", h.DERPServer.DERPHandler) router.HandleFunc("/derp/probe", derpServer.DERPProbeHandler) - router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.DERPMap)) + router.HandleFunc("/derp/latency-check", derpServer.DERPProbeHandler) + router.HandleFunc("/bootstrap-dns", derpServer.DERPBootstrapDNSHandler(h.state.DERPMap())) } apiRouter := router.PathPrefix("/api").Subrouter() apiRouter.Use(h.httpAuthenticationMiddleware) apiRouter.PathPrefix("/v1/").HandlerFunc(grpcMux.ServeHTTP) - - router.PathPrefix("/").HandlerFunc(notFoundHandler) + router.HandleFunc("/favicon.ico", FaviconHandler) + router.PathPrefix("/").HandlerFunc(BlankHandler) return router } -// Serve launches a GIN server with the Headscale API. +// Serve launches the HTTP and gRPC server service Headscale and the API. func (h *Headscale) Serve() error { var err error + capver.CanOldCodeBeCleanedUp() - // Fetch an initial DERP Map before we start serving - h.DERPMap = derp.GetDERPMap(h.cfg.DERP) + if profilingEnabled { + if profilingPath != "" { + err = os.MkdirAll(profilingPath, os.ModePerm) + if err != nil { + log.Fatal().Err(err).Msg("failed to create profiling directory") + } + + defer profile.Start(profile.ProfilePath(profilingPath)).Stop() + } else { + defer profile.Start().Stop() + } + } + + if dumpConfig { + spew.Dump(h.cfg) + } + + versionInfo := types.GetVersionInfo() + log.Info().Str("version", versionInfo.Version).Str("commit", versionInfo.Commit).Msg("Starting Headscale") + log.Info(). + Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)). + Msg("Clients with a lower minimum version will be rejected") + + h.mapBatcher = mapper.NewBatcherAndMapper(h.cfg, h.state) + h.mapBatcher.Start() + defer h.mapBatcher.Close() if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server @@ -486,30 +521,50 @@ func (h *Headscale) Serve() error { return errSTUNAddressNotSet } - region, err := h.DERPServer.GenerateRegion() - if err != nil { - return err - } - - h.DERPMap.Regions[region.RegionID] = ®ion - go h.DERPServer.ServeSTUN() } - if h.cfg.DERP.AutoUpdate { - derpMapCancelChannel := make(chan struct{}) - defer func() { derpMapCancelChannel <- struct{}{} }() - go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) + derpMap, err := derp.GetDERPMap(h.cfg.DERP) + if err != nil { + return fmt.Errorf("failed to get DERPMap: %w", err) } - if len(h.DERPMap.Regions) == 0 { + if h.cfg.DERP.ServerEnabled && h.cfg.DERP.AutomaticallyAddEmbeddedDerpRegion { + region, _ := h.DERPServer.GenerateRegion() + derpMap.Regions[region.RegionID] = ®ion + } + + if len(derpMap.Regions) == 0 { return errEmptyInitialDERPMap } - // TODO(kradalby): These should have cancel channels and be cleaned - // up on shutdown. - go h.expireEphemeralNodes(updateInterval) - go h.expireExpiredMachines(updateInterval) + h.state.SetDERPMap(derpMap) + + // Start ephemeral node garbage collector and schedule all nodes + // that are already in the database and ephemeral. If they are still + // around between restarts, they will reconnect and the GC will + // be cancelled. + go h.ephemeralGC.Start() + ephmNodes := h.state.ListEphemeralNodes() + for _, node := range ephmNodes.All() { + h.ephemeralGC.Schedule(node.ID(), h.cfg.EphemeralNodeInactivityTimeout) + } + + if h.cfg.DNSConfig.ExtraRecordsPath != "" { + h.extraRecordMan, err = dns.NewExtraRecordsManager(h.cfg.DNSConfig.ExtraRecordsPath) + if err != nil { + return fmt.Errorf("setting up extrarecord manager: %w", err) + } + h.cfg.TailcfgDNSConfig.ExtraRecords = h.extraRecordMan.Records() + go h.extraRecordMan.Run() + defer h.extraRecordMan.Close() + } + + // Start all scheduled tasks, e.g. expiring nodes, derp updates and + // records updates + scheduleCtx, scheduleCancel := context.WithCancel(context.Background()) + defer scheduleCancel() + go h.scheduledTasks(scheduleCtx) if zl.GlobalLevel() == zl.TraceLevel { zerolog.RespLog = true @@ -534,6 +589,12 @@ func (h *Headscale) Serve() error { return fmt.Errorf("unable to remove old socket file: %w", err) } + socketDir := filepath.Dir(h.cfg.UnixSocket) + err = util.EnsureDir(socketDir) + if err != nil { + return fmt.Errorf("setting up unix socket: %w", err) + } + socketListener, err := net.Listen("unix", h.cfg.UnixSocket) if err != nil { return fmt.Errorf("failed to set up gRPC socket: %w", err) @@ -555,14 +616,14 @@ func (h *Headscale) Serve() error { }..., ) if err != nil { - return err + return fmt.Errorf("setting up gRPC gateway via socket: %w", err) } // Connect to the gRPC server over localhost to skip // the authentication. err = v1.RegisterHeadscaleServiceHandler(ctx, grpcGatewayMux, grpcGatewayConn) if err != nil { - return err + return fmt.Errorf("registering Headscale API service to gRPC: %w", err) } // Start the local gRPC server without TLS and without authentication @@ -583,9 +644,7 @@ func (h *Headscale) Serve() error { tlsConfig, err := h.getTLSSettings() if err != nil { - log.Error().Err(err).Msg("Failed to set up TLS configuration") - - return err + return fmt.Errorf("configuring TLS settings: %w", err) } // @@ -606,12 +665,10 @@ func (h *Headscale) Serve() error { log.Info().Msgf("Enabling remote gRPC at %s", h.cfg.GRPCAddr) grpcOptions := []grpc.ServerOption{ - grpc.UnaryInterceptor( - grpcMiddleware.ChainUnaryServer( - h.grpcAuthenticationInterceptor, - // Uncomment to debug grpc communication. - // zerolog.NewUnaryServerInterceptor(), - ), + grpc.ChainUnaryInterceptor( + h.grpcAuthenticationInterceptor, + // Uncomment to debug grpc communication. + // zerolog.NewUnaryServerInterceptor(), ), } @@ -644,18 +701,17 @@ func (h *Headscale) Serve() error { // HTTP setup // // This is the regular router that we expose - // over our main Addr. It also serves the legacy Tailcale API + // over our main Addr router := h.createRouter(grpcGatewayMux) httpServer := &http.Server{ Addr: h.cfg.Addr, Handler: router, - ReadTimeout: types.HTTPReadTimeout, - // Go does not handle timeouts in HTTP very well, and there is - // no good way to handle streaming timeouts, therefore we need to - // keep this at unlimited and be careful to clean up connections - // https://blog.cloudflare.com/the-complete-guide-to-golang-net-http-timeouts/#aboutstreaming - WriteTimeout: 0, + ReadTimeout: types.HTTPTimeout, + + // Long polling should not have any timeout, this is overridden + // further down the chain + WriteTimeout: types.HTTPTimeout, } var httpListener net.Listener @@ -674,30 +730,43 @@ func (h *Headscale) Serve() error { log.Info(). Msgf("listening and serving HTTP on: %s", h.cfg.Addr) - promMux := http.NewServeMux() - promMux.Handle("/metrics", promhttp.Handler()) + // Only start debug/metrics server if address is configured + var debugHTTPServer *http.Server - promHTTPServer := &http.Server{ - Addr: h.cfg.MetricsAddr, - Handler: promMux, - ReadTimeout: types.HTTPReadTimeout, - WriteTimeout: 0, + var debugHTTPListener net.Listener + + if h.cfg.MetricsAddr != "" { + debugHTTPListener, err = (&net.ListenConfig{}).Listen(ctx, "tcp", h.cfg.MetricsAddr) + if err != nil { + return fmt.Errorf("failed to bind to TCP address: %w", err) + } + + debugHTTPServer = h.debugHTTPServer() + + errorGroup.Go(func() error { return debugHTTPServer.Serve(debugHTTPListener) }) + + log.Info(). + Msgf("listening and serving debug and metrics on: %s", h.cfg.MetricsAddr) + } else { + log.Info().Msg("metrics server disabled (metrics_listen_addr is empty)") } - var promHTTPListener net.Listener - promHTTPListener, err = net.Listen("tcp", h.cfg.MetricsAddr) - if err != nil { - return fmt.Errorf("failed to bind to TCP address: %w", err) + var tailsqlContext context.Context + if tailsqlEnabled { + if h.cfg.Database.Type != types.DatabaseSqlite { + log.Fatal(). + Str("type", h.cfg.Database.Type). + Msgf("tailsql only support %q", types.DatabaseSqlite) + } + if tailsqlTSKey == "" { + log.Fatal().Msg("tailsql requires TS_AUTHKEY to be set") + } + tailsqlContext = context.Background() + go runTailSQLService(ctx, util.TSLogfWrapper(), tailsqlStateDir, h.cfg.Database.Sqlite.Path) } - errorGroup.Go(func() error { return promHTTPServer.Serve(promHTTPListener) }) - - log.Info(). - Msgf("listening and serving metrics on: %s", h.cfg.MetricsAddr) - // Handle common process-killing signals so we can gracefully shut down: - h.shutdownChan = make(chan struct{}) sigc := make(chan os.Signal, 1) signal.Notify(sigc, syscall.SIGHUP, @@ -713,74 +782,95 @@ func (h *Headscale) Serve() error { case syscall.SIGHUP: log.Info(). Str("signal", sig.String()). - Msg("Received SIGHUP, reloading ACL and Config") + Msg("Received SIGHUP, reloading ACL policy") - // TODO(kradalby): Reload config on SIGHUP - - if h.cfg.ACL.PolicyPath != "" { - aclPath := util.AbsolutePathFromConfigPath(h.cfg.ACL.PolicyPath) - pol, err := policy.LoadACLPolicyFromPath(aclPath) - if err != nil { - log.Error().Err(err).Msg("Failed to reload ACL policy") - } - - h.ACLPolicy = pol - log.Info(). - Str("path", aclPath). - Msg("ACL policy successfully reloaded, notifying nodes of change") - - h.nodeNotifier.NotifyAll(types.StateUpdate{ - Type: types.StateFullUpdate, - }) + if h.cfg.Policy.IsEmpty() { + continue } + changes, err := h.state.ReloadPolicy() + if err != nil { + log.Error().Err(err).Msgf("reloading policy") + continue + } + + h.Change(changes...) + default: + info := func(msg string) { log.Info().Msg(msg) } log.Info(). Str("signal", sig.String()). Msg("Received signal to stop, shutting down gracefully") - close(h.shutdownChan) - - h.pollNetMapStreamWG.Wait() + scheduleCancel() + h.ephemeralGC.Close() // Gracefully shut down servers - ctx, cancel := context.WithTimeout( - context.Background(), + shutdownCtx, cancel := context.WithTimeout( + context.WithoutCancel(ctx), types.HTTPShutdownTimeout, ) - if err := promHTTPServer.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("Failed to shutdown prometheus http") + defer cancel() + + if debugHTTPServer != nil { + info("shutting down debug http server") + + err := debugHTTPServer.Shutdown(shutdownCtx) + if err != nil { + log.Error().Err(err).Msg("failed to shutdown prometheus http") + } } - if err := httpServer.Shutdown(ctx); err != nil { - log.Error().Err(err).Msg("Failed to shutdown http") + + info("shutting down main http server") + + err := httpServer.Shutdown(shutdownCtx) + if err != nil { + log.Error().Err(err).Msg("failed to shutdown http") } + + info("closing batcher") + h.mapBatcher.Close() + + info("waiting for netmap stream to close") + h.clientStreamsOpen.Wait() + + info("shutting down grpc server (socket)") grpcSocket.GracefulStop() if grpcServer != nil { + info("shutting down grpc server (external)") grpcServer.GracefulStop() grpcListener.Close() } + if tailsqlContext != nil { + info("shutting down tailsql") + tailsqlContext.Done() + } + // Close network listeners - promHTTPListener.Close() + info("closing network listeners") + + if debugHTTPListener != nil { + debugHTTPListener.Close() + } httpListener.Close() grpcGatewayConn.Close() // Stop listening (and unlink the socket if unix type): + info("closing socket listener") socketListener.Close() - // Close db connections - err = h.db.Close() + // Close state connections + info("closing state and database") + err = h.state.Close() if err != nil { - log.Error().Err(err).Msg("Failed to close db") + log.Error().Err(err).Msg("failed to close state") } log.Info(). Msg("Headscale stopped") - // And we're done: - cancel() - return } } @@ -808,6 +898,11 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { Cache: autocert.DirCache(h.cfg.TLS.LetsEncrypt.CacheDir), Client: &acme.Client{ DirectoryURL: h.cfg.ACMEURL, + HTTPClient: &http.Client{ + Transport: &acmeLogger{ + rt: http.DefaultTransport, + }, + }, }, Email: h.cfg.ACMEEmail, } @@ -827,7 +922,7 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { server := &http.Server{ Addr: h.cfg.TLS.LetsEncrypt.Listen, Handler: certManager.HTTPHandler(http.HandlerFunc(h.redirect)), - ReadTimeout: types.HTTPReadTimeout, + ReadTimeout: types.HTTPTimeout, } go func() { @@ -866,22 +961,13 @@ func (h *Headscale) getTLSSettings() (*tls.Config, error) { } } -func notFoundHandler( - writer http.ResponseWriter, - req *http.Request, -) { - body, _ := io.ReadAll(req.Body) - - log.Trace(). - Interface("header", req.Header). - Interface("proto", req.Proto). - Interface("url", req.URL). - Bytes("body", body). - Msg("Request did not match") - writer.WriteHeader(http.StatusNotFound) -} - func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { + dir := filepath.Dir(path) + err := util.EnsureDir(dir) + if err != nil { + return nil, fmt.Errorf("ensuring private key directory: %w", err) + } + privateKey, err := os.ReadFile(path) if errors.Is(err, os.ErrNotExist) { log.Info().Str("path", path).Msg("No private key file at path, creating...") @@ -898,7 +984,8 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { err = os.WriteFile(path, machineKeyStr, privateKeyFileMode) if err != nil { return nil, fmt.Errorf( - "failed to save private key to disk: %w", + "failed to save private key to disk at path %q: %w", + path, err, ) } @@ -917,3 +1004,35 @@ func readOrCreatePrivateKey(path string) (*key.MachinePrivate, error) { return &machineKey, nil } + +// Change is used to send changes to nodes. +// All change should be enqueued here and empty will be automatically +// ignored. +func (h *Headscale) Change(cs ...change.Change) { + h.mapBatcher.AddWork(cs...) +} + +// Provide some middleware that can inspect the ACME/autocert https calls +// and log when things are failing. +type acmeLogger struct { + rt http.RoundTripper +} + +// RoundTrip will log when ACME/autocert failures happen either when err != nil OR +// when http status codes indicate a failure has occurred. +func (l *acmeLogger) RoundTrip(req *http.Request) (*http.Response, error) { + resp, err := l.rt.RoundTrip(req) + if err != nil { + log.Error().Err(err).Str("url", req.URL.String()).Msg("ACME request failed") + return nil, err + } + + if resp.StatusCode >= http.StatusBadRequest { + defer resp.Body.Close() + + body, _ := io.ReadAll(resp.Body) + log.Error().Int("status_code", resp.StatusCode).Str("url", req.URL.String()).Bytes("body", body).Msg("ACME request returned error") + } + + return resp, nil +} diff --git a/hscontrol/assets/assets.go b/hscontrol/assets/assets.go new file mode 100644 index 00000000..13904247 --- /dev/null +++ b/hscontrol/assets/assets.go @@ -0,0 +1,24 @@ +// Package assets provides embedded static assets for Headscale. +// All static files (favicon, CSS, SVG) are embedded here for +// centralized asset management. +package assets + +import ( + _ "embed" +) + +// Favicon is the embedded favicon.png file served at /favicon.ico +// +//go:embed favicon.png +var Favicon []byte + +// CSS is the embedded style.css stylesheet used in HTML templates. +// Contains Material for MkDocs design system styles. +// +//go:embed style.css +var CSS string + +// SVG is the embedded headscale.svg logo used in HTML templates. +// +//go:embed headscale.svg +var SVG string diff --git a/hscontrol/assets/favicon.png b/hscontrol/assets/favicon.png new file mode 100644 index 00000000..4989810f Binary files /dev/null and b/hscontrol/assets/favicon.png differ diff --git a/hscontrol/assets/headscale.svg b/hscontrol/assets/headscale.svg new file mode 100644 index 00000000..caf19697 --- /dev/null +++ b/hscontrol/assets/headscale.svg @@ -0,0 +1 @@ +<svg class="headscale-logo" width="400" height="140" xmlns="http://www.w3.org/2000/svg" xml:space="preserve" style="fill-rule:evenodd;clip-rule:evenodd;stroke-linejoin:round;stroke-miterlimit:2" viewBox="32.92 0 1247.08 640"><path d="M.08 0v-.736h.068v.3C.203-.509.27-.545.347-.545c.029 0 .055.005.079.015.024.01.045.025.062.045.017.02.031.045.041.075.009.03.014.065.014.105V0H.475v-.289C.475-.352.464-.4.443-.433.422-.466.385-.483.334-.483c-.027 0-.052.006-.075.017C.236-.455.216-.439.2-.419c-.017.02-.029.044-.038.072-.009.028-.014.059-.014.093V0H.08Z" style="fill:#f8b5cb;fill-rule:nonzero" transform="translate(32.92220721 521.8022953) scale(235.3092)"/><path d="M.051-.264c0-.036.007-.071.02-.105.013-.034.031-.064.055-.09.023-.026.052-.047.086-.063.033-.015.071-.023.112-.023.039 0 .076.007.109.021.033.014.062.033.087.058.025.025.044.054.058.088.014.035.021.072.021.113v.005H.121c.001.031.007.059.018.084.01.025.024.047.042.065.018.019.04.033.065.043.025.01.052.015.082.015.026 0 .049-.003.069-.01.02-.007.038-.016.054-.028C.466-.102.48-.115.492-.13c.011-.015.022-.03.032-.046l.057.03C.556-.097.522-.058.48-.03.437-.001.387.013.328.013.284.013.245.006.21-.01.175-.024.146-.045.123-.07.1-.095.082-.125.07-.159.057-.192.051-.227.051-.264ZM.128-.32h.396C.51-.375.485-.416.449-.441.412-.466.371-.479.325-.479c-.048 0-.089.013-.123.039-.034.026-.059.066-.074.12Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(177.16674681 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.097h.067V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.012-.027.018-.056.018-.089 0-.031-.005-.059-.016-.086C.515-.375.501-.398.482-.417.462-.436.44-.452.415-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(327.76463481 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.302h.068V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.011-.027.017-.056.017-.089 0-.031-.005-.059-.016-.086C.514-.375.5-.398.481-.417.462-.436.439-.452.414-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(488.71612761 521.8022953) scale(235.3092)"/><path d="m.034-.062.043-.049c.017.019.035.034.054.044.018.01.037.015.057.015.013 0 .026-.002.038-.007.011-.004.021-.01.031-.018.009-.008.016-.017.021-.028.005-.011.008-.022.008-.035 0-.019-.005-.034-.014-.047C.263-.199.248-.21.229-.221.205-.234.183-.247.162-.259.14-.271.122-.284.107-.298.092-.311.08-.327.071-.344.062-.361.058-.381.058-.404c0-.021.004-.04.012-.058.007-.016.018-.031.031-.044.013-.013.028-.022.046-.029.018-.007.037-.01.057-.01.029 0 .056.006.079.019s.045.031.068.053l-.044.045C.291-.443.275-.456.258-.465.241-.474.221-.479.2-.479c-.022 0-.041.007-.056.02C.128-.445.12-.428.12-.408c0 .019.006.035.017.048.011.013.027.026.048.037.027.015.05.028.071.04.021.013.038.026.052.039.014.013.025.028.032.044.007.016.011.035.011.057 0 .021-.004.041-.011.059-.008.019-.019.036-.033.05-.014.015-.031.026-.05.035C.237.01.215.014.191.014c-.03 0-.059-.006-.086-.02C.077-.019.053-.037.034-.062Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(649.90292961 521.8022953) scale(235.3092)"/><path d="M.051-.266c0-.04.007-.077.022-.111.014-.034.034-.063.059-.089.025-.025.054-.044.089-.058.035-.014.072-.021.113-.021.051 0 .098.01.139.03.041.021.075.049.1.085l-.05.043C.498-.418.47-.441.439-.456.408-.471.372-.479.331-.479c-.03 0-.058.005-.083.016C.222-.452.2-.436.181-.418.162-.399.148-.376.137-.35c-.011.026-.016.054-.016.084 0 .031.005.06.016.086.011.027.025.049.044.068.019.019.041.034.067.044.025.011.053.016.084.016.077 0 .141-.03.191-.09l.051.04c-.028.036-.062.064-.103.085C.43.004.384.014.332.014.291.014.254.007.219-.008.184-.022.155-.042.13-.067.105-.092.086-.121.072-.156.058-.19.051-.227.051-.266Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(741.20289921 521.8022953) scale(235.3092)"/><path d="M.051-.267c0-.038.007-.074.021-.108.014-.033.033-.063.058-.088.025-.025.054-.045.087-.06.033-.015.069-.022.108-.022.043 0 .083.009.119.027.035.019.066.047.093.084v-.097h.067V0H.537v-.091C.508-.056.475-.029.44-.013.404.005.365.013.323.013.284.013.248.006.215-.01.182-.024.153-.045.129-.071.104-.096.085-.126.072-.16.058-.193.051-.229.051-.267Zm.279.218c.027 0 .054-.005.079-.015.025-.01.048-.024.068-.043.019-.018.035-.04.047-.067.012-.027.018-.056.018-.089 0-.031-.005-.059-.016-.086C.515-.375.501-.398.482-.417.462-.436.44-.452.415-.463.389-.474.361-.479.331-.479c-.031 0-.059.006-.084.017C.221-.45.199-.434.18-.415c-.019.02-.033.043-.043.068-.011.026-.016.053-.016.082 0 .029.005.056.016.082.011.026.025.049.044.069.019.02.041.036.066.047.025.012.053.018.083.018Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(884.27089281 521.8022953) scale(235.3092)"/><path d="M.066-.736h.068V0H.066z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(1045.22238561 521.8022953) scale(235.3092)"/><path d="M.051-.264c0-.036.007-.071.02-.105.013-.034.031-.064.055-.09.023-.026.052-.047.086-.063.033-.015.071-.023.112-.023.039 0 .076.007.109.021.033.014.062.033.087.058.025.025.044.054.058.088.014.035.021.072.021.113v.005H.121c.001.031.007.059.018.084.01.025.024.047.042.065.018.019.04.033.065.043.025.01.052.015.082.015.026 0 .049-.003.069-.01.02-.007.038-.016.054-.028C.466-.102.48-.115.492-.13c.011-.015.022-.03.032-.046l.057.03C.556-.097.522-.058.48-.03.437-.001.387.013.328.013.284.013.245.006.21-.01.175-.024.146-.045.123-.07.1-.095.082-.125.07-.159.057-.192.051-.227.051-.264ZM.128-.32h.396C.51-.375.485-.416.449-.441.412-.466.371-.479.325-.479c-.048 0-.089.013-.123.039-.034.026-.059.066-.074.12Z" style="fill:#8d8d8d;fill-rule:nonzero" transform="translate(1092.28422561 521.8022953) scale(235.3092)"/><circle cx="141.023" cy="338.36" r="117.472" style="fill:#f8b5cb" transform="matrix(.581302 0 0 .58613 40.06479894 12.59842153)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 32.39345942 21.2386)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 32.39345942 88.80371146)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 120.7528627 88.80371146)"/><circle cx="352.014" cy="268.302" r="33.095" style="fill:#a2a2a2" transform="matrix(.59308 0 0 .58289 120.99825939 21.2386)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.5782 0 0 .58289 36.19871106 15.26642564)"/><circle cx="805.557" cy="336.915" r="118.199" style="fill:#8d8d8d" transform="matrix(.5782 0 0 .58289 183.24041937 15.26642564)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="translate(34.2345 21.2386) scale(.58289)"/><path d="M680.282 124.808h-68.093v390.325h68.081v-28.23H640V153.228h40.282v-28.42Z" style="fill:#303030" transform="matrix(-.58289 0 0 .58289 1116.7719791 21.2386)"/></svg> diff --git a/hscontrol/assets/oidc_callback_template.html b/hscontrol/assets/oidc_callback_template.html deleted file mode 100644 index 2236f365..00000000 --- a/hscontrol/assets/oidc_callback_template.html +++ /dev/null @@ -1,307 +0,0 @@ -<!doctype html> -<html lang="en"> - <head> - <meta charset="UTF-8" /> - <meta name="viewport" content="width=device-width, initial-scale=1" /> - <title>Headscale Authentication Succeeded - - - -
-
- -
- -
-
Signed in via your OIDC provider
-

- {{.Verb}} as {{.User}}, you can now close this window. -

-
-
-
-

Not sure how to get started?

-

- Check out beginner and advanced guides on, or read more in the - documentation. -

- - - - - - - View the headscale documentation - - - - - - - - View the tailscale documentation - -
-
- - diff --git a/hscontrol/assets/style.css b/hscontrol/assets/style.css new file mode 100644 index 00000000..d1eac385 --- /dev/null +++ b/hscontrol/assets/style.css @@ -0,0 +1,143 @@ +/* CSS Variables from Material for MkDocs */ +:root { + --md-default-fg-color: rgba(0, 0, 0, 0.87); + --md-default-fg-color--light: rgba(0, 0, 0, 0.54); + --md-default-fg-color--lighter: rgba(0, 0, 0, 0.32); + --md-default-fg-color--lightest: rgba(0, 0, 0, 0.07); + --md-code-fg-color: #36464e; + --md-code-bg-color: #f5f5f5; + --md-primary-fg-color: #4051b5; + --md-accent-fg-color: #526cfe; + --md-typeset-a-color: var(--md-primary-fg-color); + --md-text-font: "Roboto", -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif; + --md-code-font: "Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace; +} + +/* Base Typography */ +.md-typeset { + font-size: 0.8rem; + line-height: 1.6; + color: var(--md-default-fg-color); + font-family: var(--md-text-font); + overflow-wrap: break-word; + text-align: left; +} + +/* Headings */ +.md-typeset h1 { + color: var(--md-default-fg-color--light); + font-size: 2em; + line-height: 1.3; + margin: 0 0 1.25em; + font-weight: 300; + letter-spacing: -0.01em; +} + +.md-typeset h1:not(:first-child) { + margin-top: 2em; +} + +.md-typeset h2 { + font-size: 1.5625em; + line-height: 1.4; + margin: 2.4em 0 0.64em; + font-weight: 300; + letter-spacing: -0.01em; + color: var(--md-default-fg-color--light); +} + +.md-typeset h3 { + font-size: 1.25em; + line-height: 1.5; + margin: 2em 0 0.8em; + font-weight: 400; + letter-spacing: -0.01em; + color: var(--md-default-fg-color--light); +} + +/* Paragraphs and block elements */ +.md-typeset p { + margin: 1em 0; +} + +.md-typeset blockquote, +.md-typeset dl, +.md-typeset figure, +.md-typeset ol, +.md-typeset pre, +.md-typeset ul { + margin-bottom: 1em; + margin-top: 1em; +} + +/* Lists */ +.md-typeset ol, +.md-typeset ul { + padding-left: 2em; +} + +/* Links */ +.md-typeset a { + color: var(--md-typeset-a-color); + text-decoration: none; + word-break: break-word; +} + +.md-typeset a:hover, +.md-typeset a:focus { + color: var(--md-accent-fg-color); +} + +/* Code (inline) */ +.md-typeset code { + background-color: var(--md-code-bg-color); + color: var(--md-code-fg-color); + border-radius: 0.1rem; + font-size: 0.85em; + font-family: var(--md-code-font); + padding: 0 0.2941176471em; + word-break: break-word; +} + +/* Code blocks (pre) */ +.md-typeset pre { + display: block; + line-height: 1.4; + margin: 1em 0; + overflow-x: auto; +} + +.md-typeset pre > code { + background-color: var(--md-code-bg-color); + color: var(--md-code-fg-color); + display: block; + padding: 0.7720588235em 1.1764705882em; + font-family: var(--md-code-font); + font-size: 0.85em; + line-height: 1.4; + overflow-wrap: break-word; + word-wrap: break-word; + white-space: pre-wrap; +} + +/* Links in code */ +.md-typeset a code { + color: currentcolor; +} + +/* Logo */ +.headscale-logo { + display: block; + width: 400px; + max-width: 100%; + height: auto; + margin: 0 0 3rem 0; + padding: 0; +} + +@media (max-width: 768px) { + .headscale-logo { + width: 200px; + margin-left: 0; + } +} diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 4fe5a16b..ac5968e3 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -1,10 +1,12 @@ package hscontrol import ( - "encoding/json" + "cmp" + "context" "errors" "fmt" "net/http" + "net/url" "strings" "time" @@ -14,717 +16,449 @@ import ( "gorm.io/gorm" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) -func logAuthFunc( - registerRequest tailcfg.RegisterRequest, - machineKey key.MachinePublic, -) (func(string), func(string), func(error, string)) { - return func(msg string) { - log.Info(). - Caller(). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", registerRequest.Hostinfo.Hostname). - Str("followup", registerRequest.Followup). - Time("expiry", registerRequest.Expiry). - Msg(msg) - }, - func(msg string) { - log.Trace(). - Caller(). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", registerRequest.Hostinfo.Hostname). - Str("followup", registerRequest.Followup). - Time("expiry", registerRequest.Expiry). - Msg(msg) - }, - func(err error, msg string) { - log.Error(). - Caller(). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", registerRequest.Hostinfo.Hostname). - Str("followup", registerRequest.Followup). - Time("expiry", registerRequest.Expiry). - Err(err). - Msg(msg) - } +type AuthProvider interface { + RegisterHandler(http.ResponseWriter, *http.Request) + AuthURL(types.RegistrationID) string } -// handleRegister is the logic for registering a client. func (h *Headscale) handleRegister( - writer http.ResponseWriter, - req *http.Request, - registerRequest tailcfg.RegisterRequest, + ctx context.Context, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, -) { - logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) - now := time.Now().UTC() - logTrace("handleRegister called, looking up machine in DB") - node, err := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) - logTrace("handleRegister database lookup has returned") - if errors.Is(err, gorm.ErrRecordNotFound) { - // If the node has AuthKey set, handle registration via PreAuthKeys - if registerRequest.Auth.AuthKey != "" { - h.handleAuthKey(writer, registerRequest, machineKey) +) (*tailcfg.RegisterResponse, error) { + // Check for logout/expiry FIRST, before checking auth key. + // Tailscale clients may send logout requests with BOTH a past expiry AND an auth key. + // A past expiry takes precedence - it's a logout regardless of other fields. + if !req.Expiry.IsZero() && req.Expiry.Before(time.Now()) { + log.Debug(). + Str("node.key", req.NodeKey.ShortString()). + Time("expiry", req.Expiry). + Bool("has_auth", req.Auth != nil). + Msg("Detected logout attempt with past expiry") - return - } - - // Check if the node is waiting for interactive login. - // - // TODO(juan): We could use this field to improve our protocol implementation, - // and hold the request until the client closes it, or the interactive - // login is completed (i.e., the user registers the node). - // This is not implemented yet, as it is no strictly required. The only side-effect - // is that the client will hammer headscale with requests until it gets a - // successful RegisterResponse. - if registerRequest.Followup != "" { - logTrace("register request is a followup") - if _, ok := h.registrationCache.Get(machineKey.String()); ok { - logTrace("Node is waiting for interactive login") - - select { - case <-req.Context().Done(): - return - case <-time.After(registrationHoldoff): - h.handleNewNode(writer, registerRequest, machineKey) - - return - } - } - } - - logInfo("Node not found in database, creating new") - - givenName, err := h.db.GenerateGivenName( - machineKey, - registerRequest.Hostinfo.Hostname, - ) - if err != nil { - logErr(err, "Failed to generate given name for node") - - return - } - - // The node did not have a key to authenticate, which means - // that we rely on a method that calls back some how (OpenID or CLI) - // We create the node and then keep it around until a callback - // happens - newNode := types.Node{ - MachineKey: machineKey, - Hostname: registerRequest.Hostinfo.Hostname, - GivenName: givenName, - NodeKey: registerRequest.NodeKey, - LastSeen: &now, - Expiry: &time.Time{}, - } - - if !registerRequest.Expiry.IsZero() { - logTrace("Non-zero expiry time requested") - newNode.Expiry = ®isterRequest.Expiry - } - - h.registrationCache.Set( - machineKey.String(), - newNode, - registerCacheExpiration, - ) - - h.handleNewNode(writer, registerRequest, machineKey) - - return - } - - // The node is already in the DB. This could mean one of the following: - // - The node is authenticated and ready to /map - // - We are doing a key refresh - // - The node is logged out (or expired) and pending to be authorized. TODO(juan): We need to keep alive the connection here - if node != nil { - // (juan): For a while we had a bug where we were not storing the MachineKey for the nodes using the TS2021, - // due to a misunderstanding of the protocol https://github.com/juanfont/headscale/issues/1054 - // So if we have a not valid MachineKey (but we were able to fetch the node with the NodeKeys), we update it. - if err != nil || node.MachineKey.IsZero() { - if err := h.db.NodeSetMachineKey(node, machineKey); err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("node", node.Hostname). - Err(err). - Msg("Error saving machine key to database") - - return - } - } - - // If the NodeKey stored in headscale is the same as the key presented in a registration - // request, then we have a node that is either: - // - Trying to log out (sending a expiry in the past) - // - A valid, registered node, looking for /map - // - Expired node wanting to reauthenticate - if node.NodeKey.String() == registerRequest.NodeKey.String() { - // The client sends an Expiry in the past if the client is requesting to expire the key (aka logout) - // https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go#L648 - if !registerRequest.Expiry.IsZero() && - registerRequest.Expiry.UTC().Before(now) { - h.handleNodeLogOut(writer, *node, machineKey) - - return - } - - // If node is not expired, and it is register, we have a already accepted this node, - // let it proceed with a valid registration - if !node.IsExpired() { - h.handleNodeWithValidRegistration(writer, *node, machineKey) - - return - } - } - - // The NodeKey we have matches OldNodeKey, which means this is a refresh after a key expiration - if node.NodeKey.String() == registerRequest.OldNodeKey.String() && - !node.IsExpired() { - h.handleNodeKeyRefresh( - writer, - registerRequest, - *node, - machineKey, - ) - - return - } - - if registerRequest.Followup != "" { - select { - case <-req.Context().Done(): - return - case <-time.After(registrationHoldoff): - } - } - - // The node has expired or it is logged out - h.handleNodeExpiredOrLoggedOut(writer, registerRequest, *node, machineKey) - - // TODO(juan): RegisterRequest includes an Expiry time, that we could optionally use - node.Expiry = &time.Time{} - - // If we are here it means the client needs to be reauthorized, - // we need to make sure the NodeKey matches the one in the request - // TODO(juan): What happens when using fast user switching between two - // headscale-managed tailnets? - node.NodeKey = registerRequest.NodeKey - h.registrationCache.Set( - machineKey.String(), - *node, - registerCacheExpiration, - ) - - return - } -} - -// handleAuthKey contains the logic to manage auth key client registration -// When using Noise, the machineKey is Zero. -// -// TODO: check if any locks are needed around IP allocation. -func (h *Headscale) handleAuthKey( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, - machineKey key.MachinePublic, -) { - log.Debug(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msgf("Processing auth key for %s", registerRequest.Hostinfo.Hostname) - resp := tailcfg.RegisterResponse{} - - pak, err := h.db.ValidatePreAuthKey(registerRequest.Auth.AuthKey) - if err != nil { - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Failed authentication via AuthKey") - resp.MachineAuthorized = false - - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusUnauthorized) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("Failed authentication via AuthKey") - - if pak != nil { - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() - } else { - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc() - } - - return - } - - log.Debug(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("Authentication key was valid, proceeding to acquire IP addresses") - - nodeKey := registerRequest.NodeKey - - // retrieve node information if it exist - // The error is not important, because if it does not - // exist, then this is a new node and we will move - // on to registration. - node, _ := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) - if node != nil { - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msg("node was already registered before, refreshing with new auth key") - - node.NodeKey = nodeKey - node.AuthKeyID = uint(pak.ID) - err := h.db.NodeSetExpiry(node, registerRequest.Expiry) - if err != nil { - log.Error(). - Caller(). - Str("node", node.Hostname). - Err(err). - Msg("Failed to refresh node") - - return - } - - aclTags := pak.Proto().GetAclTags() - if len(aclTags) > 0 { - // This conditional preserves the existing behaviour, although SaaS would reset the tags on auth-key login - err = h.db.SetTags(node, aclTags) + // This is a logout attempt (expiry in the past) + if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok { + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Bool("is_ephemeral", node.IsEphemeral()). + Bool("has_authkey", node.AuthKey().Valid()). + Msg("Found existing node for logout, calling handleLogout") + resp, err := h.handleLogout(node, req, machineKey) if err != nil { - log.Error(). - Caller(). - Str("node", node.Hostname). - Strs("aclTags", aclTags). - Err(err). - Msg("Failed to set tags after refreshing node") - - return + return nil, fmt.Errorf("handling logout: %w", err) } + if resp != nil { + return resp, nil + } + } else { + log.Warn(). + Str("node.key", req.NodeKey.ShortString()). + Msg("Logout attempt but node not found in NodeStore") } - } else { - now := time.Now().UTC() + } - givenName, err := h.db.GenerateGivenName(machineKey, registerRequest.Hostinfo.Hostname) + // If the register request does not contain a Auth struct, it means we are logging + // out an existing node (legacy logout path for clients that send Auth=nil). + if req.Auth == nil { + // If the register request present a NodeKey that is currently in use, we will + // check if the node needs to be sent to re-auth, or if the node is logging out. + // We do not look up nodes by [key.MachinePublic] as it might belong to multiple + // nodes, separated by users and this path is handling expiring/logout paths. + if node, ok := h.state.GetNodeByNodeKey(req.NodeKey); ok { + // When tailscaled restarts, it sends RegisterRequest with Auth=nil and Expiry=zero. + // Return the current node state without modification. + // See: https://github.com/juanfont/headscale/issues/2862 + if req.Expiry.IsZero() && node.Expiry().Valid() && !node.IsExpired() { + return nodeToRegisterResponse(node), nil + } + + resp, err := h.handleLogout(node, req, machineKey) + if err != nil { + return nil, fmt.Errorf("handling existing node: %w", err) + } + + // If resp is not nil, we have a response to return to the node. + // If resp is nil, we should proceed and see if the node is trying to re-auth. + if resp != nil { + return resp, nil + } + } else { + // If the register request is not attempting to register a node, and + // we cannot match it with an existing node, we consider that unexpected + // as only register nodes should attempt to log out. + log.Debug(). + Str("node.key", req.NodeKey.ShortString()). + Str("machine.key", machineKey.ShortString()). + Bool("unexpected", true). + Msg("received register request with no auth, and no existing node") + } + } + + // If the [tailcfg.RegisterRequest] has a Followup URL, it means that the + // node has already started the registration process and we should wait for + // it to finish the original registration. + if req.Followup != "" { + return h.waitForFollowup(ctx, req, machineKey) + } + + // Pre authenticated keys are handled slightly different than interactive + // logins as they can be done fully sync and we can respond to the node with + // the result as it is waiting. + if isAuthKey(req) { + resp, err := h.handleRegisterWithAuthKey(req, machineKey) if err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Failed to generate given name for node") + // Preserve HTTPError types so they can be handled properly by the HTTP layer + var httpErr HTTPError + if errors.As(err, &httpErr) { + return nil, httpErr + } - return + return nil, fmt.Errorf("handling register with auth key: %w", err) } - nodeToRegister := types.Node{ - Hostname: registerRequest.Hostinfo.Hostname, - GivenName: givenName, - UserID: pak.User.ID, - MachineKey: machineKey, - RegisterMethod: util.RegisterMethodAuthKey, - Expiry: ®isterRequest.Expiry, - NodeKey: nodeKey, - LastSeen: &now, - AuthKeyID: uint(pak.ID), - ForcedTags: pak.Proto().GetAclTags(), - } - - node, err = h.db.RegisterNode( - nodeToRegister, - ) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not register node") - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } + return resp, nil } - err = h.db.UsePreAuthKey(pak) + resp, err := h.handleRegisterInteractive(req, machineKey) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to use pre-auth key") - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + return nil, fmt.Errorf("handling register interactive: %w", err) } - resp.MachineAuthorized = true - resp.User = *pak.User.TailscaleUser() - // Provide LoginName when registering with pre-auth key - // Otherwise it will need to exec `tailscale up` twice to fetch the *LoginName* - resp.Login = *pak.User.TailscaleLogin() - - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Cannot encode message") - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name). - Inc() - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - log.Info(). - Str("node", registerRequest.Hostinfo.Hostname). - Str("ips", strings.Join(node.IPAddresses.StringSlice(), ", ")). - Msg("Successfully authenticated via AuthKey") + return resp, nil } -// handleNewNode returns the authorisation URL to the client based on what type -// of registration headscale is configured with. -// This url is then showed to the user by the local Tailscale client. -func (h *Headscale) handleNewNode( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, +// handleLogout checks if the [tailcfg.RegisterRequest] is a +// logout attempt from a node. If the node is not attempting to +func (h *Headscale) handleLogout( + node types.NodeView, + req tailcfg.RegisterRequest, machineKey key.MachinePublic, -) { - logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) - - resp := tailcfg.RegisterResponse{} - - // The node registration is new, redirect the client to the registration URL - logTrace("The node seems to be new, sending auth url") - - if h.oauth2Config != nil { - resp.AuthURL = fmt.Sprintf( - "%s/oidc/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), - machineKey.String(), - ) - } else { - resp.AuthURL = fmt.Sprintf("%s/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), - machineKey.String()) +) (*tailcfg.RegisterResponse, error) { + // Fail closed if it looks like this is an attempt to modify a node where + // the node key and the machine key the noise session was started with does + // not align. + if node.MachineKey() != machineKey { + return nil, NewHTTPError(http.StatusUnauthorized, "node exist with different machine key", nil) } - respBody, err := json.Marshal(resp) + // Note: We do NOT return early if req.Auth is set, because Tailscale clients + // may send logout requests with BOTH a past expiry AND an auth key. + // A past expiry indicates logout, regardless of whether Auth is present. + // The expiry check below will handle the logout logic. + + // If the node is expired and this is not a re-authentication attempt, + // force the client to re-authenticate. + // TODO(kradalby): I wonder if this is a path we ever hit? + if node.IsExpired() { + log.Trace().Str("node.name", node.Hostname()). + Uint64("node.id", node.ID().Uint64()). + Interface("reg.req", req). + Bool("unexpected", true). + Msg("Node key expired, forcing re-authentication") + return &tailcfg.RegisterResponse{ + NodeKeyExpired: true, + MachineAuthorized: false, + AuthURL: "", // Client will need to re-authenticate + }, nil + } + + // If we get here, the node is not currently expired, and not trying to + // do an auth. + // The node is likely logging out, but before we run that logic, we will validate + // that the node is not attempting to tamper/extend their expiry. + // If it is not, we will expire the node or in the case of an ephemeral node, delete it. + + // The client is trying to extend their key, this is not allowed. + if req.Expiry.After(time.Now()) { + return nil, NewHTTPError(http.StatusBadRequest, "extending key is not allowed", nil) + } + + // If the request expiry is in the past, we consider it a logout. + // Zero expiry is handled in handleRegister() before calling this function. + if req.Expiry.Before(time.Now()) { + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Bool("is_ephemeral", node.IsEphemeral()). + Bool("has_authkey", node.AuthKey().Valid()). + Time("req.expiry", req.Expiry). + Msg("Processing logout request with past expiry") + + if node.IsEphemeral() { + log.Info(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Msg("Deleting ephemeral node during logout") + + c, err := h.state.DeleteNode(node) + if err != nil { + return nil, fmt.Errorf("deleting ephemeral node: %w", err) + } + + h.Change(c) + + return &tailcfg.RegisterResponse{ + NodeKeyExpired: true, + MachineAuthorized: false, + }, nil + } + + log.Debug(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Msg("Node is not ephemeral, setting expiry instead of deleting") + } + + // Update the internal state with the nodes new expiry, meaning it is + // logged out. + updatedNode, c, err := h.state.SetNodeExpiry(node.ID(), req.Expiry) if err != nil { - logErr(err, "Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + return nil, fmt.Errorf("setting node expiry: %w", err) } - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - logErr(err, "Failed to write response") - } + h.Change(c) - logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL)) + return nodeToRegisterResponse(updatedNode), nil } -func (h *Headscale) handleNodeLogOut( - writer http.ResponseWriter, - node types.Node, - machineKey key.MachinePublic, -) { - resp := tailcfg.RegisterResponse{} +// isAuthKey reports if the register request is a registration request +// using an pre auth key. +func isAuthKey(req tailcfg.RegisterRequest) bool { + return req.Auth != nil && req.Auth.AuthKey != "" +} - log.Info(). - Str("node", node.Hostname). - Msg("Client requested logout") +func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse { + resp := &tailcfg.RegisterResponse{ + NodeKeyExpired: node.IsExpired(), - now := time.Now() - err := h.db.NodeSetExpiry(&node, now) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to expire node") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return + // Headscale does not implement the concept of machine authorization + // so we always return true here. + // Revisit this if #2176 gets implemented. + MachineAuthorized: true, } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &now, - }, + // For tagged nodes, use the TaggedDevices special user + // For user-owned nodes, include User and Login information from the actual user + if node.IsTagged() { + resp.User = types.TaggedDevices.View().TailscaleUser() + resp.Login = types.TaggedDevices.View().TailscaleLogin() + } else if node.Owner().Valid() { + resp.User = node.Owner().TailscaleUser() + resp.Login = node.Owner().TailscaleLogin() + } + + return resp +} + +func (h *Headscale) waitForFollowup( + ctx context.Context, + req tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + fu, err := url.Parse(req.Followup) + if err != nil { + return nil, NewHTTPError(http.StatusUnauthorized, "invalid followup URL", err) + } + + followupReg, err := types.RegistrationIDFromString(strings.ReplaceAll(fu.Path, "/register/", "")) + if err != nil { + return nil, NewHTTPError(http.StatusUnauthorized, "invalid registration ID", err) + } + + if reg, ok := h.state.GetRegistrationCacheEntry(followupReg); ok { + select { + case <-ctx.Done(): + return nil, NewHTTPError(http.StatusUnauthorized, "registration timed out", err) + case node := <-reg.Registered: + if node == nil { + // registration is expired in the cache, instruct the client to try a new registration + return h.reqToNewRegisterResponse(req, machineKey) + } + return nodeToRegisterResponse(node.View()), nil + } + } + + // if the follow-up registration isn't found anymore, instruct the client to try a new registration + return h.reqToNewRegisterResponse(req, machineKey) +} + +// reqToNewRegisterResponse refreshes the registration flow by creating a new +// registration ID and returning the corresponding AuthURL so the client can +// restart the authentication process. +func (h *Headscale) reqToNewRegisterResponse( + req tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + newRegID, err := types.NewRegistrationID() + if err != nil { + return nil, NewHTTPError(http.StatusInternalServerError, "failed to generate registration ID", err) + } + + // Ensure we have a valid hostname + hostname := util.EnsureHostname( + req.Hostinfo, + machineKey.String(), + req.NodeKey.String(), + ) + + // Ensure we have valid hostinfo + hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) + hostinfo.Hostname = hostname + + nodeToRegister := types.NewRegisterNode( + types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: ptr.To(time.Now()), }, - } - if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + ) + + if !req.Expiry.IsZero() { + nodeToRegister.Node.Expiry = &req.Expiry } - resp.AuthURL = "" - resp.MachineAuthorized = false - resp.NodeKeyExpired = true - resp.User = *node.User.TailscaleUser() - respBody, err := json.Marshal(resp) + log.Info().Msgf("New followup node registration using key: %s", newRegID) + h.state.SetRegistrationCacheEntry(newRegID, nodeToRegister) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(newRegID), + }, nil +} + +func (h *Headscale) handleRegisterWithAuthKey( + req tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + node, changed, err := h.state.HandleNodeFromPreAuthKey( + req, + machineKey, + ) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - - return - } - - if node.IsEphemeral() { - err = h.db.DeleteNode(&node) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("Cannot delete ephemeral node from the database") + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, NewHTTPError(http.StatusUnauthorized, "invalid pre auth key", nil) + } + var perr types.PAKError + if errors.As(err, &perr) { + return nil, NewHTTPError(http.StatusUnauthorized, perr.Error(), nil) } - return + return nil, err } - log.Info(). - Caller(). - Str("node", node.Hostname). - Msg("Successfully logged out") -} + // If node is not valid, it means an ephemeral node was deleted during logout + if !node.Valid() { + h.Change(changed) + return nil, nil + } -func (h *Headscale) handleNodeWithValidRegistration( - writer http.ResponseWriter, - node types.Node, - machineKey key.MachinePublic, -) { - resp := tailcfg.RegisterResponse{} - - // The node registration is valid, respond with redirect to /map - log.Debug(). - Caller(). - Str("node", node.Hostname). - Msg("Client is registered and we have the current NodeKey. All clear to /map") - - resp.AuthURL = "" - resp.MachineAuthorized = true - resp.User = *node.User.TailscaleUser() - resp.Login = *node.User.TailscaleLogin() - - respBody, err := json.Marshal(resp) + // This is a bit of a back and forth, but we have a bit of a chicken and egg + // dependency here. + // Because the way the policy manager works, we need to have the node + // in the database, then add it to the policy manager and then we can + // approve the route. This means we get this dance where the node is + // first added to the database, then we add it to the policy manager via + // nodesChangedHook and then we can auto approve the routes. + // As that only approves the struct object, we need to save it again and + // ensure we send an update. + // This works, but might be another good candidate for doing some sort of + // eventbus. + // TODO(kradalby): This needs to be ran as part of the batcher maybe? + // now since we dont update the node/pol here anymore + routesChange, err := h.state.AutoApproveRoutes(node) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - nodeRegistrations.WithLabelValues("update", "web", "error", node.User.Name). - Inc() - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - nodeRegistrations.WithLabelValues("update", "web", "success", node.User.Name). - Inc() - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + return nil, fmt.Errorf("auto approving routes: %w", err) } - log.Info(). - Caller(). - Str("node", node.Hostname). - Msg("Node successfully authorized") -} + // Send both changes. Empty changes are ignored by Change(). + h.Change(changed, routesChange) -func (h *Headscale) handleNodeKeyRefresh( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, - node types.Node, - machineKey key.MachinePublic, -) { - resp := tailcfg.RegisterResponse{} + // TODO(kradalby): I think this is covered above, but we need to validate that. + // // If policy changed due to node registration, send a separate policy change + // if policyChanged { + // policyChange := change.PolicyChange() + // h.Change(policyChange) + // } - log.Info(). - Caller(). - Str("node", node.Hostname). - Msg("We have the OldNodeKey in the database. This is a key refresh") - - err := h.db.NodeSetNodeKey(&node, registerRequest.NodeKey) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to update machine key in the database") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - resp.AuthURL = "" - resp.User = *node.User.TailscaleUser() - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - log.Info(). - Caller(). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("old_node_key", registerRequest.OldNodeKey.ShortString()). - Str("node", node.Hostname). - Msg("Node key successfully refreshed") -} - -func (h *Headscale) handleNodeExpiredOrLoggedOut( - writer http.ResponseWriter, - registerRequest tailcfg.RegisterRequest, - node types.Node, - machineKey key.MachinePublic, -) { - resp := tailcfg.RegisterResponse{} - - if registerRequest.Auth.AuthKey != "" { - h.handleAuthKey(writer, registerRequest, machineKey) - - return - } - - // The client has registered before, but has expired or logged out - log.Trace(). - Caller(). - Str("node", node.Hostname). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Msg("Node registration has expired or logged out. Sending a auth url to register") - - if h.oauth2Config != nil { - resp.AuthURL = fmt.Sprintf("%s/oidc/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), - machineKey.String()) - } else { - resp.AuthURL = fmt.Sprintf("%s/register/%s", - strings.TrimSuffix(h.cfg.ServerURL, "/"), - machineKey.String()) - } - - respBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") - nodeRegistrations.WithLabelValues("reauth", "web", "error", node.User.Name). - Inc() - http.Error(writer, "Internal server error", http.StatusInternalServerError) - - return - } - nodeRegistrations.WithLabelValues("reauth", "web", "success", node.User.Name). - Inc() - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(respBody) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + resp := &tailcfg.RegisterResponse{ + MachineAuthorized: true, + NodeKeyExpired: node.IsExpired(), + User: node.Owner().TailscaleUser(), + Login: node.Owner().TailscaleLogin(), } log.Trace(). Caller(). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("node", node.Hostname). - Msg("Node logged out. Sent AuthURL for reauthentication") + Interface("reg.resp", resp). + Interface("reg.req", req). + Str("node.name", node.Hostname()). + Uint64("node.id", node.ID().Uint64()). + Msg("RegisterResponse") + + return resp, nil +} + +func (h *Headscale) handleRegisterInteractive( + req tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (*tailcfg.RegisterResponse, error) { + registrationId, err := types.NewRegistrationID() + if err != nil { + return nil, fmt.Errorf("generating registration ID: %w", err) + } + + // Ensure we have a valid hostname + hostname := util.EnsureHostname( + req.Hostinfo, + machineKey.String(), + req.NodeKey.String(), + ) + + // Ensure we have valid hostinfo + hostinfo := cmp.Or(req.Hostinfo, &tailcfg.Hostinfo{}) + if req.Hostinfo == nil { + log.Warn(). + Str("machine.key", machineKey.ShortString()). + Str("node.key", req.NodeKey.ShortString()). + Str("generated.hostname", hostname). + Msg("Received registration request with nil hostinfo, generated default hostname") + } else if req.Hostinfo.Hostname == "" { + log.Warn(). + Str("machine.key", machineKey.ShortString()). + Str("node.key", req.NodeKey.ShortString()). + Str("generated.hostname", hostname). + Msg("Received registration request with empty hostname, generated default") + } + hostinfo.Hostname = hostname + + nodeToRegister := types.NewRegisterNode( + types.Node{ + Hostname: hostname, + MachineKey: machineKey, + NodeKey: req.NodeKey, + Hostinfo: hostinfo, + LastSeen: ptr.To(time.Now()), + }, + ) + + if !req.Expiry.IsZero() { + nodeToRegister.Node.Expiry = &req.Expiry + } + + h.state.SetRegistrationCacheEntry( + registrationId, + nodeToRegister, + ) + + log.Info().Msgf("Starting node registration using key: %s", registrationId) + + return &tailcfg.RegisterResponse{ + AuthURL: h.authProvider.AuthURL(registrationId), + }, nil } diff --git a/hscontrol/auth_noise.go b/hscontrol/auth_noise.go deleted file mode 100644 index 323a49b0..00000000 --- a/hscontrol/auth_noise.go +++ /dev/null @@ -1,57 +0,0 @@ -package hscontrol - -import ( - "encoding/json" - "io" - "net/http" - - "github.com/rs/zerolog/log" - "tailscale.com/tailcfg" -) - -// // NoiseRegistrationHandler handles the actual registration process of a node. -func (ns *noiseServer) NoiseRegistrationHandler( - writer http.ResponseWriter, - req *http.Request, -) { - log.Trace().Caller().Msgf("Noise registration handler for client %s", req.RemoteAddr) - if req.Method != http.MethodPost { - http.Error(writer, "Wrong method", http.StatusMethodNotAllowed) - - return - } - - log.Trace(). - Any("headers", req.Header). - Caller(). - Msg("Headers") - - body, _ := io.ReadAll(req.Body) - registerRequest := tailcfg.RegisterRequest{} - if err := json.Unmarshal(body, ®isterRequest); err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse RegisterRequest") - nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() - http.Error(writer, "Internal error", http.StatusInternalServerError) - - return - } - - // Reject unsupported versions - if registerRequest.Version < MinimumCapVersion { - log.Info(). - Caller(). - Int("min_version", int(MinimumCapVersion)). - Int("client_version", int(registerRequest.Version)). - Msg("unsupported client connected") - http.Error(writer, "Internal error", http.StatusBadRequest) - - return - } - - ns.nodeKey = registerRequest.NodeKey - - ns.headscale.handleRegister(writer, req, registerRequest, ns.conn.Peer()) -} diff --git a/hscontrol/auth_tags_test.go b/hscontrol/auth_tags_test.go new file mode 100644 index 00000000..bbaa834b --- /dev/null +++ b/hscontrol/auth_tags_test.go @@ -0,0 +1,689 @@ +package hscontrol + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// TestTaggedPreAuthKeyCreatesTaggedNode tests that a PreAuthKey with tags creates +// a tagged node with: +// - Tags from the PreAuthKey +// - UserID tracking who created the key (informational "created by") +// - IsTagged() returns true. +func TestTaggedPreAuthKeyCreatesTaggedNode(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.NotEmpty(t, pak.Tags, "PreAuthKey should have tags") + require.ElementsMatch(t, tags, pak.Tags, "PreAuthKey should have specified tags") + + // Register a node using the tagged key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify the node was created with tags + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertions for tags-as-identity model + assert.True(t, node.IsTagged(), "Node should be tagged") + assert.ElementsMatch(t, tags, node.Tags().AsSlice(), "Node should have tags from PreAuthKey") + assert.True(t, node.UserID().Valid(), "Node should have UserID tracking creator") + assert.Equal(t, user.ID, node.UserID().Get(), "UserID should track PreAuthKey creator") + + // Verify node is identified correctly + assert.True(t, node.IsTagged(), "Tagged node is not user-owned") + assert.True(t, node.HasTag("tag:server"), "Node should have tag:server") + assert.True(t, node.HasTag("tag:prod"), "Node should have tag:prod") + assert.False(t, node.HasTag("tag:other"), "Node should not have tag:other") +} + +// TestReAuthDoesNotReapplyTags tests that when a node re-authenticates using the +// same PreAuthKey, the tags are NOT re-applied. Tags are only set during initial +// authentication. This is critical for the container restart scenario (#2830). +// +// NOTE: This test verifies that re-authentication preserves the node's current tags +// without testing tag modification via SetNodeTags (which requires ACL policy setup). +func TestReAuthDoesNotReapplyTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + initialTags := []string{"tag:server", "tag:dev"} + + // Create a tagged PreAuthKey with reusable=true for re-auth + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, initialTags) + require.NoError(t, err) + + // Initial registration + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify initial tags + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + require.ElementsMatch(t, initialTags, node.Tags().AsSlice()) + + // Re-authenticate with the SAME PreAuthKey (container restart scenario) + // Key behavior: Tags should NOT be re-applied during re-auth + reAuthReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key + }, + NodeKey: nodeKey.Public(), // Same node key + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + reAuthResp, err := app.handleRegisterWithAuthKey(reAuthReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, reAuthResp.MachineAuthorized) + + // CRITICAL: Tags should remain unchanged after re-auth + // They should match the original tags, proving they weren't re-applied + nodeAfterReauth, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, nodeAfterReauth.IsTagged(), "Node should still be tagged") + assert.ElementsMatch(t, initialTags, nodeAfterReauth.Tags().AsSlice(), "Tags should remain unchanged on re-auth") + + // Verify only one node was created (no duplicates) + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 1, nodes.Len(), "Should have exactly one node") +} + +// NOTE: TestSetTagsOnUserOwnedNode functionality is covered by gRPC tests in grpcv1_test.go +// which properly handle ACL policy setup. The test verifies that SetTags can convert +// user-owned nodes to tagged nodes while preserving UserID. + +// TestCannotRemoveAllTags tests that attempting to remove all tags from a +// tagged node fails with ErrCannotRemoveAllTags. Once a node is tagged, +// it must always have at least one tag (Tailscale requirement). +func TestCannotRemoveAllTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a tagged node + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify node is tagged + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + + // Attempt to remove all tags by setting empty array + _, _, err = app.state.SetNodeTags(node.ID(), []string{}) + require.Error(t, err, "Should not be able to remove all tags") + require.ErrorIs(t, err, types.ErrCannotRemoveAllTags, "Error should be ErrCannotRemoveAllTags") + + // Verify node still has original tags + nodeAfter, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, nodeAfter.IsTagged(), "Node should still be tagged") + assert.ElementsMatch(t, tags, nodeAfter.Tags().AsSlice(), "Tags should be unchanged") +} + +// TestUserOwnedNodeCreatedWithUntaggedPreAuthKey tests that using a PreAuthKey +// without tags creates a user-owned node (no tags, UserID is the owner). +func TestUserOwnedNodeCreatedWithUntaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("node-owner") + + // Create an untagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + require.Empty(t, pak.Tags, "PreAuthKey should not be tagged") + require.Empty(t, pak.Tags, "PreAuthKey should have no tags") + + // Register a node + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "user-owned-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify node is user-owned + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertions for user-owned node + assert.False(t, node.IsTagged(), "Node should not be tagged") + assert.False(t, node.IsTagged(), "Node should be user-owned (not tagged)") + assert.Empty(t, node.Tags().AsSlice(), "Node should have no tags") + assert.True(t, node.UserID().Valid(), "Node should have UserID") + assert.Equal(t, user.ID, node.UserID().Get(), "UserID should be the PreAuthKey owner") +} + +// TestMultipleNodesWithSameReusableTaggedPreAuthKey tests that a reusable +// PreAuthKey with tags can be used to register multiple nodes, and all nodes +// receive the same tags from the key. +func TestMultipleNodesWithSameReusableTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a REUSABLE tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register first node + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Register second node with SAME PreAuthKey + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + // Verify both nodes exist and have the same tags + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found) + + // Both nodes should be tagged with the same tags + assert.True(t, node1.IsTagged(), "First node should be tagged") + assert.True(t, node2.IsTagged(), "Second node should be tagged") + assert.ElementsMatch(t, tags, node1.Tags().AsSlice(), "First node should have PreAuthKey tags") + assert.ElementsMatch(t, tags, node2.Tags().AsSlice(), "Second node should have PreAuthKey tags") + + // Both nodes should track the same creator + assert.Equal(t, user.ID, node1.UserID().Get(), "First node should track creator") + assert.Equal(t, user.ID, node2.UserID().Get(), "Second node should track creator") + + // Verify we have exactly 2 nodes + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 2, nodes.Len(), "Should have exactly two nodes") +} + +// TestNonReusableTaggedPreAuthKey tests that a non-reusable PreAuthKey with tags +// can only be used once. The second attempt should fail. +func TestNonReusableTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a NON-REUSABLE tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register first node - should succeed + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Verify first node was created with tags + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + assert.ElementsMatch(t, tags, node1.Tags().AsSlice()) + + // Attempt to register second node with SAME non-reusable key - should fail + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same non-reusable key + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.Error(t, err, "Should not be able to reuse non-reusable PreAuthKey") + + // Verify only one node was created + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + assert.Equal(t, 1, nodes.Len(), "Should have exactly one node") +} + +// TestExpiredTaggedPreAuthKey tests that an expired PreAuthKey with tags +// cannot be used to register a node. +func TestExpiredTaggedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a PreAuthKey that expires immediately + expiration := time.Now().Add(-1 * time.Hour) // Already expired + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiration, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Attempt to register with expired key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.Error(t, err, "Should not be able to use expired PreAuthKey") + + // Verify no node was created + _, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + assert.False(t, found, "No node should be created with expired key") +} + +// TestSingleVsMultipleTags tests that PreAuthKeys work correctly with both +// a single tag and multiple tags. +func TestSingleVsMultipleTags(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + + // Test with single tag + singleTag := []string{"tag:server"} + pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, singleTag) + require.NoError(t, err) + + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-tag-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + assert.ElementsMatch(t, singleTag, node1.Tags().AsSlice()) + + // Test with multiple tags + multipleTags := []string{"tag:server", "tag:prod", "tag:database"} + pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, multipleTags) + require.NoError(t, err) + + machineKey2 := key.NewMachine() + nodeKey2 := key.NewNode() + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak2.Key, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-tag-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found) + assert.True(t, node2.IsTagged()) + assert.ElementsMatch(t, multipleTags, node2.Tags().AsSlice()) + + // Verify HasTag works for all tags + assert.True(t, node2.HasTag("tag:server")) + assert.True(t, node2.HasTag("tag:prod")) + assert.True(t, node2.HasTag("tag:database")) + assert.False(t, node2.HasTag("tag:other")) +} + +// TestTaggedPreAuthKeyDisablesKeyExpiry tests that nodes registered with +// a tagged PreAuthKey have key expiry disabled (expiry is nil). +func TestTaggedPreAuthKeyDisablesKeyExpiry(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server", "tag:prod"} + + // Create a tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + require.ElementsMatch(t, tags, pak.Tags) + + // Register a node using the tagged key + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Client requests an expiry time, but for tagged nodes it should be ignored + clientRequestedExpiry := time.Now().Add(24 * time.Hour) + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-expiry-test", + }, + Expiry: clientRequestedExpiry, + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify the node has key expiry DISABLED (expiry is nil/zero) + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertion: Tagged nodes should have expiry disabled + assert.True(t, node.IsTagged(), "Node should be tagged") + assert.False(t, node.Expiry().Valid(), "Tagged node should have expiry disabled (nil)") +} + +// TestUntaggedPreAuthKeyPreservesKeyExpiry tests that nodes registered with +// an untagged PreAuthKey preserve the client's requested key expiry. +func TestUntaggedPreAuthKeyPreservesKeyExpiry(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("node-owner") + + // Create an untagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + require.Empty(t, pak.Tags, "PreAuthKey should not be tagged") + + // Register a node + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Client requests an expiry time + clientRequestedExpiry := time.Now().Add(24 * time.Hour) + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "untagged-expiry-test", + }, + Expiry: clientRequestedExpiry, + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify the node has the client's requested expiry + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertion: User-owned nodes should preserve client expiry + assert.False(t, node.IsTagged(), "Node should not be tagged") + assert.True(t, node.Expiry().Valid(), "User-owned node should have expiry set") + // Allow some tolerance for test execution time + assert.WithinDuration(t, clientRequestedExpiry, node.Expiry().Get(), 5*time.Second, + "User-owned node should have the client's requested expiry") +} + +// TestTaggedNodeReauthPreservesDisabledExpiry tests that when a tagged node +// re-authenticates, the disabled expiry is preserved (not updated from client request). +func TestTaggedNodeReauthPreservesDisabledExpiry(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a reusable tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + // Initial registration + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-reauth-test", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp.MachineAuthorized) + + // Verify initial registration has expiry disabled + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + require.True(t, node.IsTagged()) + require.False(t, node.Expiry().Valid(), "Initial registration should have expiry disabled") + + // Re-authenticate with a NEW expiry request (should be ignored for tagged nodes) + newRequestedExpiry := time.Now().Add(48 * time.Hour) + reAuthReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-reauth-test", + }, + Expiry: newRequestedExpiry, // Client requests new expiry + } + + reAuthResp, err := app.handleRegisterWithAuthKey(reAuthReq, machineKey.Public()) + require.NoError(t, err) + require.True(t, reAuthResp.MachineAuthorized) + + // Verify expiry is STILL disabled after re-auth + nodeAfterReauth, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + + // Critical assertion: Tagged node should preserve disabled expiry on re-auth + assert.True(t, nodeAfterReauth.IsTagged(), "Node should still be tagged") + assert.False(t, nodeAfterReauth.Expiry().Valid(), + "Tagged node should have expiry PRESERVED as disabled after re-auth") +} + +// TestReAuthWithDifferentMachineKey tests the edge case where a node attempts +// to re-authenticate with the same NodeKey but a DIFFERENT MachineKey. +// This scenario should be handled gracefully (currently creates a new node). +func TestReAuthWithDifferentMachineKey(t *testing.T) { + app := createTestApp(t) + + user := app.state.CreateUserForTest("tag-creator") + tags := []string{"tag:server"} + + // Create a reusable tagged PreAuthKey + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + require.NoError(t, err) + + // Initial registration + machineKey1 := key.NewMachine() + nodeKey := key.NewNode() // Same NodeKey for both attempts + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized) + + // Verify initial node + node1, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, node1.IsTagged()) + + // Re-authenticate with DIFFERENT MachineKey but SAME NodeKey + machineKey2 := key.NewMachine() // Different machine key + + regReq2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), // Same NodeKey + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public()) + require.NoError(t, err) + require.True(t, resp2.MachineAuthorized) + + // Verify the node still exists and has tags + // Note: Depending on implementation, this might be the same node or a new node + node2, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, node2.IsTagged()) + assert.ElementsMatch(t, tags, node2.Tags().AsSlice()) +} diff --git a/hscontrol/auth_test.go b/hscontrol/auth_test.go new file mode 100644 index 00000000..1677642f --- /dev/null +++ b/hscontrol/auth_test.go @@ -0,0 +1,3727 @@ +package hscontrol + +import ( + "context" + "fmt" + "net/url" + "strings" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +// Interactive step type constants +const ( + stepTypeInitialRequest = "initial_request" + stepTypeAuthCompletion = "auth_completion" + stepTypeFollowupRequest = "followup_request" +) + +// interactiveStep defines a step in the interactive authentication workflow +type interactiveStep struct { + stepType string // stepTypeInitialRequest, stepTypeAuthCompletion, or stepTypeFollowupRequest + expectAuthURL bool + expectCacheEntry bool + callAuthPath bool // Real call to HandleNodeFromAuthPath, not mocked +} + +func TestAuthenticationFlows(t *testing.T) { + // Shared test keys for consistent behavior across test cases + machineKey1 := key.NewMachine() + machineKey2 := key.NewMachine() + nodeKey1 := key.NewNode() + nodeKey2 := key.NewNode() + + tests := []struct { + name string + setupFunc func(*testing.T, *Headscale) (string, error) // Returns dynamic values like auth keys + request func(dynamicValue string) tailcfg.RegisterRequest + machineKey func() key.MachinePublic + wantAuth bool + wantError bool + wantAuthURL bool + wantExpired bool + validate func(*testing.T, *tailcfg.RegisterResponse, *Headscale) + + // Interactive workflow support + requiresInteractiveFlow bool + interactiveSteps []interactiveStep + validateRegistrationCache bool + expectedAuthURLPattern string + simulateAuthCompletion bool + validateCompleteResponse bool + }{ + // === PRE-AUTH KEY SCENARIOS === + // Tests authentication using pre-authorization keys for automated node registration. + // Pre-auth keys allow nodes to join without interactive authentication. + + // TEST: Valid pre-auth key registers a new node + // WHAT: Tests successful node registration using a valid pre-auth key + // INPUT: Register request with valid pre-auth key, node key, and hostinfo + // EXPECTED: Node is authorized immediately, registered in database + // WHY: Pre-auth keys enable automated/headless node registration without user interaction + { + name: "preauth_key_valid_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("preauth-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "preauth-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + assert.NotEmpty(t, resp.User.DisplayName) + + // Verify node was created in database + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "preauth-node-1", node.Hostname()) + }, + }, + + // TEST: Reusable pre-auth key can register multiple nodes + // WHAT: Tests that a reusable pre-auth key can be used for multiple node registrations + // INPUT: Same reusable pre-auth key used to register two different nodes + // EXPECTED: Both nodes successfully register with the same key + // WHY: Reusable keys allow multiple machines to join using one key (useful for fleet deployments) + { + name: "preauth_key_reusable_multiple_nodes", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("reusable-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Use the key for first node + firstReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reusable-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reusable-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify both nodes exist + node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) + assert.True(t, found2) + assert.Equal(t, "reusable-node-1", node1.Hostname()) + assert.Equal(t, "reusable-node-2", node2.Hostname()) + }, + }, + + // TEST: Single-use pre-auth key cannot be reused + // WHAT: Tests that a single-use pre-auth key fails on second use + // INPUT: Single-use key used for first node (succeeds), then attempted for second node + // EXPECTED: First node registers successfully, second node fails with error + // WHY: Single-use keys provide security by preventing key reuse after initial registration + { + name: "preauth_key_single_use_exhausted", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("single-use-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + if err != nil { + return "", err + } + + // Use the key for first node (should work) + firstReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-use-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(firstReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "single-use-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, + wantError: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // First node should exist, second should not + _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + _, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found1) + assert.False(t, found2) + }, + }, + + // TEST: Invalid pre-auth key is rejected + // WHAT: Tests that an invalid/non-existent pre-auth key is rejected + // INPUT: Register request with invalid auth key string + // EXPECTED: Registration fails with error + // WHY: Invalid keys must be rejected to prevent unauthorized node registration + { + name: "preauth_key_invalid", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "invalid-key-12345", nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "invalid-key-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // TEST: Ephemeral pre-auth key creates ephemeral node + // WHAT: Tests that a node registered with ephemeral key is marked as ephemeral + // INPUT: Pre-auth key with ephemeral=true, standard register request + // EXPECTED: Node registers and is marked as ephemeral (will be deleted on logout) + // WHY: Ephemeral nodes auto-cleanup when disconnected, useful for temporary/CI environments + { + name: "preauth_key_ephemeral_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("ephemeral-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "ephemeral-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify ephemeral node was created + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.NotNil(t, node.AuthKey) + assert.True(t, node.AuthKey().Ephemeral()) + }, + }, + + // === INTERACTIVE REGISTRATION SCENARIOS === + // Tests interactive authentication flow where user completes registration via web UI. + // Interactive flow: node requests registration → receives AuthURL → user authenticates → node gets registered + + // TEST: Complete interactive workflow for new node + // WHAT: Tests full interactive registration flow from initial request to completion + // INPUT: Register request with no auth → user completes auth → followup request + // EXPECTED: Initial request returns AuthURL, after auth completion node is registered + // WHY: Interactive flow is the standard user-facing authentication method for new nodes + { + name: "full_interactive_workflow_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-flow-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, // cleaned up after completion + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + }, + // TEST: Interactive workflow with no Auth struct in request + // WHAT: Tests interactive flow when request has no Auth field (nil) + // INPUT: Register request with Auth field set to nil + // EXPECTED: Node receives AuthURL and can complete registration via interactive flow + // WHY: Validates handling of requests without Auth field, same as empty auth + { + name: "interactive_workflow_no_auth_struct", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + // No Auth field at all + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-no-auth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, // cleaned up after completion + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + }, + + // === EXISTING NODE SCENARIOS === + // Tests behavior when existing registered nodes send requests (logout, re-auth, expiry, etc.) + + // TEST: Existing node logout with past expiry + // WHAT: Tests node logout by sending request with expiry in the past + // INPUT: Previously registered node sends request with Auth=nil and past expiry time + // EXPECTED: Node expiry is updated, NodeKeyExpired=true, MachineAuthorized=true (for compatibility) + // WHY: Nodes signal logout by setting expiry to past time; system updates node state accordingly + { + name: "existing_node_logout", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("logout-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "logout-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + t.Logf("Setup registered node: %+v", resp) + + // Wait for node to be available in NodeStore with debug info + var attemptCount int + require.EventuallyWithT(t, func(c *assert.CollectT) { + attemptCount++ + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + if assert.True(c, found, "node should be available in NodeStore") { + t.Logf("Node found in NodeStore after %d attempts", attemptCount) + } + }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(-1 * time.Hour), // Past expiry = logout + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.True(t, resp.NodeKeyExpired) + }, + }, + // TEST: Existing node with different machine key is rejected + // WHAT: Tests that requests for existing node with wrong machine key are rejected + // INPUT: Node key matches existing node, but machine key is different + // EXPECTED: Request fails with unauthorized error (machine key mismatch) + // WHY: Machine key must match to prevent node hijacking/impersonation + { + name: "existing_node_machine_key_mismatch", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("mismatch-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register with machineKey1 + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "mismatch-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(-1 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, // Different machine key + wantError: true, + }, + // TEST: Existing node cannot extend expiry without re-auth + // WHAT: Tests that nodes cannot extend their expiry time without authentication + // INPUT: Existing node sends request with Auth=nil and future expiry (extension attempt) + // EXPECTED: Request fails with error (extending key not allowed) + // WHY: Prevents nodes from extending their own lifetime; must re-authenticate + { + name: "existing_node_key_extension_not_allowed", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("extend-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "extend-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(48 * time.Hour), // Future time = extend attempt + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Expired node must re-authenticate + // WHAT: Tests that expired nodes receive NodeKeyExpired=true and must re-auth + // INPUT: Previously expired node sends request with no auth + // EXPECTED: Response has NodeKeyExpired=true, node must re-authenticate + // WHY: Expired nodes must go through authentication again for security + { + name: "existing_node_expired_forces_reauth", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("reauth-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + var node types.NodeView + var found bool + require.EventuallyWithT(t, func(c *assert.CollectT) { + node, found = app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + if !found { + return "", fmt.Errorf("node not found after setup") + } + + // Expire the node + expiredTime := time.Now().Add(-1 * time.Hour) + _, _, err = app.state.SetNodeExpiry(node.ID(), expiredTime) + return "", err + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(24 * time.Hour), // Future expiry + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.NodeKeyExpired) + assert.False(t, resp.MachineAuthorized) + }, + }, + // TEST: Ephemeral node is deleted on logout + // WHAT: Tests that ephemeral nodes are deleted (not just expired) on logout + // INPUT: Ephemeral node sends logout request (past expiry) + // EXPECTED: Node is completely deleted from database, not just marked expired + // WHY: Ephemeral nodes should not persist after logout; auto-cleanup + { + name: "ephemeral_node_logout_deletion", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("ephemeral-logout-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) + if err != nil { + return "", err + } + + // Register ephemeral node + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "ephemeral-logout-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available in NodeStore + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now().Add(-1 * time.Hour), // Logout + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.NodeKeyExpired) + assert.False(t, resp.MachineAuthorized) + + // Ephemeral node should be deleted, not just marked expired + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.False(t, found, "ephemeral node should be deleted on logout") + }, + }, + + // === FOLLOWUP REGISTRATION SCENARIOS === + // Tests followup request handling after interactive registration is initiated. + // Followup requests are sent by nodes waiting for auth completion. + + // TEST: Successful followup registration after auth completion + // WHAT: Tests node successfully completes registration via followup URL + // INPUT: Register request with followup URL after auth completion + // EXPECTED: Node receives successful registration response with user info + // WHY: Followup mechanism allows nodes to poll/wait for auth completion + { + name: "followup_registration_success", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + + registered := make(chan *types.Node, 1) + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: "followup-success-node", + }, + Registered: registered, + } + app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + + // Simulate successful registration - send to buffered channel + // The channel is buffered (size 1), so this can complete immediately + // and handleRegister will receive the value when it starts waiting + go func() { + user := app.state.CreateUserForTest("followup-user") + node := app.state.CreateNodeForTest(user, "followup-success-node") + registered <- node + }() + + return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + }, + }, + // TEST: Followup registration times out when auth not completed + // WHAT: Tests that followup request times out if auth is not completed in time + // INPUT: Followup request with short timeout, no auth completion + // EXPECTED: Request times out with unauthorized error + // WHY: Prevents indefinite waiting; nodes must retry if auth takes too long + { + name: "followup_registration_timeout", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + + registered := make(chan *types.Node, 1) + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: "followup-timeout-node", + }, + Registered: registered, + } + app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + // Don't send anything on channel - will timeout + + return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Invalid followup URL is rejected + // WHAT: Tests that malformed/invalid followup URLs are rejected + // INPUT: Register request with invalid URL in Followup field + // EXPECTED: Request fails with error (invalid followup URL) + // WHY: Validates URL format to prevent errors and potential exploits + { + name: "followup_invalid_url", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "invalid://url[malformed", nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Non-existent registration ID is rejected + // WHAT: Tests that followup with non-existent registration ID fails + // INPUT: Valid followup URL but registration ID not in cache + // EXPECTED: Request fails with unauthorized error + // WHY: Registration must exist in cache; prevents invalid/expired registrations + { + name: "followup_registration_not_found", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "http://localhost:8080/register/nonexistent-id", nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // === EDGE CASES === + // Tests handling of malformed, invalid, or unusual input data + + // TEST: Empty hostname is handled with defensive code + // WHAT: Tests that empty hostname in hostinfo generates a default hostname + // INPUT: Register request with hostinfo containing empty hostname string + // EXPECTED: Node registers successfully with generated hostname (node-MACHINEKEY) + // WHY: Defensive code prevents errors from missing hostnames; generates sensible default + { + name: "empty_hostname", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("empty-hostname-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "", // Empty hostname should be handled gracefully + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created with generated hostname + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.NotEmpty(t, node.Hostname()) + }, + }, + // TEST: Nil hostinfo is handled with defensive code + // WHAT: Tests that nil hostinfo in register request is handled gracefully + // INPUT: Register request with Hostinfo field set to nil + // EXPECTED: Node registers successfully with generated hostname starting with "node-" + // WHY: Defensive code prevents nil pointer panics; creates valid default hostinfo + { + name: "nil_hostinfo", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("nil-hostinfo-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: nil, // Nil hostinfo should be handled with defensive code + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created with generated hostname from defensive code + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.NotEmpty(t, node.Hostname()) + // Hostname should start with "node-" (generated from machine key) + assert.True(t, strings.HasPrefix(node.Hostname(), "node-")) + }, + }, + + // === PRE-AUTH KEY WITH EXPIRY SCENARIOS === + // Tests pre-auth key expiration handling + + // TEST: Expired pre-auth key is rejected + // WHAT: Tests that a pre-auth key with past expiration date cannot be used + // INPUT: Pre-auth key with expiry 1 hour in the past + // EXPECTED: Registration fails with error + // WHY: Expired keys must be rejected to maintain security and key lifecycle management + { + name: "preauth_key_expired", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("expired-pak-user") + expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "expired-pak-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // TEST: Pre-auth key with ACL tags applies tags to node + // WHAT: Tests that ACL tags from pre-auth key are applied to registered node + // INPUT: Pre-auth key with ACL tags ["tag:test", "tag:integration"], register request + // EXPECTED: Node registers with specified ACL tags applied as ForcedTags + // WHY: Pre-auth keys can enforce ACL policies on nodes during registration + { + name: "preauth_key_with_acl_tags", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("tagged-pak-user") + tags := []string{"tag:server", "tag:database"} + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-pak-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify node was created with tags + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "tagged-pak-node", node.Hostname()) + if node.AuthKey().Valid() { + assert.NotEmpty(t, node.AuthKey().Tags()) + } + }, + }, + + // === ADVERTISE-TAGS (RequestTags) SCENARIOS === + // Tests for client-provided tags via --advertise-tags flag + + // TEST: PreAuthKey registration rejects client-provided RequestTags + // WHAT: Tests that PreAuthKey registrations cannot use client-provided tags + // INPUT: PreAuthKey registration with RequestTags in Hostinfo + // EXPECTED: Registration fails with "requested tags [...] are invalid or not permitted" error + // WHY: PreAuthKey nodes get their tags from the key itself, not from client requests + { + name: "preauth_key_rejects_request_tags", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + t.Helper() + + user := app.state.CreateUserForTest("pak-requesttags-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pak-requesttags-node", + RequestTags: []string{"tag:unauthorized"}, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: machineKey1.Public, + wantError: true, + }, + + // TEST: Tagged PreAuthKey ignores client-provided RequestTags + // WHAT: Tests that tagged PreAuthKey uses key tags, not client RequestTags + // INPUT: Tagged PreAuthKey registration with different RequestTags + // EXPECTED: Registration fails because RequestTags are rejected for PreAuthKey + // WHY: Tags-as-identity: PreAuthKey tags are authoritative, client cannot override + { + name: "tagged_preauth_key_rejects_client_request_tags", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + t.Helper() + + user := app.state.CreateUserForTest("tagged-pak-clienttags-user") + keyTags := []string{"tag:authorized"} + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, keyTags) + if err != nil { + return "", err + } + + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-pak-clienttags-node", + RequestTags: []string{"tag:client-wants-this"}, // Should be rejected + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: machineKey1.Public, + wantError: true, // RequestTags rejected for PreAuthKey registrations + }, + + // === RE-AUTHENTICATION SCENARIOS === + // TEST: Existing node re-authenticates with new pre-auth key + // WHAT: Tests that existing node can re-authenticate using new pre-auth key + // INPUT: Existing node sends request with new valid pre-auth key + // EXPECTED: Node successfully re-authenticates, stays authorized + // WHY: Allows nodes to refresh authentication using pre-auth keys + { + name: "existing_node_reauth_with_new_authkey", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("reauth-user") + + // First, register with initial auth key + pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + // Create new auth key for re-authentication + pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak2.Key, nil + }, + request: func(newAuthKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: newAuthKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-node-updated", + }, + Expiry: time.Now().Add(48 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify node was updated, not duplicated + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "reauth-node-updated", node.Hostname()) + }, + }, + // TEST: Existing node re-authenticates via interactive flow + // WHAT: Tests that existing expired node can re-authenticate interactively + // INPUT: Expired node initiates interactive re-authentication + // EXPECTED: Node receives AuthURL and can complete re-authentication + // WHY: Allows expired nodes to re-authenticate without pre-auth keys + { + name: "existing_node_reauth_interactive_flow", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("interactive-reauth-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register initially with auth key + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-reauth-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: "", // Empty auth key triggers interactive flow + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-reauth-node-updated", + }, + Expiry: time.Now().Add(48 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuthURL: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.Contains(t, resp.AuthURL, "register/") + assert.False(t, resp.MachineAuthorized) + }, + }, + + // === NODE KEY ROTATION SCENARIOS === + // Tests node key rotation where node changes its node key while keeping same machine key + + // TEST: Node key rotation with same machine key updates in place + // WHAT: Tests that registering with new node key and same machine key updates existing node + // INPUT: Register node with nodeKey1, then register again with nodeKey2 but same machineKey + // EXPECTED: Node is updated in place; nodeKey2 exists, nodeKey1 no longer exists + // WHY: Same machine key means same physical device; node key rotation updates, doesn't duplicate + { + name: "node_key_rotation_same_machine", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("rotation-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register with initial node key + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + // Create new auth key for rotation + pakRotation, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pakRotation.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey2.Public(), // Different node key, same machine + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // When same machine key is used, node is updated in place (not duplicated) + // The old nodeKey1 should no longer exist + _, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.False(t, found1, "old node key should not exist after rotation") + + // The new nodeKey2 should exist with the same machine key + node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + assert.True(t, found2, "new node key should exist after rotation") + assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should remain the same") + }, + }, + + // === MALFORMED REQUEST SCENARIOS === + // Tests handling of requests with malformed or unusual field values + + // TEST: Zero-time expiry is handled correctly + // WHAT: Tests registration with expiry set to zero time value + // INPUT: Register request with Expiry set to time.Time{} (zero value) + // EXPECTED: Node registers successfully; zero time treated as no expiry + // WHY: Zero time is valid Go default; should be handled gracefully + { + name: "malformed_expiry_zero_time", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("zero-expiry-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "zero-expiry-node", + }, + Expiry: time.Time{}, // Zero time + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created with default expiry handling + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "zero-expiry-node", node.Hostname()) + }, + }, + // TEST: Malformed hostinfo with very long hostname is truncated + // WHAT: Tests that excessively long hostname is truncated to DNS label limit + // INPUT: Hostinfo with 110-character hostname (exceeds 63-char DNS limit) + // EXPECTED: Node registers successfully; hostname truncated to 63 characters + // WHY: Defensive code enforces DNS label limit (RFC 1123); prevents errors + { + name: "malformed_hostinfo_invalid_data", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("invalid-hostinfo-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node-with-very-long-hostname-that-might-exceed-normal-limits-and-contain-special-chars-!@#$%", + BackendLogID: "invalid-log-id", + OS: "unknown-os", + OSVersion: "999.999.999", + DeviceModel: "test-device-model", + // Note: RequestTags are not included for PreAuthKey registrations + // since tags come from the key itself, not client requests. + Services: []tailcfg.Service{{Proto: "tcp", Port: 65535}}, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + + // Node should be created even with malformed hostinfo + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + // Hostname should be sanitized or handled gracefully + assert.NotEmpty(t, node.Hostname()) + }, + }, + + // === REGISTRATION CACHE EDGE CASES === + // Tests edge cases in registration cache handling during interactive flow + + // TEST: Followup registration with nil response (cache expired during auth) + // WHAT: Tests that followup request handles nil node response (cache expired/cleared) + // INPUT: Followup request where auth completion sends nil (cache was cleared) + // EXPECTED: Returns new AuthURL so client can retry authentication + // WHY: Nil response means cache expired - give client new AuthURL instead of error + { + name: "followup_registration_node_nil_response", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + + registered := make(chan *types.Node, 1) + nodeToRegister := types.RegisterNode{ + Node: types.Node{ + Hostname: "nil-response-node", + }, + Registered: registered, + } + app.state.SetRegistrationCacheEntry(regID, nodeToRegister) + + // Simulate registration that returns nil (cache expired during auth) + // The channel is buffered (size 1), so this can complete immediately + go func() { + registered <- nil // Nil indicates cache expiry + }() + + return fmt.Sprintf("http://localhost:8080/register/%s", regID), nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "nil-response-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: false, // Should not be authorized yet - needs to use new AuthURL + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Should get a new AuthURL, not an error + assert.NotEmpty(t, resp.AuthURL, "should receive new AuthURL when cache returns nil") + assert.Contains(t, resp.AuthURL, "/register/", "AuthURL should contain registration path") + assert.False(t, resp.MachineAuthorized, "machine should not be authorized yet") + }, + }, + // TEST: Malformed followup path is rejected + // WHAT: Tests that followup URL with malformed path is rejected + // INPUT: Followup URL with path that doesn't match expected format + // EXPECTED: Request fails with error (invalid followup URL) + // WHY: Path validation prevents processing of corrupted/invalid URLs + { + name: "followup_registration_malformed_path", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "http://localhost:8080/register/", nil // Missing registration ID + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + // TEST: Wrong followup path format is rejected + // WHAT: Tests that followup URL with incorrect path structure fails + // INPUT: Valid URL but path doesn't start with "/register/" + // EXPECTED: Request fails with error (invalid path format) + // WHY: Strict path validation ensures only valid registration URLs accepted + { + name: "followup_registration_wrong_path_format", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "http://localhost:8080/wrong/path/format", nil + }, + request: func(followupURL string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: followupURL, + NodeKey: nodeKey1.Public(), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantError: true, + }, + + // === AUTH PROVIDER EDGE CASES === + // TEST: Interactive workflow preserves custom hostinfo + // WHAT: Tests that custom hostinfo fields are preserved through interactive flow + // INPUT: Interactive registration with detailed hostinfo (OS, version, model) + // EXPECTED: Node registers with all hostinfo fields preserved + // WHY: Ensures interactive flow doesn't lose custom hostinfo data + // NOTE: RequestTags are NOT tested here because tag authorization via + // advertise-tags requires the user to have existing nodes (for IP-based + // ownership verification). New users registering their first node cannot + // claim tags via RequestTags - they must use a tagged PreAuthKey instead. + { + name: "interactive_workflow_with_custom_hostinfo", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "custom-interactive-node", + OS: "linux", + OSVersion: "20.04", + DeviceModel: "server", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, // cleaned up after completion + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Verify custom hostinfo was preserved through interactive workflow + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be found after interactive registration") + if found { + assert.Equal(t, "custom-interactive-node", node.Hostname()) + assert.Equal(t, "linux", node.Hostinfo().OS()) + assert.Equal(t, "20.04", node.Hostinfo().OSVersion()) + assert.Equal(t, "server", node.Hostinfo().DeviceModel()) + } + }, + }, + + // === PRE-AUTH KEY USAGE TRACKING === + // Tests accurate tracking of pre-auth key usage counts + + // TEST: Pre-auth key usage count is tracked correctly + // WHAT: Tests that each use of a pre-auth key increments its usage counter + // INPUT: Reusable pre-auth key used to register three different nodes + // EXPECTED: All three nodes register successfully, key usage count increments each time + // WHY: Usage tracking enables monitoring and auditing of pre-auth key usage + { + name: "preauth_key_usage_count_tracking", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("usage-count-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // Single use + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "usage-count-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify auth key usage was tracked + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "usage-count-node", node.Hostname()) + + // Key should now be used up (single use) + if node.AuthKey().Valid() { + assert.False(t, node.AuthKey().Reusable()) + } + }, + }, + + // === REGISTRATION ID GENERATION AND ADVANCED EDGE CASES === + // TEST: Interactive workflow generates valid registration IDs + // WHAT: Tests that interactive flow generates unique, valid registration IDs + // INPUT: Interactive registration request + // EXPECTED: AuthURL contains valid registration ID that can be extracted + // WHY: Registration IDs must be unique and valid for cache lookup + { + name: "interactive_workflow_registration_id_generation", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "registration-id-test-node", + OS: "test-os", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + expectedAuthURLPattern: "/register/", + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Verify registration ID was properly generated and used + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be registered after interactive workflow") + if found { + assert.Equal(t, "registration-id-test-node", node.Hostname()) + assert.Equal(t, "test-os", node.Hostinfo().OS()) + } + }, + }, + { + name: "concurrent_registration_same_node_key", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("concurrent-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "concurrent-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify node was registered + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "concurrent-node", node.Hostname()) + }, + }, + // TEST: Auth key expiry vs request expiry handling + // WHAT: Tests that pre-auth key expiry is independent of request expiry + // INPUT: Valid pre-auth key (future expiry), request with past expiry + // EXPECTED: Node registers with request expiry used (logout scenario) + // WHY: Request expiry overrides key expiry; allows logout with valid key + { + name: "auth_key_with_future_expiry_past_request_expiry", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("future-expiry-user") + // Auth key expires in the future + expiry := time.Now().Add(48 * time.Hour) + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) + if err != nil { + return "", err + } + return pak.Key, nil + }, + request: func(authKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "future-expiry-node", + }, + // Request expires before auth key + Expiry: time.Now().Add(12 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Node should be created with request expiry (shorter than auth key expiry) + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.Equal(t, "future-expiry-node", node.Hostname()) + }, + }, + // TEST: Re-authentication with different user's auth key + // WHAT: Tests node transfer when re-authenticating with a different user's auth key + // INPUT: Node registered with user1's auth key, re-authenticates with user2's auth key + // EXPECTED: Node is transferred to user2 (updates UserID and related fields) + // WHY: Validates device reassignment scenarios where a machine moves between users + { + name: "reauth_existing_node_different_user_auth_key", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Create two users + user1 := app.state.CreateUserForTest("user1-context") + user2 := app.state.CreateUserForTest("user2-context") + + // Register node with user1's auth key + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "context-node-user1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + // Return user2's auth key for re-authentication + pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + return pak2.Key, nil + }, + request: func(user2AuthKey string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: user2AuthKey, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "context-node-user2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.False(t, resp.NodeKeyExpired) + + // Verify NEW node was created for user2 + node2, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) + require.True(t, found, "new node should exist for user2") + assert.Equal(t, uint(2), node2.UserID().Get(), "new node should belong to user2") + + user := node2.User() + assert.Equal(t, "user2-context", user.Name(), "new node should show user2 username") + + // Verify original node still exists for user1 + node1, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) + require.True(t, found, "original node should still exist for user1") + assert.Equal(t, uint(1), node1.UserID().Get(), "original node should still belong to user1") + + // Verify they are different nodes (different IDs) + assert.NotEqual(t, node1.ID(), node2.ID(), "should be different node IDs") + }, + }, + // TEST: Re-authentication with different user via interactive flow creates new node + // WHAT: Tests new node creation when re-authenticating interactively with a different user + // INPUT: Node registered with user1, re-authenticates interactively as user2 (same machine key, same node key) + // EXPECTED: New node is created for user2, user1's original node remains (no transfer) + // WHY: Same physical machine can have separate node identities per user + { + name: "interactive_reauth_existing_node_different_user_creates_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Create user1 and register a node with auth key + user1 := app.state.CreateUserForTest("interactive-user-1") + + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register node with user1's auth key first + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "transfer-node-user1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{}, // Empty auth triggers interactive flow + NodeKey: nodeKey1.Public(), // Same node key as original registration + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "transfer-node-user2", // Different hostname + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, // Same machine key + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // User1's original node should STILL exist (not transferred) + node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1)) + require.True(t, found1, "user1's original node should still exist") + assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1") + assert.Equal(t, nodeKey1.Public(), node1.NodeKey(), "user1's node should have original node key") + + // User2 should have a NEW node created + node2, found2 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2)) + require.True(t, found2, "user2 should have new node created") + assert.Equal(t, uint(2), node2.UserID().Get(), "user2's node should belong to user2") + + user := node2.User() + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should show correct username") + + // Both nodes should have the same machine key but different IDs + assert.NotEqual(t, node1.ID(), node2.ID(), "should be different nodes (different IDs)") + assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "user2's node should have same machine key") + }, + }, + // TEST: Followup request after registration cache expiry + // WHAT: Tests that expired followup requests get a new AuthURL instead of error + // INPUT: Followup request for registration ID that has expired/been evicted from cache + // EXPECTED: Returns new AuthURL (not error) so client can retry authentication + // WHY: Validates new reqToNewRegisterResponse functionality - prevents client getting stuck + { + name: "followup_request_after_cache_expiry", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Generate a registration ID that doesn't exist in cache + // This simulates an expired/missing cache entry + regID, err := types.NewRegistrationID() + if err != nil { + return "", err + } + // Don't add it to cache - it's already expired/missing + return regID.String(), nil + }, + request: func(regID string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Followup: "http://localhost:8080/register/" + regID, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "expired-cache-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: false, // Should not be authorized yet - needs to use new AuthURL + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Should get a new AuthURL, not an error + assert.NotEmpty(t, resp.AuthURL, "should receive new AuthURL when registration expired") + assert.Contains(t, resp.AuthURL, "/register/", "AuthURL should contain registration path") + assert.False(t, resp.MachineAuthorized, "machine should not be authorized yet") + + // Verify the response contains a valid registration URL + authURL, err := url.Parse(resp.AuthURL) + assert.NoError(t, err, "AuthURL should be a valid URL") + assert.True(t, strings.HasPrefix(authURL.Path, "/register/"), "AuthURL path should start with /register/") + + // Extract and validate the new registration ID exists in cache + newRegIDStr := strings.TrimPrefix(authURL.Path, "/register/") + newRegID, err := types.RegistrationIDFromString(newRegIDStr) + assert.NoError(t, err, "should be able to parse new registration ID") + + // Verify new registration entry exists in cache + _, found := app.state.GetRegistrationCacheEntry(newRegID) + assert.True(t, found, "new registration should exist in cache") + }, + }, + // TEST: Logout with expiry exactly at current time + // WHAT: Tests logout when expiry is set to exact current time (boundary case) + // INPUT: Existing node sends request with expiry=time.Now() (not past, not future) + // EXPECTED: Node is logged out (treated as expired) + // WHY: Edge case: current time should be treated as expired + { + name: "logout_with_exactly_now_expiry", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + user := app.state.CreateUserForTest("exact-now-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node first + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "exact-now-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegisterWithAuthKey(regReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: nil, + NodeKey: nodeKey1.Public(), + Expiry: time.Now(), // Exactly now (edge case between past and future) + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + wantAuth: true, + wantExpired: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + assert.True(t, resp.MachineAuthorized) + assert.True(t, resp.NodeKeyExpired) + + // Node should be marked as expired but still exist + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found) + assert.True(t, node.IsExpired()) + }, + }, + // TEST: Interactive workflow timeout cleans up cache + // WHAT: Tests that timed-out interactive registrations clean up cache entries + // INPUT: Interactive registration that times out without completion + // EXPECTED: Cache entry should be cleaned up (behavior depends on implementation) + // WHY: Prevents cache bloat from abandoned registrations + { + name: "interactive_workflow_timeout_cleanup", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey2.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "interactive-timeout-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey2.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + // NOTE: No auth_completion step - simulates timeout scenario + }, + validateRegistrationCache: true, // should be cleaned up eventually + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Verify AuthURL was generated but registration not completed + assert.Contains(t, resp.AuthURL, "/register/") + assert.False(t, resp.MachineAuthorized) + }, + }, + + // === COMPREHENSIVE INTERACTIVE WORKFLOW EDGE CASES === + // TEST: Interactive workflow with existing node from different user creates new node + // WHAT: Tests new node creation when re-authenticating interactively with different user + // INPUT: Node already registered with user1, interactive auth with user2 (same machine key, different node key) + // EXPECTED: New node is created for user2, user1's original node remains (no transfer) + // WHY: Same physical machine can have separate node identities per user + { + name: "interactive_workflow_with_existing_node_different_user_creates_new_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // First create a node under user1 + user1 := app.state.CreateUserForTest("existing-user-1") + + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + // Register the node with user1 first + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "existing-node-user1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{}, // Empty auth triggers interactive flow + NodeKey: nodeKey2.Public(), // Different node key for different user + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "existing-node-user2", // Different hostname + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // User1's original node with nodeKey1 should STILL exist + node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found1, "user1's original node with nodeKey1 should still exist") + assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1") + assert.Equal(t, uint64(1), node1.ID().Uint64(), "user1's node should be ID=1") + + // User2 should have a NEW node with nodeKey2 + node2, found2 := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found2, "user2 should have new node with nodeKey2") + + assert.Equal(t, "existing-node-user2", node2.Hostname(), "hostname should be from new registration") + user := node2.User() + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2") + assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should be the same") + + // Verify it's a NEW node, not transferred + assert.NotEqual(t, uint64(1), node2.ID().Uint64(), "should be a NEW node (different ID)") + }, + }, + // TEST: Interactive workflow with malformed followup URL + // WHAT: Tests that malformed followup URLs in interactive flow are rejected + // INPUT: Interactive registration with invalid followup URL format + // EXPECTED: Request fails with error (invalid URL) + // WHY: Validates followup URLs to prevent errors + { + name: "interactive_workflow_malformed_followup_url", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "malformed-followup-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Test malformed followup URLs after getting initial AuthURL + authURL := resp.AuthURL + assert.Contains(t, authURL, "/register/") + + // Test various malformed followup URLs - use completely invalid IDs to avoid blocking + malformedURLs := []string{ + "invalid-url", + "/register/", + "/register/invalid-id-that-does-not-exist", + "/register/00000000-0000-0000-0000-000000000000", + "http://malicious-site.com/register/invalid-id", + } + + for _, malformedURL := range malformedURLs { + followupReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Followup: malformedURL, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "malformed-followup-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + // These should all fail gracefully + _, err := app.handleRegister(context.Background(), followupReq, machineKey1.Public()) + assert.Error(t, err, "malformed followup URL should be rejected: %s", malformedURL) + } + }, + }, + // TEST: Concurrent interactive workflow registrations + // WHAT: Tests multiple simultaneous interactive registrations + // INPUT: Two nodes initiate interactive registration concurrently + // EXPECTED: Both registrations succeed independently + // WHY: System should handle concurrent interactive flows without conflicts + { + name: "interactive_workflow_concurrent_registrations", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "concurrent-registration-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // This test validates concurrent interactive registration attempts + assert.Contains(t, resp.AuthURL, "/register/") + + // Start multiple concurrent followup requests + authURL := resp.AuthURL + numConcurrent := 3 + results := make(chan error, numConcurrent) + + for i := range numConcurrent { + go func(index int) { + followupReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Followup: authURL, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: fmt.Sprintf("concurrent-node-%d", index), + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err := app.handleRegister(context.Background(), followupReq, machineKey1.Public()) + results <- err + }(i) + } + + // Complete the authentication to signal the waiting goroutines + // The goroutines will receive from the buffered channel when ready + registrationID, err := extractRegistrationIDFromAuthURL(authURL) + require.NoError(t, err) + + user := app.state.CreateUserForTest("concurrent-test-user") + _, _, err = app.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), + nil, + "concurrent-test-method", + ) + require.NoError(t, err) + + // Collect results - at least one should succeed + successCount := 0 + for range numConcurrent { + select { + case err := <-results: + if err == nil { + successCount++ + } + case <-time.After(2 * time.Second): + // Some may timeout, which is expected + } + } + + // At least one concurrent request should have succeeded + assert.GreaterOrEqual(t, successCount, 1, "at least one concurrent registration should succeed") + }, + }, + // TEST: Interactive workflow with node key rotation attempt + // WHAT: Tests interactive registration with different node key (appears as rotation) + // INPUT: Node registered with nodeKey1, then interactive registration with nodeKey2 + // EXPECTED: Creates new node for different user (not true rotation) + // WHY: Interactive flow creates new nodes with new users; doesn't rotate existing nodes + { + name: "interactive_workflow_node_key_rotation", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + // Register initial node + user := app.state.CreateUserForTest("rotation-user") + + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + if err != nil { + return "", err + } + + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node-initial", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegister(context.Background(), initialReq, machineKey1.Public()) + if err != nil { + return "", err + } + + // Wait for node to be available + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(c, found, "node should be available in NodeStore") + }, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore") + + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey2.Public(), // Different node key (rotation scenario) + OldNodeKey: nodeKey1.Public(), // Previous node key + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "rotation-node-updated", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // User1's original node with nodeKey1 should STILL exist + oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, foundOld, "user1's original node with nodeKey1 should still exist") + assert.Equal(t, uint(1), oldNode.UserID().Get(), "user1's node should still belong to user1") + assert.Equal(t, uint64(1), oldNode.ID().Uint64(), "user1's node should be ID=1") + + // User2 should have a NEW node with nodeKey2 + newNode, found := app.state.GetNodeByNodeKey(nodeKey2.Public()) + require.True(t, found, "user2 should have new node with nodeKey2") + assert.Equal(t, "rotation-node-updated", newNode.Hostname()) + assert.Equal(t, machineKey1.Public(), newNode.MachineKey()) + + user := newNode.User() + assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2") + + // Verify it's a NEW node, not transferred + assert.NotEqual(t, uint64(1), newNode.ID().Uint64(), "should be a NEW node (different ID)") + }, + }, + // TEST: Interactive workflow with nil hostinfo + // WHAT: Tests interactive registration when request has nil hostinfo + // INPUT: Interactive registration request with Hostinfo=nil + // EXPECTED: Node registers successfully with generated default hostname + // WHY: Defensive code handles nil hostinfo in interactive flow + { + name: "interactive_workflow_with_nil_hostinfo", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: nil, // Nil hostinfo should be handled gracefully + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + requiresInteractiveFlow: true, + interactiveSteps: []interactiveStep{ + {stepType: stepTypeInitialRequest, expectAuthURL: true, expectCacheEntry: true}, + {stepType: stepTypeAuthCompletion, callAuthPath: true, expectCacheEntry: false}, + }, + validateCompleteResponse: true, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Should handle nil hostinfo gracefully + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be registered despite nil hostinfo") + if found { + // Should have some default hostname or handle nil gracefully + hostname := node.Hostname() + assert.NotEmpty(t, hostname, "should have some hostname even with nil hostinfo") + } + }, + }, + // TEST: Registration cache cleanup on authentication error + // WHAT: Tests that cache is cleaned up when authentication fails + // INPUT: Interactive registration that fails during auth completion + // EXPECTED: Cache entry removed after error + // WHY: Failed registrations should clean up to prevent stale cache entries + { + name: "interactive_workflow_registration_cache_cleanup_on_error", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "cache-cleanup-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Get initial AuthURL and extract registration ID + authURL := resp.AuthURL + assert.Contains(t, authURL, "/register/") + + registrationID, err := extractRegistrationIDFromAuthURL(authURL) + require.NoError(t, err) + + // Verify cache entry exists + cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + assert.True(t, found, "registration cache entry should exist initially") + assert.NotNil(t, cacheEntry) + + // Try to complete authentication with invalid user ID (should cause error) + invalidUserID := types.UserID(99999) // Non-existent user + _, _, err = app.state.HandleNodeFromAuthPath( + registrationID, + invalidUserID, + nil, + "error-test-method", + ) + assert.Error(t, err, "should fail with invalid user ID") + + // Cache entry should still exist after auth error (for retry scenarios) + _, stillFound := app.state.GetRegistrationCacheEntry(registrationID) + assert.True(t, stillFound, "registration cache entry should still exist after auth error for potential retry") + }, + }, + // TEST: Multiple interactive workflow steps for same node + // WHAT: Tests that interactive workflow can handle multi-step process for same node + // INPUT: Node goes through complete interactive flow with multiple steps + // EXPECTED: Node successfully completes registration after all steps + // WHY: Validates complete interactive flow works end-to-end + // TEST: Interactive workflow with multiple registration attempts for same node + // WHAT: Tests that multiple interactive registrations can be created for same node + // INPUT: Start two interactive registrations, verify both cache entries exist + // EXPECTED: Both registrations get different IDs and can coexist + // WHY: Validates that multiple pending registrations don't interfere with each other + { + name: "interactive_workflow_multiple_steps_same_node", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-step-node", + OS: "linux", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + // Test multiple interactive registration attempts for the same node can coexist + authURL1 := resp.AuthURL + assert.Contains(t, authURL1, "/register/") + + // Start a second interactive registration for the same node + secondReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "multi-step-node-updated", + OS: "linux-updated", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) + require.NoError(t, err) + authURL2 := resp2.AuthURL + assert.Contains(t, authURL2, "/register/") + + // Both should have different registration IDs + regID1, err1 := extractRegistrationIDFromAuthURL(authURL1) + regID2, err2 := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err1) + require.NoError(t, err2) + assert.NotEqual(t, regID1, regID2, "different registration attempts should have different IDs") + + // Both cache entries should exist simultaneously + _, found1 := app.state.GetRegistrationCacheEntry(regID1) + _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first registration cache entry should exist") + assert.True(t, found2, "second registration cache entry should exist") + + // This validates that multiple pending registrations can coexist + // without interfering with each other + }, + }, + // TEST: Complete one of multiple pending registrations + // WHAT: Tests completing the second of two pending registrations for same node + // INPUT: Create two pending registrations, complete the second one + // EXPECTED: Second registration completes successfully, node is created + // WHY: Validates that you can complete any pending registration, not just the first + { + name: "interactive_workflow_complete_second_of_multiple_pending", + setupFunc: func(t *testing.T, app *Headscale) (string, error) { + return "", nil + }, + request: func(_ string) tailcfg.RegisterRequest { + return tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pending-node-1", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + }, + machineKey: func() key.MachinePublic { return machineKey1.Public() }, + validate: func(t *testing.T, resp *tailcfg.RegisterResponse, app *Headscale) { + authURL1 := resp.AuthURL + regID1, err := extractRegistrationIDFromAuthURL(authURL1) + require.NoError(t, err) + + // Start a second interactive registration for the same node + secondReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pending-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegister(context.Background(), secondReq, machineKey1.Public()) + require.NoError(t, err) + authURL2 := resp2.AuthURL + regID2, err := extractRegistrationIDFromAuthURL(authURL2) + require.NoError(t, err) + + // Verify both exist + _, found1 := app.state.GetRegistrationCacheEntry(regID1) + _, found2 := app.state.GetRegistrationCacheEntry(regID2) + assert.True(t, found1, "first cache entry should exist") + assert.True(t, found2, "second cache entry should exist") + + // Complete the SECOND registration (not the first) + user := app.state.CreateUserForTest("second-registration-user") + + // Start followup request in goroutine (it will wait for auth completion) + responseChan := make(chan *tailcfg.RegisterResponse, 1) + errorChan := make(chan error, 1) + + followupReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Followup: authURL2, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "pending-node-2", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + go func() { + resp, err := app.handleRegister(context.Background(), followupReq, machineKey1.Public()) + if err != nil { + errorChan <- err + return + } + responseChan <- resp + }() + + // Complete authentication for second registration + // The goroutine will receive the node from the buffered channel + _, _, err = app.state.HandleNodeFromAuthPath( + regID2, + types.UserID(user.ID), + nil, + "second-registration-method", + ) + require.NoError(t, err) + + // Wait for followup to complete + select { + case err := <-errorChan: + t.Fatalf("followup request failed: %v", err) + case finalResp := <-responseChan: + require.NotNil(t, finalResp) + assert.True(t, finalResp.MachineAuthorized, "machine should be authorized") + case <-time.After(2 * time.Second): + t.Fatal("followup request timed out") + } + + // Verify the node was created with the second registration's data + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + assert.True(t, found, "node should be registered") + if found { + assert.Equal(t, "pending-node-2", node.Hostname()) + assert.Equal(t, "second-registration-user", node.User().Name()) + } + + // First registration should still be in cache (not completed) + _, stillFound := app.state.GetRegistrationCacheEntry(regID1) + assert.True(t, stillFound, "first registration should still be pending") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create test app + app := createTestApp(t) + + // Run setup function + dynamicValue, err := tt.setupFunc(t, app) + require.NoError(t, err, "setup should not fail") + + // Check if this test requires interactive workflow + if tt.requiresInteractiveFlow { + runInteractiveWorkflowTest(t, tt, app, dynamicValue) + return + } + + // Build request + req := tt.request(dynamicValue) + machineKey := tt.machineKey() + + // Set up context with timeout for followup tests + ctx := context.Background() + if req.Followup != "" { + var cancel context.CancelFunc + ctx, cancel = context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + } + + // Debug: check node availability before test execution + if req.Auth == nil { + if node, found := app.state.GetNodeByNodeKey(req.NodeKey); found { + t.Logf("Node found before handleRegister: hostname=%s, expired=%t", node.Hostname(), node.IsExpired()) + } else { + t.Logf("Node NOT found before handleRegister for key %s", req.NodeKey.ShortString()) + } + } + + // Execute the test + resp, err := app.handleRegister(ctx, req, machineKey) + + // Validate error expectations + if tt.wantError { + assert.Error(t, err, "expected error but got none") + return + } + + require.NoError(t, err, "unexpected error: %v", err) + require.NotNil(t, resp, "response should not be nil") + + // Validate basic response properties + if tt.wantAuth { + assert.True(t, resp.MachineAuthorized, "machine should be authorized") + } else { + assert.False(t, resp.MachineAuthorized, "machine should not be authorized") + } + + if tt.wantAuthURL { + assert.NotEmpty(t, resp.AuthURL, "should have AuthURL") + assert.Contains(t, resp.AuthURL, "register/", "AuthURL should contain registration path") + } + + if tt.wantExpired { + assert.True(t, resp.NodeKeyExpired, "node key should be expired") + } else { + assert.False(t, resp.NodeKeyExpired, "node key should not be expired") + } + + // Run custom validation if provided + if tt.validate != nil { + tt.validate(t, resp, app) + } + }) + } +} + +// runInteractiveWorkflowTest executes a multi-step interactive authentication workflow +func runInteractiveWorkflowTest(t *testing.T, tt struct { + name string + setupFunc func(*testing.T, *Headscale) (string, error) + request func(dynamicValue string) tailcfg.RegisterRequest + machineKey func() key.MachinePublic + wantAuth bool + wantError bool + wantAuthURL bool + wantExpired bool + validate func(*testing.T, *tailcfg.RegisterResponse, *Headscale) + requiresInteractiveFlow bool + interactiveSteps []interactiveStep + validateRegistrationCache bool + expectedAuthURLPattern string + simulateAuthCompletion bool + validateCompleteResponse bool +}, app *Headscale, dynamicValue string, +) { + // Build initial request + req := tt.request(dynamicValue) + machineKey := tt.machineKey() + ctx := context.Background() + + // Execute interactive workflow steps + var ( + initialResp *tailcfg.RegisterResponse + authURL string + registrationID types.RegistrationID + finalResp *tailcfg.RegisterResponse + err error + ) + + // Execute the steps in the correct sequence for interactive workflow + for i, step := range tt.interactiveSteps { + t.Logf("Executing interactive step %d: %s", i+1, step.stepType) + + switch step.stepType { + case stepTypeInitialRequest: + // Step 1: Initial request should get AuthURL back + initialResp, err = app.handleRegister(ctx, req, machineKey) + require.NoError(t, err, "initial request should not fail") + require.NotNil(t, initialResp, "initial response should not be nil") + + if step.expectAuthURL { + require.NotEmpty(t, initialResp.AuthURL, "should have AuthURL") + require.Contains(t, initialResp.AuthURL, "/register/", "AuthURL should contain registration path") + authURL = initialResp.AuthURL + + // Extract registration ID from AuthURL + registrationID, err = extractRegistrationIDFromAuthURL(authURL) + require.NoError(t, err, "should be able to extract registration ID from AuthURL") + } + + if step.expectCacheEntry { + // Verify registration cache entry was created + cacheEntry, found := app.state.GetRegistrationCacheEntry(registrationID) + require.True(t, found, "registration cache entry should exist") + require.NotNil(t, cacheEntry, "cache entry should not be nil") + require.Equal(t, req.NodeKey, cacheEntry.Node.NodeKey, "cache entry should have correct node key") + } + + case stepTypeAuthCompletion: + // Step 2: Start followup request that will wait, then complete authentication + if step.callAuthPath { + require.NotEmpty(t, registrationID, "registration ID should be available from previous step") + + // Prepare followup request + followupReq := tt.request(dynamicValue) + followupReq.Followup = authURL + + // Start the followup request in a goroutine - it will wait for channel signal + responseChan := make(chan *tailcfg.RegisterResponse, 1) + errorChan := make(chan error, 1) + + go func() { + resp, err := app.handleRegister(context.Background(), followupReq, machineKey) + if err != nil { + errorChan <- err + return + } + responseChan <- resp + }() + + // Complete the authentication - the goroutine will receive from the buffered channel + user := app.state.CreateUserForTest("interactive-test-user") + _, _, err = app.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), + nil, // no custom expiry + "test-method", + ) + require.NoError(t, err, "HandleNodeFromAuthPath should succeed") + + // Wait for the followup request to complete + select { + case err := <-errorChan: + require.NoError(t, err, "followup request should not fail") + case finalResp = <-responseChan: + require.NotNil(t, finalResp, "final response should not be nil") + // Verify machine is now authorized + require.True(t, finalResp.MachineAuthorized, "machine should be authorized after followup") + case <-time.After(5 * time.Second): + t.Fatal("followup request timed out waiting for authentication completion") + } + } + + case stepTypeFollowupRequest: + // This step is deprecated - followup is now handled within auth_completion step + t.Logf("followup_request step is deprecated - use expectCacheEntry in auth_completion instead") + + default: + t.Fatalf("unknown interactive step type: %s", step.stepType) + } + + // Check cache cleanup expectation for this step + if step.expectCacheEntry == false && registrationID != "" { + // Verify cache entry was cleaned up + _, found := app.state.GetRegistrationCacheEntry(registrationID) + require.False(t, found, "registration cache entry should be cleaned up after step: %s", step.stepType) + } + } + + // Validate final response if requested + if tt.validateCompleteResponse && finalResp != nil { + validateCompleteRegistrationResponse(t, finalResp, req) + } + + // Run custom validation if provided + if tt.validate != nil { + responseToValidate := finalResp + if responseToValidate == nil { + responseToValidate = initialResp + } + tt.validate(t, responseToValidate, app) + } +} + +// extractRegistrationIDFromAuthURL extracts the registration ID from an AuthURL +func extractRegistrationIDFromAuthURL(authURL string) (types.RegistrationID, error) { + // AuthURL format: "http://localhost/register/abc123" + const registerPrefix = "/register/" + idx := strings.LastIndex(authURL, registerPrefix) + if idx == -1 { + return "", fmt.Errorf("invalid AuthURL format: %s", authURL) + } + + idStr := authURL[idx+len(registerPrefix):] + return types.RegistrationIDFromString(idStr) +} + +// validateCompleteRegistrationResponse performs comprehensive validation of a registration response +func validateCompleteRegistrationResponse(t *testing.T, resp *tailcfg.RegisterResponse, originalReq tailcfg.RegisterRequest) { + // Basic response validation + require.NotNil(t, resp, "response should not be nil") + require.True(t, resp.MachineAuthorized, "machine should be authorized") + require.False(t, resp.NodeKeyExpired, "node key should not be expired") + require.NotEmpty(t, resp.User.DisplayName, "user should have display name") + + // Additional validation can be added here as needed + // Note: NodeKey field may not be present in all response types + + // Additional validation can be added here as needed +} + +// Simple test to validate basic node creation and lookup +func TestNodeStoreLookup(t *testing.T) { + app := createTestApp(t) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + + // Register a node + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized) + + t.Logf("Registered node successfully: %+v", resp) + + // Wait for node to be available in NodeStore + var node types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { + var found bool + node, found = app.state.GetNodeByNodeKey(nodeKey.Public()) + assert.True(c, found, "Node should be found in NodeStore") + }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available in NodeStore") + + require.Equal(t, "test-node", node.Hostname()) + + t.Logf("Found node: hostname=%s, id=%d", node.Hostname(), node.ID().Uint64()) +} + +// TestPreAuthKeyLogoutAndReloginDifferentUser tests the scenario where: +// 1. Multiple nodes register with different users using pre-auth keys +// 2. All nodes logout +// 3. All nodes re-login using a different user's pre-auth key +// EXPECTED BEHAVIOR: Should create NEW nodes for the new user, leaving old nodes with the old user. +// This matches the integration test expectation and web flow behavior. +func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) { + app := createTestApp(t) + + // Create two users + user1 := app.state.CreateUserForTest("user1") + user2 := app.state.CreateUserForTest("user2") + + // Create pre-auth keys for both users + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) + require.NoError(t, err) + pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil) + require.NoError(t, err) + + // Create machine and node keys for 4 nodes (2 per user) + type nodeInfo struct { + machineKey key.MachinePrivate + nodeKey key.NodePrivate + hostname string + nodeID types.NodeID + } + + nodes := []nodeInfo{ + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user1-node1"}, + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user1-node2"}, + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user2-node1"}, + {machineKey: key.NewMachine(), nodeKey: key.NewNode(), hostname: "user2-node2"}, + } + + // Register nodes: first 2 to user1, last 2 to user2 + for i, node := range nodes { + authKey := pak1.Key + if i >= 2 { + authKey = pak2.Key + } + + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: authKey, + }, + NodeKey: node.nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: node.hostname, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, node.machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized) + + // Get the node ID + var registeredNode types.NodeView + require.EventuallyWithT(t, func(c *assert.CollectT) { + var found bool + registeredNode, found = app.state.GetNodeByNodeKey(node.nodeKey.Public()) + assert.True(c, found, "Node should be found in NodeStore") + }, 1*time.Second, 100*time.Millisecond, "waiting for node to be available") + + nodes[i].nodeID = registeredNode.ID() + t.Logf("Registered node %s with ID %d to user%d", node.hostname, registeredNode.ID().Uint64(), i/2+1) + } + + // Verify initial state: user1 has 2 nodes, user2 has 2 nodes + user1Nodes := app.state.ListNodesByUser(types.UserID(user1.ID)) + user2Nodes := app.state.ListNodesByUser(types.UserID(user2.ID)) + require.Equal(t, 2, user1Nodes.Len(), "user1 should have 2 nodes initially") + require.Equal(t, 2, user2Nodes.Len(), "user2 should have 2 nodes initially") + + t.Logf("Initial state verified: user1=%d nodes, user2=%d nodes", user1Nodes.Len(), user2Nodes.Len()) + + // Simulate logout for all nodes + for _, node := range nodes { + logoutReq := tailcfg.RegisterRequest{ + Auth: nil, // nil Auth indicates logout + NodeKey: node.nodeKey.Public(), + } + + resp, err := app.handleRegister(context.Background(), logoutReq, node.machineKey.Public()) + require.NoError(t, err) + t.Logf("Logout response for %s: %+v", node.hostname, resp) + } + + t.Logf("All nodes logged out") + + // Create a new pre-auth key for user1 (reusable for all nodes) + newPak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) + require.NoError(t, err) + + // Re-login all nodes using user1's new pre-auth key + for i, node := range nodes { + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: newPak1.Key, + }, + NodeKey: node.nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: node.hostname, + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, node.machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized) + + t.Logf("Re-registered node %s (originally user%d) with user1's pre-auth key", node.hostname, i/2+1) + } + + // Verify final state after re-login + // EXPECTED: New nodes created for user1, old nodes remain with original users + user1NodesAfter := app.state.ListNodesByUser(types.UserID(user1.ID)) + user2NodesAfter := app.state.ListNodesByUser(types.UserID(user2.ID)) + + t.Logf("Final state: user1=%d nodes, user2=%d nodes", user1NodesAfter.Len(), user2NodesAfter.Len()) + + // CORRECT BEHAVIOR: When re-authenticating with a DIFFERENT user's pre-auth key, + // new nodes should be created (not transferred). This matches: + // 1. The integration test expectation + // 2. The web flow behavior (creates new nodes) + // 3. The principle that each user owns distinct node entries + require.Equal(t, 4, user1NodesAfter.Len(), "user1 should have 4 nodes total (2 original + 2 new from user2's machines)") + require.Equal(t, 2, user2NodesAfter.Len(), "user2 should still have 2 nodes (old nodes from original registration)") + + // Verify original nodes still exist with original users + for i := range 2 { + node := nodes[i] + // User1's original nodes should still be owned by user1 + registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "User1's original node %s should still exist", node.hostname) + require.Equal(t, user1.ID, registeredNode.UserID().Get(), "Node %s should still belong to user1", node.hostname) + t.Logf("✓ User1's original node %s (ID=%d) still owned by user1", node.hostname, registeredNode.ID().Uint64()) + } + + for i := 2; i < 4; i++ { + node := nodes[i] + // User2's original nodes should still be owned by user2 + registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user2.ID)) + require.True(t, found, "User2's original node %s should still exist", node.hostname) + require.Equal(t, user2.ID, registeredNode.UserID().Get(), "Node %s should still belong to user2", node.hostname) + t.Logf("✓ User2's original node %s (ID=%d) still owned by user2", node.hostname, registeredNode.ID().Uint64()) + } + + // Verify new nodes were created for user1 with the same machine keys + t.Logf("Verifying new nodes created for user1 from user2's machine keys...") + for i := 2; i < 4; i++ { + node := nodes[i] + // Should be able to find a node with user1 and this machine key (the new one) + newNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "Should have created new node for user1 with machine key from %s", node.hostname) + require.Equal(t, user1.ID, newNode.UserID().Get(), "New node should belong to user1") + t.Logf("✓ New node created for user1 with machine key from %s (ID=%d)", node.hostname, newNode.ID().Uint64()) + } +} + +// TestWebFlowReauthDifferentUser validates CLI registration behavior when switching users. +// This test replicates the TestAuthWebFlowLogoutAndReloginNewUser integration test scenario. +// +// IMPORTANT: CLI registration creates NEW nodes (different from interactive flow which transfers). +// +// Scenario: +// 1. Node registers with user1 via pre-auth key +// 2. Node logs out (expires) +// 3. Admin runs: headscale nodes register --user user2 --key +// +// Expected behavior: +// - User1's original node should STILL EXIST (expired) +// - User2 should get a NEW node created (NOT transfer) +// - Both nodes share the same machine key (same physical device) +func TestWebFlowReauthDifferentUser(t *testing.T) { + machineKey := key.NewMachine() + nodeKey1 := key.NewNode() + nodeKey2 := key.NewNode() // Node key rotates on re-auth + + app := createTestApp(t) + + // Step 1: Register node for user1 via pre-auth key (simulating initial web flow registration) + user1 := app.state.CreateUserForTest("user1") + pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil) + require.NoError(t, err) + + regReq1 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak1.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-machine", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey.Public()) + require.NoError(t, err) + require.True(t, resp1.MachineAuthorized, "Should be authorized via pre-auth key") + + // Verify node exists for user1 + user1Node, found := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "Node should exist for user1") + require.Equal(t, user1.ID, user1Node.UserID().Get(), "Node should belong to user1") + user1NodeID := user1Node.ID() + t.Logf("✓ User1 node created with ID: %d", user1NodeID) + + // Step 2: Simulate logout by expiring the node + pastTime := time.Now().Add(-1 * time.Hour) + logoutReq := tailcfg.RegisterRequest{ + NodeKey: nodeKey1.Public(), + Expiry: pastTime, // Expired = logout + } + _, err = app.handleRegister(context.Background(), logoutReq, machineKey.Public()) + require.NoError(t, err) + + // Verify node is expired + user1Node, found = app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + require.True(t, found, "Node should still exist after logout") + require.True(t, user1Node.IsExpired(), "Node should be expired after logout") + t.Logf("✓ User1 node expired (logged out)") + + // Step 3: Start interactive re-authentication (simulates "tailscale up") + user2 := app.state.CreateUserForTest("user2") + + reAuthReq := tailcfg.RegisterRequest{ + // No Auth field - triggers interactive flow + NodeKey: nodeKey2.Public(), // New node key (rotated on re-auth) + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-machine", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + // Initial request should return AuthURL + initialResp, err := app.handleRegister(context.Background(), reAuthReq, machineKey.Public()) + require.NoError(t, err) + require.NotEmpty(t, initialResp.AuthURL, "Should receive AuthURL for interactive flow") + t.Logf("✓ Interactive flow started, AuthURL: %s", initialResp.AuthURL) + + // Extract registration ID from AuthURL + regID, err := extractRegistrationIDFromAuthURL(initialResp.AuthURL) + require.NoError(t, err, "Should extract registration ID from AuthURL") + require.NotEmpty(t, regID, "Should have valid registration ID") + + // Step 4: Admin completes authentication via CLI + // This simulates: headscale nodes register --user user2 --key + node, _, err := app.state.HandleNodeFromAuthPath( + regID, + types.UserID(user2.ID), // Register to user2, not user1! + nil, // No custom expiry + "cli", // Registration method (CLI register command) + ) + require.NoError(t, err, "HandleNodeFromAuthPath should succeed") + t.Logf("✓ Admin registered node to user2 via CLI (node ID: %d)", node.ID()) + + t.Run("user1_original_node_still_exists", func(t *testing.T) { + // User1's original node should STILL exist (not transferred to user2) + user1NodeAfter, found1 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + assert.True(t, found1, "User1's original node should still exist (not transferred)") + + if !found1 { + t.Fatal("User1's node was transferred or deleted - this breaks the integration test!") + } + + assert.Equal(t, user1.ID, user1NodeAfter.UserID().Get(), "User1's node should still belong to user1") + assert.Equal(t, user1NodeID, user1NodeAfter.ID(), "Should be the same node (same ID)") + assert.True(t, user1NodeAfter.IsExpired(), "User1's node should still be expired") + t.Logf("✓ User1's original node still exists (ID: %d, expired: %v)", user1NodeAfter.ID(), user1NodeAfter.IsExpired()) + }) + + t.Run("user2_has_new_node_created", func(t *testing.T) { + // User2 should have a NEW node created (not transfer from user1) + user2Node, found2 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user2.ID)) + assert.True(t, found2, "User2 should have a new node created") + + if !found2 { + t.Fatal("User2 doesn't have a node - registration failed!") + } + + assert.Equal(t, user2.ID, user2Node.UserID().Get(), "User2's node should belong to user2") + assert.NotEqual(t, user1NodeID, user2Node.ID(), "Should be a NEW node (different ID), not transfer!") + assert.Equal(t, machineKey.Public(), user2Node.MachineKey(), "Should have same machine key") + assert.Equal(t, nodeKey2.Public(), user2Node.NodeKey(), "Should have new node key") + assert.False(t, user2Node.IsExpired(), "User2's node should NOT be expired (active)") + t.Logf("✓ User2's new node created (ID: %d, active)", user2Node.ID()) + }) + + t.Run("returned_node_is_user2_new_node", func(t *testing.T) { + // The node returned from HandleNodeFromAuthPath should be user2's NEW node + assert.Equal(t, user2.ID, node.UserID().Get(), "Returned node should belong to user2") + assert.NotEqual(t, user1NodeID, node.ID(), "Returned node should be NEW, not transferred from user1") + t.Logf("✓ HandleNodeFromAuthPath returned user2's new node (ID: %d)", node.ID()) + }) + + t.Run("both_nodes_share_machine_key", func(t *testing.T) { + // Both nodes should have the same machine key (same physical device) + user1NodeFinal, found1 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID)) + user2NodeFinal, found2 := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user2.ID)) + + require.True(t, found1, "User1 node should exist") + require.True(t, found2, "User2 node should exist") + + assert.Equal(t, machineKey.Public(), user1NodeFinal.MachineKey(), "User1 node should have correct machine key") + assert.Equal(t, machineKey.Public(), user2NodeFinal.MachineKey(), "User2 node should have same machine key") + t.Logf("✓ Both nodes share machine key: %s", machineKey.Public().ShortString()) + }) + + t.Run("total_node_count", func(t *testing.T) { + // We should have exactly 2 nodes total: one for user1 (expired), one for user2 (active) + allNodesSlice := app.state.ListNodes() + assert.Equal(t, 2, allNodesSlice.Len(), "Should have exactly 2 nodes total") + + // Count nodes per user + user1Nodes := 0 + user2Nodes := 0 + for i := 0; i < allNodesSlice.Len(); i++ { + n := allNodesSlice.At(i) + if n.UserID().Get() == user1.ID { + user1Nodes++ + } + + if n.UserID().Get() == user2.ID { + user2Nodes++ + } + } + + assert.Equal(t, 1, user1Nodes, "User1 should have 1 node") + assert.Equal(t, 1, user2Nodes, "User2 should have 1 node") + t.Logf("✓ Total: 2 nodes (user1: 1 expired, user2: 1 active)") + }) +} + +// Helper function to create test app +func createTestApp(t *testing.T) *Headscale { + t.Helper() + + tmpDir := t.TempDir() + + cfg := types.Config{ + ServerURL: "http://localhost:8080", + NoisePrivateKeyPath: tmpDir + "/noise_private.key", + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + OIDC: types.OIDCConfig{}, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + Tuning: types.Tuning{ + BatchChangeDelay: 100 * time.Millisecond, + BatcherWorkers: 1, + }, + } + + app, err := NewHeadscale(&cfg) + require.NoError(t, err) + + // Initialize and start the mapBatcher to handle Change() calls + app.mapBatcher = mapper.NewBatcherAndMapper(&cfg, app.state) + app.mapBatcher.Start() + + // Clean up the batcher when the test finishes + t.Cleanup(func() { + if app.mapBatcher != nil { + app.mapBatcher.Close() + } + }) + + return app +} + +// TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey tests the scenario reported in +// https://github.com/juanfont/headscale/issues/2830 +// +// Scenario: +// 1. Node registers successfully with a single-use pre-auth key +// 2. Node is running fine +// 3. Node restarts (e.g., after headscale upgrade or tailscale container restart) +// 4. Node sends RegisterRequest with the same pre-auth key +// 5. BUG: Headscale rejects the request with "authkey expired" or "authkey already used" +// +// Expected behavior: +// When an existing node (identified by matching NodeKey + MachineKey) re-registers +// with a pre-auth key that it previously used, the registration should succeed. +// The node is not creating a new registration - it's re-authenticating the same device. +func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create user and single-use pre-auth key + user := app.state.CreateUserForTest("test-user") + pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // reusable=false + require.NoError(t, err) + + // Fetch the full pre-auth key to check Reusable field + pak, err := app.state.GetPreAuthKey(pakNew.Key) + require.NoError(t, err) + require.False(t, pak.Reusable, "key should be single-use for this test") + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // STEP 1: Initial registration with pre-auth key (simulates fresh node joining) + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pakNew.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + t.Log("Step 1: Initial registration with pre-auth key") + initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public()) + require.NoError(t, err, "initial registration should succeed") + require.NotNil(t, initialResp) + assert.True(t, initialResp.MachineAuthorized, "node should be authorized") + assert.False(t, initialResp.NodeKeyExpired, "node key should not be expired") + + // Verify node was created in database + node, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found, "node should exist after initial registration") + assert.Equal(t, "test-node", node.Hostname()) + assert.Equal(t, nodeKey.Public(), node.NodeKey()) + assert.Equal(t, machineKey.Public(), node.MachineKey()) + + // Verify pre-auth key is now marked as used + usedPak, err := app.state.GetPreAuthKey(pakNew.Key) + require.NoError(t, err) + assert.True(t, usedPak.Used, "pre-auth key should be marked as used after initial registration") + + // STEP 2: Simulate node restart - node sends RegisterRequest again with same pre-auth key + // This happens when: + // - Tailscale container restarts + // - Tailscaled service restarts + // - System reboots + // The Tailscale client persists the pre-auth key in its state and sends it on every registration + t.Log("Step 2: Node restart - re-registration with same (now used) pre-auth key") + restartReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pakNew.Key, // Same key, now marked as Used=true + }, + NodeKey: nodeKey.Public(), // Same node key + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + // BUG: This fails with "authkey already used" or "authkey expired" + // EXPECTED: Should succeed because it's the same node re-registering + restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public()) + + // This is the assertion that currently FAILS in v0.27.0 + assert.NoError(t, err, "BUG: existing node re-registration with its own used pre-auth key should succeed") + if err != nil { + t.Logf("Error received (this is the bug): %v", err) + t.Logf("Expected behavior: Node should be able to re-register with the same pre-auth key it used initially") + return // Stop here to show the bug clearly + } + + require.NotNil(t, restartResp) + assert.True(t, restartResp.MachineAuthorized, "node should remain authorized after restart") + assert.False(t, restartResp.NodeKeyExpired, "node key should not be expired after restart") + + // Verify it's the same node (not a duplicate) + nodeAfterRestart, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found, "node should still exist after restart") + assert.Equal(t, node.ID(), nodeAfterRestart.ID(), "should be the same node, not a new one") + assert.Equal(t, "test-node", nodeAfterRestart.Hostname()) +} + +// TestNodeReregistrationWithReusablePreAuthKey tests that reusable keys work correctly +// for node re-registration. +func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + user := app.state.CreateUserForTest("test-user") + pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) // reusable=true + require.NoError(t, err) + + // Fetch the full pre-auth key to check Reusable field + pak, err := app.state.GetPreAuthKey(pakNew.Key) + require.NoError(t, err) + require.True(t, pak.Reusable) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Initial registration + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pakNew.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reusable-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + initialResp, err := app.handleRegister(context.Background(), initialReq, machineKey.Public()) + require.NoError(t, err) + require.NotNil(t, initialResp) + assert.True(t, initialResp.MachineAuthorized) + + // Node restart - re-registration with reusable key + restartReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pakNew.Key, // Reusable key + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reusable-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + restartResp, err := app.handleRegister(context.Background(), restartReq, machineKey.Public()) + require.NoError(t, err, "reusable key should allow re-registration") + require.NotNil(t, restartResp) + assert.True(t, restartResp.MachineAuthorized) + assert.False(t, restartResp.NodeKeyExpired) +} + +// TestNodeReregistrationWithExpiredPreAuthKey tests that truly expired keys +// are still rejected even for existing nodes. +func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + user := app.state.CreateUserForTest("test-user") + expiry := time.Now().Add(-1 * time.Hour) // Already expired + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Try to register with expired key + req := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "expired-key-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegister(context.Background(), req, machineKey.Public()) + assert.Error(t, err, "expired pre-auth key should be rejected") + assert.Contains(t, err.Error(), "authkey expired", "error should mention key expiration") +} + +// TestIssue2830_ExistingNodeReregistersWithExpiredKey tests the fix for issue #2830. +// When a node is already registered and the pre-auth key expires, the node should +// still be able to re-register (e.g., after a container restart) using the same +// expired key. The key was only needed for initial authentication. +func TestIssue2830_ExistingNodeReregistersWithExpiredKey(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + user := app.state.CreateUserForTest("test-user") + + // Create a valid key (will expire it later) + expiry := time.Now().Add(1 * time.Hour) + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiry, nil) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register the node initially (key is still valid) + req := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "issue2830-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegister(context.Background(), req, machineKey.Public()) + require.NoError(t, err, "initial registration should succeed") + require.NotNil(t, resp) + require.True(t, resp.MachineAuthorized, "node should be authorized after initial registration") + + // Verify node was created + allNodes := app.state.ListNodes() + require.Equal(t, 1, allNodes.Len()) + initialNodeID := allNodes.At(0).ID() + + // Now expire the key by updating it in the database to have an expiry in the past. + // This simulates the real-world scenario where a key expires after initial registration. + pastExpiry := time.Now().Add(-1 * time.Hour) + err = app.state.DB().DB.Model(&types.PreAuthKey{}). + Where("id = ?", pak.ID). + Update("expiration", pastExpiry).Error + require.NoError(t, err, "should be able to update key expiration") + + // Reload the key to verify it's now expired + expiredPak, err := app.state.GetPreAuthKey(pak.Key) + require.NoError(t, err) + require.NotNil(t, expiredPak.Expiration) + require.True(t, expiredPak.Expiration.Before(time.Now()), "key should be expired") + + // Verify the expired key would fail validation + err = expiredPak.Validate() + require.Error(t, err, "key should fail validation when expired") + require.Contains(t, err.Error(), "authkey expired") + + // Attempt to re-register with the SAME key (now expired). + // This should SUCCEED because: + // - The node already exists with the same MachineKey and User + // - The fix allows existing nodes to re-register even with expired keys + // - The key was only needed for initial authentication + req2 := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, // Same key as initial registration (now expired) + }, + NodeKey: nodeKey.Public(), // Same NodeKey as initial registration + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "issue2830-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp2, err := app.handleRegister(context.Background(), req2, machineKey.Public()) + require.NoError(t, err, "re-registration should succeed even with expired key for existing node") + assert.NotNil(t, resp2) + assert.True(t, resp2.MachineAuthorized, "node should remain authorized after re-registration") + + // Verify we still have only one node (re-registered, not created new) + allNodes = app.state.ListNodes() + require.Equal(t, 1, allNodes.Len(), "should have exactly one node (re-registered)") + assert.Equal(t, initialNodeID, allNodes.At(0).ID(), "node ID should not change on re-registration") +} + +// TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey tests that an existing node +// can re-register using a pre-auth key that's already marked as Used=true, as long as: +// 1. The node is re-registering with the same MachineKey it originally used +// 2. The node is using the same pre-auth key it was originally registered with (AuthKeyID matches) +// +// This is the fix for GitHub issue #2830: https://github.com/juanfont/headscale/issues/2830 +// +// Background: When Docker/Kubernetes containers restart, they keep their persistent state +// (including the MachineKey), but container entrypoints unconditionally run: +// +// tailscale up --authkey=$TS_AUTHKEY +// +// This caused nodes to be rejected after restart because the pre-auth key was already +// marked as Used=true from the initial registration. The fix allows re-registration of +// existing nodes with their own used keys. +func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.T) { + app := createTestApp(t) + + // Create a user + user := app.state.CreateUserForTest("testuser") + + // Create a SINGLE-USE pre-auth key (reusable=false) + // This is the type of key that triggers the bug in issue #2830 + preAuthKeyNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + // Fetch the full pre-auth key to check Reusable and Used fields + preAuthKey, err := app.state.GetPreAuthKey(preAuthKeyNew.Key) + require.NoError(t, err) + require.False(t, preAuthKey.Reusable, "Pre-auth key must be single-use to test issue #2830") + require.False(t, preAuthKey.Used, "Pre-auth key should not be used yet") + + // Generate node keys for the client + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Step 1: Initial registration with the pre-auth key + // This simulates the first time the container starts and runs 'tailscale up --authkey=...' + initialReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: preAuthKeyNew.Key, // Use the full key from creation + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "issue-2830-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + initialResp, err := app.handleRegisterWithAuthKey(initialReq, machineKey.Public()) + require.NoError(t, err, "Initial registration should succeed") + require.True(t, initialResp.MachineAuthorized, "Node should be authorized after initial registration") + require.NotNil(t, initialResp.User, "User should be set in response") + require.Equal(t, "testuser", initialResp.User.DisplayName, "User should match the pre-auth key's user") + + // Verify the pre-auth key is now marked as Used + updatedKey, err := app.state.GetPreAuthKey(preAuthKeyNew.Key) + require.NoError(t, err) + require.True(t, updatedKey.Used, "Pre-auth key should be marked as Used after initial registration") + + // Step 2: Container restart scenario + // The container keeps its MachineKey (persistent state), but the entrypoint script + // unconditionally runs 'tailscale up --authkey=$TS_AUTHKEY' again + // + // WITHOUT THE FIX: This would fail with "authkey already used" error + // WITH THE FIX: This succeeds because it's the same node re-registering with its own key + + // Simulate sending the same RegisterRequest again (same MachineKey, same AuthKey) + // This is exactly what happens when a container restarts + reregisterReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: preAuthKeyNew.Key, // Same key, now marked as Used=true + }, + NodeKey: nodeKey.Public(), // Same NodeKey + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "issue-2830-test-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + reregisterResp, err := app.handleRegisterWithAuthKey(reregisterReq, machineKey.Public()) // Same MachineKey + require.NoError(t, err, "Re-registration with same MachineKey and used pre-auth key should succeed (fixes #2830)") + require.True(t, reregisterResp.MachineAuthorized, "Node should remain authorized after re-registration") + require.NotNil(t, reregisterResp.User, "User should be set in re-registration response") + require.Equal(t, "testuser", reregisterResp.User.DisplayName, "User should remain the same") + + // Verify that only ONE node was created (not a duplicate) + nodes := app.state.ListNodesByUser(types.UserID(user.ID)) + require.Equal(t, 1, nodes.Len(), "Should have exactly one node (no duplicates created)") + require.Equal(t, "issue-2830-test-node", nodes.At(0).Hostname(), "Node hostname should match") + + // Step 3: Verify that a DIFFERENT machine cannot use the same used key + // This ensures we didn't break the security model - only the original node can re-register + differentMachineKey := key.NewMachine() + differentNodeKey := key.NewNode() + + attackReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: preAuthKeyNew.Key, // Try to use the same key + }, + NodeKey: differentNodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "attacker-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + _, err = app.handleRegisterWithAuthKey(attackReq, differentMachineKey.Public()) + require.Error(t, err, "Different machine should NOT be able to use the same used pre-auth key") + require.Contains(t, err.Error(), "already used", "Error should indicate key is already used") + + // Verify still only one node (the original one) + nodesAfterAttack := app.state.ListNodesByUser(types.UserID(user.ID)) + require.Equal(t, 1, nodesAfterAttack.Len(), "Should still have exactly one node (attack prevented)") +} + +// TestWebAuthRejectsUnauthorizedRequestTags tests that web auth registrations +// validate RequestTags against policy and reject unauthorized tags. +func TestWebAuthRejectsUnauthorizedRequestTags(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create a user that will authenticate via web auth + user := app.state.CreateUserForTest("webauth-tags-user") + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Simulate a registration cache entry (as would be created during web auth) + registrationID := types.MustRegistrationID() + regEntry := types.NewRegisterNode(types.Node{ + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + Hostname: "webauth-tags-node", + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "webauth-tags-node", + RequestTags: []string{"tag:unauthorized"}, // This tag is not in policy + }, + }) + app.state.SetRegistrationCacheEntry(registrationID, regEntry) + + // Complete the web auth - should fail because tag is unauthorized + _, _, err := app.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), + nil, // no expiry + "webauth", + ) + + // Expect error due to unauthorized tags + require.Error(t, err, "HandleNodeFromAuthPath should reject unauthorized RequestTags") + require.Contains(t, err.Error(), "requested tags", + "Error should indicate requested tags are invalid or not permitted") + require.Contains(t, err.Error(), "tag:unauthorized", + "Error should mention the rejected tag") + + // Verify no node was created + _, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.False(t, found, "Node should not be created when tags are unauthorized") +} + +// TestWebAuthReauthWithEmptyTagsRemovesAllTags tests that when an existing tagged node +// reauths with empty RequestTags, all tags are removed and ownership returns to user. +// This is the fix for issue #2979. +func TestWebAuthReauthWithEmptyTagsRemovesAllTags(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create a user + user := app.state.CreateUserForTest("reauth-untag-user") + + // Update policy manager to recognize the new user + // This is necessary because CreateUserForTest doesn't update the policy manager + err := app.state.UpdatePolicyManagerUsersForTest() + require.NoError(t, err, "Failed to update policy manager users") + + // Set up policy that allows the user to own these tags + policy := `{ + "tagOwners": { + "tag:valid-owned": ["reauth-untag-user@"], + "tag:second": ["reauth-untag-user@"] + }, + "acls": [{"action": "accept", "src": ["*"], "dst": ["*:*"]}] + }` + _, err = app.state.SetPolicy([]byte(policy)) + require.NoError(t, err, "Failed to set policy") + + machineKey := key.NewMachine() + nodeKey1 := key.NewNode() + + // Step 1: Initial registration with tags + registrationID1 := types.MustRegistrationID() + regEntry1 := types.NewRegisterNode(types.Node{ + MachineKey: machineKey.Public(), + NodeKey: nodeKey1.Public(), + Hostname: "reauth-untag-node", + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-untag-node", + RequestTags: []string{"tag:valid-owned", "tag:second"}, + }, + }) + app.state.SetRegistrationCacheEntry(registrationID1, regEntry1) + + // Complete initial registration with tags + node, _, err := app.state.HandleNodeFromAuthPath( + registrationID1, + types.UserID(user.ID), + nil, + "webauth", + ) + require.NoError(t, err, "Initial registration should succeed") + require.True(t, node.IsTagged(), "Node should be tagged after initial registration") + require.ElementsMatch(t, []string{"tag:valid-owned", "tag:second"}, node.Tags().AsSlice()) + t.Logf("Initial registration complete - Node ID: %d, Tags: %v, IsTagged: %t", + node.ID().Uint64(), node.Tags().AsSlice(), node.IsTagged()) + + // Step 2: Reauth with EMPTY tags to untag + nodeKey2 := key.NewNode() // New node key for reauth + registrationID2 := types.MustRegistrationID() + regEntry2 := types.NewRegisterNode(types.Node{ + MachineKey: machineKey.Public(), // Same machine key + NodeKey: nodeKey2.Public(), // Different node key (rotation) + Hostname: "reauth-untag-node", + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "reauth-untag-node", + RequestTags: []string{}, // EMPTY - should untag + }, + }) + app.state.SetRegistrationCacheEntry(registrationID2, regEntry2) + + // Complete reauth with empty tags + nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( + registrationID2, + types.UserID(user.ID), + nil, + "webauth", + ) + require.NoError(t, err, "Reauth should succeed") + + // Verify tags were removed + require.False(t, nodeAfterReauth.IsTagged(), "Node should NOT be tagged after reauth with empty tags") + require.Empty(t, nodeAfterReauth.Tags().AsSlice(), "Node should have no tags") + + // Verify ownership returned to user + require.True(t, nodeAfterReauth.UserID().Valid(), "Node should have a user ID") + require.Equal(t, user.ID, nodeAfterReauth.UserID().Get(), "Node should be owned by the user again") + + // Verify it's the same node (not a new one) + require.Equal(t, node.ID(), nodeAfterReauth.ID(), "Should be the same node after reauth") + + t.Logf("Reauth complete - Node ID: %d, Tags: %v, IsTagged: %t, UserID: %d", + nodeAfterReauth.ID().Uint64(), nodeAfterReauth.Tags().AsSlice(), + nodeAfterReauth.IsTagged(), nodeAfterReauth.UserID().Get()) +} + +// TestAuthKeyTaggedToUserOwnedViaReauth tests that a node originally registered +// with a tagged pre-auth key can transition to user-owned by re-authenticating +// via web auth with empty RequestTags. This ensures authkey-tagged nodes are +// not permanently locked to being tagged. +func TestAuthKeyTaggedToUserOwnedViaReauth(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create a user + user := app.state.CreateUserForTest("authkey-to-user") + + // Create a tagged pre-auth key + authKeyTags := []string{"tag:server", "tag:prod"} + pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, authKeyTags) + require.NoError(t, err, "Failed to create tagged pre-auth key") + + machineKey := key.NewMachine() + nodeKey1 := key.NewNode() + + // Step 1: Initial registration with tagged pre-auth key + regReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "authkey-tagged-node", + }, + Expiry: time.Now().Add(24 * time.Hour), + } + + resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public()) + require.NoError(t, err, "Initial registration should succeed") + require.True(t, resp.MachineAuthorized, "Node should be authorized") + + // Verify initial state: node is tagged via authkey + node, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found, "Node should be found") + require.True(t, node.IsTagged(), "Node should be tagged after authkey registration") + require.ElementsMatch(t, authKeyTags, node.Tags().AsSlice(), "Node should have authkey tags") + require.NotNil(t, node.AuthKey(), "Node should have AuthKey reference") + require.Positive(t, node.AuthKey().Tags().Len(), "AuthKey should have tags") + + t.Logf("Initial registration complete - Node ID: %d, Tags: %v, IsTagged: %t, AuthKey.Tags.Len: %d", + node.ID().Uint64(), node.Tags().AsSlice(), node.IsTagged(), node.AuthKey().Tags().Len()) + + // Step 2: Reauth via web auth with EMPTY tags to transition to user-owned + nodeKey2 := key.NewNode() // New node key for reauth + registrationID := types.MustRegistrationID() + regEntry := types.NewRegisterNode(types.Node{ + MachineKey: machineKey.Public(), // Same machine key + NodeKey: nodeKey2.Public(), // Different node key (rotation) + Hostname: "authkey-tagged-node", + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "authkey-tagged-node", + RequestTags: []string{}, // EMPTY - should untag + }, + }) + app.state.SetRegistrationCacheEntry(registrationID, regEntry) + + // Complete reauth with empty tags + nodeAfterReauth, _, err := app.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), + nil, + "webauth", + ) + require.NoError(t, err, "Reauth should succeed") + + // Verify tags were removed (authkey-tagged → user-owned transition) + require.False(t, nodeAfterReauth.IsTagged(), "Node should NOT be tagged after reauth with empty tags") + require.Empty(t, nodeAfterReauth.Tags().AsSlice(), "Node should have no tags") + + // Verify ownership returned to user + require.True(t, nodeAfterReauth.UserID().Valid(), "Node should have a user ID") + require.Equal(t, user.ID, nodeAfterReauth.UserID().Get(), "Node should be owned by the user") + + // Verify it's the same node (not a new one) + require.Equal(t, node.ID(), nodeAfterReauth.ID(), "Should be the same node after reauth") + + // AuthKey reference should still exist (for audit purposes) + require.NotNil(t, nodeAfterReauth.AuthKey(), "AuthKey reference should be preserved") + + t.Logf("Reauth complete - Node ID: %d, Tags: %v, IsTagged: %t, UserID: %d", + nodeAfterReauth.ID().Uint64(), nodeAfterReauth.Tags().AsSlice(), + nodeAfterReauth.IsTagged(), nodeAfterReauth.UserID().Get()) +} diff --git a/hscontrol/capver/capver.go b/hscontrol/capver/capver.go new file mode 100644 index 00000000..61d67444 --- /dev/null +++ b/hscontrol/capver/capver.go @@ -0,0 +1,131 @@ +package capver + +//go:generate go run ../../tools/capver/main.go + +import ( + "slices" + "sort" + "strings" + + xmaps "golang.org/x/exp/maps" + "tailscale.com/tailcfg" + "tailscale.com/util/set" +) + +const ( + // minVersionParts is the minimum number of version parts needed for major.minor. + minVersionParts = 2 + + // legacyDERPCapVer is the capability version when LegacyDERP can be cleaned up. + legacyDERPCapVer = 111 +) + +// CanOldCodeBeCleanedUp is intended to be called on startup to see if +// there are old code that can ble cleaned up, entries should contain +// a CapVer where something can be cleaned up and a panic if it can. +// This is only intended to catch things in tests. +// +// All uses of Capability version checks should be listed here. +func CanOldCodeBeCleanedUp() { + if MinSupportedCapabilityVersion >= legacyDERPCapVer { + panic("LegacyDERP can be cleaned up in tail.go") + } +} + +func tailscaleVersSorted() []string { + vers := xmaps.Keys(tailscaleToCapVer) + sort.Strings(vers) + + return vers +} + +func capVersSorted() []tailcfg.CapabilityVersion { + capVers := xmaps.Keys(capVerToTailscaleVer) + slices.Sort(capVers) + + return capVers +} + +// TailscaleVersion returns the Tailscale version for the given CapabilityVersion. +func TailscaleVersion(ver tailcfg.CapabilityVersion) string { + return capVerToTailscaleVer[ver] +} + +// CapabilityVersion returns the CapabilityVersion for the given Tailscale version. +// It accepts both full versions (v1.90.1) and minor versions (v1.90). +func CapabilityVersion(ver string) tailcfg.CapabilityVersion { + if !strings.HasPrefix(ver, "v") { + ver = "v" + ver + } + + // Try direct lookup first (works for minor versions like v1.90) + if cv, ok := tailscaleToCapVer[ver]; ok { + return cv + } + + // Try extracting minor version from full version (v1.90.1 -> v1.90) + parts := strings.Split(strings.TrimPrefix(ver, "v"), ".") + if len(parts) >= minVersionParts { + minor := "v" + parts[0] + "." + parts[1] + return tailscaleToCapVer[minor] + } + + return 0 +} + +// TailscaleLatest returns the n latest Tailscale versions. +func TailscaleLatest(n int) []string { + if n <= 0 { + return nil + } + + tsSorted := tailscaleVersSorted() + + if n > len(tsSorted) { + return tsSorted + } + + return tsSorted[len(tsSorted)-n:] +} + +// TailscaleLatestMajorMinor returns the n latest Tailscale versions (e.g. 1.80). +func TailscaleLatestMajorMinor(n int, stripV bool) []string { + if n <= 0 { + return nil + } + + majors := set.Set[string]{} + + for _, vers := range tailscaleVersSorted() { + if stripV { + vers = strings.TrimPrefix(vers, "v") + } + + v := strings.Split(vers, ".") + majors.Add(v[0] + "." + v[1]) + } + + majorSl := majors.Slice() + sort.Strings(majorSl) + + if n > len(majorSl) { + return majorSl + } + + return majorSl[len(majorSl)-n:] +} + +// CapVerLatest returns the n latest CapabilityVersions. +func CapVerLatest(n int) []tailcfg.CapabilityVersion { + if n <= 0 { + return nil + } + + s := capVersSorted() + + if n > len(s) { + return s + } + + return s[len(s)-n:] +} diff --git a/hscontrol/capver/capver_generated.go b/hscontrol/capver/capver_generated.go new file mode 100644 index 00000000..11ad89cc --- /dev/null +++ b/hscontrol/capver/capver_generated.go @@ -0,0 +1,84 @@ +package capver + +// Generated DO NOT EDIT + +import "tailscale.com/tailcfg" + +var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{ + "v1.24": 32, + "v1.26": 32, + "v1.28": 32, + "v1.30": 41, + "v1.32": 46, + "v1.34": 51, + "v1.36": 56, + "v1.38": 58, + "v1.40": 61, + "v1.42": 62, + "v1.44": 63, + "v1.46": 65, + "v1.48": 68, + "v1.50": 74, + "v1.52": 79, + "v1.54": 79, + "v1.56": 82, + "v1.58": 85, + "v1.60": 87, + "v1.62": 88, + "v1.64": 90, + "v1.66": 95, + "v1.68": 97, + "v1.70": 102, + "v1.72": 104, + "v1.74": 106, + "v1.76": 106, + "v1.78": 109, + "v1.80": 113, + "v1.82": 115, + "v1.84": 116, + "v1.86": 123, + "v1.88": 125, + "v1.90": 130, + "v1.92": 131, +} + +var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{ + 32: "v1.24", + 41: "v1.30", + 46: "v1.32", + 51: "v1.34", + 56: "v1.36", + 58: "v1.38", + 61: "v1.40", + 62: "v1.42", + 63: "v1.44", + 65: "v1.46", + 68: "v1.48", + 74: "v1.50", + 79: "v1.52", + 82: "v1.56", + 85: "v1.58", + 87: "v1.60", + 88: "v1.62", + 90: "v1.64", + 95: "v1.66", + 97: "v1.68", + 102: "v1.70", + 104: "v1.72", + 106: "v1.74", + 109: "v1.78", + 113: "v1.80", + 115: "v1.82", + 116: "v1.84", + 123: "v1.86", + 125: "v1.88", + 130: "v1.90", + 131: "v1.92", +} + +// SupportedMajorMinorVersions is the number of major.minor Tailscale versions supported. +const SupportedMajorMinorVersions = 10 + +// MinSupportedCapabilityVersion represents the minimum capability version +// supported by this Headscale instance (latest 10 minor versions) +const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = 106 diff --git a/hscontrol/capver/capver_test.go b/hscontrol/capver/capver_test.go new file mode 100644 index 00000000..5c5d5b44 --- /dev/null +++ b/hscontrol/capver/capver_test.go @@ -0,0 +1,29 @@ +package capver + +import ( + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestTailscaleLatestMajorMinor(t *testing.T) { + for _, test := range tailscaleLatestMajorMinorTests { + t.Run("", func(t *testing.T) { + output := TailscaleLatestMajorMinor(test.n, test.stripV) + if diff := cmp.Diff(output, test.expected); diff != "" { + t.Errorf("TailscaleLatestMajorMinor(%d, %v) mismatch (-want +got):\n%s", test.n, test.stripV, diff) + } + }) + } +} + +func TestCapVerMinimumTailscaleVersion(t *testing.T) { + for _, test := range capVerMinimumTailscaleVersionTests { + t.Run("", func(t *testing.T) { + output := TailscaleVersion(test.input) + if output != test.expected { + t.Errorf("CapVerFromTailscaleVersion(%d) = %s; want %s", test.input, output, test.expected) + } + }) + } +} diff --git a/hscontrol/capver/capver_test_data.go b/hscontrol/capver/capver_test_data.go new file mode 100644 index 00000000..91928d29 --- /dev/null +++ b/hscontrol/capver/capver_test_data.go @@ -0,0 +1,40 @@ +package capver + +// Generated DO NOT EDIT + +import "tailscale.com/tailcfg" + +var tailscaleLatestMajorMinorTests = []struct { + n int + stripV bool + expected []string +}{ + {3, false, []string{"v1.88", "v1.90", "v1.92"}}, + {2, true, []string{"1.90", "1.92"}}, + {10, true, []string{ + "1.74", + "1.76", + "1.78", + "1.80", + "1.82", + "1.84", + "1.86", + "1.88", + "1.90", + "1.92", + }}, + {0, false, nil}, +} + +var capVerMinimumTailscaleVersionTests = []struct { + input tailcfg.CapabilityVersion + expected string +}{ + {106, "v1.74"}, + {32, "v1.24"}, + {41, "v1.30"}, + {46, "v1.32"}, + {51, "v1.34"}, + {9001, ""}, // Test case for a version higher than any in the map + {60, ""}, // Test case for a version lower than any in the map +} diff --git a/hscontrol/db/addresses.go b/hscontrol/db/addresses.go deleted file mode 100644 index beccf843..00000000 --- a/hscontrol/db/addresses.go +++ /dev/null @@ -1,99 +0,0 @@ -// Codehere is mostly taken from github.com/tailscale/tailscale -// Copyright (c) 2020 Tailscale Inc & AUTHORS All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package db - -import ( - "errors" - "fmt" - "net/netip" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "go4.org/netipx" -) - -var ErrCouldNotAllocateIP = errors.New("could not find any suitable IP") - -func (hsdb *HSDatabase) getAvailableIPs() (types.NodeAddresses, error) { - var ips types.NodeAddresses - var err error - for _, ipPrefix := range hsdb.ipPrefixes { - var ip *netip.Addr - ip, err = hsdb.getAvailableIP(ipPrefix) - if err != nil { - return ips, err - } - ips = append(ips, *ip) - } - - return ips, err -} - -func (hsdb *HSDatabase) getAvailableIP(ipPrefix netip.Prefix) (*netip.Addr, error) { - usedIps, err := hsdb.getUsedIPs() - if err != nil { - return nil, err - } - - ipPrefixNetworkAddress, ipPrefixBroadcastAddress := util.GetIPPrefixEndpoints(ipPrefix) - - // Get the first IP in our prefix - ip := ipPrefixNetworkAddress.Next() - - for { - if !ipPrefix.Contains(ip) { - return nil, ErrCouldNotAllocateIP - } - - switch { - case ip.Compare(ipPrefixBroadcastAddress) == 0: - fallthrough - case usedIps.Contains(ip): - fallthrough - case ip == netip.Addr{} || ip.IsLoopback(): - ip = ip.Next() - - continue - - default: - return &ip, nil - } - } -} - -func (hsdb *HSDatabase) getUsedIPs() (*netipx.IPSet, error) { - // FIXME: This really deserves a better data model, - // but this was quick to get running and it should be enough - // to begin experimenting with a dual stack tailnet. - var addressesSlices []string - hsdb.db.Model(&types.Node{}).Pluck("ip_addresses", &addressesSlices) - - var ips netipx.IPSetBuilder - for _, slice := range addressesSlices { - var machineAddresses types.NodeAddresses - err := machineAddresses.Scan(slice) - if err != nil { - return &netipx.IPSet{}, fmt.Errorf( - "failed to read ip from database: %w", - err, - ) - } - - for _, ip := range machineAddresses { - ips.Add(ip) - } - } - - ipSet, err := ips.IPSet() - if err != nil { - return &netipx.IPSet{}, fmt.Errorf( - "failed to build IP Set: %w", - err, - ) - } - - return ipSet, nil -} diff --git a/hscontrol/db/addresses_test.go b/hscontrol/db/addresses_test.go deleted file mode 100644 index 07059eab..00000000 --- a/hscontrol/db/addresses_test.go +++ /dev/null @@ -1,182 +0,0 @@ -package db - -import ( - "net/netip" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "go4.org/netipx" - "gopkg.in/check.v1" -) - -func (s *Suite) TestGetAvailableIp(c *check.C) { - ips, err := db.getAvailableIPs() - - c.Assert(err, check.IsNil) - - expected := netip.MustParseAddr("10.27.0.1") - - c.Assert(len(ips), check.Equals, 1) - c.Assert(ips[0].String(), check.Equals, expected.String()) -} - -func (s *Suite) TestGetUsedIps(c *check.C) { - ips, err := db.getAvailableIPs() - c.Assert(err, check.IsNil) - - user, err := db.CreateUser("test-ip") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.NotNil) - - node := types.Node{ - ID: 0, - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - IPAddresses: ips, - } - db.db.Save(&node) - - usedIps, err := db.getUsedIPs() - - c.Assert(err, check.IsNil) - - expected := netip.MustParseAddr("10.27.0.1") - expectedIPSetBuilder := netipx.IPSetBuilder{} - expectedIPSetBuilder.Add(expected) - expectedIPSet, _ := expectedIPSetBuilder.IPSet() - - c.Assert(usedIps.Equal(expectedIPSet), check.Equals, true) - c.Assert(usedIps.Contains(expected), check.Equals, true) - - node1, err := db.GetNodeByID(0) - c.Assert(err, check.IsNil) - - c.Assert(len(node1.IPAddresses), check.Equals, 1) - c.Assert(node1.IPAddresses[0], check.Equals, expected) -} - -func (s *Suite) TestGetMultiIp(c *check.C) { - user, err := db.CreateUser("test-ip-multi") - c.Assert(err, check.IsNil) - - for index := 1; index <= 350; index++ { - db.ipAllocationMutex.Lock() - - ips, err := db.getAvailableIPs() - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.NotNil) - - node := types.Node{ - ID: uint64(index), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - IPAddresses: ips, - } - db.db.Save(&node) - - db.ipAllocationMutex.Unlock() - } - - usedIps, err := db.getUsedIPs() - c.Assert(err, check.IsNil) - - expected0 := netip.MustParseAddr("10.27.0.1") - expected9 := netip.MustParseAddr("10.27.0.10") - expected300 := netip.MustParseAddr("10.27.0.45") - - notExpectedIPSetBuilder := netipx.IPSetBuilder{} - notExpectedIPSetBuilder.Add(expected0) - notExpectedIPSetBuilder.Add(expected9) - notExpectedIPSetBuilder.Add(expected300) - notExpectedIPSet, err := notExpectedIPSetBuilder.IPSet() - c.Assert(err, check.IsNil) - - // We actually expect it to be a lot larger - c.Assert(usedIps.Equal(notExpectedIPSet), check.Equals, false) - - c.Assert(usedIps.Contains(expected0), check.Equals, true) - c.Assert(usedIps.Contains(expected9), check.Equals, true) - c.Assert(usedIps.Contains(expected300), check.Equals, true) - - // Check that we can read back the IPs - node1, err := db.GetNodeByID(1) - c.Assert(err, check.IsNil) - c.Assert(len(node1.IPAddresses), check.Equals, 1) - c.Assert( - node1.IPAddresses[0], - check.Equals, - netip.MustParseAddr("10.27.0.1"), - ) - - node50, err := db.GetNodeByID(50) - c.Assert(err, check.IsNil) - c.Assert(len(node50.IPAddresses), check.Equals, 1) - c.Assert( - node50.IPAddresses[0], - check.Equals, - netip.MustParseAddr("10.27.0.50"), - ) - - expectedNextIP := netip.MustParseAddr("10.27.1.95") - nextIP, err := db.getAvailableIPs() - c.Assert(err, check.IsNil) - - c.Assert(len(nextIP), check.Equals, 1) - c.Assert(nextIP[0].String(), check.Equals, expectedNextIP.String()) - - // If we call get Available again, we should receive - // the same IP, as it has not been reserved. - nextIP2, err := db.getAvailableIPs() - c.Assert(err, check.IsNil) - - c.Assert(len(nextIP2), check.Equals, 1) - c.Assert(nextIP2[0].String(), check.Equals, expectedNextIP.String()) -} - -func (s *Suite) TestGetAvailableIpNodeWithoutIP(c *check.C) { - ips, err := db.getAvailableIPs() - c.Assert(err, check.IsNil) - - expected := netip.MustParseAddr("10.27.0.1") - - c.Assert(len(ips), check.Equals, 1) - c.Assert(ips[0].String(), check.Equals, expected.String()) - - user, err := db.CreateUser("test-ip") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.NotNil) - - node := types.Node{ - ID: 0, - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - ips2, err := db.getAvailableIPs() - c.Assert(err, check.IsNil) - - c.Assert(len(ips2), check.Equals, 1) - c.Assert(ips2[0].String(), check.Equals, expected.String()) -} diff --git a/hscontrol/db/api_key.go b/hscontrol/db/api_key.go index bc8dc2bb..7457670c 100644 --- a/hscontrol/db/api_key.go +++ b/hscontrol/db/api_key.go @@ -9,36 +9,64 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "golang.org/x/crypto/bcrypt" + "gorm.io/gorm" ) const ( - apiPrefixLength = 7 - apiKeyLength = 32 + apiKeyPrefix = "hskey-api-" //nolint:gosec // This is a prefix, not a credential + apiKeyPrefixLength = 12 + apiKeyHashLength = 64 + + // Legacy format constants. + legacyAPIPrefixLength = 7 + legacyAPIKeyLength = 32 ) -var ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") +var ( + ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey") + ErrAPIKeyGenerationFailed = errors.New("failed to generate API key") + ErrAPIKeyInvalidGeneration = errors.New("generated API key failed validation") +) // CreateAPIKey creates a new ApiKey in a user, and returns it. func (hsdb *HSDatabase) CreateAPIKey( expiration *time.Time, ) (string, *types.APIKey, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - prefix, err := util.GenerateRandomStringURLSafe(apiPrefixLength) + // Generate public prefix (12 chars) + prefix, err := util.GenerateRandomStringURLSafe(apiKeyPrefixLength) if err != nil { return "", nil, err } - toBeHashed, err := util.GenerateRandomStringURLSafe(apiKeyLength) + // Validate prefix + if len(prefix) != apiKeyPrefixLength { + return "", nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyPrefixLength, len(prefix)) + } + + if !isValidBase64URLSafe(prefix) { + return "", nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrAPIKeyInvalidGeneration) + } + + // Generate secret (64 chars) + secret, err := util.GenerateRandomStringURLSafe(apiKeyHashLength) if err != nil { return "", nil, err } - // Key to return to user, this will only be visible _once_ - keyStr := prefix + "." + toBeHashed + // Validate secret + if len(secret) != apiKeyHashLength { + return "", nil, fmt.Errorf("%w: generated secret has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyHashLength, len(secret)) + } - hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) + if !isValidBase64URLSafe(secret) { + return "", nil, fmt.Errorf("%w: generated secret contains invalid characters", ErrAPIKeyInvalidGeneration) + } + + // Full key string (shown ONCE to user) + keyStr := apiKeyPrefix + prefix + "-" + secret + + // bcrypt hash of secret + hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost) if err != nil { return "", nil, err } @@ -49,7 +77,7 @@ func (hsdb *HSDatabase) CreateAPIKey( Expiration: expiration, } - if err := hsdb.db.Save(&key).Error; err != nil { + if err := hsdb.DB.Save(&key).Error; err != nil { return "", nil, fmt.Errorf("failed to save API key to database: %w", err) } @@ -58,11 +86,8 @@ func (hsdb *HSDatabase) CreateAPIKey( // ListAPIKeys returns the list of ApiKeys for a user. func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - keys := []types.APIKey{} - if err := hsdb.db.Find(&keys).Error; err != nil { + if err := hsdb.DB.Find(&keys).Error; err != nil { return nil, err } @@ -71,11 +96,8 @@ func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) { // GetAPIKey returns a ApiKey for a given key. func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.First(&key, "prefix = ?", prefix); result.Error != nil { + if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil { return nil, result.Error } @@ -84,11 +106,8 @@ func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) { // GetAPIKeyByID returns a ApiKey for a given id. func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - key := types.APIKey{} - if result := hsdb.db.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { + if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil { return nil, result.Error } @@ -98,10 +117,7 @@ func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) { // DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey // does not exist. func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if result := hsdb.db.Unscoped().Delete(key); result.Error != nil { + if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil { return result.Error } @@ -110,10 +126,7 @@ func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error { // ExpireAPIKey marks a ApiKey as expired. func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(&key).Update("Expiration", time.Now()).Error; err != nil { + if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil { return err } @@ -121,26 +134,164 @@ func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error { } func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - prefix, hash, found := strings.Cut(keyStr, ".") - if !found { - return false, ErrAPIKeyFailedToParse - } - - key, err := hsdb.GetAPIKey(prefix) + key, err := validateAPIKey(hsdb.DB, keyStr) if err != nil { - return false, fmt.Errorf("failed to validate api key: %w", err) - } - - if key.Expiration.Before(time.Now()) { - return false, nil - } - - if err := bcrypt.CompareHashAndPassword(key.Hash, []byte(hash)); err != nil { return false, err } + if key.Expiration != nil && key.Expiration.Before(time.Now()) { + return false, nil + } + return true, nil } + +// ParseAPIKeyPrefix extracts the database prefix from a display prefix. +// Handles formats: "hskey-api-{12chars}-***", "hskey-api-{12chars}", or just "{12chars}". +// Returns the 12-character prefix suitable for database lookup. +func ParseAPIKeyPrefix(displayPrefix string) (string, error) { + // If it's already just the 12-character prefix, return it + if len(displayPrefix) == apiKeyPrefixLength && isValidBase64URLSafe(displayPrefix) { + return displayPrefix, nil + } + + // If it starts with the API key prefix, parse it + if strings.HasPrefix(displayPrefix, apiKeyPrefix) { + // Remove the "hskey-api-" prefix + _, remainder, found := strings.Cut(displayPrefix, apiKeyPrefix) + if !found { + return "", fmt.Errorf("%w: invalid display prefix format", ErrAPIKeyFailedToParse) + } + + // Extract just the first 12 characters (the actual prefix) + if len(remainder) < apiKeyPrefixLength { + return "", fmt.Errorf("%w: prefix too short", ErrAPIKeyFailedToParse) + } + + prefix := remainder[:apiKeyPrefixLength] + + // Validate it's base64 URL-safe + if !isValidBase64URLSafe(prefix) { + return "", fmt.Errorf("%w: prefix contains invalid characters", ErrAPIKeyFailedToParse) + } + + return prefix, nil + } + + // For legacy 7-character prefixes or other formats, return as-is + return displayPrefix, nil +} + +// validateAPIKey validates an API key and returns the key if valid. +// Handles both new (hskey-api-{prefix}-{secret}) and legacy (prefix.secret) formats. +func validateAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) { + // Validate input is not empty + if keyStr == "" { + return nil, ErrAPIKeyFailedToParse + } + + // Check for new format: hskey-api-{prefix}-{secret} + _, prefixAndSecret, found := strings.Cut(keyStr, apiKeyPrefix) + + if !found { + // Legacy format: prefix.secret + return validateLegacyAPIKey(db, keyStr) + } + + // New format: parse and verify + const expectedMinLength = apiKeyPrefixLength + 1 + apiKeyHashLength + if len(prefixAndSecret) < expectedMinLength { + return nil, fmt.Errorf( + "%w: key too short, expected at least %d chars after prefix, got %d", + ErrAPIKeyFailedToParse, + expectedMinLength, + len(prefixAndSecret), + ) + } + + // Use fixed-length parsing + prefix := prefixAndSecret[:apiKeyPrefixLength] + + // Validate separator at expected position + if prefixAndSecret[apiKeyPrefixLength] != '-' { + return nil, fmt.Errorf( + "%w: expected separator '-' at position %d, got '%c'", + ErrAPIKeyFailedToParse, + apiKeyPrefixLength, + prefixAndSecret[apiKeyPrefixLength], + ) + } + + secret := prefixAndSecret[apiKeyPrefixLength+1:] + + // Validate secret length + if len(secret) != apiKeyHashLength { + return nil, fmt.Errorf( + "%w: secret length mismatch, expected %d chars, got %d", + ErrAPIKeyFailedToParse, + apiKeyHashLength, + len(secret), + ) + } + + // Validate prefix contains only base64 URL-safe characters + if !isValidBase64URLSafe(prefix) { + return nil, fmt.Errorf( + "%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrAPIKeyFailedToParse, + ) + } + + // Validate secret contains only base64 URL-safe characters + if !isValidBase64URLSafe(secret) { + return nil, fmt.Errorf( + "%w: secret contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrAPIKeyFailedToParse, + ) + } + + // Look up by prefix (indexed) + var key types.APIKey + + err := db.First(&key, "prefix = ?", prefix).Error + if err != nil { + return nil, fmt.Errorf("API key not found: %w", err) + } + + // Verify bcrypt hash + err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret)) + if err != nil { + return nil, fmt.Errorf("invalid API key: %w", err) + } + + return &key, nil +} + +// validateLegacyAPIKey validates a legacy format API key (prefix.secret). +func validateLegacyAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) { + // Legacy format uses "." as separator + prefix, secret, found := strings.Cut(keyStr, ".") + if !found { + return nil, ErrAPIKeyFailedToParse + } + + // Legacy prefix is 7 chars + if len(prefix) != legacyAPIPrefixLength { + return nil, fmt.Errorf("%w: legacy prefix length mismatch", ErrAPIKeyFailedToParse) + } + + var key types.APIKey + + err := db.First(&key, "prefix = ?", prefix).Error + if err != nil { + return nil, fmt.Errorf("API key not found: %w", err) + } + + // Verify bcrypt (key.Hash stores bcrypt of full secret) + err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret)) + if err != nil { + return nil, fmt.Errorf("invalid API key: %w", err) + } + + return &key, nil +} diff --git a/hscontrol/db/api_key_test.go b/hscontrol/db/api_key_test.go index c0b4e988..a34dd94b 100644 --- a/hscontrol/db/api_key_test.go +++ b/hscontrol/db/api_key_test.go @@ -1,89 +1,275 @@ package db import ( + "strings" + "testing" "time" - "gopkg.in/check.v1" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" ) -func (*Suite) TestCreateAPIKey(c *check.C) { +func TestCreateAPIKey(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + apiKeyStr, apiKey, err := db.CreateAPIKey(nil) - c.Assert(err, check.IsNil) - c.Assert(apiKey, check.NotNil) + require.NoError(t, err) + require.NotNil(t, apiKey) // Did we get a valid key? - c.Assert(apiKey.Prefix, check.NotNil) - c.Assert(apiKey.Hash, check.NotNil) - c.Assert(apiKeyStr, check.Not(check.Equals), "") + assert.NotNil(t, apiKey.Prefix) + assert.NotNil(t, apiKey.Hash) + assert.NotEmpty(t, apiKeyStr) _, err = db.ListAPIKeys() - c.Assert(err, check.IsNil) + require.NoError(t, err) keys, err := db.ListAPIKeys() - c.Assert(err, check.IsNil) - c.Assert(len(keys), check.Equals, 1) + require.NoError(t, err) + assert.Len(t, keys, 1) } -func (*Suite) TestAPIKeyDoesNotExist(c *check.C) { +func TestAPIKeyDoesNotExist(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + key, err := db.GetAPIKey("does-not-exist") - c.Assert(err, check.NotNil) - c.Assert(key, check.IsNil) + require.Error(t, err) + assert.Nil(t, key) } -func (*Suite) TestValidateAPIKeyOk(c *check.C) { +func TestValidateAPIKeyOk(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + nowPlus2 := time.Now().Add(2 * time.Hour) apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) - c.Assert(err, check.IsNil) - c.Assert(apiKey, check.NotNil) + require.NoError(t, err) + require.NotNil(t, apiKey) valid, err := db.ValidateAPIKey(apiKeyStr) - c.Assert(err, check.IsNil) - c.Assert(valid, check.Equals, true) + require.NoError(t, err) + assert.True(t, valid) } -func (*Suite) TestValidateAPIKeyNotOk(c *check.C) { +func TestValidateAPIKeyNotOk(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + nowMinus2 := time.Now().Add(time.Duration(-2) * time.Hour) apiKeyStr, apiKey, err := db.CreateAPIKey(&nowMinus2) - c.Assert(err, check.IsNil) - c.Assert(apiKey, check.NotNil) + require.NoError(t, err) + require.NotNil(t, apiKey) valid, err := db.ValidateAPIKey(apiKeyStr) - c.Assert(err, check.IsNil) - c.Assert(valid, check.Equals, false) + require.NoError(t, err) + assert.False(t, valid) now := time.Now() apiKeyStrNow, apiKey, err := db.CreateAPIKey(&now) - c.Assert(err, check.IsNil) - c.Assert(apiKey, check.NotNil) + require.NoError(t, err) + require.NotNil(t, apiKey) validNow, err := db.ValidateAPIKey(apiKeyStrNow) - c.Assert(err, check.IsNil) - c.Assert(validNow, check.Equals, false) + require.NoError(t, err) + assert.False(t, validNow) validSilly, err := db.ValidateAPIKey("nota.validkey") - c.Assert(err, check.NotNil) - c.Assert(validSilly, check.Equals, false) + require.Error(t, err) + assert.False(t, validSilly) validWithErr, err := db.ValidateAPIKey("produceerrorkey") - c.Assert(err, check.NotNil) - c.Assert(validWithErr, check.Equals, false) + require.Error(t, err) + assert.False(t, validWithErr) } -func (*Suite) TestExpireAPIKey(c *check.C) { +func TestExpireAPIKey(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + nowPlus2 := time.Now().Add(2 * time.Hour) apiKeyStr, apiKey, err := db.CreateAPIKey(&nowPlus2) - c.Assert(err, check.IsNil) - c.Assert(apiKey, check.NotNil) + require.NoError(t, err) + require.NotNil(t, apiKey) valid, err := db.ValidateAPIKey(apiKeyStr) - c.Assert(err, check.IsNil) - c.Assert(valid, check.Equals, true) + require.NoError(t, err) + assert.True(t, valid) err = db.ExpireAPIKey(apiKey) - c.Assert(err, check.IsNil) - c.Assert(apiKey.Expiration, check.NotNil) + require.NoError(t, err) + assert.NotNil(t, apiKey.Expiration) notValid, err := db.ValidateAPIKey(apiKeyStr) - c.Assert(err, check.IsNil) - c.Assert(notValid, check.Equals, false) + require.NoError(t, err) + assert.False(t, notValid) +} + +func TestAPIKeyWithPrefix(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "new_key_with_prefix", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + keyStr, apiKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + + // Verify format: hskey-api-{12-char-prefix}-{64-char-secret} + assert.True(t, strings.HasPrefix(keyStr, "hskey-api-")) + + _, prefixAndSecret, found := strings.Cut(keyStr, "hskey-api-") + assert.True(t, found) + assert.GreaterOrEqual(t, len(prefixAndSecret), 12+1+64) + + prefix := prefixAndSecret[:12] + assert.Len(t, prefix, 12) + assert.Equal(t, byte('-'), prefixAndSecret[12]) + secret := prefixAndSecret[13:] + assert.Len(t, secret, 64) + + // Verify stored fields + assert.Len(t, apiKey.Prefix, types.NewAPIKeyPrefixLength) + assert.NotNil(t, apiKey.Hash) + }, + }, + { + name: "new_key_can_be_retrieved", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + keyStr, createdKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + + // Validate the created key + valid, err := db.ValidateAPIKey(keyStr) + require.NoError(t, err) + assert.True(t, valid) + + // Verify prefix is correct length + assert.Len(t, createdKey.Prefix, types.NewAPIKeyPrefixLength) + }, + }, + { + name: "invalid_key_format_rejected", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + invalidKeys := []string{ + "", + "hskey-api-short", + "hskey-api-ABCDEFGHIJKL-tooshort", + "hskey-api-ABC$EFGHIJKL-" + strings.Repeat("a", 64), + "hskey-api-ABCDEFGHIJKL" + strings.Repeat("a", 64), // missing separator + } + + for _, invalidKey := range invalidKeys { + valid, err := db.ValidateAPIKey(invalidKey) + require.Error(t, err, "key should be rejected: %s", invalidKey) + assert.False(t, valid) + } + }, + }, + { + name: "legacy_key_still_works", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + // Insert legacy API key directly (7-char prefix + 32-char secret) + legacyPrefix := "abcdefg" + legacySecret := strings.Repeat("x", 32) + legacyKey := legacyPrefix + "." + legacySecret + hash, err := bcrypt.GenerateFromPassword([]byte(legacySecret), bcrypt.DefaultCost) + require.NoError(t, err) + + now := time.Now() + err = db.DB.Exec(` + INSERT INTO api_keys (prefix, hash, created_at) + VALUES (?, ?, ?) + `, legacyPrefix, hash, now).Error + require.NoError(t, err) + + // Validate legacy key + valid, err := db.ValidateAPIKey(legacyKey) + require.NoError(t, err) + assert.True(t, valid) + }, + }, + { + name: "wrong_secret_rejected", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + keyStr, _, err := db.CreateAPIKey(nil) + require.NoError(t, err) + + // Tamper with the secret + _, prefixAndSecret, _ := strings.Cut(keyStr, "hskey-api-") + prefix := prefixAndSecret[:12] + tamperedKey := "hskey-api-" + prefix + "-" + strings.Repeat("x", 64) + + valid, err := db.ValidateAPIKey(tamperedKey) + require.Error(t, err) + assert.False(t, valid) + }, + }, + { + name: "expired_key_rejected", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + // Create expired key + expired := time.Now().Add(-1 * time.Hour) + keyStr, _, err := db.CreateAPIKey(&expired) + require.NoError(t, err) + + // Should fail validation + valid, err := db.ValidateAPIKey(keyStr) + require.NoError(t, err) + assert.False(t, valid) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) + } +} + +func TestGetAPIKeyByID(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + // Create an API key + _, apiKey, err := db.CreateAPIKey(nil) + require.NoError(t, err) + require.NotNil(t, apiKey) + + // Retrieve by ID + retrievedKey, err := db.GetAPIKeyByID(apiKey.ID) + require.NoError(t, err) + require.NotNil(t, retrievedKey) + assert.Equal(t, apiKey.ID, retrievedKey.ID) + assert.Equal(t, apiKey.Prefix, retrievedKey.Prefix) +} + +func TestGetAPIKeyByIDNotFound(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + // Try to get a non-existent key by ID + key, err := db.GetAPIKeyByID(99999) + require.Error(t, err) + assert.Nil(t, key) } diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 030a6f0b..a1429aa6 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -2,355 +2,841 @@ package db import ( "context" - "database/sql" + _ "embed" + "encoding/json" "errors" "fmt" "net/netip" - "strings" - "sync" + "path/filepath" + "slices" + "strconv" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" - "github.com/juanfont/headscale/hscontrol/notifier" + "github.com/juanfont/headscale/hscontrol/db/sqliteconfig" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" + "github.com/tailscale/squibble" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/logger" + "gorm.io/gorm/schema" + "tailscale.com/net/tsaddr" + "zgo.at/zcache/v2" ) -const ( - Postgres = "postgres" - Sqlite = "sqlite3" -) +//go:embed schema.sql +var dbSchema string + +func init() { + schema.RegisterSerializer("text", TextSerialiser{}) +} var errDatabaseNotSupported = errors.New("database type not supported") -// KV is a key-value store in a psql table. For future use... -// TODO(kradalby): Is this used for anything? -type KV struct { - Key string - Value string -} +var errForeignKeyConstraintsViolated = errors.New("foreign key constraints violated") + +const ( + maxIdleConns = 100 + maxOpenConns = 100 + contextTimeoutSecs = 10 +) type HSDatabase struct { - db *gorm.DB - notifier *notifier.Notifier - - mu sync.RWMutex - - ipAllocationMutex sync.Mutex - - ipPrefixes []netip.Prefix - baseDomain string + DB *gorm.DB + cfg *types.Config + regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] } -// TODO(kradalby): assemble this struct from toptions or something typed -// rather than arguments. +// NewHeadscaleDatabase creates a new database connection and runs migrations. +// It accepts the full configuration to allow migrations access to policy settings. func NewHeadscaleDatabase( - dbType, connectionAddr string, - debug bool, - notifier *notifier.Notifier, - ipPrefixes []netip.Prefix, - baseDomain string, + cfg *types.Config, + regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], ) (*HSDatabase, error) { - dbConn, err := openDB(dbType, connectionAddr, debug) + dbConn, err := openDB(cfg.Database) if err != nil { return nil, err } - migrations := gormigrate.New(dbConn, gormigrate.DefaultOptions, []*gormigrate.Migration{ - // New migrations should be added as transactions at the end of this list. - // The initial commit here is quite messy, completely out of order and - // has no versioning and is the tech debt of not having versioned migrations - // prior to this point. This first migration is all DB changes to bring a DB - // up to 0.23.0. - { - ID: "202312101416", - Migrate: func(tx *gorm.DB) error { - if dbType == Postgres { - tx.Exec(`create extension if not exists "uuid-ossp";`) - } + migrations := gormigrate.New( + dbConn, + gormigrate.DefaultOptions, + []*gormigrate.Migration{ + // New migrations must be added as transactions at the end of this list. + // Migrations start from v0.25.0. If upgrading from v0.24.x or earlier, + // you must first upgrade to v0.25.1 before upgrading to this version. - _ = tx.Migrator().RenameTable("namespaces", "users") - - // the big rename from Machine to Node - _ = tx.Migrator().RenameTable("machines", "nodes") - _ = tx.Migrator().RenameColumn(&types.Route{}, "machine_id", "node_id") - - err = tx.AutoMigrate(types.User{}) - if err != nil { - return err - } - - _ = tx.Migrator().RenameColumn(&types.Node{}, "namespace_id", "user_id") - _ = tx.Migrator().RenameColumn(&types.PreAuthKey{}, "namespace_id", "user_id") - - _ = tx.Migrator().RenameColumn(&types.Node{}, "ip_address", "ip_addresses") - _ = tx.Migrator().RenameColumn(&types.Node{}, "name", "hostname") - - // GivenName is used as the primary source of DNS names, make sure - // the field is populated and normalized if it was not when the - // node was registered. - _ = tx.Migrator().RenameColumn(&types.Node{}, "nickname", "given_name") - - // If the Node table has a column for registered, - // find all occourences of "false" and drop them. Then - // remove the column. - if tx.Migrator().HasColumn(&types.Node{}, "registered") { - log.Info(). - Msg(`Database has legacy "registered" column in node, removing...`) - - nodes := types.Nodes{} - if err := tx.Not("registered").Find(&nodes).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") + // v0.25.0 + { + // Add a constraint to routes ensuring they cannot exist without a node. + ID: "202501221827", + Migrate: func(tx *gorm.DB) error { + // Remove any invalid routes associated with a node that does not exist. + if tx.Migrator().HasTable(&types.Route{}) && tx.Migrator().HasTable(&types.Node{}) { + err := tx.Exec("delete from routes where node_id not in (select id from nodes)").Error + if err != nil { + return err + } } - for _, node := range nodes { - log.Info(). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Msg("Deleting unregistered node") - if err := tx.Delete(&types.Node{}, node.ID).Error; err != nil { - log.Error(). + // Remove any invalid routes without a node_id. + if tx.Migrator().HasTable(&types.Route{}) { + err := tx.Exec("delete from routes where node_id is null").Error + if err != nil { + return err + } + } + + err := tx.AutoMigrate(&types.Route{}) + if err != nil { + return fmt.Errorf("automigrating types.Route: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // Add back constraint so you cannot delete preauth keys that + // is still used by a node. + { + ID: "202501311657", + Migrate: func(tx *gorm.DB) error { + err := tx.AutoMigrate(&types.PreAuthKey{}) + if err != nil { + return fmt.Errorf("automigrating types.PreAuthKey: %w", err) + } + err = tx.AutoMigrate(&types.Node{}) + if err != nil { + return fmt.Errorf("automigrating types.Node: %w", err) + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // Ensure there are no nodes referring to a deleted preauthkey. + { + ID: "202502070949", + Migrate: func(tx *gorm.DB) error { + if tx.Migrator().HasTable(&types.PreAuthKey{}) { + err := tx.Exec(` +UPDATE nodes +SET auth_key_id = NULL +WHERE auth_key_id IS NOT NULL +AND auth_key_id NOT IN ( + SELECT id FROM pre_auth_keys +); + `).Error + if err != nil { + return fmt.Errorf("setting auth_key to null on nodes with non-existing keys: %w", err) + } + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // v0.26.0 + // Migrate all routes from the Route table to the new field ApprovedRoutes + // in the Node table. Then drop the Route table. + { + ID: "202502131714", + Migrate: func(tx *gorm.DB) error { + if !tx.Migrator().HasColumn(&types.Node{}, "approved_routes") { + err := tx.Migrator().AddColumn(&types.Node{}, "approved_routes") + if err != nil { + return fmt.Errorf("adding column types.Node: %w", err) + } + } + + nodeRoutes := map[uint64][]netip.Prefix{} + + var routes []types.Route + err = tx.Find(&routes).Error + if err != nil { + return fmt.Errorf("fetching routes: %w", err) + } + + for _, route := range routes { + if route.Enabled { + nodeRoutes[route.NodeID] = append(nodeRoutes[route.NodeID], route.Prefix) + } + } + + for nodeID, routes := range nodeRoutes { + tsaddr.SortPrefixes(routes) + routes = slices.Compact(routes) + + data, err := json.Marshal(routes) + + err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", data).Error + if err != nil { + return fmt.Errorf("saving approved routes to new column: %w", err) + } + } + + // Drop the old table. + _ = tx.Migrator().DropTable(&types.Route{}) + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + { + ID: "202502171819", + Migrate: func(tx *gorm.DB) error { + // This migration originally removed the last_seen column + // from the node table, but it was added back in + // 202505091439. + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // Add back last_seen column to node table. + { + ID: "202505091439", + Migrate: func(tx *gorm.DB) error { + // Add back last_seen column to node table if it does not exist. + // This is a workaround for the fact that the last_seen column + // was removed in the 202502171819 migration, but only for some + // beta testers. + if !tx.Migrator().HasColumn(&types.Node{}, "last_seen") { + _ = tx.Migrator().AddColumn(&types.Node{}, "last_seen") + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // Fix the provider identifier for users that have a double slash in the + // provider identifier. + { + ID: "202505141324", + Migrate: func(tx *gorm.DB) error { + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("listing users: %w", err) + } + + for _, user := range users { + user.ProviderIdentifier.String = types.CleanIdentifier(user.ProviderIdentifier.String) + + err := tx.Save(user).Error + if err != nil { + return fmt.Errorf("saving user: %w", err) + } + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // v0.27.0 + // Schema migration to ensure all tables match the expected schema. + // This migration recreates all tables to match the exact structure in schema.sql, + // preserving all data during the process. + // Only SQLite will be migrated for consistency. + { + ID: "202507021200", + Migrate: func(tx *gorm.DB) error { + // Only run on SQLite + if cfg.Database.Type != types.DatabaseSqlite { + log.Info().Msg("Skipping schema migration on non-SQLite database") + return nil + } + + log.Info().Msg("Starting schema recreation with table renaming") + + // Rename existing tables to _old versions + tablesToRename := []string{"users", "pre_auth_keys", "api_keys", "nodes", "policies"} + + // Check if routes table exists and drop it (should have been migrated already) + var routesExists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='routes'").Row().Scan(&routesExists) + if err == nil && routesExists { + log.Info().Msg("Dropping leftover routes table") + if err := tx.Exec("DROP TABLE routes").Error; err != nil { + return fmt.Errorf("dropping routes table: %w", err) + } + } + + // Drop all indexes first to avoid conflicts + indexesToDrop := []string{ + "idx_users_deleted_at", + "idx_provider_identifier", + "idx_name_provider_identifier", + "idx_name_no_provider_identifier", + "idx_api_keys_prefix", + "idx_policies_deleted_at", + } + + for _, index := range indexesToDrop { + _ = tx.Exec("DROP INDEX IF EXISTS " + index).Error + } + + for _, table := range tablesToRename { + // Check if table exists before renaming + var exists bool + err := tx.Raw("SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name=?", table).Row().Scan(&exists) + if err != nil { + return fmt.Errorf("checking if table %s exists: %w", table, err) + } + + if exists { + // Drop old table if it exists from previous failed migration + _ = tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error + + // Rename current table to _old + if err := tx.Exec("ALTER TABLE " + table + " RENAME TO " + table + "_old").Error; err != nil { + return fmt.Errorf("renaming table %s to %s_old: %w", table, table, err) + } + } + } + + // Create new tables with correct schema + tableCreationSQL := []string{ + `CREATE TABLE users( + id integer PRIMARY KEY AUTOINCREMENT, + name text, + display_name text, + email text, + provider_identifier text, + provider text, + profile_pic_url text, + created_at datetime, + updated_at datetime, + deleted_at datetime +)`, + `CREATE TABLE pre_auth_keys( + id integer PRIMARY KEY AUTOINCREMENT, + key text, + user_id integer, + reusable numeric, + ephemeral numeric DEFAULT false, + used numeric DEFAULT false, + tags text, + expiration datetime, + created_at datetime, + CONSTRAINT fk_pre_auth_keys_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL +)`, + `CREATE TABLE api_keys( + id integer PRIMARY KEY AUTOINCREMENT, + prefix text, + hash blob, + expiration datetime, + last_seen datetime, + created_at datetime +)`, + `CREATE TABLE nodes( + id integer PRIMARY KEY AUTOINCREMENT, + machine_key text, + node_key text, + disco_key text, + endpoints text, + host_info text, + ipv4 text, + ipv6 text, + hostname text, + given_name varchar(63), + user_id integer, + register_method text, + forced_tags text, + auth_key_id integer, + last_seen datetime, + expiry datetime, + approved_routes text, + created_at datetime, + updated_at datetime, + deleted_at datetime, + CONSTRAINT fk_nodes_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT fk_nodes_auth_key FOREIGN KEY(auth_key_id) REFERENCES pre_auth_keys(id) +)`, + `CREATE TABLE policies( + id integer PRIMARY KEY AUTOINCREMENT, + data text, + created_at datetime, + updated_at datetime, + deleted_at datetime +)`, + } + + for _, createSQL := range tableCreationSQL { + if err := tx.Exec(createSQL).Error; err != nil { + return fmt.Errorf("creating new table: %w", err) + } + } + + // Copy data directly using SQL + dataCopySQL := []string{ + `INSERT INTO users (id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at) + SELECT id, name, display_name, email, provider_identifier, provider, profile_pic_url, created_at, updated_at, deleted_at + FROM users_old`, + + `INSERT INTO pre_auth_keys (id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at) + SELECT id, key, user_id, reusable, ephemeral, used, tags, expiration, created_at + FROM pre_auth_keys_old`, + + `INSERT INTO api_keys (id, prefix, hash, expiration, last_seen, created_at) + SELECT id, prefix, hash, expiration, last_seen, created_at + FROM api_keys_old`, + + `INSERT INTO nodes (id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at) + SELECT id, machine_key, node_key, disco_key, endpoints, host_info, ipv4, ipv6, hostname, given_name, user_id, register_method, forced_tags, auth_key_id, last_seen, expiry, approved_routes, created_at, updated_at, deleted_at + FROM nodes_old`, + + `INSERT INTO policies (id, data, created_at, updated_at, deleted_at) + SELECT id, data, created_at, updated_at, deleted_at + FROM policies_old`, + } + + for _, copySQL := range dataCopySQL { + if err := tx.Exec(copySQL).Error; err != nil { + return fmt.Errorf("copying data: %w", err) + } + } + + // Create indexes + indexes := []string{ + "CREATE INDEX idx_users_deleted_at ON users(deleted_at)", + `CREATE UNIQUE INDEX idx_provider_identifier ON users( + provider_identifier +) WHERE provider_identifier IS NOT NULL`, + `CREATE UNIQUE INDEX idx_name_provider_identifier ON users( + name, + provider_identifier +)`, + `CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users( + name +) WHERE provider_identifier IS NULL`, + "CREATE UNIQUE INDEX idx_api_keys_prefix ON api_keys(prefix)", + "CREATE INDEX idx_policies_deleted_at ON policies(deleted_at)", + } + + for _, indexSQL := range indexes { + if err := tx.Exec(indexSQL).Error; err != nil { + return fmt.Errorf("creating index: %w", err) + } + } + + // Drop old tables only after everything succeeds + for _, table := range tablesToRename { + if err := tx.Exec("DROP TABLE IF EXISTS " + table + "_old").Error; err != nil { + log.Warn().Str("table", table+"_old").Err(err).Msg("Failed to drop old table, but migration succeeded") + } + } + + log.Info().Msg("Schema recreation completed successfully") + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + // v0.27.1 + { + // Drop all tables that are no longer in use and has existed. + // They potentially still present from broken migrations in the past. + ID: "202510311551", + Migrate: func(tx *gorm.DB) error { + for _, oldTable := range []string{"namespaces", "machines", "shared_machines", "kvs", "pre_auth_key_acl_tags", "routes"} { + err := tx.Migrator().DropTable(oldTable) + if err != nil { + log.Trace().Str("table", oldTable). Err(err). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Msg("Error deleting unregistered node") + Msg("Error dropping old table, continuing...") } } - err := tx.Migrator().DropColumn(&types.Node{}, "registered") + return nil + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, + }, + { + // Drop all indices that are no longer in use and has existed. + // They potentially still present from broken migrations in the past. + // They should all be cleaned up by the db engine, but we are a bit + // conservative to ensure all our previous mess is cleaned up. + ID: "202511101554-drop-old-idx", + Migrate: func(tx *gorm.DB) error { + for _, oldIdx := range []struct{ name, table string }{ + {"idx_namespaces_deleted_at", "namespaces"}, + {"idx_routes_deleted_at", "routes"}, + {"idx_shared_machines_deleted_at", "shared_machines"}, + } { + err := tx.Migrator().DropIndex(oldIdx.table, oldIdx.name) + if err != nil { + log.Trace(). + Str("index", oldIdx.name). + Str("table", oldIdx.table). + Err(err). + Msg("Error dropping old index, continuing...") + } + } + + return nil + }, + Rollback: func(tx *gorm.DB) error { + return nil + }, + }, + + // Migrations **above** this points will be REMOVED in version **0.29.0** + // This is to clean up a lot of old migrations that is seldom used + // and carries a lot of technical debt. + // Any new migrations should be added after the comment below and follow + // the rules it sets out. + + // From this point, the following rules must be followed: + // - NEVER use gorm.AutoMigrate, write the exact migration steps needed + // - AutoMigrate depends on the struct staying exactly the same, which it won't over time. + // - Never write migrations that requires foreign keys to be disabled. + // - ALL errors in migrations must be handled properly. + + { + // Add columns for prefix and hash for pre auth keys, implementing + // them with the same security model as api keys. + ID: "202511011637-preauthkey-bcrypt", + Migrate: func(tx *gorm.DB) error { + // Check and add prefix column if it doesn't exist + if !tx.Migrator().HasColumn(&types.PreAuthKey{}, "prefix") { + err := tx.Migrator().AddColumn(&types.PreAuthKey{}, "prefix") + if err != nil { + return fmt.Errorf("adding prefix column: %w", err) + } + } + + // Check and add hash column if it doesn't exist + if !tx.Migrator().HasColumn(&types.PreAuthKey{}, "hash") { + err := tx.Migrator().AddColumn(&types.PreAuthKey{}, "hash") + if err != nil { + return fmt.Errorf("adding hash column: %w", err) + } + } + + // Create partial unique index to allow multiple legacy keys (NULL/empty prefix) + // while enforcing uniqueness for new bcrypt-based keys + err := tx.Exec("CREATE UNIQUE INDEX IF NOT EXISTS idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''").Error if err != nil { - log.Error().Err(err).Msg("Error dropping registered column") - } - } - - err = tx.AutoMigrate(&types.Route{}) - if err != nil { - return err - } - - err = tx.AutoMigrate(&types.Node{}) - if err != nil { - return err - } - - // Ensure all keys have correct prefixes - // https://github.com/tailscale/tailscale/blob/main/types/key/node.go#L35 - type result struct { - ID uint64 - MachineKey string - NodeKey string - DiscoKey string - } - var results []result - err = tx.Raw("SELECT id, node_key, machine_key, disco_key FROM nodes").Find(&results).Error - if err != nil { - return err - } - - for _, node := range results { - mKey := node.MachineKey - if !strings.HasPrefix(node.MachineKey, "mkey:") { - mKey = "mkey:" + node.MachineKey - } - nKey := node.NodeKey - if !strings.HasPrefix(node.NodeKey, "nodekey:") { - nKey = "nodekey:" + node.NodeKey + return fmt.Errorf("creating prefix index: %w", err) } - dKey := node.DiscoKey - if !strings.HasPrefix(node.DiscoKey, "discokey:") { - dKey = "discokey:" + node.DiscoKey + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + { + ID: "202511122344-remove-newline-index", + Migrate: func(tx *gorm.DB) error { + // Reformat multi-line indexes to single-line for consistency + // This migration drops and recreates the three user identity indexes + // to match the single-line format expected by schema validation + + // Drop existing multi-line indexes + dropIndexes := []string{ + `DROP INDEX IF EXISTS idx_provider_identifier`, + `DROP INDEX IF EXISTS idx_name_provider_identifier`, + `DROP INDEX IF EXISTS idx_name_no_provider_identifier`, } - err := tx.Exec( - "UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", - sql.Named("mKey", mKey), - sql.Named("nKey", nKey), - sql.Named("dKey", dKey), - sql.Named("id", node.ID), - ).Error + for _, dropSQL := range dropIndexes { + err := tx.Exec(dropSQL).Error + if err != nil { + return fmt.Errorf("dropping index: %w", err) + } + } + + // Recreate indexes in single-line format + createIndexes := []string{ + `CREATE UNIQUE INDEX idx_provider_identifier ON users(provider_identifier) WHERE provider_identifier IS NOT NULL`, + `CREATE UNIQUE INDEX idx_name_provider_identifier ON users(name, provider_identifier)`, + `CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users(name) WHERE provider_identifier IS NULL`, + } + + for _, createSQL := range createIndexes { + err := tx.Exec(createSQL).Error + if err != nil { + return fmt.Errorf("creating index: %w", err) + } + } + + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + { + // Rename forced_tags column to tags in nodes table. + // This must run after migration 202505141324 which creates tables with forced_tags. + ID: "202511131445-node-forced-tags-to-tags", + Migrate: func(tx *gorm.DB) error { + // Rename the column from forced_tags to tags + err := tx.Migrator().RenameColumn(&types.Node{}, "forced_tags", "tags") if err != nil { - return err - } - } - - if tx.Migrator().HasColumn(&types.Node{}, "enabled_routes") { - log.Info().Msgf("Database has legacy enabled_routes column in node, migrating...") - - type NodeAux struct { - ID uint64 - EnabledRoutes types.IPPrefixes + return fmt.Errorf("renaming forced_tags to tags: %w", err) } - nodesAux := []NodeAux{} - err := tx.Table("nodes").Select("id, enabled_routes").Scan(&nodesAux).Error + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, + }, + { + // Migrate RequestTags from host_info JSON to tags column. + // In 0.27.x, tags from --advertise-tags (ValidTags) were stored only in + // host_info.RequestTags, not in the tags column (formerly forced_tags). + // This migration validates RequestTags against the policy's tagOwners + // and merges validated tags into the tags column. + // Fixes: https://github.com/juanfont/headscale/issues/3006 + ID: "202601121700-migrate-hostinfo-request-tags", + Migrate: func(tx *gorm.DB) error { + // 1. Load policy from file or database based on configuration + policyData, err := PolicyBytes(tx, cfg) if err != nil { - log.Fatal().Err(err).Msg("Error accessing db") + log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)") + return nil } - for _, node := range nodesAux { - for _, prefix := range node.EnabledRoutes { - if err != nil { - log.Error(). - Err(err). - Str("enabled_route", prefix.String()). - Msg("Error parsing enabled_route") - continue - } + if len(policyData) == 0 { + log.Info().Msg("No policy found, skipping RequestTags migration (tags will be validated on node reconnect)") + return nil + } - err = tx.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). - First(&types.Route{}). - Error - if err == nil { - log.Info(). - Str("enabled_route", prefix.String()). - Msg("Route already migrated to new table, skipping") + // 2. Load users and nodes to create PolicyManager + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("loading users for RequestTags migration: %w", err) + } - continue - } + nodes, err := ListNodes(tx) + if err != nil { + return fmt.Errorf("loading nodes for RequestTags migration: %w", err) + } - route := types.Route{ - NodeID: node.ID, - Advertised: true, - Enabled: true, - Prefix: types.IPPrefix(prefix), - } - if err := tx.Create(&route).Error; err != nil { - log.Error().Err(err).Msg("Error creating route") + // 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.) + polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice()) + if err != nil { + log.Warn().Err(err).Msg("Failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)") + return nil + } + + // 4. Process each node + for _, node := range nodes { + if node.Hostinfo == nil { + continue + } + + requestTags := node.Hostinfo.RequestTags + if len(requestTags) == 0 { + continue + } + + existingTags := node.Tags + + var validatedTags, rejectedTags []string + + nodeView := node.View() + + for _, tag := range requestTags { + if polMan.NodeCanHaveTag(nodeView, tag) { + if !slices.Contains(existingTags, tag) { + validatedTags = append(validatedTags, tag) + } } else { - log.Info(). - Uint64("node_id", route.NodeID). - Str("prefix", prefix.String()). - Msg("Route migrated") + rejectedTags = append(rejectedTags, tag) } } - } - err = tx.Migrator().DropColumn(&types.Node{}, "enabled_routes") - if err != nil { - log.Error().Err(err).Msg("Error dropping enabled_routes column") - } - } - - if tx.Migrator().HasColumn(&types.Node{}, "given_name") { - nodes := types.Nodes{} - if err := tx.Find(&nodes).Error; err != nil { - log.Error().Err(err).Msg("Error accessing db") - } - - for item, node := range nodes { - if node.GivenName == "" { - normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( - node.Hostname, - ) - if err != nil { - log.Error(). - Caller(). - Str("hostname", node.Hostname). - Err(err). - Msg("Failed to normalize node hostname in DB migration") + if len(validatedTags) == 0 { + if len(rejectedTags) > 0 { + log.Debug(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("rejected_tags", rejectedTags). + Msg("RequestTags rejected during migration (not authorized)") } - err = tx.Model(nodes[item]).Updates(types.Node{ - GivenName: normalizedHostname, - }).Error - if err != nil { - log.Error(). - Caller(). - Str("hostname", node.Hostname). - Err(err). - Msg("Failed to save normalized node name in DB migration") - } + continue } + + mergedTags := append(existingTags, validatedTags...) + slices.Sort(mergedTags) + mergedTags = slices.Compact(mergedTags) + + tagsJSON, err := json.Marshal(mergedTags) + if err != nil { + return fmt.Errorf("serializing merged tags for node %d: %w", node.ID, err) + } + + err = tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error + if err != nil { + return fmt.Errorf("updating tags for node %d: %w", node.ID, err) + } + + log.Info(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("validated_tags", validatedTags). + Strs("rejected_tags", rejectedTags). + Strs("existing_tags", existingTags). + Strs("merged_tags", mergedTags). + Msg("Migrated validated RequestTags from host_info to tags column") } - } - err = tx.AutoMigrate(&KV{}) - if err != nil { - return err - } - - err = tx.AutoMigrate(&types.PreAuthKey{}) - if err != nil { - return err - } - - err = tx.AutoMigrate(&types.PreAuthKeyACLTag{}) - if err != nil { - return err - } - - _ = tx.Migrator().DropTable("shared_machines") - - err = tx.AutoMigrate(&types.APIKey{}) - if err != nil { - return err - } - - return nil - }, - Rollback: func(tx *gorm.DB) error { - return nil - }, - }, - { - // drop key-value table, it is not used, and has not contained - // useful data for a long time or ever. - ID: "202312101430", - Migrate: func(tx *gorm.DB) error { - return tx.Migrator().DropTable("kvs") - }, - Rollback: func(tx *gorm.DB) error { - return nil + return nil + }, + Rollback: func(db *gorm.DB) error { return nil }, }, }, + ) + + migrations.InitSchema(func(tx *gorm.DB) error { + // Create all tables using AutoMigrate + err := tx.AutoMigrate( + &types.User{}, + &types.PreAuthKey{}, + &types.APIKey{}, + &types.Node{}, + &types.Policy{}, + ) + if err != nil { + return err + } + + // Drop all indexes (both GORM-created and potentially pre-existing ones) + // to ensure we can recreate them in the correct format + dropIndexes := []string{ + `DROP INDEX IF EXISTS "idx_users_deleted_at"`, + `DROP INDEX IF EXISTS "idx_api_keys_prefix"`, + `DROP INDEX IF EXISTS "idx_policies_deleted_at"`, + `DROP INDEX IF EXISTS "idx_provider_identifier"`, + `DROP INDEX IF EXISTS "idx_name_provider_identifier"`, + `DROP INDEX IF EXISTS "idx_name_no_provider_identifier"`, + `DROP INDEX IF EXISTS "idx_pre_auth_keys_prefix"`, + } + + for _, dropSQL := range dropIndexes { + err := tx.Exec(dropSQL).Error + if err != nil { + return err + } + } + + // Recreate indexes without backticks to match schema.sql format + indexes := []string{ + `CREATE INDEX idx_users_deleted_at ON users(deleted_at)`, + `CREATE UNIQUE INDEX idx_api_keys_prefix ON api_keys(prefix)`, + `CREATE INDEX idx_policies_deleted_at ON policies(deleted_at)`, + `CREATE UNIQUE INDEX idx_provider_identifier ON users(provider_identifier) WHERE provider_identifier IS NOT NULL`, + `CREATE UNIQUE INDEX idx_name_provider_identifier ON users(name, provider_identifier)`, + `CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users(name) WHERE provider_identifier IS NULL`, + `CREATE UNIQUE INDEX idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''`, + } + + for _, indexSQL := range indexes { + err := tx.Exec(indexSQL).Error + if err != nil { + return err + } + } + + return nil }) - if err = migrations.Migrate(); err != nil { - log.Fatal().Err(err).Msgf("Migration failed: %v", err) + err = runMigrations(cfg.Database, dbConn, migrations) + if err != nil { + return nil, fmt.Errorf("migration failed: %w", err) + } + + // Validate that the schema ends up in the expected state. + // This is currently only done on sqlite as squibble does not + // support Postgres and we use our sqlite schema as our source of + // truth. + if cfg.Database.Type == types.DatabaseSqlite { + sqlConn, err := dbConn.DB() + if err != nil { + return nil, fmt.Errorf("getting DB from gorm: %w", err) + } + + // or else it blocks... + sqlConn.SetMaxIdleConns(maxIdleConns) + sqlConn.SetMaxOpenConns(maxOpenConns) + defer sqlConn.SetMaxIdleConns(1) + defer sqlConn.SetMaxOpenConns(1) + + ctx, cancel := context.WithTimeout(context.Background(), contextTimeoutSecs*time.Second) + defer cancel() + + opts := squibble.DigestOptions{ + IgnoreTables: []string{ + // Litestream tables, these are inserted by + // litestream and not part of our schema + // https://litestream.io/how-it-works + "_litestream_lock", + "_litestream_seq", + }, + } + + if err := squibble.Validate(ctx, sqlConn, dbSchema, &opts); err != nil { + return nil, fmt.Errorf("validating schema: %w", err) + } } db := HSDatabase{ - db: dbConn, - notifier: notifier, - - ipPrefixes: ipPrefixes, - baseDomain: baseDomain, + DB: dbConn, + cfg: cfg, + regCache: regCache, } return &db, err } -func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { - log.Debug().Str("type", dbType).Str("connection", connectionAddr).Msg("opening database") - +func openDB(cfg types.DatabaseConfig) (*gorm.DB, error) { + // TODO(kradalby): Integrate this with zerolog var dbLogger logger.Interface - if debug { - dbLogger = logger.Default + if cfg.Debug { + dbLogger = util.NewDBLogWrapper(&log.Logger, cfg.Gorm.SlowThreshold, cfg.Gorm.SkipErrRecordNotFound, cfg.Gorm.ParameterizedQueries) } else { dbLogger = logger.Default.LogMode(logger.Silent) } - switch dbType { - case Sqlite: + switch cfg.Type { + case types.DatabaseSqlite: + dir := filepath.Dir(cfg.Sqlite.Path) + err := util.EnsureDir(dir) + if err != nil { + return nil, fmt.Errorf("creating directory for sqlite: %w", err) + } + + log.Info(). + Str("database", types.DatabaseSqlite). + Str("path", cfg.Sqlite.Path). + Msg("Opening database") + + // Build SQLite configuration with pragmas set at connection time + sqliteConfig := sqliteconfig.Default(cfg.Sqlite.Path) + if cfg.Sqlite.WriteAheadLog { + sqliteConfig.JournalMode = sqliteconfig.JournalModeWAL + sqliteConfig.WALAutocheckpoint = cfg.Sqlite.WALAutoCheckPoint + } + + connectionURL, err := sqliteConfig.ToURL() + if err != nil { + return nil, fmt.Errorf("building sqlite connection URL: %w", err) + } + db, err := gorm.Open( - sqlite.Open(connectionAddr+"?_synchronous=1&_journal_mode=WAL"), + sqlite.Open(connectionURL), &gorm.Config{ - DisableForeignKeyConstraintWhenMigrating: true, - Logger: dbLogger, + PrepareStmt: cfg.Gorm.PrepareStmt, + Logger: dbLogger, }, ) - db.Exec("PRAGMA foreign_keys=ON") - // The pure Go SQLite library does not handle locking in - // the same way as the C based one and we cant use the gorm + // the same way as the C based one and we can't use the gorm // connection pool as of 2022/02/23. sqlDB, _ := db.DB() sqlDB.SetMaxIdleConns(1) @@ -359,24 +845,178 @@ func openDB(dbType, connectionAddr string, debug bool) (*gorm.DB, error) { return db, err - case Postgres: - return gorm.Open(postgres.Open(connectionAddr), &gorm.Config{ - DisableForeignKeyConstraintWhenMigrating: true, - Logger: dbLogger, + case types.DatabasePostgres: + dbString := fmt.Sprintf( + "host=%s dbname=%s user=%s", + cfg.Postgres.Host, + cfg.Postgres.Name, + cfg.Postgres.User, + ) + + log.Info(). + Str("database", types.DatabasePostgres). + Str("path", dbString). + Msg("Opening database") + + if sslEnabled, err := strconv.ParseBool(cfg.Postgres.Ssl); err == nil { + if !sslEnabled { + dbString += " sslmode=disable" + } + } else { + dbString += " sslmode=" + cfg.Postgres.Ssl + } + + if cfg.Postgres.Port != 0 { + dbString += fmt.Sprintf(" port=%d", cfg.Postgres.Port) + } + + if cfg.Postgres.Pass != "" { + dbString += " password=" + cfg.Postgres.Pass + } + + db, err := gorm.Open(postgres.Open(dbString), &gorm.Config{ + Logger: dbLogger, }) + if err != nil { + return nil, err + } + + sqlDB, _ := db.DB() + sqlDB.SetMaxIdleConns(cfg.Postgres.MaxIdleConnections) + sqlDB.SetMaxOpenConns(cfg.Postgres.MaxOpenConnections) + sqlDB.SetConnMaxIdleTime( + time.Duration(cfg.Postgres.ConnMaxIdleTimeSecs) * time.Second, + ) + + return db, nil } return nil, fmt.Errorf( "database of type %s is not supported: %w", - dbType, + cfg.Type, errDatabaseNotSupported, ) } +func runMigrations(cfg types.DatabaseConfig, dbConn *gorm.DB, migrations *gormigrate.Gormigrate) error { + if cfg.Type == types.DatabaseSqlite { + // SQLite: Run migrations step-by-step, only disabling foreign keys when necessary + + // List of migration IDs that require foreign keys to be disabled + // These are migrations that perform complex schema changes that GORM cannot handle safely with FK enabled + // NO NEW MIGRATIONS SHOULD BE ADDED HERE. ALL NEW MIGRATIONS MUST RUN WITH FOREIGN KEYS ENABLED. + migrationsRequiringFKDisabled := map[string]bool{ + "202501221827": true, // Route table automigration with FK constraint issues + "202501311657": true, // PreAuthKey table automigration with FK constraint issues + // Add other migration IDs here as they are identified to need FK disabled + } + + // Get the current foreign key status + var fkOriginallyEnabled int + if err := dbConn.Raw("PRAGMA foreign_keys").Scan(&fkOriginallyEnabled).Error; err != nil { + return fmt.Errorf("checking foreign key status: %w", err) + } + + // Get all migration IDs in order from the actual migration definitions + // Only IDs that are in the migrationsRequiringFKDisabled map will be processed with FK disabled + // any other new migrations are ran after. + migrationIDs := []string{ + // v0.25.0 + "202501221827", + "202501311657", + "202502070949", + + // v0.26.0 + "202502131714", + "202502171819", + "202505091439", + "202505141324", + + // As of 2025-07-02, no new IDs should be added here. + // They will be ran by the migrations.Migrate() call below. + } + + for _, migrationID := range migrationIDs { + log.Trace().Caller().Str("migration_id", migrationID).Msg("Running migration") + needsFKDisabled := migrationsRequiringFKDisabled[migrationID] + + if needsFKDisabled { + // Disable foreign keys for this migration + if err := dbConn.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { + return fmt.Errorf("disabling foreign keys for migration %s: %w", migrationID, err) + } + } else { + // Ensure foreign keys are enabled for this migration + if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + return fmt.Errorf("enabling foreign keys for migration %s: %w", migrationID, err) + } + } + + // Run up to this specific migration (will only run the next pending migration) + if err := migrations.MigrateTo(migrationID); err != nil { + return fmt.Errorf("running migration %s: %w", migrationID, err) + } + } + + if err := dbConn.Exec("PRAGMA foreign_keys = ON").Error; err != nil { + return fmt.Errorf("restoring foreign keys: %w", err) + } + + // Run the rest of the migrations + if err := migrations.Migrate(); err != nil { + return err + } + + // Check for constraint violations at the end + type constraintViolation struct { + Table string + RowID int + Parent string + ConstraintIndex int + } + + var violatedConstraints []constraintViolation + + rows, err := dbConn.Raw("PRAGMA foreign_key_check").Rows() + if err != nil { + return err + } + + for rows.Next() { + var violation constraintViolation + if err := rows.Scan(&violation.Table, &violation.RowID, &violation.Parent, &violation.ConstraintIndex); err != nil { + return err + } + + violatedConstraints = append(violatedConstraints, violation) + } + _ = rows.Close() + + if len(violatedConstraints) > 0 { + for _, violation := range violatedConstraints { + log.Error(). + Str("table", violation.Table). + Int("row_id", violation.RowID). + Str("parent", violation.Parent). + Msg("Foreign key constraint violated") + } + + return errForeignKeyConstraintsViolated + } + } else { + // PostgreSQL can run all migrations in one block - no foreign key issues + if err := migrations.Migrate(); err != nil { + return err + } + } + + return nil +} + func (hsdb *HSDatabase) PingDB(ctx context.Context) error { ctx, cancel := context.WithTimeout(ctx, time.Second) defer cancel() - sqlDB, err := hsdb.db.DB() + sqlDB, err := hsdb.DB.DB() if err != nil { return err } @@ -385,10 +1025,54 @@ func (hsdb *HSDatabase) PingDB(ctx context.Context) error { } func (hsdb *HSDatabase) Close() error { - db, err := hsdb.db.DB() + db, err := hsdb.DB.DB() if err != nil { return err } + if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { + db.Exec("VACUUM") + } + return db.Close() } + +func (hsdb *HSDatabase) Read(fn func(rx *gorm.DB) error) error { + rx := hsdb.DB.Begin() + defer rx.Rollback() + return fn(rx) +} + +func Read[T any](db *gorm.DB, fn func(rx *gorm.DB) (T, error)) (T, error) { + rx := db.Begin() + defer rx.Rollback() + ret, err := fn(rx) + if err != nil { + var no T + return no, err + } + + return ret, nil +} + +func (hsdb *HSDatabase) Write(fn func(tx *gorm.DB) error) error { + tx := hsdb.DB.Begin() + defer tx.Rollback() + if err := fn(tx); err != nil { + return err + } + + return tx.Commit().Error +} + +func Write[T any](db *gorm.DB, fn func(tx *gorm.DB) (T, error)) (T, error) { + tx := db.Begin() + defer tx.Rollback() + ret, err := fn(tx) + if err != nil { + var no T + return no, err + } + + return ret, tx.Commit().Error +} diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go new file mode 100644 index 00000000..3cd0d14e --- /dev/null +++ b/hscontrol/db/db_test.go @@ -0,0 +1,443 @@ +package db + +import ( + "database/sql" + "os" + "os/exec" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "zgo.at/zcache/v2" +) + +// TestSQLiteMigrationAndDataValidation tests specific SQLite migration scenarios +// and validates data integrity after migration. All migrations that require data validation +// should be added here. +func TestSQLiteMigrationAndDataValidation(t *testing.T) { + tests := []struct { + dbPath string + wantFunc func(*testing.T, *HSDatabase) + }{ + // at 14:15:06 ❯ go run ./cmd/headscale preauthkeys list + // ID | Key | Reusable | Ephemeral | Used | Expiration | Created | Tags + // 1 | 09b28f.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp + // 2 | 3112b9.. | false | false | false | 2024-09-27 | 2024-09-27 | tag:derp + { + dbPath: "testdata/sqlite/failing-node-preauth-constraint_dump.sql", + wantFunc: func(t *testing.T, hsdb *HSDatabase) { + t.Helper() + // Comprehensive data preservation validation for node-preauth constraint issue + // Expected data from dump: 1 user, 2 api_keys, 6 nodes + + // Verify users data preservation + users, err := Read(hsdb.DB, func(rx *gorm.DB) ([]types.User, error) { + return ListUsers(rx) + }) + require.NoError(t, err) + assert.Len(t, users, 1, "should preserve all 1 user from original schema") + + // Verify api_keys data preservation + var apiKeyCount int + err = hsdb.DB.Raw("SELECT COUNT(*) FROM api_keys").Scan(&apiKeyCount).Error + require.NoError(t, err) + assert.Equal(t, 2, apiKeyCount, "should preserve all 2 api_keys from original schema") + + // Verify nodes data preservation and field validation + nodes, err := Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) + require.NoError(t, err) + assert.Len(t, nodes, 6, "should preserve all 6 nodes from original schema") + + for _, node := range nodes { + assert.Falsef(t, node.MachineKey.IsZero(), "expected non zero machinekey") + assert.Contains(t, node.MachineKey.String(), "mkey:") + assert.Falsef(t, node.NodeKey.IsZero(), "expected non zero nodekey") + assert.Contains(t, node.NodeKey.String(), "nodekey:") + assert.Falsef(t, node.DiscoKey.IsZero(), "expected non zero discokey") + assert.Contains(t, node.DiscoKey.String(), "discokey:") + assert.Nil(t, node.AuthKey) + assert.Nil(t, node.AuthKeyID) + } + }, + }, + // Test for RequestTags migration (202601121700-migrate-hostinfo-request-tags) + // and forced_tags->tags rename migration (202511131445-node-forced-tags-to-tags) + // + // This test validates that: + // 1. The forced_tags column is renamed to tags + // 2. RequestTags from host_info are validated against policy tagOwners + // 3. Authorized tags are migrated to the tags column + // 4. Unauthorized tags are rejected + // 5. Existing tags are preserved + // 6. Group membership is evaluated for tag authorization + { + dbPath: "testdata/sqlite/request_tags_migration_test.sql", + wantFunc: func(t *testing.T, hsdb *HSDatabase) { + t.Helper() + + nodes, err := Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + return ListNodes(rx) + }) + require.NoError(t, err) + require.Len(t, nodes, 7, "should have all 7 nodes") + + // Helper to find node by hostname + findNode := func(hostname string) *types.Node { + for _, n := range nodes { + if n.Hostname == hostname { + return n + } + } + + return nil + } + + // Node 1: user1 has RequestTags for tag:server (authorized) + // Expected: tags = ["tag:server"] + node1 := findNode("node1") + require.NotNil(t, node1, "node1 should exist") + assert.Contains(t, node1.Tags, "tag:server", "node1 should have tag:server migrated from RequestTags") + + // Node 2: user1 has RequestTags for tag:unauthorized (NOT authorized) + // Expected: tags = [] (unchanged) + node2 := findNode("node2") + require.NotNil(t, node2, "node2 should exist") + assert.Empty(t, node2.Tags, "node2 should have empty tags (unauthorized tag rejected)") + + // Node 3: user2 has RequestTags for tag:client (authorized) + existing tag:existing + // Expected: tags = ["tag:client", "tag:existing"] + node3 := findNode("node3") + require.NotNil(t, node3, "node3 should exist") + assert.Contains(t, node3.Tags, "tag:client", "node3 should have tag:client migrated from RequestTags") + assert.Contains(t, node3.Tags, "tag:existing", "node3 should preserve existing tag") + + // Node 4: user1 has RequestTags for tag:server which already exists + // Expected: tags = ["tag:server"] (no duplicates) + node4 := findNode("node4") + require.NotNil(t, node4, "node4 should exist") + assert.Equal(t, []string{"tag:server"}, node4.Tags, "node4 should have tag:server without duplicates") + + // Node 5: user2 has no RequestTags + // Expected: tags = [] (unchanged) + node5 := findNode("node5") + require.NotNil(t, node5, "node5 should exist") + assert.Empty(t, node5.Tags, "node5 should have empty tags (no RequestTags)") + + // Node 6: admin1 has RequestTags for tag:admin (authorized via group:admins) + // Expected: tags = ["tag:admin"] + node6 := findNode("node6") + require.NotNil(t, node6, "node6 should exist") + assert.Contains(t, node6.Tags, "tag:admin", "node6 should have tag:admin migrated via group membership") + + // Node 7: user1 has RequestTags for tag:server (authorized) and tag:forbidden (unauthorized) + // Expected: tags = ["tag:server"] (only authorized tag) + node7 := findNode("node7") + require.NotNil(t, node7, "node7 should exist") + assert.Contains(t, node7.Tags, "tag:server", "node7 should have tag:server migrated") + assert.NotContains(t, node7.Tags, "tag:forbidden", "node7 should NOT have tag:forbidden (unauthorized)") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.dbPath, func(t *testing.T) { + if !strings.HasSuffix(tt.dbPath, ".sql") { + t.Fatalf("TestSQLiteMigrationAndDataValidation only supports .sql files, got: %s", tt.dbPath) + } + + hsdb := dbForTestWithPath(t, tt.dbPath) + if tt.wantFunc != nil { + tt.wantFunc(t, hsdb) + } + }) + } +} + +func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { + return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +} + +func createSQLiteFromSQLFile(sqlFilePath, dbPath string) error { + db, err := sql.Open("sqlite", dbPath) + if err != nil { + return err + } + defer db.Close() + + schemaContent, err := os.ReadFile(sqlFilePath) + if err != nil { + return err + } + + _, err = db.Exec(string(schemaContent)) + + return err +} + +// requireConstraintFailed checks if the error is a constraint failure with +// either SQLite and PostgreSQL error messages. +func requireConstraintFailed(t *testing.T, err error) { + t.Helper() + require.Error(t, err) + if !strings.Contains(err.Error(), "UNIQUE constraint failed:") && !strings.Contains(err.Error(), "violates unique constraint") { + require.Failf(t, "expected error to contain a constraint failure, got: %s", err.Error()) + } +} + +func TestConstraints(t *testing.T) { + tests := []struct { + name string + run func(*testing.T, *gorm.DB) + }{ + { + name: "no-duplicate-username-if-no-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, types.User{Name: "user1"}) + require.NoError(t, err) + _, err = CreateUser(db, types.User{Name: "user1"}) + requireConstraintFailed(t, err) + }, + }, + { + name: "no-oidc-duplicate-username-and-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + requireConstraintFailed(t, err) + }, + }, + { + name: "no-oidc-duplicate-id", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "user1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err := db.Save(&user).Error + require.NoError(t, err) + + user = types.User{ + Model: gorm.Model{ID: 2}, + Name: "user1.1", + } + user.ProviderIdentifier = sql.NullString{String: "http://test.com/user1", Valid: true} + + err = db.Save(&user).Error + requireConstraintFailed(t, err) + }, + }, + { + name: "allow-duplicate-username-cli-then-oidc", + run: func(t *testing.T, db *gorm.DB) { + _, err := CreateUser(db, types.User{Name: "user1"}) // Create CLI username + require.NoError(t, err) + + user := types.User{ + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, + } + + err = db.Save(&user).Error + require.NoError(t, err) + }, + }, + { + name: "allow-duplicate-username-oidc-then-cli", + run: func(t *testing.T, db *gorm.DB) { + user := types.User{ + Name: "user1", + ProviderIdentifier: sql.NullString{String: "http://test.com/user1", Valid: true}, + } + + err := db.Save(&user).Error + require.NoError(t, err) + + _, err = CreateUser(db, types.User{Name: "user1"}) // Create CLI username + require.NoError(t, err) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name+"-postgres", func(t *testing.T) { + db := newPostgresTestDB(t) + tt.run(t, db.DB.Debug()) + }) + t.Run(tt.name+"-sqlite", func(t *testing.T) { + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating database: %s", err) + } + + tt.run(t, db.DB.Debug()) + }) + } +} + +// TestPostgresMigrationAndDataValidation tests specific PostgreSQL migration scenarios +// and validates data integrity after migration. All migrations that require data validation +// should be added here. +// +// TODO(kradalby): Convert to use plain text SQL dumps instead of binary .pssql dumps for consistency +// with SQLite tests and easier version control. +func TestPostgresMigrationAndDataValidation(t *testing.T) { + tests := []struct { + name string + dbPath string + wantFunc func(*testing.T, *HSDatabase) + }{} + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + u := newPostgresDBForTest(t) + + pgRestorePath, err := exec.LookPath("pg_restore") + if err != nil { + t.Fatal("pg_restore not found in PATH. Please install it and ensure it is accessible.") + } + + // Construct the pg_restore command + cmd := exec.Command(pgRestorePath, "--verbose", "--if-exists", "--clean", "--no-owner", "--dbname", u.String(), tt.dbPath) + + // Set the output streams + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + + // Execute the command + err = cmd.Run() + if err != nil { + t.Fatalf("failed to restore postgres database: %s", err) + } + + db := newHeadscaleDBFromPostgresURL(t, u) + + if tt.wantFunc != nil { + tt.wantFunc(t, db) + } + }) + } +} + +func dbForTest(t *testing.T) *HSDatabase { + t.Helper() + return dbForTestWithPath(t, "") +} + +func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { + t.Helper() + + dbPath := t.TempDir() + "/headscale_test.db" + + // If SQL file path provided, validate and create database from it + if sqlFilePath != "" { + // Validate that the file is a SQL text file + if !strings.HasSuffix(sqlFilePath, ".sql") { + t.Fatalf("dbForTestWithPath only accepts .sql files, got: %s", sqlFilePath) + } + + err := createSQLiteFromSQLFile(sqlFilePath, dbPath) + if err != nil { + t.Fatalf("setting up database from SQL file %s: %s", sqlFilePath, err) + } + } + + db, err := NewHeadscaleDatabase( + &types.Config{ + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + }, + emptyCache(), + ) + if err != nil { + t.Fatalf("setting up database: %s", err) + } + + if sqlFilePath != "" { + t.Logf("database set up from %s at: %s", sqlFilePath, dbPath) + } else { + t.Logf("database set up at: %s", dbPath) + } + + return db +} + +// TestSQLiteAllTestdataMigrations tests migration compatibility across all SQLite schemas +// in the testdata directory. It verifies they can be successfully migrated to the current +// schema version. This test only validates migration success, not data integrity. +// +// All test database files are SQL dumps (created with `sqlite3 headscale.db .dump`) generated +// with old Headscale binaries on empty databases (no user/node data). These dumps include the +// migration history in the `migrations` table, which allows the migration system to correctly +// skip already-applied migrations and only run new ones. +func TestSQLiteAllTestdataMigrations(t *testing.T) { + t.Parallel() + schemas, err := os.ReadDir("testdata/sqlite") + require.NoError(t, err) + + t.Logf("loaded %d schemas", len(schemas)) + + for _, schema := range schemas { + if schema.IsDir() { + continue + } + + t.Logf("validating: %s", schema.Name()) + + t.Run(schema.Name(), func(t *testing.T) { + t.Parallel() + + dbPath := t.TempDir() + "/headscale_test.db" + + // Setup a database with the old schema + schemaPath := filepath.Join("testdata/sqlite", schema.Name()) + err := createSQLiteFromSQLFile(schemaPath, dbPath) + require.NoError(t, err) + + _, err = NewHeadscaleDatabase( + &types.Config{ + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + }, + emptyCache(), + ) + require.NoError(t, err) + }) + } +} diff --git a/hscontrol/db/ephemeral_garbage_collector_test.go b/hscontrol/db/ephemeral_garbage_collector_test.go new file mode 100644 index 00000000..d118b7fd --- /dev/null +++ b/hscontrol/db/ephemeral_garbage_collector_test.go @@ -0,0 +1,395 @@ +package db + +import ( + "runtime" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" +) + +const ( + fiveHundred = 500 * time.Millisecond + oneHundred = 100 * time.Millisecond + fifty = 50 * time.Millisecond +) + +// TestEphemeralGarbageCollectorGoRoutineLeak is a test for a goroutine leak in EphemeralGarbageCollector(). +// It creates a new EphemeralGarbageCollector, schedules several nodes for deletion with a short expiry, +// and verifies that the nodes are deleted when the expiry time passes, and then +// for any leaked goroutines after the garbage collector is closed. +func TestEphemeralGarbageCollectorGoRoutineLeak(t *testing.T) { + // Count goroutines at the start + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial number of goroutines: %d", initialGoroutines) + + // Basic deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + var deletionWg sync.WaitGroup + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionWg.Done() + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + + // Schedule several nodes for deletion with short expiry + const expiry = fifty + const numNodes = 100 + + // Set up wait group for expected deletions + deletionWg.Add(numNodes) + + for i := 1; i <= numNodes; i++ { + gc.Schedule(types.NodeID(i), expiry) + } + + // Wait for all scheduled deletions to complete + deletionWg.Wait() + + // Check nodes are deleted + deleteMutex.Lock() + assert.Len(t, deletedIDs, numNodes, "Not all nodes were deleted") + deleteMutex.Unlock() + + // Schedule and immediately cancel to test that part of the code + for i := numNodes + 1; i <= numNodes*2; i++ { + nodeID := types.NodeID(i) + gc.Schedule(nodeID, time.Hour) + gc.Cancel(nodeID) + } + + // Close GC + gc.Close() + + // Wait for goroutines to clean up and verify no leaks + assert.EventuallyWithT(t, func(c *assert.CollectT) { + finalGoroutines := runtime.NumGoroutine() + // NB: We have to allow for a small number of extra goroutines because of test itself + assert.LessOrEqual(c, finalGoroutines, initialGoroutines+5, + "There are significantly more goroutines after GC usage, which suggests a leak") + }, time.Second, 10*time.Millisecond, "goroutines should clean up after GC close") + + t.Logf("Final number of goroutines: %d", runtime.NumGoroutine()) +} + +// TestEphemeralGarbageCollectorReschedule is a test for the rescheduling of nodes in EphemeralGarbageCollector(). +// It creates a new EphemeralGarbageCollector, schedules a node for deletion with a longer expiry, +// and then reschedules it with a shorter expiry, and verifies that the node is deleted only once. +func TestEphemeralGarbageCollectorReschedule(t *testing.T) { + // Deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + + deletionNotifier := make(chan types.NodeID, 1) + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + + deletionNotifier <- nodeID + } + + // Start GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + defer gc.Close() + + const shortExpiry = fifty + const longExpiry = 1 * time.Hour + + nodeID := types.NodeID(1) + + // Schedule node for deletion with long expiry + gc.Schedule(nodeID, longExpiry) + + // Reschedule the same node with a shorter expiry + gc.Schedule(nodeID, shortExpiry) + + // Wait for deletion notification with timeout + select { + case deletedNodeID := <-deletionNotifier: + assert.Equal(t, nodeID, deletedNodeID, "The correct node should be deleted") + case <-time.After(time.Second): + t.Fatal("Timed out waiting for node deletion") + } + + // Verify that the node was deleted exactly once + deleteMutex.Lock() + assert.Len(t, deletedIDs, 1, "Node should be deleted exactly once") + assert.Equal(t, nodeID, deletedIDs[0], "The correct node should be deleted") + deleteMutex.Unlock() +} + +// TestEphemeralGarbageCollectorCancelAndReschedule is a test for the cancellation and rescheduling of nodes in EphemeralGarbageCollector(). +// It creates a new EphemeralGarbageCollector, schedules a node for deletion, cancels it, and then reschedules it, +// and verifies that the node is deleted only once. +func TestEphemeralGarbageCollectorCancelAndReschedule(t *testing.T) { + // Deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + deletionNotifier := make(chan types.NodeID, 1) + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + deletionNotifier <- nodeID + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + defer gc.Close() + + nodeID := types.NodeID(1) + const expiry = fifty + + // Schedule node for deletion + gc.Schedule(nodeID, expiry) + + // Cancel the scheduled deletion + gc.Cancel(nodeID) + + // Use a timeout to verify no deletion occurred + select { + case <-deletionNotifier: + t.Fatal("Node was deleted after cancellation") + case <-time.After(expiry * 2): // Still need a timeout for negative test + // This is expected - no deletion should occur + } + + deleteMutex.Lock() + assert.Empty(t, deletedIDs, "Node should not be deleted after cancellation") + deleteMutex.Unlock() + + // Reschedule the node + gc.Schedule(nodeID, expiry) + + // Wait for deletion with timeout + select { + case deletedNodeID := <-deletionNotifier: + // Verify the correct node was deleted + assert.Equal(t, nodeID, deletedNodeID, "The correct node should be deleted") + case <-time.After(time.Second): // Longer timeout as a safety net + t.Fatal("Timed out waiting for node deletion") + } + + // Verify final state + deleteMutex.Lock() + assert.Len(t, deletedIDs, 1, "Node should be deleted after rescheduling") + assert.Equal(t, nodeID, deletedIDs[0], "The correct node should be deleted") + deleteMutex.Unlock() +} + +// TestEphemeralGarbageCollectorCloseBeforeTimerFires is a test for the closing of the EphemeralGarbageCollector before the timer fires. +// It creates a new EphemeralGarbageCollector, schedules a node for deletion, closes the GC, and verifies that the node is not deleted. +func TestEphemeralGarbageCollectorCloseBeforeTimerFires(t *testing.T) { + // Deletion tracking + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + + deletionNotifier := make(chan types.NodeID, 1) + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + + deletionNotifier <- nodeID + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + + const ( + longExpiry = 1 * time.Hour + shortWait = fifty * 2 + ) + + // Schedule node deletion with a long expiry + gc.Schedule(types.NodeID(1), longExpiry) + + // Close the GC before the timer + gc.Close() + + // Verify that no deletion occurred within a reasonable time + select { + case <-deletionNotifier: + t.Fatal("Node was deleted after GC was closed, which should not happen") + case <-time.After(shortWait): + // Expected: no deletion should occur + } + + // Verify that no deletion occurred + deleteMutex.Lock() + assert.Empty(t, deletedIDs, "No node should be deleted when GC is closed before timer fires") + deleteMutex.Unlock() +} + +// TestEphemeralGarbageCollectorScheduleAfterClose verifies that calling Schedule after Close +// is a no-op and doesn't cause any panics, goroutine leaks, or other issues. +func TestEphemeralGarbageCollectorScheduleAfterClose(t *testing.T) { + // Count initial goroutines to check for leaks + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial number of goroutines: %d", initialGoroutines) + + // Deletion tracking + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + nodeDeleted := make(chan struct{}) + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + close(nodeDeleted) // Signal that deletion happened + } + + // Start new GC + gc := NewEphemeralGarbageCollector(deleteFunc) + + // Use a WaitGroup to ensure the GC has started + var startWg sync.WaitGroup + startWg.Add(1) + go func() { + startWg.Done() // Signal that the goroutine has started + gc.Start() + }() + startWg.Wait() // Wait for the GC to start + + // Close GC right away + gc.Close() + + // Now try to schedule node for deletion with a very short expiry + // If the Schedule operation incorrectly creates a timer, it would fire quickly + nodeID := types.NodeID(1) + gc.Schedule(nodeID, 1*time.Millisecond) + + // Check if any node was deleted (which shouldn't happen) + // Use timeout to wait for potential deletion + select { + case <-nodeDeleted: + t.Fatal("Node was deleted after GC was closed, which should not happen") + case <-time.After(fiveHundred): + // This is the expected path - no deletion should occur + } + + // Check no node was deleted + deleteMutex.Lock() + nodesDeleted := len(deletedIDs) + deleteMutex.Unlock() + assert.Equal(t, 0, nodesDeleted, "No nodes should be deleted when Schedule is called after Close") + + // Check for goroutine leaks after GC is fully closed + assert.EventuallyWithT(t, func(c *assert.CollectT) { + finalGoroutines := runtime.NumGoroutine() + // Allow for small fluctuations in goroutine count for testing routines etc + assert.LessOrEqual(c, finalGoroutines, initialGoroutines+2, + "There should be no significant goroutine leaks when Schedule is called after Close") + }, time.Second, 10*time.Millisecond, "goroutines should clean up after GC close") + + t.Logf("Final number of goroutines: %d", runtime.NumGoroutine()) +} + +// TestEphemeralGarbageCollectorConcurrentScheduleAndClose tests the behavior of the garbage collector +// when Schedule and Close are called concurrently from multiple goroutines. +func TestEphemeralGarbageCollectorConcurrentScheduleAndClose(t *testing.T) { + // Count initial goroutines + initialGoroutines := runtime.NumGoroutine() + t.Logf("Initial number of goroutines: %d", initialGoroutines) + + // Deletion tracking mechanism + var deletedIDs []types.NodeID + var deleteMutex sync.Mutex + + deleteFunc := func(nodeID types.NodeID) { + deleteMutex.Lock() + deletedIDs = append(deletedIDs, nodeID) + deleteMutex.Unlock() + } + + // Start the GC + gc := NewEphemeralGarbageCollector(deleteFunc) + go gc.Start() + + // Number of concurrent scheduling goroutines + const numSchedulers = 10 + const nodesPerScheduler = 50 + + const closeAfterNodes = 25 // Close GC after this many nodes per scheduler + + // Use WaitGroup to wait for all scheduling goroutines to finish + var wg sync.WaitGroup + wg.Add(numSchedulers + 1) // +1 for the closer goroutine + + // Create a stopper channel to signal scheduling goroutines to stop + stopScheduling := make(chan struct{}) + + // Track how many nodes have been scheduled + var scheduledCount int64 + + // Launch goroutines that continuously schedule nodes + for schedulerIndex := range numSchedulers { + go func(schedulerID int) { + defer wg.Done() + + baseNodeID := schedulerID * nodesPerScheduler + + // Keep scheduling nodes until signaled to stop + for j := range nodesPerScheduler { + select { + case <-stopScheduling: + return + default: + nodeID := types.NodeID(baseNodeID + j + 1) + gc.Schedule(nodeID, 1*time.Hour) // Long expiry to ensure it doesn't trigger during test + atomic.AddInt64(&scheduledCount, 1) + + // Yield to other goroutines to introduce variability + runtime.Gosched() + } + } + }(schedulerIndex) + } + + // Close the garbage collector after some nodes have been scheduled + go func() { + defer wg.Done() + + // Wait until enough nodes have been scheduled + for atomic.LoadInt64(&scheduledCount) < int64(numSchedulers*closeAfterNodes) { + runtime.Gosched() + } + + // Close GC + gc.Close() + + // Signal schedulers to stop + close(stopScheduling) + }() + + // Wait for all goroutines to complete + wg.Wait() + + // Check for leaks using EventuallyWithT + assert.EventuallyWithT(t, func(c *assert.CollectT) { + finalGoroutines := runtime.NumGoroutine() + // Allow for a reasonable small variable routine count due to testing + assert.LessOrEqual(c, finalGoroutines, initialGoroutines+5, + "There should be no significant goroutine leaks during concurrent Schedule and Close operations") + }, time.Second, 10*time.Millisecond, "goroutines should clean up") + + t.Logf("Final number of goroutines: %d", runtime.NumGoroutine()) +} diff --git a/hscontrol/db/ip.go b/hscontrol/db/ip.go new file mode 100644 index 00000000..972d8e72 --- /dev/null +++ b/hscontrol/db/ip.go @@ -0,0 +1,352 @@ +package db + +import ( + "crypto/rand" + "database/sql" + "errors" + "fmt" + "math/big" + "net/netip" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "go4.org/netipx" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" +) + +var errGeneratedIPBytesInvalid = errors.New("generated ip bytes are invalid ip") + +// IPAllocator is a singleton responsible for allocating +// IP addresses for nodes and making sure the same +// address is not handed out twice. There can only be one +// and it needs to be created before any other database +// writes occur. +type IPAllocator struct { + mu sync.Mutex + + prefix4 *netip.Prefix + prefix6 *netip.Prefix + + // Previous IPs handed out + prev4 netip.Addr + prev6 netip.Addr + + // strategy used for handing out IP addresses. + strategy types.IPAllocationStrategy + + // Set of all IPs handed out. + // This might not be in sync with the database, + // but it is more conservative. If saves to the + // database fails, the IP will be allocated here + // until the next restart of Headscale. + usedIPs netipx.IPSetBuilder +} + +// NewIPAllocator returns a new IPAllocator singleton which +// can be used to hand out unique IP addresses within the +// provided IPv4 and IPv6 prefix. It needs to be created +// when headscale starts and needs to finish its read +// transaction before any writes to the database occur. +func NewIPAllocator( + db *HSDatabase, + prefix4, prefix6 *netip.Prefix, + strategy types.IPAllocationStrategy, +) (*IPAllocator, error) { + ret := IPAllocator{ + prefix4: prefix4, + prefix6: prefix6, + + strategy: strategy, + } + + var v4s []sql.NullString + var v6s []sql.NullString + + if db != nil { + err := db.Read(func(rx *gorm.DB) error { + return rx.Model(&types.Node{}).Pluck("ipv4", &v4s).Error + }) + if err != nil { + return nil, fmt.Errorf("reading IPv4 addresses from database: %w", err) + } + + err = db.Read(func(rx *gorm.DB) error { + return rx.Model(&types.Node{}).Pluck("ipv6", &v6s).Error + }) + if err != nil { + return nil, fmt.Errorf("reading IPv6 addresses from database: %w", err) + } + } + + var ips netipx.IPSetBuilder + + // Add network and broadcast addrs to used pool so they + // are not handed out to nodes. + if prefix4 != nil { + network4, broadcast4 := util.GetIPPrefixEndpoints(*prefix4) + ips.Add(network4) + ips.Add(broadcast4) + + // Use network as starting point, it will be used to call .Next() + // TODO(kradalby): Could potentially take all the IPs loaded from + // the database into account to start at a more "educated" location. + ret.prev4 = network4 + } + + if prefix6 != nil { + network6, broadcast6 := util.GetIPPrefixEndpoints(*prefix6) + ips.Add(network6) + ips.Add(broadcast6) + + ret.prev6 = network6 + } + + // Fetch all the IP Addresses currently handed out from the Database + // and add them to the used IP set. + for _, addrStr := range append(v4s, v6s...) { + if addrStr.Valid { + addr, err := netip.ParseAddr(addrStr.String) + if err != nil { + return nil, fmt.Errorf("parsing IP address from database: %w", err) + } + + ips.Add(addr) + } + } + + // Build the initial IPSet to validate that we can use it. + _, err := ips.IPSet() + if err != nil { + return nil, fmt.Errorf( + "building initial IP Set: %w", + err, + ) + } + + ret.usedIPs = ips + + return &ret, nil +} + +func (i *IPAllocator) Next() (*netip.Addr, *netip.Addr, error) { + i.mu.Lock() + defer i.mu.Unlock() + + var err error + var ret4 *netip.Addr + var ret6 *netip.Addr + + if i.prefix4 != nil { + ret4, err = i.next(i.prev4, i.prefix4) + if err != nil { + return nil, nil, fmt.Errorf("allocating IPv4 address: %w", err) + } + i.prev4 = *ret4 + } + + if i.prefix6 != nil { + ret6, err = i.next(i.prev6, i.prefix6) + if err != nil { + return nil, nil, fmt.Errorf("allocating IPv6 address: %w", err) + } + i.prev6 = *ret6 + } + + return ret4, ret6, nil +} + +var ErrCouldNotAllocateIP = errors.New("failed to allocate IP") + +func (i *IPAllocator) nextLocked(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) { + i.mu.Lock() + defer i.mu.Unlock() + + return i.next(prev, prefix) +} + +func (i *IPAllocator) next(prev netip.Addr, prefix *netip.Prefix) (*netip.Addr, error) { + var err error + var ip netip.Addr + + switch i.strategy { + case types.IPAllocationStrategySequential: + // Get the first IP in our prefix + ip = prev.Next() + case types.IPAllocationStrategyRandom: + ip, err = randomNext(*prefix) + if err != nil { + return nil, fmt.Errorf("getting random IP: %w", err) + } + } + + // TODO(kradalby): maybe this can be done less often. + set, err := i.usedIPs.IPSet() + if err != nil { + return nil, err + } + + for { + if !prefix.Contains(ip) { + return nil, ErrCouldNotAllocateIP + } + + // Check if the IP has already been allocated + // or if it is a IP reserved by Tailscale. + if set.Contains(ip) || isTailscaleReservedIP(ip) { + switch i.strategy { + case types.IPAllocationStrategySequential: + ip = ip.Next() + case types.IPAllocationStrategyRandom: + ip, err = randomNext(*prefix) + if err != nil { + return nil, fmt.Errorf("getting random IP: %w", err) + } + } + + continue + } + + i.usedIPs.Add(ip) + + return &ip, nil + } +} + +func randomNext(pfx netip.Prefix) (netip.Addr, error) { + rang := netipx.RangeOfPrefix(pfx) + fromIP, toIP := rang.From(), rang.To() + + var from, to big.Int + + from.SetBytes(fromIP.AsSlice()) + to.SetBytes(toIP.AsSlice()) + + // Find the max, this is how we can do "random range", + // get the "max" as 0 -> to - from and then add back from + // after. + tempMax := big.NewInt(0).Sub(&to, &from) + + out, err := rand.Int(rand.Reader, tempMax) + if err != nil { + return netip.Addr{}, fmt.Errorf("generating random IP: %w", err) + } + + valInRange := big.NewInt(0).Add(&from, out) + + ip, ok := netip.AddrFromSlice(valInRange.Bytes()) + if !ok { + return netip.Addr{}, errGeneratedIPBytesInvalid + } + + if !pfx.Contains(ip) { + return netip.Addr{}, fmt.Errorf( + "generated ip(%s) not in prefix(%s)", + ip.String(), + pfx.String(), + ) + } + + return ip, nil +} + +func isTailscaleReservedIP(ip netip.Addr) bool { + return tsaddr.ChromeOSVMRange().Contains(ip) || + tsaddr.TailscaleServiceIP() == ip || + tsaddr.TailscaleServiceIPv6() == ip +} + +// BackfillNodeIPs will take a database transaction, and +// iterate through all of the current nodes in headscale +// and ensure it has IP addresses according to the current +// configuration. +// This means that if both IPv4 and IPv6 is set in the +// config, and some nodes are missing that type of IP, +// it will be added. +// If a prefix type has been removed (IPv4 or IPv6), it +// will remove the IPs in that family from the node. +func (db *HSDatabase) BackfillNodeIPs(i *IPAllocator) ([]string, error) { + var err error + var ret []string + err = db.Write(func(tx *gorm.DB) error { + if i == nil { + return errors.New("backfilling IPs: ip allocator was nil") + } + + log.Trace().Caller().Msgf("starting to backfill IPs") + + nodes, err := ListNodes(tx) + if err != nil { + return fmt.Errorf("listing nodes to backfill IPs: %w", err) + } + + for _, node := range nodes { + log.Trace().Caller().Uint64("node.id", node.ID.Uint64()).Str("node.name", node.Hostname).Msg("IP backfill check started because node found in database") + + changed := false + // IPv4 prefix is set, but node ip is missing, alloc + if i.prefix4 != nil && node.IPv4 == nil { + ret4, err := i.nextLocked(i.prev4, i.prefix4) + if err != nil { + return fmt.Errorf("failed to allocate ipv4 for node(%d): %w", node.ID, err) + } + + node.IPv4 = ret4 + changed = true + ret = append(ret, fmt.Sprintf("assigned IPv4 %q to Node(%d) %q", ret4.String(), node.ID, node.Hostname)) + } + + // IPv6 prefix is set, but node ip is missing, alloc + if i.prefix6 != nil && node.IPv6 == nil { + ret6, err := i.nextLocked(i.prev6, i.prefix6) + if err != nil { + return fmt.Errorf("failed to allocate ipv6 for node(%d): %w", node.ID, err) + } + + node.IPv6 = ret6 + changed = true + ret = append(ret, fmt.Sprintf("assigned IPv6 %q to Node(%d) %q", ret6.String(), node.ID, node.Hostname)) + } + + // IPv4 prefix is not set, but node has IP, remove + if i.prefix4 == nil && node.IPv4 != nil { + ret = append(ret, fmt.Sprintf("removing IPv4 %q from Node(%d) %q", node.IPv4.String(), node.ID, node.Hostname)) + node.IPv4 = nil + changed = true + } + + // IPv6 prefix is not set, but node has IP, remove + if i.prefix6 == nil && node.IPv6 != nil { + ret = append(ret, fmt.Sprintf("removing IPv6 %q from Node(%d) %q", node.IPv6.String(), node.ID, node.Hostname)) + node.IPv6 = nil + changed = true + } + + if changed { + // Use Updates() with Select() to only update IP fields, avoiding overwriting + // other fields like Expiry. We need Select() because Updates() alone skips + // zero values, but we DO want to update IPv4/IPv6 to nil when removing them. + // See issue #2862. + err := tx.Model(node).Select("ipv4", "ipv6").Updates(node).Error + if err != nil { + return fmt.Errorf("saving node(%d) after adding IPs: %w", node.ID, err) + } + } + } + + return nil + }) + + return ret, err +} + +func (i *IPAllocator) FreeIPs(ips []netip.Addr) { + i.mu.Lock() + defer i.mu.Unlock() + + for _, ip := range ips { + i.usedIPs.Remove(ip) + } +} diff --git a/hscontrol/db/ip_test.go b/hscontrol/db/ip_test.go new file mode 100644 index 00000000..7ba335e8 --- /dev/null +++ b/hscontrol/db/ip_test.go @@ -0,0 +1,514 @@ +package db + +import ( + "fmt" + "net/netip" + "strings" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/net/tsaddr" + "tailscale.com/types/ptr" +) + +var mpp = func(pref string) *netip.Prefix { + p := netip.MustParsePrefix(pref) + return &p +} + +var na = func(pref string) netip.Addr { + return netip.MustParseAddr(pref) +} + +var nap = func(pref string) *netip.Addr { + n := na(pref) + return &n +} + +func TestIPAllocatorSequential(t *testing.T) { + tests := []struct { + name string + dbFunc func() *HSDatabase + + prefix4 *netip.Prefix + prefix6 *netip.Prefix + getCount int + want4 []netip.Addr + want6 []netip.Addr + }{ + { + name: "simple", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 1, + + want4: []netip.Addr{ + na("100.64.0.1"), + }, + want6: []netip.Addr{ + na("fd7a:115c:a1e0::1"), + }, + }, + { + name: "simple-v4", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix4: mpp("100.64.0.0/10"), + + getCount: 1, + + want4: []netip.Addr{ + na("100.64.0.1"), + }, + }, + { + name: "simple-v6", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 1, + + want6: []netip.Addr{ + na("fd7a:115c:a1e0::1"), + }, + }, + { + name: "simple-with-db", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.1"), + IPv6: nap("fd7a:115c:a1e0::1"), + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 1, + + want4: []netip.Addr{ + na("100.64.0.2"), + }, + want6: []netip.Addr{ + na("fd7a:115c:a1e0::2"), + }, + }, + { + name: "before-after-free-middle-in-db", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.2"), + IPv6: nap("fd7a:115c:a1e0::2"), + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 2, + + want4: []netip.Addr{ + na("100.64.0.1"), + na("100.64.0.3"), + }, + want6: []netip.Addr{ + na("fd7a:115c:a1e0::1"), + na("fd7a:115c:a1e0::3"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := tt.dbFunc() + + alloc, _ := NewIPAllocator( + db, + tt.prefix4, + tt.prefix6, + types.IPAllocationStrategySequential, + ) + + var got4s []netip.Addr + var got6s []netip.Addr + + for range tt.getCount { + got4, got6, err := alloc.Next() + if err != nil { + t.Fatalf("allocating next IP: %s", err) + } + + if got4 != nil { + got4s = append(got4s, *got4) + } + + if got6 != nil { + got6s = append(got6s, *got6) + } + } + if diff := cmp.Diff(tt.want4, got4s, util.Comparers...); diff != "" { + t.Errorf("IPAllocator 4s unexpected result (-want +got):\n%s", diff) + } + + if diff := cmp.Diff(tt.want6, got6s, util.Comparers...); diff != "" { + t.Errorf("IPAllocator 6s unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestIPAllocatorRandom(t *testing.T) { + tests := []struct { + name string + dbFunc func() *HSDatabase + + getCount int + + prefix4 *netip.Prefix + prefix6 *netip.Prefix + want4 bool + want6 bool + }{ + { + name: "simple", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 1, + + want4: true, + want6: true, + }, + { + name: "simple-v4", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix4: mpp("100.64.0.0/10"), + + getCount: 1, + + want4: true, + want6: false, + }, + { + name: "simple-v6", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 1, + + want4: false, + want6: true, + }, + { + name: "generate-lots-of-random", + dbFunc: func() *HSDatabase { + return nil + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + getCount: 1000, + + want4: true, + want6: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := tt.dbFunc() + + alloc, _ := NewIPAllocator(db, tt.prefix4, tt.prefix6, types.IPAllocationStrategyRandom) + + for range tt.getCount { + got4, got6, err := alloc.Next() + if err != nil { + t.Fatalf("allocating next IP: %s", err) + } + + t.Logf("addrs ipv4: %v, ipv6: %v", got4, got6) + + if tt.want4 { + if got4 == nil { + t.Fatalf("expected ipv4 addr, got nil") + } + } + + if tt.want6 { + if got6 == nil { + t.Fatalf("expected ipv4 addr, got nil") + } + } + } + }) + } +} + +func TestBackfillIPAddresses(t *testing.T) { + fullNodeP := func(i int) *types.Node { + v4 := fmt.Sprintf("100.64.0.%d", i) + v6 := fmt.Sprintf("fd7a:115c:a1e0::%d", i) + return &types.Node{ + IPv4: nap(v4), + IPv6: nap(v6), + } + } + tests := []struct { + name string + dbFunc func() *HSDatabase + + prefix4 *netip.Prefix + prefix6 *netip.Prefix + want types.Nodes + }{ + { + name: "simple-backfill-ipv6", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.1"), + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + &types.Node{ + IPv4: nap("100.64.0.1"), + IPv6: nap("fd7a:115c:a1e0::1"), + }, + }, + }, + { + name: "simple-backfill-ipv4", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv6: nap("fd7a:115c:a1e0::1"), + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + &types.Node{ + IPv4: nap("100.64.0.1"), + IPv6: nap("fd7a:115c:a1e0::1"), + }, + }, + }, + { + name: "simple-backfill-remove-ipv6", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.1"), + IPv6: nap("fd7a:115c:a1e0::1"), + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + + want: types.Nodes{ + &types.Node{ + IPv4: nap("100.64.0.1"), + }, + }, + }, + { + name: "simple-backfill-remove-ipv4", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.1"), + IPv6: nap("fd7a:115c:a1e0::1"), + }) + + return db + }, + + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + &types.Node{ + IPv6: nap("fd7a:115c:a1e0::1"), + }, + }, + }, + { + name: "multi-backfill-ipv6", + dbFunc: func() *HSDatabase { + db := dbForTest(t) + user := types.User{Name: ""} + db.DB.Save(&user) + + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.1"), + }) + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.2"), + }) + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.3"), + }) + db.DB.Save(&types.Node{ + User: &user, + IPv4: nap("100.64.0.4"), + }) + + return db + }, + + prefix4: mpp("100.64.0.0/10"), + prefix6: mpp("fd7a:115c:a1e0::/48"), + + want: types.Nodes{ + fullNodeP(1), + fullNodeP(2), + fullNodeP(3), + fullNodeP(4), + }, + }, + } + + comps := append(util.Comparers, cmpopts.IgnoreFields(types.Node{}, + "ID", + "User", + "UserID", + "Endpoints", + "Hostinfo", + "CreatedAt", + "UpdatedAt", + )) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db := tt.dbFunc() + + alloc, err := NewIPAllocator( + db, + tt.prefix4, + tt.prefix6, + types.IPAllocationStrategySequential, + ) + if err != nil { + t.Fatalf("failed to set up ip alloc: %s", err) + } + + logs, err := db.BackfillNodeIPs(alloc) + if err != nil { + t.Fatalf("failed to backfill: %s", err) + } + + t.Logf("backfill log: \n%s", strings.Join(logs, "\n")) + + got, err := db.ListNodes() + if err != nil { + t.Fatalf("failed to get nodes: %s", err) + } + + if diff := cmp.Diff(tt.want, got, comps...); diff != "" { + t.Errorf("Backfill unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestIPAllocatorNextNoReservedIPs(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + defer db.Close() + + alloc, err := NewIPAllocator( + db, + ptr.To(tsaddr.CGNATRange()), + ptr.To(tsaddr.TailscaleULARange()), + types.IPAllocationStrategySequential, + ) + if err != nil { + t.Fatalf("failed to set up ip alloc: %s", err) + } + + // Validate that we do not give out 100.100.100.100 + nextQuad100, err := alloc.next(na("100.100.100.99"), ptr.To(tsaddr.CGNATRange())) + require.NoError(t, err) + assert.Equal(t, na("100.100.100.101"), *nextQuad100) + + // Validate that we do not give out fd7a:115c:a1e0::53 + nextQuad100v6, err := alloc.next(na("fd7a:115c:a1e0::52"), ptr.To(tsaddr.TailscaleULARange())) + require.NoError(t, err) + assert.Equal(t, na("fd7a:115c:a1e0::54"), *nextQuad100v6) + + // Validate that we do not give out fd7a:115c:a1e0::53 + nextChrome, err := alloc.next(na("100.115.91.255"), ptr.To(tsaddr.CGNATRange())) + t.Logf("chrome: %s", nextChrome.String()) + require.NoError(t, err) + assert.Equal(t, na("100.115.94.0"), *nextChrome) +} diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index ac0e0b38..bf407bb4 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -1,20 +1,26 @@ package db import ( + "encoding/json" "errors" "fmt" "net/netip" + "regexp" + "slices" "sort" + "strconv" "strings" + "sync" + "testing" "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/patrickmn/go-cache" "github.com/rs/zerolog/log" "gorm.io/gorm" - "tailscale.com/tailcfg" + "tailscale.com/net/tsaddr" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) const ( @@ -22,6 +28,8 @@ const ( NodeGivenNameTrimSize = 2 ) +var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") + var ( ErrNodeNotFound = errors.New("node not found") ErrNodeRouteIsNotAvailable = errors.New("route is not available on node") @@ -29,33 +37,26 @@ var ( "node not found in registration cache", ) ErrCouldNotConvertNodeInterface = errors.New("failed to convert node interface") - ErrDifferentRegisteredUser = errors.New( - "node was previously registered with a different user", - ) ) -// ListPeers returns all peers of node, regardless of any Policy or if the node is expired. -func (hsdb *HSDatabase) ListPeers(node *types.Node) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listPeers(node) +// ListPeers returns peers of node, regardless of any Policy or if the node is expired. +// If no peer IDs are given, all peers are returned. +// If at least one peer ID is given, only these peer nodes will be returned. +func (hsdb *HSDatabase) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { + return ListPeers(hsdb.DB, nodeID, peerIDs...) } -func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msg("Finding direct peers") - +// ListPeers returns peers of node, regardless of any Policy or if the node is expired. +// If no peer IDs are given, all peers are returned. +// If at least one peer ID is given, only these peer nodes will be returned. +func ListPeers(tx *gorm.DB, nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Preload("Routes"). - Where("node_key <> ?", - node.NodeKey.String()).Find(&nodes).Error; err != nil { + Where("id <> ?", nodeID). + Where(peerIDs).Find(&nodes).Error; err != nil { return types.Nodes{}, err } @@ -64,54 +65,47 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { return nodes, nil } -func (hsdb *HSDatabase) ListNodes() ([]types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodes() +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter. +func (hsdb *HSDatabase) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { + return ListNodes(hsdb.DB, nodeIDs...) } -func (hsdb *HSDatabase) listNodes() ([]types.Node, error) { - nodes := []types.Node{} - if err := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Find(&nodes).Error; err != nil { - return nil, err - } - - return nodes, nil -} - -func (hsdb *HSDatabase) ListNodesByGivenName(givenName string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listNodesByGivenName(givenName) -} - -func (hsdb *HSDatabase) listNodesByGivenName(givenName string) (types.Nodes, error) { +// ListNodes queries the database for either all nodes if no parameters are given +// or for the given nodes if at least one node ID is given as parameter. +func ListNodes(tx *gorm.DB, nodeIDs ...types.NodeID) (types.Nodes, error) { nodes := types.Nodes{} - if err := hsdb.db. + if err := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).Find(&nodes).Error; err != nil { + Where(nodeIDs).Find(&nodes).Error; err != nil { return nil, err } return nodes, nil } -// GetNode finds a Node by name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +func (hsdb *HSDatabase) ListEphemeralNodes() (types.Nodes, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (types.Nodes, error) { + nodes := types.Nodes{} + if err := rx.Joins("AuthKey").Where(`"AuthKey"."ephemeral" = true`).Find(&nodes).Error; err != nil { + return nil, err + } - nodes, err := hsdb.ListNodesByUser(user) + return nodes, nil + }) +} + +func (hsdb *HSDatabase) getNode(uid types.UserID, name string) (*types.Node, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.Node, error) { + return getNode(rx, uid, name) + }) +} + +// getNode finds a Node by name and user and returns the Node struct. +func getNode(tx *gorm.DB, uid types.UserID, name string) (*types.Node, error) { + nodes, err := ListNodesByUser(tx, uid) if err != nil { return nil, err } @@ -125,38 +119,17 @@ func (hsdb *HSDatabase) GetNode(user string, name string) (*types.Node, error) { return nil, ErrNodeNotFound } -// GetNodeByGivenName finds a Node by given name and user and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByGivenName( - user string, - givenName string, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if err := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - Where("given_name = ?", givenName).First(&node).Error; err != nil { - return nil, err - } - - return nil, ErrNodeNotFound +func (hsdb *HSDatabase) GetNodeByID(id types.NodeID) (*types.Node, error) { + return GetNodeByID(hsdb.DB, id) } // GetNodeByID finds a Node by ID and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - +func GetNodeByID(tx *gorm.DB, id types.NodeID) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Preload("Routes"). Find(&types.Node{ID: id}).First(&mach); result.Error != nil { return nil, result.Error } @@ -164,25 +137,20 @@ func (hsdb *HSDatabase) GetNodeByID(id uint64) (*types.Node, error) { return &mach, nil } -// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByMachineKey( - machineKey key.MachinePublic, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeByMachineKey(machineKey) +func (hsdb *HSDatabase) GetNodeByMachineKey(machineKey key.MachinePublic) (*types.Node, error) { + return GetNodeByMachineKey(hsdb.DB, machineKey) } -func (hsdb *HSDatabase) getNodeByMachineKey( +// GetNodeByMachineKey finds a Node by its MachineKey and returns the Node struct. +func GetNodeByMachineKey( + tx *gorm.DB, machineKey key.MachinePublic, ) (*types.Node, error) { mach := types.Node{} - if result := hsdb.db. + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Preload("Routes"). First(&mach, "machine_key = ?", machineKey.String()); result.Error != nil { return nil, result.Error } @@ -190,292 +158,237 @@ func (hsdb *HSDatabase) getNodeByMachineKey( return &mach, nil } -// GetNodeByNodeKey finds a Node by its current NodeKey. -func (hsdb *HSDatabase) GetNodeByNodeKey( +func (hsdb *HSDatabase) GetNodeByNodeKey(nodeKey key.NodePublic) (*types.Node, error) { + return GetNodeByNodeKey(hsdb.DB, nodeKey) +} + +// GetNodeByNodeKey finds a Node by its NodeKey and returns the Node struct. +func GetNodeByNodeKey( + tx *gorm.DB, nodeKey key.NodePublic, ) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - node := types.Node{} - if result := hsdb.db. + mach := types.Node{} + if result := tx. Preload("AuthKey"). Preload("AuthKey.User"). Preload("User"). - Preload("Routes"). - First(&node, "node_key = ?", - nodeKey.String()); result.Error != nil { + First(&mach, "node_key = ?", nodeKey.String()); result.Error != nil { return nil, result.Error } - return &node, nil + return &mach, nil } -// GetNodeByAnyKey finds a Node by its MachineKey, its current NodeKey or the old one, and returns the Node struct. -func (hsdb *HSDatabase) GetNodeByAnyKey( - machineKey key.MachinePublic, nodeKey key.NodePublic, oldNodeKey key.NodePublic, -) (*types.Node, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +func (hsdb *HSDatabase) SetTags( + nodeID types.NodeID, + tags []string, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + return SetTags(tx, nodeID, tags) + }) +} - node := types.Node{} - if result := hsdb.db. - Preload("AuthKey"). - Preload("AuthKey.User"). - Preload("User"). - Preload("Routes"). - First(&node, "machine_key = ? OR node_key = ? OR node_key = ?", - machineKey.String(), - nodeKey.String(), - oldNodeKey.String()); result.Error != nil { - return nil, result.Error +// SetTags takes a NodeID and update the forced tags. +// It will overwrite any tags with the new list. +func SetTags( + tx *gorm.DB, + nodeID types.NodeID, + tags []string, +) error { + if len(tags) == 0 { + // if no tags are provided, we remove all tags + err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error + if err != nil { + return fmt.Errorf("removing tags: %w", err) + } + + return nil } - return &node, nil -} + slices.Sort(tags) + tags = slices.Compact(tags) + b, err := json.Marshal(tags) + if err != nil { + return err + } -func (hsdb *HSDatabase) NodeReloadFromDatabase(node *types.Node) error { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - if result := hsdb.db.Find(node).First(&node); result.Error != nil { - return result.Error + err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", string(b)).Error + if err != nil { + return fmt.Errorf("updating tags: %w", err) } return nil } // SetTags takes a Node struct pointer and update the forced tags. -func (hsdb *HSDatabase) SetTags( - node *types.Node, - tags []string, +func SetApprovedRoutes( + tx *gorm.DB, + nodeID types.NodeID, + routes []netip.Prefix, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + if len(routes) == 0 { + // if no routes are provided, we remove all + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", "[]").Error; err != nil { + return fmt.Errorf("removing approved routes: %w", err) + } - if len(tags) == 0 { return nil } - newTags := []string{} - for _, tag := range tags { - if !util.StringOrPrefixListContains(newTags, tag) { - newTags = append(newTags, tag) - } + // When approving exit routes, ensure both IPv4 and IPv6 are included + // If either 0.0.0.0/0 or ::/0 is being approved, both should be approved + hasIPv4Exit := slices.Contains(routes, tsaddr.AllIPv4()) + hasIPv6Exit := slices.Contains(routes, tsaddr.AllIPv6()) + + if hasIPv4Exit && !hasIPv6Exit { + routes = append(routes, tsaddr.AllIPv6()) + } else if hasIPv6Exit && !hasIPv4Exit { + routes = append(routes, tsaddr.AllIPv4()) } - if err := hsdb.db.Model(node).Updates(types.Node{ - ForcedTags: newTags, - }).Error; err != nil { - return fmt.Errorf("failed to update tags for node in the database: %w", err) + b, err := json.Marshal(routes) + if err != nil { + return err } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.SetTags", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("approved_routes", string(b)).Error; err != nil { + return fmt.Errorf("updating approved routes: %w", err) } return nil } +// SetLastSeen sets a node's last seen field indicating that we +// have recently communicating with this node. +func (hsdb *HSDatabase) SetLastSeen(nodeID types.NodeID, lastSeen time.Time) error { + return hsdb.Write(func(tx *gorm.DB) error { + return SetLastSeen(tx, nodeID, lastSeen) + }) +} + +// SetLastSeen sets a node's last seen field indicating that we +// have recently communicating with this node. +func SetLastSeen(tx *gorm.DB, nodeID types.NodeID, lastSeen time.Time) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("last_seen", lastSeen).Error +} + // RenameNode takes a Node struct and a new GivenName for the nodes -// and renames it. -func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - err := util.CheckForFQDNRules( - newName, - ) - if err != nil { - log.Error(). - Caller(). - Str("func", "RenameNode"). - Str("node", node.Hostname). - Str("newName", newName). - Err(err). - Msg("failed to rename node") - - return err +// and renames it. Validation should be done in the state layer before calling this function. +func RenameNode(tx *gorm.DB, + nodeID types.NodeID, newName string, +) error { + if err := util.ValidateHostname(newName); err != nil { + return fmt.Errorf("renaming node: %w", err) } - node.GivenName = newName - if err := hsdb.db.Model(node).Updates(types.Node{ - GivenName: newName, - }).Error; err != nil { + // Check if the new name is unique + var count int64 + if err := tx.Model(&types.Node{}).Where("given_name = ? AND id != ?", newName, nodeID).Count(&count).Error; err != nil { + return fmt.Errorf("failed to check name uniqueness: %w", err) + } + + if count > 0 { + return errors.New("name is not unique") + } + + if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("given_name", newName).Error; err != nil { return fmt.Errorf("failed to rename node in the database: %w", err) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.RenameNode", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) - } - return nil } +func (hsdb *HSDatabase) NodeSetExpiry(nodeID types.NodeID, expiry time.Time) error { + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetExpiry(tx, nodeID, expiry) + }) +} + // NodeSetExpiry takes a Node struct and a new expiry time. -func (hsdb *HSDatabase) NodeSetExpiry(node *types.Node, expiry time.Time) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.nodeSetExpiry(node, expiry) +func NodeSetExpiry(tx *gorm.DB, + nodeID types.NodeID, expiry time.Time, +) error { + return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } -func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error { - if err := hsdb.db.Model(node).Updates(types.Node{ - Expiry: &expiry, - }).Error; err != nil { - return fmt.Errorf( - "failed to refresh node (update expiration) in the database: %w", - err, - ) - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: &expiry, - }, - }, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - - return nil +func (hsdb *HSDatabase) DeleteNode(node *types.Node) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DeleteNode(tx, node) + }) } // DeleteNode deletes a Node from the database. -func (hsdb *HSDatabase) DeleteNode(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.deleteNode(node) -} - -func (hsdb *HSDatabase) deleteNode(node *types.Node) error { - err := hsdb.deleteNodeRoutes(node) - if err != nil { - return err - } - +// Caller is responsible for notifying all of change. +func DeleteNode(tx *gorm.DB, + node *types.Node, +) error { // Unscoped causes the node to be fully removed from the database. - if err := hsdb.db.Unscoped().Delete(&node).Error; err != nil { + if err := tx.Unscoped().Delete(&types.Node{}, node.ID).Error; err != nil { return err } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - return nil } -// UpdateLastSeen sets a node's last seen field indicating that we -// have recently communicating with this node. -// This is mostly used to indicate if a node is online and is not -// extremely important to make sure is fully correct and to avoid -// holding up the hot path, does not contain any locks and isnt -// concurrency safe. But that should be ok. -func (hsdb *HSDatabase) UpdateLastSeen(node *types.Node) error { - return hsdb.db.Model(node).Updates(types.Node{ - LastSeen: node.LastSeen, - }).Error +// DeleteEphemeralNode deletes a Node from the database, note that this method +// will remove it straight, and not notify any changes or consider any routes. +// It is intended for Ephemeral nodes. +func (hsdb *HSDatabase) DeleteEphemeralNode( + nodeID types.NodeID, +) error { + return hsdb.Write(func(tx *gorm.DB) error { + if err := tx.Unscoped().Delete(&types.Node{}, nodeID).Error; err != nil { + return err + } + return nil + }) } -func (hsdb *HSDatabase) RegisterNodeFromAuthCallback( - cache *cache.Cache, - mkey key.MachinePublic, - userName string, - nodeExpiry *time.Time, - registrationMethod string, -) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +// RegisterNodeForTest is used only for testing purposes to register a node directly in the database. +// Production code should use state.HandleNodeFromAuthPath or state.HandleNodeFromPreAuthKey. +func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *netip.Addr) (*types.Node, error) { + if !testing.Testing() { + panic("RegisterNodeForTest can only be called during tests") + } - log.Debug(). - Str("machine_key", mkey.ShortString()). - Str("userName", userName). - Str("registrationMethod", registrationMethod). - Str("expiresAt", fmt.Sprintf("%v", nodeExpiry)). - Msg("Registering node from API/CLI or auth callback") + logEvent := log.Debug(). + Str("node", node.Hostname). + Str("machine_key", node.MachineKey.ShortString()). + Str("node_key", node.NodeKey.ShortString()) - if nodeInterface, ok := cache.Get(mkey.String()); ok { - if registrationNode, ok := nodeInterface.(types.Node); ok { - user, err := hsdb.getUser(userName) - if err != nil { - return nil, fmt.Errorf( - "failed to find user in register node from auth callback, %w", - err, - ) - } + if node.User != nil { + logEvent = logEvent.Str("user", node.User.Username()) + } else if node.UserID != nil { + logEvent = logEvent.Uint("user_id", *node.UserID) + } else { + logEvent = logEvent.Str("user", "none") + } - // Registration of expired node with different user - if registrationNode.ID != 0 && - registrationNode.UserID != user.ID { - return nil, ErrDifferentRegisteredUser - } + logEvent.Msg("Registering test node") - registrationNode.UserID = user.ID - registrationNode.RegisterMethod = registrationMethod - - if nodeExpiry != nil { - registrationNode.Expiry = nodeExpiry - } - - node, err := hsdb.registerNode( - registrationNode, - ) - - if err == nil { - cache.Delete(mkey.String()) - } - - return node, err - } else { - return nil, ErrCouldNotConvertNodeInterface + // If the a new node is registered with the same machine key, to the same user, + // update the existing node. + // If the same node is registered again, but to a new user, then that is considered + // a new node. + oldNode, _ := GetNodeByMachineKey(tx, node.MachineKey) + if oldNode != nil && oldNode.UserID == node.UserID { + node.ID = oldNode.ID + node.GivenName = oldNode.GivenName + node.ApprovedRoutes = oldNode.ApprovedRoutes + // Don't overwrite the provided IPs with old ones when they exist + if ipv4 == nil { + ipv4 = oldNode.IPv4 + } + if ipv6 == nil { + ipv6 = oldNode.IPv6 } } - return nil, ErrNodeNotFoundRegistrationCache -} - -// RegisterNode is executed from the CLI to register a new Node using its MachineKey. -func (hsdb *HSDatabase) RegisterNode(node types.Node) (*types.Node, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.registerNode(node) -} - -func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { - log.Debug(). - Str("node", node.Hostname). - Str("machine_key", node.MachineKey.ShortString()). - Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Name). - Msg("Registering node") - - // If the node exists and we had already IPs for it, we just save it + // If the node exists and it already has IP(s), we just save it // so we store the node.Expire and node.Nodekey that has been set when // adding it to the registrationCache - if len(node.IPAddresses) > 0 { - if err := hsdb.db.Save(&node).Error; err != nil { + if node.IPv4 != nil || node.IPv6 != nil { + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register existing node in the database: %w", err) } @@ -484,267 +397,85 @@ func (hsdb *HSDatabase) registerNode(node types.Node) (*types.Node, error) { Str("node", node.Hostname). Str("machine_key", node.MachineKey.ShortString()). Str("node_key", node.NodeKey.ShortString()). - Str("user", node.User.Name). - Msg("Node authorized again") + Str("user", node.User.Username()). + Msg("Test node authorized again") return &node, nil } - hsdb.ipAllocationMutex.Lock() - defer hsdb.ipAllocationMutex.Unlock() + node.IPv4 = ipv4 + node.IPv6 = ipv6 - ips, err := hsdb.getAvailableIPs() + var err error + node.Hostname, err = util.NormaliseHostname(node.Hostname) if err != nil { - log.Error(). - Caller(). - Err(err). - Str("node", node.Hostname). - Msg("Could not find IP for the new node") - - return nil, err + newHostname := util.InvalidString() + log.Info().Err(err).Str("invalid-hostname", node.Hostname).Str("new-hostname", newHostname).Msgf("Invalid hostname, replacing") + node.Hostname = newHostname } - node.IPAddresses = ips + if node.GivenName == "" { + givenName, err := EnsureUniqueGivenName(tx, node.Hostname) + if err != nil { + return nil, fmt.Errorf("failed to ensure unique given name: %w", err) + } - if err := hsdb.db.Save(&node).Error; err != nil { + node.GivenName = givenName + } + + if err := tx.Save(&node).Error; err != nil { return nil, fmt.Errorf("failed register(save) node in the database: %w", err) } log.Trace(). Caller(). Str("node", node.Hostname). - Str("ip", strings.Join(ips.StringSlice(), ",")). - Msg("Node registered with the database") + Msg("Test node registered with the database") return &node, nil } // NodeSetNodeKey sets the node key of a node and saves it to the database. -func (hsdb *HSDatabase) NodeSetNodeKey(node *types.Node, nodeKey key.NodePublic) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Model(node).Updates(types.Node{ +func NodeSetNodeKey(tx *gorm.DB, node *types.Node, nodeKey key.NodePublic) error { + return tx.Model(node).Updates(types.Node{ NodeKey: nodeKey, - }).Error; err != nil { - return err - } - - return nil + }).Error } -// NodeSetMachineKey sets the node key of a node and saves it to the database. func (hsdb *HSDatabase) NodeSetMachineKey( node *types.Node, machineKey key.MachinePublic, ) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() + return hsdb.Write(func(tx *gorm.DB) error { + return NodeSetMachineKey(tx, node, machineKey) + }) +} - if err := hsdb.db.Model(node).Updates(types.Node{ +// NodeSetMachineKey sets the node key of a node and saves it to the database. +func NodeSetMachineKey( + tx *gorm.DB, + node *types.Node, + machineKey key.MachinePublic, +) error { + return tx.Model(node).Updates(types.Node{ MachineKey: machineKey, - }).Error; err != nil { - return err - } - - return nil -} - -// NodeSave saves a node object to the database, prefer to use a specific save method rather -// than this. It is intended to be used when we are changing or. -func (hsdb *HSDatabase) NodeSave(node *types.Node) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if err := hsdb.db.Save(node).Error; err != nil { - return err - } - - return nil -} - -// GetAdvertisedRoutes returns the routes that are be advertised by the given node. -func (hsdb *HSDatabase) GetAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getAdvertisedRoutes(node *types.Node) ([]netip.Prefix, error) { - routes := types.Routes{} - - err := hsdb.db. - Preload("Node"). - Where("node_id = ? AND advertised = ?", node.ID, true).Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error(). - Caller(). - Err(err). - Str("node", node.Hostname). - Msg("Could not get advertised routes for node") - - return nil, err - } - - prefixes := []netip.Prefix{} - for _, route := range routes { - prefixes = append(prefixes, netip.Prefix(route.Prefix)) - } - - return prefixes, nil -} - -// GetEnabledRoutes returns the routes that are enabled for the node. -func (hsdb *HSDatabase) GetEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getEnabledRoutes(node) -} - -func (hsdb *HSDatabase) getEnabledRoutes(node *types.Node) ([]netip.Prefix, error) { - routes := types.Routes{} - - err := hsdb.db. - Preload("Node"). - Where("node_id = ? AND advertised = ? AND enabled = ?", node.ID, true, true). - Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error(). - Caller(). - Err(err). - Str("node", node.Hostname). - Msg("Could not get enabled routes for node") - - return nil, err - } - - prefixes := []netip.Prefix{} - for _, route := range routes { - prefixes = append(prefixes, netip.Prefix(route.Prefix)) - } - - return prefixes, nil -} - -func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return false - } - - enabledRoutes, err := hsdb.getEnabledRoutes(node) - if err != nil { - log.Error().Err(err).Msg("Could not get enabled routes") - - return false - } - - for _, enabledRoute := range enabledRoutes { - if route == enabledRoute { - return true - } - } - - return false -} - -// enableRoutes enables new routes based on a list of new routes. -func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { - newRoutes := make([]netip.Prefix, len(routeStrs)) - for index, routeStr := range routeStrs { - route, err := netip.ParsePrefix(routeStr) - if err != nil { - return err - } - - newRoutes[index] = route - } - - advertisedRoutes, err := hsdb.getAdvertisedRoutes(node) - if err != nil { - return err - } - - for _, newRoute := range newRoutes { - if !util.StringOrPrefixListContains(advertisedRoutes, newRoute) { - return fmt.Errorf( - "route (%s) is not available on node %s: %w", - node.Hostname, - newRoute, ErrNodeRouteIsNotAvailable, - ) - } - } - - // Separate loop so we don't leave things in a half-updated state - for _, prefix := range newRoutes { - route := types.Route{} - err := hsdb.db.Preload("Node"). - Where("node_id = ? AND prefix = ?", node.ID, types.IPPrefix(prefix)). - First(&route).Error - if err == nil { - route.Enabled = true - - // Mark already as primary if there is only this node offering this subnet - // (and is not an exit route) - if !route.IsExitRoute() { - route.IsPrimary = hsdb.isUniquePrefix(route) - } - - err = hsdb.db.Save(&route).Error - if err != nil { - return fmt.Errorf("failed to enable route: %w", err) - } - } else { - return fmt.Errorf("failed to find route: %w", err) - } - } - - // Ensure the node has the latest routes when notifying the other - // nodes - nRoutes, err := hsdb.getNodeRoutes(node) - if err != nil { - return fmt.Errorf("failed to read back routes: %w", err) - } - - node.Routes = nRoutes - - log.Trace(). - Caller(). - Str("node", node.Hostname). - Strs("routes", routeStrs). - Msg("enabling routes") - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from db.enableRoutes", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyWithIgnore( - stateUpdate, node.MachineKey.String()) - } - - return nil + }).Error } func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { - normalizedHostname, err := util.NormalizeToFQDNRulesConfigFromViper( - suppliedName, - ) - if err != nil { - return "", err + // Strip invalid DNS characters for givenName + suppliedName = strings.ToLower(suppliedName) + suppliedName = invalidDNSRegex.ReplaceAllString(suppliedName, "") + + if len(suppliedName) > util.LabelHostnameLength { + return "", types.ErrHostnameTooLong } if randomSuffix { // Trim if a hostname will be longer than 63 chars after adding the hash. trimmedHostnameLength := util.LabelHostnameLength - NodeGivenNameHashLength - NodeGivenNameTrimSize - if len(normalizedHostname) > trimmedHostnameLength { - normalizedHostname = normalizedHostname[:trimmedHostnameLength] + if len(suppliedName) > trimmedHostnameLength { + suppliedName = suppliedName[:trimmedHostnameLength] } suffix, err := util.GenerateRandomStringDNSSafe(NodeGivenNameHashLength) @@ -752,39 +483,39 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { return "", err } - normalizedHostname += "-" + suffix + suppliedName += "-" + suffix } - return normalizedHostname, nil + return suppliedName, nil } -func (hsdb *HSDatabase) GenerateGivenName( - mkey key.MachinePublic, - suppliedName string, +func isUniqueName(tx *gorm.DB, name string) (bool, error) { + nodes := types.Nodes{} + if err := tx. + Where("given_name = ?", name).Find(&nodes).Error; err != nil { + return false, err + } + + return len(nodes) == 0, nil +} + +// EnsureUniqueGivenName generates a unique given name for a node based on its hostname. +func EnsureUniqueGivenName( + tx *gorm.DB, + name string, ) (string, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - givenName, err := generateGivenName(suppliedName, false) + givenName, err := generateGivenName(name, false) if err != nil { return "", err } - // Tailscale rules (may differ) https://tailscale.com/kb/1098/machine-names/ - nodes, err := hsdb.listNodesByGivenName(givenName) + unique, err := isUniqueName(tx, givenName) if err != nil { return "", err } - var nodeFound *types.Node - for idx, node := range nodes { - if node.GivenName == givenName { - nodeFound = nodes[idx] - } - } - - if nodeFound != nil && nodeFound.MachineKey.String() != mkey.String() { - postfixedName, err := generateGivenName(suppliedName, true) + if !unique { + postfixedName, err := generateGivenName(name, true) if err != nil { return "", err } @@ -795,116 +526,258 @@ func (hsdb *HSDatabase) GenerateGivenName( return givenName, nil } -func (hsdb *HSDatabase) ExpireEphemeralNodes(inactivityThreshhold time.Duration) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +// EphemeralGarbageCollector is a garbage collector that will delete nodes after +// a certain amount of time. +// It is used to delete ephemeral nodes that have disconnected and should be +// cleaned up. +type EphemeralGarbageCollector struct { + mu sync.Mutex - users, err := hsdb.listUsers() - if err != nil { - log.Error().Err(err).Msg("Error listing users") + deleteFunc func(types.NodeID) + toBeDeleted map[types.NodeID]*time.Timer - return + deleteCh chan types.NodeID + cancelCh chan struct{} +} + +// NewEphemeralGarbageCollector creates a new EphemeralGarbageCollector, it takes +// a deleteFunc that will be called when a node is scheduled for deletion. +func NewEphemeralGarbageCollector(deleteFunc func(types.NodeID)) *EphemeralGarbageCollector { + return &EphemeralGarbageCollector{ + toBeDeleted: make(map[types.NodeID]*time.Timer), + deleteCh: make(chan types.NodeID, 10), + cancelCh: make(chan struct{}), + deleteFunc: deleteFunc, + } +} + +// Close stops the garbage collector. +func (e *EphemeralGarbageCollector) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + // Stop all timers + for _, timer := range e.toBeDeleted { + timer.Stop() } - for _, user := range users { - nodes, err := hsdb.listNodesByUser(user.Name) - if err != nil { - log.Error(). - Err(err). - Str("user", user.Name). - Msg("Error listing nodes in user") + // Close the cancel channel to signal all goroutines to exit + close(e.cancelCh) +} +// Schedule schedules a node for deletion after the expiry duration. +// If the garbage collector is already closed, this is a no-op. +func (e *EphemeralGarbageCollector) Schedule(nodeID types.NodeID, expiry time.Duration) { + e.mu.Lock() + defer e.mu.Unlock() + + // Don't schedule new timers if the garbage collector is already closed + select { + case <-e.cancelCh: + // The cancel channel is closed, meaning the GC is shutting down + // or already shut down, so we shouldn't schedule anything new + return + default: + // Continue with scheduling + } + + // If a timer already exists for this node, stop it first + if oldTimer, exists := e.toBeDeleted[nodeID]; exists { + oldTimer.Stop() + } + + timer := time.NewTimer(expiry) + e.toBeDeleted[nodeID] = timer + // Start a goroutine to handle the timer completion + go func() { + select { + case <-timer.C: + // This is to handle the situation where the GC is shutting down and + // we are trying to schedule a new node for deletion at the same time + // i.e. We don't want to send to deleteCh if the GC is shutting down + // So, we try to send to deleteCh, but also watch for cancelCh + select { + case e.deleteCh <- nodeID: + // Successfully sent to deleteCh + case <-e.cancelCh: + // GC is shutting down, don't send to deleteCh + return + } + case <-e.cancelCh: + // If the GC is closed, exit the goroutine return } + }() +} - expired := make([]tailcfg.NodeID, 0) - for idx, node := range nodes { - if node.IsEphemeral() && node.LastSeen != nil && - time.Now(). - After(node.LastSeen.Add(inactivityThreshhold)) { - expired = append(expired, tailcfg.NodeID(node.ID)) +// Cancel cancels the deletion of a node. +func (e *EphemeralGarbageCollector) Cancel(nodeID types.NodeID) { + e.mu.Lock() + defer e.mu.Unlock() - log.Info(). - Str("node", node.Hostname). - Msg("Ephemeral client removed from database") + if timer, ok := e.toBeDeleted[nodeID]; ok { + timer.Stop() + delete(e.toBeDeleted, nodeID) + } +} - err = hsdb.deleteNode(nodes[idx]) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Msg("🤮 Cannot delete ephemeral node from the database") - } - } - } +// Start starts the garbage collector. +func (e *EphemeralGarbageCollector) Start() { + for { + select { + case <-e.cancelCh: + return + case nodeID := <-e.deleteCh: + e.mu.Lock() + delete(e.toBeDeleted, nodeID) + e.mu.Unlock() - if len(expired) > 0 { - hsdb.notifier.NotifyAll(types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: expired, - }) + go e.deleteFunc(nodeID) } } } -func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string) *types.Node { + if !testing.Testing() { + panic("CreateNodeForTest can only be called during tests") + } - // use the time of the start of the function to ensure we - // dont miss some nodes by returning it _after_ we have - // checked everything. - started := time.Now() + if user == nil { + panic("CreateNodeForTest requires a valid user") + } - expired := make([]*tailcfg.PeerChange, 0) + nodeName := "testnode" + if len(hostname) > 0 && hostname[0] != "" { + nodeName = hostname[0] + } - nodes, err := hsdb.listNodes() + // Create a preauth key for the node + pak, err := hsdb.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) if err != nil { - log.Error(). - Err(err). - Msg("Error listing nodes to find expired nodes") - - return time.Unix(0, 0) - } - for index, node := range nodes { - if node.IsExpired() && - // TODO(kradalby): Replace this, it is very spammy - // It will notify about all nodes that has been expired. - // It should only notify about expired nodes since _last check_. - node.Expiry.After(lastCheck) { - expired = append(expired, &tailcfg.PeerChange{ - NodeID: tailcfg.NodeID(node.ID), - KeyExpiry: node.Expiry, - }) - - now := time.Now() - // Do not use setNodeExpiry as that has a notifier hook, which - // can cause a deadlock, we are updating all changed nodes later - // and there is no point in notifiying twice. - if err := hsdb.db.Model(nodes[index]).Updates(types.Node{ - Expiry: &now, - }).Error; err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Str("name", node.GivenName). - Msg("🤮 Cannot expire node") - } else { - log.Info(). - Str("node", node.Hostname). - Str("name", node.GivenName). - Msg("Node successfully expired") - } - } + panic(fmt.Sprintf("failed to create preauth key for test node: %v", err)) } - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: expired, - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) + nodeKey := key.NewNode() + machineKey := key.NewMachine() + discoKey := key.NewDisco() + + node := &types.Node{ + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + DiscoKey: discoKey.Public(), + Hostname: nodeName, + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), } - return started + err = hsdb.DB.Save(node).Error + if err != nil { + panic(fmt.Sprintf("failed to create test node: %v", err)) + } + + return node +} + +func (hsdb *HSDatabase) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node { + if !testing.Testing() { + panic("CreateRegisteredNodeForTest can only be called during tests") + } + + node := hsdb.CreateNodeForTest(user, hostname...) + + // Allocate IPs for the test node using the database's IP allocator + // This is a simplified allocation for testing - in production this would use State.ipAlloc + ipv4, ipv6, err := hsdb.allocateTestIPs(node.ID) + if err != nil { + panic(fmt.Sprintf("failed to allocate IPs for test node: %v", err)) + } + + var registeredNode *types.Node + err = hsdb.DB.Transaction(func(tx *gorm.DB) error { + var err error + registeredNode, err = RegisterNodeForTest(tx, *node, ipv4, ipv6) + return err + }) + if err != nil { + panic(fmt.Sprintf("failed to register test node: %v", err)) + } + + return registeredNode +} + +func (hsdb *HSDatabase) CreateNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node { + if !testing.Testing() { + panic("CreateNodesForTest can only be called during tests") + } + + if user == nil { + panic("CreateNodesForTest requires a valid user") + } + + prefix := "testnode" + if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { + prefix = hostnamePrefix[0] + } + + nodes := make([]*types.Node, count) + for i := range count { + hostname := prefix + "-" + strconv.Itoa(i) + nodes[i] = hsdb.CreateNodeForTest(user, hostname) + } + + return nodes +} + +func (hsdb *HSDatabase) CreateRegisteredNodesForTest(user *types.User, count int, hostnamePrefix ...string) []*types.Node { + if !testing.Testing() { + panic("CreateRegisteredNodesForTest can only be called during tests") + } + + if user == nil { + panic("CreateRegisteredNodesForTest requires a valid user") + } + + prefix := "testnode" + if len(hostnamePrefix) > 0 && hostnamePrefix[0] != "" { + prefix = hostnamePrefix[0] + } + + nodes := make([]*types.Node, count) + for i := range count { + hostname := prefix + "-" + strconv.Itoa(i) + nodes[i] = hsdb.CreateRegisteredNodeForTest(user, hostname) + } + + return nodes +} + +// allocateTestIPs allocates sequential test IPs for nodes during testing. +func (hsdb *HSDatabase) allocateTestIPs(nodeID types.NodeID) (*netip.Addr, *netip.Addr, error) { + if !testing.Testing() { + panic("allocateTestIPs can only be called during tests") + } + + // Use simple sequential allocation for tests + // IPv4: 100.64.x.y (where x = nodeID/256, y = nodeID%256) + // IPv6: fd7a:115c:a1e0::x:y (where x = high byte, y = low byte) + // This supports up to 65535 nodes + const ( + maxTestNodes = 65535 + ipv4ByteDivisor = 256 + ) + + if nodeID > maxTestNodes { + return nil, nil, ErrCouldNotAllocateIP + } + + // Split nodeID into high and low bytes for IPv4 (100.64.high.low) + highByte := byte(nodeID / ipv4ByteDivisor) + lowByte := byte(nodeID % ipv4ByteDivisor) + ipv4 := netip.AddrFrom4([4]byte{100, 64, highByte, lowByte}) + + // For IPv6, use the last two bytes of the address (fd7a:115c:a1e0::high:low) + ipv6 := netip.AddrFrom16([16]byte{0xfd, 0x7a, 0x11, 0x5c, 0xa1, 0xe0, 0, 0, 0, 0, 0, 0, 0, 0, highByte, lowByte}) + + return &ipv4, &ipv6, nil } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 140c264b..7e00f9ca 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -1,30 +1,109 @@ package db import ( + "crypto/rand" "fmt" + "math/big" "net/netip" "regexp" - "strconv" + "runtime" + "sync" + "sync/atomic" "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "gopkg.in/check.v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) -func (s *Suite) TestGetNode(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) +func TestGetNode(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.NotNil) + _, err = db.getNode(types.UserID(user.ID), "testnode") + require.Error(t, err) + + node := db.CreateNodeForTest(user, "testnode") + + _, err = db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.Equal(t, "testnode", node.Hostname) +} + +func TestGetNodeByID(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user := db.CreateUserForTest("test") + + _, err = db.GetNodeByID(0) + require.Error(t, err) + + node := db.CreateNodeForTest(user, "testnode") + + retrievedNode, err := db.GetNodeByID(node.ID) + require.NoError(t, err) + assert.Equal(t, "testnode", retrievedNode.Hostname) +} + +func TestHardDeleteNode(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user := db.CreateUserForTest("test") + node := db.CreateNodeForTest(user, "testnode3") + + err = db.DeleteNode(node) + require.NoError(t, err) + + _, err = db.getNode(types.UserID(user.ID), "testnode3") + require.Error(t, err) +} + +func TestListPeersManyNodes(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user := db.CreateUserForTest("test") + + _, err = db.GetNodeByID(0) + require.Error(t, err) + + nodes := db.CreateNodesForTest(user, 11, "testnode") + + firstNode := nodes[0] + peersOfFirstNode, err := db.ListPeers(firstNode.ID) + require.NoError(t, err) + + assert.Len(t, peersOfFirstNode, 10) + assert.Equal(t, "testnode-1", peersOfFirstNode[0].Hostname) + assert.Equal(t, "testnode-6", peersOfFirstNode[5].Hostname) + assert.Equal(t, "testnode-10", peersOfFirstNode[9].Hostname) +} + +func TestExpireNode(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + _, err = db.getNode(types.UserID(user.ID), "testnode") + require.Error(t, err) nodeKey := key.NewNode() machineKey := key.NewMachine() @@ -34,371 +113,41 @@ func (s *Suite) TestGetNode(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(node) - - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetNodeByID(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetNodeByNodeKey(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.GetNodeByNodeKey(nodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetNodeByAnyNodeKey(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - oldNodeKey := key.NewNode() - - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.GetNodeByAnyKey(machineKey.Public(), nodeKey.Public(), oldNodeKey.Public()) - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestHardDeleteNode(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode3", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(1), - } - db.db.Save(&node) - - err = db.DeleteNode(&node) - c.Assert(err, check.IsNil) - - _, err = db.GetNode(user.Name, "testnode3") - c.Assert(err, check.NotNil) -} - -func (s *Suite) TestListPeers(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: uint64(index), - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode" + strconv.Itoa(index), - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - } - - node0ByID, err := db.GetNodeByID(0) - c.Assert(err, check.IsNil) - - peersOfNode0, err := db.ListPeers(node0ByID) - c.Assert(err, check.IsNil) - - c.Assert(len(peersOfNode0), check.Equals, 9) - c.Assert(peersOfNode0[0].Hostname, check.Equals, "testnode2") - c.Assert(peersOfNode0[5].Hostname, check.Equals, "testnode7") - c.Assert(peersOfNode0[8].Hostname, check.Equals, "testnode10") -} - -func (s *Suite) TestGetACLFilteredPeers(c *check.C) { - type base struct { - user *types.User - key *types.PreAuthKey - } - - stor := make([]base, 0) - - for _, name := range []string{"test", "admin"} { - user, err := db.CreateUser(name) - c.Assert(err, check.IsNil) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - stor = append(stor, base{user, pak}) - } - - _, err := db.GetNodeByID(0) - c.Assert(err, check.NotNil) - - for index := 0; index <= 10; index++ { - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := types.Node{ - ID: uint64(index), - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr(fmt.Sprintf("100.64.0.%v", strconv.Itoa(index+1))), - }, - Hostname: "testnode" + strconv.Itoa(index), - UserID: stor[index%2].user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(stor[index%2].key.ID), - } - db.db.Save(&node) - } - - aclPolicy := &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:test": {"admin"}, - }, - Hosts: map[string]netip.Prefix{}, - TagOwners: map[string][]string{}, - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"admin"}, - Destinations: []string{"*:*"}, - }, - { - Action: "accept", - Sources: []string{"test"}, - Destinations: []string{"test:*"}, - }, - }, - Tests: []policy.ACLTest{}, - } - - adminNode, err := db.GetNodeByID(1) - c.Logf("Node(%v), user: %v", adminNode.Hostname, adminNode.User) - c.Assert(err, check.IsNil) - - testNode, err := db.GetNodeByID(2) - c.Logf("Node(%v), user: %v", testNode.Hostname, testNode.User) - c.Assert(err, check.IsNil) - - adminPeers, err := db.ListPeers(adminNode) - c.Assert(err, check.IsNil) - - testPeers, err := db.ListPeers(testNode) - c.Assert(err, check.IsNil) - - adminRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, adminNode, adminPeers) - c.Assert(err, check.IsNil) - - testRules, _, err := policy.GenerateFilterAndSSHRules(aclPolicy, testNode, testPeers) - c.Assert(err, check.IsNil) - - peersOfAdminNode := policy.FilterNodesByACL(adminNode, adminPeers, adminRules) - peersOfTestNode := policy.FilterNodesByACL(testNode, testPeers, testRules) - - c.Log(peersOfTestNode) - c.Assert(len(peersOfTestNode), check.Equals, 9) - c.Assert(peersOfTestNode[0].Hostname, check.Equals, "testnode1") - c.Assert(peersOfTestNode[1].Hostname, check.Equals, "testnode3") - c.Assert(peersOfTestNode[3].Hostname, check.Equals, "testnode5") - - c.Log(peersOfAdminNode) - c.Assert(len(peersOfAdminNode), check.Equals, 9) - c.Assert(peersOfAdminNode[0].Hostname, check.Equals, "testnode2") - c.Assert(peersOfAdminNode[2].Hostname, check.Equals, "testnode4") - c.Assert(peersOfAdminNode[5].Hostname, check.Equals, "testnode7") -} - -func (s *Suite) TestExpireNode(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: ptr.To(pak.ID), Expiry: &time.Time{}, } - db.db.Save(node) + db.DB.Save(node) - nodeFromDB, err := db.GetNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert(nodeFromDB, check.NotNil) + nodeFromDB, err := db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + require.NotNil(t, nodeFromDB) - c.Assert(nodeFromDB.IsExpired(), check.Equals, false) + assert.False(t, nodeFromDB.IsExpired()) now := time.Now() - err = db.NodeSetExpiry(nodeFromDB, now) - c.Assert(err, check.IsNil) + err = db.NodeSetExpiry(nodeFromDB.ID, now) + require.NoError(t, err) - c.Assert(nodeFromDB.IsExpired(), check.Equals, true) + nodeFromDB, err = db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + + assert.True(t, nodeFromDB.IsExpired()) } -func (s *Suite) TestSerdeAddressStrignSlice(c *check.C) { - input := types.NodeAddresses([]netip.Addr{ - netip.MustParseAddr("192.0.2.1"), - netip.MustParseAddr("2001:db8::1"), - }) - serialized, err := input.Value() - c.Assert(err, check.IsNil) - if serial, ok := serialized.(string); ok { - c.Assert(serial, check.Equals, "192.0.2.1,2001:db8::1") - } +func TestSetTags(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) - var deserialized types.NodeAddresses - err = deserialized.Scan(serialized) - c.Assert(err, check.IsNil) + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) - c.Assert(len(deserialized), check.Equals, len(input)) - for i := range deserialized { - c.Assert(deserialized[i], check.Equals, input[i]) - } -} + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) -func (s *Suite) TestGenerateGivenName(c *check.C) { - user1, err := db.CreateUser("user-1") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user1.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("user-1", "testnode") - c.Assert(err, check.NotNil) - - nodeKey := key.NewNode() - machineKey := key.NewMachine() - - machineKey2 := key.NewMachine() - - node := &types.Node{ - ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), - Hostname: "hostname-1", - GivenName: "hostname-1", - UserID: user1.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(node) - - givenName, err := db.GenerateGivenName(machineKey2.Public(), "hostname-2") - comment := check.Commentf("Same user, unique nodes, unique hostnames, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-2", comment) - - givenName, err = db.GenerateGivenName(machineKey.Public(), "hostname-1") - comment = check.Commentf("Same user, same node, same hostname, no conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Equals, "hostname-1", comment) - - givenName, err = db.GenerateGivenName(machineKey2.Public(), "hostname-1") - comment = check.Commentf("Same user, unique nodes, same hostname, conflict") - c.Assert(err, check.IsNil, comment) - c.Assert(givenName, check.Matches, fmt.Sprintf("^hostname-1-[a-z0-9]{%d}$", NodeGivenNameHashLength), comment) -} - -func (s *Suite) TestSetTags(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "testnode") - c.Assert(err, check.NotNil) + _, err = db.getNode(types.UserID(user.ID), "testnode") + require.Error(t, err) nodeKey := key.NewNode() machineKey := key.NewMachine() @@ -408,31 +157,29 @@ func (s *Suite) TestSetTags(c *check.C) { MachineKey: machineKey.Public(), NodeKey: nodeKey.Public(), Hostname: "testnode", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + AuthKeyID: ptr.To(pak.ID), } - db.db.Save(node) + + trx := db.DB.Save(node) + require.NoError(t, trx.Error) // assign simple tags sTags := []string{"tag:test", "tag:foo"} - err = db.SetTags(node, sTags) - c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert(node.ForcedTags, check.DeepEquals, types.StringList(sTags)) + err = db.SetTags(node.ID, sTags) + require.NoError(t, err) + node, err = db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.Equal(t, sTags, node.Tags) - // assign duplicat tags, expect no errors but no doubles in DB + // assign duplicate tags, expect no errors but no doubles in DB eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"} - err = db.SetTags(node, eTags) - c.Assert(err, check.IsNil) - node, err = db.GetNode("test", "testnode") - c.Assert(err, check.IsNil) - c.Assert( - node.ForcedTags, - check.DeepEquals, - types.StringList([]string{"tag:bar", "tag:test", "tag:unknown"}), - ) + err = db.SetTags(node.ID, eTags) + require.NoError(t, err) + node, err = db.getNode(types.UserID(user.ID), "testnode") + require.NoError(t, err) + assert.Equal(t, []string{"tag:bar", "tag:test", "tag:unknown"}, node.Tags) } func TestHeadscale_generateGivenName(t *testing.T) { @@ -455,6 +202,15 @@ func TestHeadscale_generateGivenName(t *testing.T) { want: regexp.MustCompile("^testnode$"), wantErr: false, }, + { + name: "UPPERCASE node name generation", + args: args{ + suppliedName: "TestNode", + randomSuffix: false, + }, + want: regexp.MustCompile("^testnode$"), + wantErr: false, + }, { name: "node name with 53 chars", args: args{ @@ -542,78 +298,833 @@ func TestHeadscale_generateGivenName(t *testing.T) { } } -func (s *Suite) TestAutoApproveRoutes(c *check.C) { - acl := []byte(` +func TestAutoApproveRoutes(t *testing.T) { + tests := []struct { + name string + acl string + routes []netip.Prefix + want []netip.Prefix + want2 []netip.Prefix + expectChange bool // whether to expect route changes + }{ + { + name: "no-auto-approvers-empty-policy", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ] +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - no auto-approvers + want2: []netip.Prefix{}, // Should be empty - no auto-approvers + expectChange: false, // No changes expected + }, + { + name: "no-auto-approvers-explicit-empty", + acl: ` +{ + "groups": { + "group:admins": ["test@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["group:admins:*"] + } + ], + "autoApprovers": { + "routes": {}, + "exitNode": [] + } +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + want: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + want2: []netip.Prefix{}, // Should be empty - explicitly empty auto-approvers + expectChange: false, // No changes expected + }, + { + name: "2068-approve-issue-sub-kube", + acl: ` +{ + "groups": { + "group:k8s": ["test@"] + }, + +// "acls": [ +// {"action": "accept", "users": ["*"], "ports": ["*:*"]}, +// ], + + "autoApprovers": { + "routes": { + "10.42.0.0/16": ["test@"], + } + } +}`, + routes: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + want: []netip.Prefix{netip.MustParsePrefix("10.42.7.0/24")}, + expectChange: true, // Routes should be approved + }, + { + name: "2068-approve-issue-sub-exit-tag", + acl: ` { "tagOwners": { - "tag:exit": ["test"], + "tag:exit": ["test@"], }, "groups": { - "group:test": ["test"] + "group:test": ["test@"] }, - "acls": [ - {"action": "accept", "users": ["*"], "ports": ["*:*"]}, - ], +// "acls": [ +// {"action": "accept", "users": ["*"], "ports": ["*:*"]}, +// ], "autoApprovers": { "exitNode": ["tag:exit"], "routes": { "10.10.0.0/16": ["group:test"], - "10.11.0.0/16": ["test"], + "10.11.0.0/16": ["test@"], + "8.11.0.0/24": ["test2@"], // No nodes + } + } +}`, + routes: []netip.Prefix{ + tsaddr.AllIPv4(), + tsaddr.AllIPv6(), + netip.MustParsePrefix("10.10.0.0/16"), + netip.MustParsePrefix("10.11.0.0/24"), + + // Not approved + netip.MustParsePrefix("8.11.0.0/24"), + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.0.0/16"), + netip.MustParsePrefix("10.11.0.0/24"), + }, + want2: []netip.Prefix{ + tsaddr.AllIPv4(), + tsaddr.AllIPv6(), + }, + expectChange: true, // Routes should be approved + }, + } + + for _, tt := range tests { + pmfs := policy.PolicyManagerFuncsForTest([]byte(tt.acl)) + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + adb, err := newSQLiteTestDB() + require.NoError(t, err) + + user, err := adb.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + _, err = adb.CreateUser(types.User{Name: "test2"}) + require.NoError(t, err) + taggedUser, err := adb.CreateUser(types.User{Name: "tagged"}) + require.NoError(t, err) + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.routes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + } + + err = adb.DB.Save(&node).Error + require.NoError(t, err) + + nodeTagged := types.Node{ + ID: 2, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "taggednode", + UserID: &taggedUser.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.routes, + }, + Tags: []string{"tag:exit"}, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + } + + err = adb.DB.Save(&nodeTagged).Error + require.NoError(t, err) + + users, err := adb.ListUsers() + assert.NoError(t, err) + + nodes, err := adb.ListNodes() + assert.NoError(t, err) + + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, pm) + + newRoutes1, changed1 := policy.ApproveRoutesWithPolicy(pm, node.View(), node.ApprovedRoutes, tt.routes) + assert.Equal(t, tt.expectChange, changed1) + + if changed1 { + err = SetApprovedRoutes(adb.DB, node.ID, newRoutes1) + require.NoError(t, err) + } + + newRoutes2, changed2 := policy.ApproveRoutesWithPolicy(pm, nodeTagged.View(), nodeTagged.ApprovedRoutes, tt.routes) + if changed2 { + err = SetApprovedRoutes(adb.DB, nodeTagged.ID, newRoutes2) + require.NoError(t, err) + } + + node1ByID, err := adb.GetNodeByID(1) + require.NoError(t, err) + + // For empty auto-approvers tests, handle nil vs empty slice comparison + expectedRoutes1 := tt.want + if len(expectedRoutes1) == 0 { + expectedRoutes1 = nil + } + if diff := cmp.Diff(expectedRoutes1, node1ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { + t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) + } + + node2ByID, err := adb.GetNodeByID(2) + require.NoError(t, err) + + expectedRoutes2 := tt.want2 + if len(expectedRoutes2) == 0 { + expectedRoutes2 = nil + } + if diff := cmp.Diff(expectedRoutes2, node2ByID.AllApprovedRoutes(), util.Comparers...); diff != "" { + t.Errorf("unexpected enabled routes (-want +got):\n%s", diff) + } + }) } } } - `) - pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(err, check.IsNil) - c.Assert(pol, check.NotNil) +func TestEphemeralGarbageCollectorOrder(t *testing.T) { + want := []types.NodeID{1, 3} + got := []types.NodeID{} + var mu sync.Mutex - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) + deletionCount := make(chan struct{}, 10) - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + mu.Lock() + defer mu.Unlock() + got = append(got, ni) - nodeKey := key.NewNode() - machineKey := key.NewMachine() + deletionCount <- struct{}{} + }) + go e.Start() - defaultRouteV4 := netip.MustParsePrefix("0.0.0.0/0") - defaultRouteV6 := netip.MustParsePrefix("::/0") - route1 := netip.MustParsePrefix("10.10.0.0/16") - // Check if a subprefix of an autoapproved route is approved - route2 := netip.MustParsePrefix("10.11.0.0/24") + // Use shorter timeouts for faster tests + go e.Schedule(1, 50*time.Millisecond) + go e.Schedule(2, 100*time.Millisecond) + go e.Schedule(3, 150*time.Millisecond) + go e.Schedule(4, 200*time.Millisecond) + + // Wait for first deletion (node 1 at 50ms) + select { + case <-deletionCount: + case <-time.After(time.Second): + t.Fatal("timeout waiting for first deletion") + } + + // Cancel nodes 2 and 4 + go e.Cancel(2) + go e.Cancel(4) + + // Wait for node 3 to be deleted (at 150ms) + select { + case <-deletionCount: + case <-time.After(time.Second): + t.Fatal("timeout waiting for second deletion") + } + + // Give a bit more time for any unexpected deletions + select { + case <-deletionCount: + // Unexpected - more deletions than expected + case <-time.After(300 * time.Millisecond): + // Expected - no more deletions + } + + e.Close() + + mu.Lock() + defer mu.Unlock() + + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("wrong nodes deleted, unexpected result (-want +got):\n%s", diff) + } +} + +func TestEphemeralGarbageCollectorLoads(t *testing.T) { + var got []types.NodeID + var mu sync.Mutex + + want := 1000 + + var deletedCount int64 + + e := NewEphemeralGarbageCollector(func(ni types.NodeID) { + mu.Lock() + defer mu.Unlock() + + // Yield to other goroutines to introduce variability + runtime.Gosched() + got = append(got, ni) + + atomic.AddInt64(&deletedCount, 1) + }) + go e.Start() + + // Use shorter expiry for faster tests + for i := range want { + go e.Schedule(types.NodeID(i), 100*time.Millisecond) //nolint:gosec // test code, no overflow risk + } + + // Wait for all deletions to complete + assert.EventuallyWithT(t, func(c *assert.CollectT) { + count := atomic.LoadInt64(&deletedCount) + assert.Equal(c, int64(want), count, "all nodes should be deleted") + }, 10*time.Second, 50*time.Millisecond, "waiting for all deletions") + + e.Close() + + mu.Lock() + defer mu.Unlock() + + if len(got) != want { + t.Errorf("expected %d, got %d", want, len(got)) + } +} + +func generateRandomNumber(t *testing.T, max int64) int64 { + t.Helper() + maxB := big.NewInt(max) + n, err := rand.Int(rand.Reader, maxB) + if err != nil { + t.Fatalf("getting random number: %s", err) + } + + return n.Int64() + 1 +} + +func TestListEphemeralNodes(t *testing.T) { + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil) + require.NoError(t, err) node := types.Node{ ID: 0, - MachineKey: machineKey.Public(), - NodeKey: nodeKey.Public(), + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), Hostname: "test", - UserID: user.ID, + UserID: &user.ID, RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:exit"}, - RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2}, - }, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, + AuthKeyID: ptr.To(pak.ID), } - db.db.Save(&node) + nodeEph := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "ephemeral", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pakEph.ID), + } - sendUpdate, err := db.SaveNodeRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(sendUpdate, check.Equals, false) + err = db.DB.Save(&node).Error + require.NoError(t, err) - node0ByID, err := db.GetNodeByID(0) - c.Assert(err, check.IsNil) + err = db.DB.Save(&nodeEph).Error + require.NoError(t, err) - err = db.EnableAutoApprovedRoutes(pol, node0ByID) - c.Assert(err, check.IsNil) + nodes, err := db.ListNodes() + require.NoError(t, err) - enabledRoutes, err := db.GetEnabledRoutes(node0ByID) - c.Assert(err, check.IsNil) - c.Assert(enabledRoutes, check.HasLen, 4) + ephemeralNodes, err := db.ListEphemeralNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 2) + assert.Len(t, ephemeralNodes, 1) + + assert.Equal(t, nodeEph.ID, ephemeralNodes[0].ID) + assert.Equal(t, nodeEph.AuthKeyID, ephemeralNodes[0].AuthKeyID) + assert.Equal(t, nodeEph.UserID, ephemeralNodes[0].UserID) + assert.Equal(t, nodeEph.Hostname, ephemeralNodes[0].Hostname) +} + +func TestNodeNaming(t *testing.T) { + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + user2, err := db.CreateUser(types.User{Name: "user2"}) + require.NoError(t, err) + + node := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test", + UserID: &user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + // Using non-ASCII characters in the hostname can + // break your network, so they should be replaced when registering + // a node. + // https://github.com/juanfont/headscale/issues/2343 + nodeInvalidHostname := types.Node{ + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "我的电脑", + UserID: &user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + nodeShortHostname := types.Node{ + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "a", + UserID: &user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + } + + err = db.DB.Save(&node).Error + require.NoError(t, err) + + err = db.DB.Save(&node2).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNodeForTest(tx, node, nil, nil) + if err != nil { + return err + } + _, err = RegisterNodeForTest(tx, node2, nil, nil) + if err != nil { + return err + } + _, err = RegisterNodeForTest(tx, nodeInvalidHostname, ptr.To(mpp("100.64.0.66/32").Addr()), nil) + _, err = RegisterNodeForTest(tx, nodeShortHostname, ptr.To(mpp("100.64.0.67/32").Addr()), nil) + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 4) + + t.Logf("node1 %s %s", nodes[0].Hostname, nodes[0].GivenName) + t.Logf("node2 %s %s", nodes[1].Hostname, nodes[1].GivenName) + t.Logf("node3 %s %s", nodes[2].Hostname, nodes[2].GivenName) + t.Logf("node4 %s %s", nodes[3].Hostname, nodes[3].GivenName) + + assert.Equal(t, nodes[0].Hostname, nodes[0].GivenName) + assert.NotEqual(t, nodes[1].Hostname, nodes[1].GivenName) + assert.Equal(t, nodes[0].Hostname, nodes[1].Hostname) + assert.NotEqual(t, nodes[0].Hostname, nodes[1].GivenName) + assert.Contains(t, nodes[1].GivenName, nodes[0].Hostname) + assert.Equal(t, nodes[0].GivenName, nodes[1].Hostname) + assert.Len(t, nodes[0].Hostname, 4) + assert.Len(t, nodes[1].Hostname, 4) + assert.Len(t, nodes[0].GivenName, 4) + assert.Len(t, nodes[1].GivenName, 13) + assert.Contains(t, nodes[2].Hostname, "invalid-") // invalid chars + assert.Contains(t, nodes[2].GivenName, "invalid-") + assert.Contains(t, nodes[3].Hostname, "invalid-") // too short + assert.Contains(t, nodes[3].GivenName, "invalid-") + + // Nodes can be renamed to a unique name + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "newname") + }) + require.NoError(t, err) + + nodes, err = db.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 4) + assert.Equal(t, "test", nodes[0].Hostname) + assert.Equal(t, "newname", nodes[0].GivenName) + + // Nodes can reuse name that is no longer used + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[1].ID, "test") + }) + require.NoError(t, err) + + nodes, err = db.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 4) + assert.Equal(t, "test", nodes[0].Hostname) + assert.Equal(t, "newname", nodes[0].GivenName) + assert.Equal(t, "test", nodes[1].GivenName) + + // Nodes cannot be renamed to used names + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "test") + }) + assert.ErrorContains(t, err, "name is not unique") + + // Rename invalid chars + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[2].ID, "我的电脑") + }) + assert.ErrorContains(t, err, "invalid characters") + + // Rename too short + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[3].ID, "a") + }) + assert.ErrorContains(t, err, "at least 2 characters") + + // Rename with emoji + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "hostname-with-💩") + }) + assert.ErrorContains(t, err, "invalid characters") + + // Rename with only emoji + err = db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, "🚀") + }) + assert.ErrorContains(t, err, "invalid characters") +} + +func TestRenameNodeComprehensive(t *testing.T) { + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + node := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + err = db.DB.Save(&node).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNodeForTest(tx, node, nil, nil) + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 1) + + tests := []struct { + name string + newName string + wantErr string + }{ + { + name: "uppercase_rejected", + newName: "User2-Host", + wantErr: "must be lowercase", + }, + { + name: "underscore_rejected", + newName: "test_node", + wantErr: "invalid characters", + }, + { + name: "at_sign_uppercase_rejected", + newName: "Test@Host", + wantErr: "must be lowercase", + }, + { + name: "at_sign_rejected", + newName: "test@host", + wantErr: "invalid characters", + }, + { + name: "chinese_chars_with_dash_rejected", + newName: "server-北京-01", + wantErr: "invalid characters", + }, + { + name: "chinese_only_rejected", + newName: "我的电脑", + wantErr: "invalid characters", + }, + { + name: "emoji_with_text_rejected", + newName: "laptop-🚀", + wantErr: "invalid characters", + }, + { + name: "mixed_chinese_emoji_rejected", + newName: "测试💻机器", + wantErr: "invalid characters", + }, + { + name: "only_emojis_rejected", + newName: "🎉🎊", + wantErr: "invalid characters", + }, + { + name: "only_at_signs_rejected", + newName: "@@@", + wantErr: "invalid characters", + }, + { + name: "starts_with_dash_rejected", + newName: "-test", + wantErr: "cannot start or end with a hyphen", + }, + { + name: "ends_with_dash_rejected", + newName: "test-", + wantErr: "cannot start or end with a hyphen", + }, + { + name: "too_long_hostname_rejected", + newName: "this-is-a-very-long-hostname-that-exceeds-sixty-three-characters-limit", + wantErr: "must not exceed 63 characters", + }, + { + name: "too_short_hostname_rejected", + newName: "a", + wantErr: "at least 2 characters", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := db.Write(func(tx *gorm.DB) error { + return RenameNode(tx, nodes[0].ID, tt.newName) + }) + assert.ErrorContains(t, err, tt.wantErr) + }) + } +} + +func TestListPeers(t *testing.T) { + // Setup test database + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + user2, err := db.CreateUser(types.User{Name: "user2"}) + require.NoError(t, err) + + node1 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test1", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test2", + UserID: &user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + err = db.DB.Save(&node1).Error + require.NoError(t, err) + + err = db.DB.Save(&node2).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNodeForTest(tx, node1, nil, nil) + if err != nil { + return err + } + _, err = RegisterNodeForTest(tx, node2, nil, nil) + + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 2) + + // No parameter means no filter, should return all peers + nodes, err = db.ListPeers(1) + require.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Empty node list should return all peers + nodes, err = db.ListPeers(1, types.NodeIDs{}...) + require.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // No match in IDs should return empty list and no error + nodes, err = db.ListPeers(1, types.NodeIDs{3, 4, 5}...) + require.NoError(t, err) + assert.Empty(t, nodes) + + // Partial match in IDs + nodes, err = db.ListPeers(1, types.NodeIDs{2, 3}...) + require.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Several matched IDs, but node ID is still filtered out + nodes, err = db.ListPeers(1, types.NodeIDs{1, 2, 3}...) + require.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, "test2", nodes[0].Hostname) +} + +func TestListNodes(t *testing.T) { + // Setup test database + db, err := newSQLiteTestDB() + if err != nil { + t.Fatalf("creating db: %s", err) + } + + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + user2, err := db.CreateUser(types.User{Name: "user2"}) + require.NoError(t, err) + + node1 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test1", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + node2 := types.Node{ + ID: 0, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test2", + UserID: &user2.ID, + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{}, + } + + err = db.DB.Save(&node1).Error + require.NoError(t, err) + + err = db.DB.Save(&node2).Error + require.NoError(t, err) + + err = db.DB.Transaction(func(tx *gorm.DB) error { + _, err := RegisterNodeForTest(tx, node1, nil, nil) + if err != nil { + return err + } + _, err = RegisterNodeForTest(tx, node2, nil, nil) + + return err + }) + require.NoError(t, err) + + nodes, err := db.ListNodes() + require.NoError(t, err) + + assert.Len(t, nodes, 2) + + // No parameter means no filter, should return all nodes + nodes, err = db.ListNodes() + require.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) + + // Empty node list should return all nodes + nodes, err = db.ListNodes(types.NodeIDs{}...) + require.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) + + // No match in IDs should return empty list and no error + nodes, err = db.ListNodes(types.NodeIDs{3, 4, 5}...) + require.NoError(t, err) + assert.Empty(t, nodes) + + // Partial match in IDs + nodes, err = db.ListNodes(types.NodeIDs{2, 3}...) + require.NoError(t, err) + assert.Len(t, nodes, 1) + assert.Equal(t, "test2", nodes[0].Hostname) + + // Several matched IDs + nodes, err = db.ListNodes(types.NodeIDs{1, 2, 3}...) + require.NoError(t, err) + assert.Len(t, nodes, 2) + assert.Equal(t, "test1", nodes[0].Hostname) + assert.Equal(t, "test2", nodes[1].Hostname) } diff --git a/hscontrol/db/policy.go b/hscontrol/db/policy.go new file mode 100644 index 00000000..bdc8af41 --- /dev/null +++ b/hscontrol/db/policy.go @@ -0,0 +1,91 @@ +package db + +import ( + "errors" + "os" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "gorm.io/gorm" + "gorm.io/gorm/clause" +) + +// SetPolicy sets the policy in the database. +func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) { + // Create a new policy. + p := types.Policy{ + Data: policy, + } + + if err := hsdb.DB.Clauses(clause.Returning{}).Create(&p).Error; err != nil { + return nil, err + } + + return &p, nil +} + +// GetPolicy returns the latest policy in the database. +func (hsdb *HSDatabase) GetPolicy() (*types.Policy, error) { + return GetPolicy(hsdb.DB) +} + +// GetPolicy returns the latest policy from the database. +// This standalone function can be used in contexts where HSDatabase is not available, +// such as during migrations. +func GetPolicy(tx *gorm.DB) (*types.Policy, error) { + var p types.Policy + + // Query: + // SELECT * FROM policies ORDER BY id DESC LIMIT 1; + err := tx. + Order("id DESC"). + Limit(1). + First(&p).Error + if err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return nil, types.ErrPolicyNotFound + } + + return nil, err + } + + return &p, nil +} + +// PolicyBytes loads policy configuration from file or database based on the configured mode. +// Returns nil if no policy is configured, which is valid. +// This standalone function can be used in contexts where HSDatabase is not available, +// such as during migrations. +func PolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) { + switch cfg.Policy.Mode { + case types.PolicyModeFile: + path := cfg.Policy.Path + + // It is fine to start headscale without a policy file. + if len(path) == 0 { + return nil, nil + } + + absPath := util.AbsolutePathFromConfigPath(path) + + return os.ReadFile(absPath) + + case types.PolicyModeDB: + p, err := GetPolicy(tx) + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err + } + + if p.Data == "" { + return nil, nil + } + + return []byte(p.Data), nil + } + + return nil, nil +} diff --git a/hscontrol/db/preauth_keys.go b/hscontrol/db/preauth_keys.go index e743988f..c5904353 100644 --- a/hscontrol/db/preauth_keys.go +++ b/hscontrol/db/preauth_keys.go @@ -1,42 +1,84 @@ package db import ( - "crypto/rand" - "encoding/hex" "errors" "fmt" + "slices" "strings" "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" + "tailscale.com/util/set" ) var ( - ErrPreAuthKeyNotFound = errors.New("AuthKey not found") - ErrPreAuthKeyExpired = errors.New("AuthKey expired") - ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used") + ErrPreAuthKeyNotFound = errors.New("auth-key not found") + ErrPreAuthKeyExpired = errors.New("auth-key expired") + ErrSingleUseAuthKeyHasBeenUsed = errors.New("auth-key has already been used") ErrUserMismatch = errors.New("user mismatch") - ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid") + ErrPreAuthKeyACLTagInvalid = errors.New("auth-key tag is invalid") ) -// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. func (hsdb *HSDatabase) CreatePreAuthKey( - userName string, + uid *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string, -) (*types.PreAuthKey, error) { - // TODO(kradalby): figure out this lock - // hsdb.mu.Lock() - // defer hsdb.mu.Unlock() +) (*types.PreAuthKeyNew, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.PreAuthKeyNew, error) { + return CreatePreAuthKey(tx, uid, reusable, ephemeral, expiration, aclTags) + }) +} - user, err := hsdb.GetUser(userName) - if err != nil { - return nil, err +const ( + authKeyPrefix = "hskey-auth-" + authKeyPrefixLength = 12 + authKeyLength = 64 +) + +// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it. +// The uid parameter can be nil for system-created tagged keys. +// For tagged keys, uid tracks "created by" (who created the key). +// For user-owned keys, uid tracks the node owner. +func CreatePreAuthKey( + tx *gorm.DB, + uid *types.UserID, + reusable bool, + ephemeral bool, + expiration *time.Time, + aclTags []string, +) (*types.PreAuthKeyNew, error) { + // Validate: must be tagged OR user-owned, not neither + if uid == nil && len(aclTags) == 0 { + return nil, ErrPreAuthKeyNotTaggedOrOwned } + var ( + user *types.User + userID *uint + ) + + if uid != nil { + var err error + + user, err = GetUserByID(tx, *uid) + if err != nil { + return nil, err + } + + userID = &user.ID + } + + // Remove duplicates and sort for consistency + aclTags = set.SetOf(aclTags).Slice() + slices.Sort(aclTags) + + // TODO(kradalby): factor out and create a reusable tag validation, + // check if there is one in Tailscale's lib. for _, tag := range aclTags { if !strings.HasPrefix(tag, "tag:") { return nil, fmt.Errorf( @@ -48,182 +90,250 @@ func (hsdb *HSDatabase) CreatePreAuthKey( } now := time.Now().UTC() - kstr, err := hsdb.generateKey() + + prefix, err := util.GenerateRandomStringURLSafe(authKeyPrefixLength) + if err != nil { + return nil, err + } + + // Validate generated prefix (should always be valid, but be defensive) + if len(prefix) != authKeyPrefixLength { + return nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyPrefixLength, len(prefix)) + } + + if !isValidBase64URLSafe(prefix) { + return nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrPreAuthKeyFailedToParse) + } + + toBeHashed, err := util.GenerateRandomStringURLSafe(authKeyLength) + if err != nil { + return nil, err + } + + // Validate generated hash (should always be valid, but be defensive) + if len(toBeHashed) != authKeyLength { + return nil, fmt.Errorf("%w: generated hash has invalid length: expected %d, got %d", ErrPreAuthKeyFailedToParse, authKeyLength, len(toBeHashed)) + } + + if !isValidBase64URLSafe(toBeHashed) { + return nil, fmt.Errorf("%w: generated hash contains invalid characters", ErrPreAuthKeyFailedToParse) + } + + keyStr := authKeyPrefix + prefix + "-" + toBeHashed + + hash, err := bcrypt.GenerateFromPassword([]byte(toBeHashed), bcrypt.DefaultCost) if err != nil { return nil, err } key := types.PreAuthKey{ - Key: kstr, - UserID: user.ID, - User: *user, + UserID: userID, // nil for system-created keys, or "created by" for tagged keys + User: user, // nil for system-created keys Reusable: reusable, Ephemeral: ephemeral, CreatedAt: &now, Expiration: expiration, + Tags: aclTags, // empty for user-owned keys + Prefix: prefix, // Store prefix + Hash: hash, // Store hash } - err = hsdb.db.Transaction(func(db *gorm.DB) error { - if err := db.Save(&key).Error; err != nil { - return fmt.Errorf("failed to create key in the database: %w", err) - } + if err := tx.Save(&key).Error; err != nil { + return nil, fmt.Errorf("failed to create key in the database: %w", err) + } - if len(aclTags) > 0 { - seenTags := map[string]bool{} + return &types.PreAuthKeyNew{ + ID: key.ID, + Key: keyStr, + Reusable: key.Reusable, + Ephemeral: key.Ephemeral, + Tags: key.Tags, + Expiration: key.Expiration, + CreatedAt: key.CreatedAt, + User: key.User, + }, nil +} - for _, tag := range aclTags { - if !seenTags[tag] { - if err := db.Save(&types.PreAuthKeyACLTag{PreAuthKeyID: key.ID, Tag: tag}).Error; err != nil { - return fmt.Errorf( - "failed to ceate key tag in the database: %w", - err, - ) - } - seenTags[tag] = true - } - } - } - - return nil +func (hsdb *HSDatabase) ListPreAuthKeys() ([]types.PreAuthKey, error) { + return Read(hsdb.DB, func(rx *gorm.DB) ([]types.PreAuthKey, error) { + return ListPreAuthKeys(rx) }) - - if err != nil { - return nil, err - } - - return &key, nil } -// ListPreAuthKeys returns the list of PreAuthKeys for a user. -func (hsdb *HSDatabase) ListPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +// ListPreAuthKeys returns all PreAuthKeys in the database. +func ListPreAuthKeys(tx *gorm.DB) ([]types.PreAuthKey, error) { + var keys []types.PreAuthKey - return hsdb.listPreAuthKeys(userName) -} - -func (hsdb *HSDatabase) listPreAuthKeys(userName string) ([]types.PreAuthKey, error) { - user, err := hsdb.getUser(userName) + err := tx.Preload("User").Find(&keys).Error if err != nil { return nil, err } - keys := []types.PreAuthKey{} - if err := hsdb.db.Preload("User").Preload("ACLTags").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil { - return nil, err - } - return keys, nil } -// GetPreAuthKey returns a PreAuthKey for a given key. -func (hsdb *HSDatabase) GetPreAuthKey(user string, key string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +var ( + ErrPreAuthKeyFailedToParse = errors.New("failed to parse auth-key") + ErrPreAuthKeyNotTaggedOrOwned = errors.New("auth-key must be either tagged or owned by user") +) - pak, err := hsdb.ValidatePreAuthKey(key) +func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) { + var pak types.PreAuthKey + + // Validate input is not empty + if keyStr == "" { + return nil, ErrPreAuthKeyFailedToParse + } + + _, prefixAndHash, found := strings.Cut(keyStr, authKeyPrefix) + + if !found { + // Legacy format (plaintext) - backwards compatibility + err := tx.Preload("User").First(&pak, "key = ?", keyStr).Error + if err != nil { + return nil, ErrPreAuthKeyNotFound + } + + return &pak, nil + } + + // New format: hskey-auth-{12-char-prefix}-{64-char-hash} + // Expected minimum length: 12 (prefix) + 1 (separator) + 64 (hash) = 77 + const expectedMinLength = authKeyPrefixLength + 1 + authKeyLength + if len(prefixAndHash) < expectedMinLength { + return nil, fmt.Errorf( + "%w: key too short, expected at least %d chars after prefix, got %d", + ErrPreAuthKeyFailedToParse, + expectedMinLength, + len(prefixAndHash), + ) + } + + // Use fixed-length parsing instead of separator-based to handle dashes in base64 URL-safe + prefix := prefixAndHash[:authKeyPrefixLength] + + // Validate separator at expected position + if prefixAndHash[authKeyPrefixLength] != '-' { + return nil, fmt.Errorf( + "%w: expected separator '-' at position %d, got '%c'", + ErrPreAuthKeyFailedToParse, + authKeyPrefixLength, + prefixAndHash[authKeyPrefixLength], + ) + } + + hash := prefixAndHash[authKeyPrefixLength+1:] + + // Validate hash length + if len(hash) != authKeyLength { + return nil, fmt.Errorf( + "%w: hash length mismatch, expected %d chars, got %d", + ErrPreAuthKeyFailedToParse, + authKeyLength, + len(hash), + ) + } + + // Validate prefix contains only base64 URL-safe characters + if !isValidBase64URLSafe(prefix) { + return nil, fmt.Errorf( + "%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrPreAuthKeyFailedToParse, + ) + } + + // Validate hash contains only base64 URL-safe characters + if !isValidBase64URLSafe(hash) { + return nil, fmt.Errorf( + "%w: hash contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)", + ErrPreAuthKeyFailedToParse, + ) + } + + // Look up key by prefix + err := tx.Preload("User").First(&pak, "prefix = ?", prefix).Error if err != nil { - return nil, err + return nil, ErrPreAuthKeyNotFound } - if pak.User.Name != user { - return nil, ErrUserMismatch + // Verify hash matches + err = bcrypt.CompareHashAndPassword(pak.Hash, []byte(hash)) + if err != nil { + return nil, fmt.Errorf("invalid auth key: %w", err) } - return pak, nil + return &pak, nil +} + +// isValidBase64URLSafe checks if a string contains only base64 URL-safe characters. +func isValidBase64URLSafe(s string) bool { + for _, c := range s { + if (c < 'A' || c > 'Z') && (c < 'a' || c > 'z') && (c < '0' || c > '9') && c != '-' && c != '_' { + return false + } + } + + return true +} + +func (hsdb *HSDatabase) GetPreAuthKey(key string) (*types.PreAuthKey, error) { + return GetPreAuthKey(hsdb.DB, key) +} + +// GetPreAuthKey returns a PreAuthKey for a given key. The caller is responsible +// for checking if the key is usable (expired or used). +func GetPreAuthKey(tx *gorm.DB, key string) (*types.PreAuthKey, error) { + return findAuthKey(tx, key) } // DestroyPreAuthKey destroys a preauthkey. Returns error if the PreAuthKey -// does not exist. -func (hsdb *HSDatabase) DestroyPreAuthKey(pak types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.destroyPreAuthKey(pak) -} - -func (hsdb *HSDatabase) destroyPreAuthKey(pak types.PreAuthKey) error { - return hsdb.db.Transaction(func(db *gorm.DB) error { - if result := db.Unscoped().Where(types.PreAuthKeyACLTag{PreAuthKeyID: pak.ID}).Delete(&types.PreAuthKeyACLTag{}); result.Error != nil { - return result.Error +// does not exist. This also clears the auth_key_id on any nodes that reference +// this key. +func DestroyPreAuthKey(tx *gorm.DB, id uint64) error { + return tx.Transaction(func(db *gorm.DB) error { + // First, clear the foreign key reference on any nodes using this key + err := db.Model(&types.Node{}). + Where("auth_key_id = ?", id). + Update("auth_key_id", nil).Error + if err != nil { + return fmt.Errorf("failed to clear auth_key_id on nodes: %w", err) } - if result := db.Unscoped().Delete(pak); result.Error != nil { - return result.Error + // Then delete the pre-auth key + err = tx.Unscoped().Delete(&types.PreAuthKey{}, id).Error + if err != nil { + return err } return nil }) } -// MarkExpirePreAuthKey marks a PreAuthKey as expired. -func (hsdb *HSDatabase) ExpirePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() +func (hsdb *HSDatabase) ExpirePreAuthKey(id uint64) error { + return hsdb.Write(func(tx *gorm.DB) error { + return ExpirePreAuthKey(tx, id) + }) +} - if err := hsdb.db.Model(&k).Update("Expiration", time.Now()).Error; err != nil { - return err - } - - return nil +func (hsdb *HSDatabase) DeletePreAuthKey(id uint64) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DestroyPreAuthKey(tx, id) + }) } // UsePreAuthKey marks a PreAuthKey as used. -func (hsdb *HSDatabase) UsePreAuthKey(k *types.PreAuthKey) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - k.Used = true - if err := hsdb.db.Save(k).Error; err != nil { +func UsePreAuthKey(tx *gorm.DB, k *types.PreAuthKey) error { + err := tx.Model(k).Update("used", true).Error + if err != nil { return fmt.Errorf("failed to update key used status in the database: %w", err) } + k.Used = true return nil } -// ValidatePreAuthKey does the heavy lifting for validation of the PreAuthKey coming from a node -// If returns no error and a PreAuthKey, it can be used. -func (hsdb *HSDatabase) ValidatePreAuthKey(k string) (*types.PreAuthKey, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - pak := types.PreAuthKey{} - if result := hsdb.db.Preload("User").Preload("ACLTags").First(&pak, "key = ?", k); errors.Is( - result.Error, - gorm.ErrRecordNotFound, - ) { - return nil, ErrPreAuthKeyNotFound - } - - if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { - return nil, ErrPreAuthKeyExpired - } - - if pak.Reusable || pak.Ephemeral { // we don't need to check if has been used before - return &pak, nil - } - - nodes := types.Nodes{} - if err := hsdb.db. - Preload("AuthKey"). - Where(&types.Node{AuthKeyID: uint(pak.ID)}). - Find(&nodes).Error; err != nil { - return nil, err - } - - if len(nodes) != 0 || pak.Used { - return nil, ErrSingleUseAuthKeyHasBeenUsed - } - - return &pak, nil -} - -func (hsdb *HSDatabase) generateKey() (string, error) { - size := 24 - bytes := make([]byte, size) - if _, err := rand.Read(bytes); err != nil { - return "", err - } - - return hex.EncodeToString(bytes), nil +// MarkExpirePreAuthKey marks a PreAuthKey as expired. +func ExpirePreAuthKey(tx *gorm.DB, id uint64) error { + now := time.Now() + return tx.Model(&types.PreAuthKey{}).Where("id = ?", id).Update("expiration", now).Error } diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index df9c2a10..7c5dcbd7 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -1,202 +1,447 @@ package db import ( + "fmt" + "slices" + "strings" + "testing" "time" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "gopkg.in/check.v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/ptr" ) -func (*Suite) TestCreatePreAuthKey(c *check.C) { - _, err := db.CreatePreAuthKey("bogus", true, false, nil, nil) +func TestCreatePreAuthKey(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "error_invalid_user_id", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - c.Assert(err, check.NotNil) + _, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil) + assert.Error(t, err) + }, + }, + { + name: "success_create_and_list", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) - key, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) - c.Assert(err, check.IsNil) + key, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, key.Key) - // Did we get a valid key? - c.Assert(key.Key, check.NotNil) - c.Assert(len(key.Key), check.Equals, 48) + // List keys for the user + keys, err := db.ListPreAuthKeys() + require.NoError(t, err) + assert.Len(t, keys, 1) - // Make sure the User association is populated - c.Assert(key.User.Name, check.Equals, user.Name) + // Verify User association is populated + assert.Equal(t, user.ID, keys[0].User.ID) + }, + }, + } - _, err = db.ListPreAuthKeys("bogus") - c.Assert(err, check.NotNil) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) - keys, err := db.ListPreAuthKeys(user.Name) - c.Assert(err, check.IsNil) - c.Assert(len(keys), check.Equals, 1) - - // Make sure the User association is populated - c.Assert((keys)[0].User.Name, check.Equals, user.Name) + tt.test(t, db) + }) + } } -func (*Suite) TestExpiredPreAuthKey(c *check.C) { - user, err := db.CreateUser("test2") - c.Assert(err, check.IsNil) +func TestPreAuthKeyACLTags(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "reject_malformed_tags", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + user, err := db.CreateUser(types.User{Name: "test-tags-1"}) + require.NoError(t, err) + + _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"badtag"}) + assert.Error(t, err) + }, + }, + { + name: "deduplicate_and_sort_tags", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + user, err := db.CreateUser(types.User{Name: "test-tags-2"}) + require.NoError(t, err) + + expectedTags := []string{"tag:test1", "tag:test2"} + tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} + + _, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate) + require.NoError(t, err) + + listedPaks, err := db.ListPreAuthKeys() + require.NoError(t, err) + require.Len(t, listedPaks, 1) + + gotTags := listedPaks[0].Proto().GetAclTags() + slices.Sort(gotTags) + assert.Equal(t, expectedTags, gotTags) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) + } +} + +func TestCannotDeleteAssignedPreAuthKey(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + user, err := db.CreateUser(types.User{Name: "test8"}) + require.NoError(t, err) + + key, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:good"}) + require.NoError(t, err) + + node := types.Node{ + ID: 0, + Hostname: "testest", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(key.ID), + } + db.DB.Save(&node) + + err = db.DB.Delete(&types.PreAuthKey{ID: key.ID}).Error + require.ErrorContains(t, err, "constraint failed: FOREIGN KEY constraint failed") +} + +func TestPreAuthKeyAuthentication(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user := db.CreateUserForTest("test-user") + + tests := []struct { + name string + setupKey func() string // Returns key string to test + wantFindErr bool // Error when finding the key + wantValidateErr bool // Error when validating the key + validateResult func(*testing.T, *types.PreAuthKey) + }{ + { + name: "legacy_key_plaintext", + setupKey: func() string { + // Insert legacy key directly using GORM (simulate existing production key) + // Note: We use raw SQL to bypass GORM's handling and set prefix to empty string + // which simulates how legacy keys exist in production databases + legacyKey := "abc123def456ghi789jkl012mno345pqr678stu901vwx234yz" + now := time.Now() + + // Use raw SQL to insert with empty prefix to avoid UNIQUE constraint + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at) + VALUES (?, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, true, false, false, now).Error + require.NoError(t, err) + + return legacyKey + }, + wantFindErr: false, + wantValidateErr: false, + validateResult: func(t *testing.T, pak *types.PreAuthKey) { + t.Helper() + + assert.Equal(t, user.ID, *pak.UserID) + assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated + assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix + assert.Nil(t, pak.Hash) // Legacy keys have nil Hash + }, + }, + { + name: "new_key_bcrypt", + setupKey: func() string { + // Create new key via API + keyStr, err := db.CreatePreAuthKey( + user.TypedID(), + true, false, nil, []string{"tag:test"}, + ) + require.NoError(t, err) + + return keyStr.Key + }, + wantFindErr: false, + wantValidateErr: false, + validateResult: func(t *testing.T, pak *types.PreAuthKey) { + t.Helper() + + assert.Equal(t, user.ID, *pak.UserID) + assert.Empty(t, pak.Key) // New keys have empty Key + assert.NotEmpty(t, pak.Prefix) // New keys have Prefix + assert.NotNil(t, pak.Hash) // New keys have Hash + assert.Len(t, pak.Prefix, 12) // Prefix is 12 chars + }, + }, + { + name: "new_key_format_validation", + setupKey: func() string { + keyStr, err := db.CreatePreAuthKey( + user.TypedID(), + true, false, nil, nil, + ) + require.NoError(t, err) + + // Verify format: hskey-auth-{12-char-prefix}-{64-char-hash} + // Use fixed-length parsing since prefix/hash can contain dashes (base64 URL-safe) + assert.True(t, strings.HasPrefix(keyStr.Key, "hskey-auth-")) + + // Extract prefix and hash using fixed-length parsing like the real code does + _, prefixAndHash, found := strings.Cut(keyStr.Key, "hskey-auth-") + assert.True(t, found) + assert.GreaterOrEqual(t, len(prefixAndHash), 12+1+64) // prefix + '-' + hash minimum + + prefix := prefixAndHash[:12] + assert.Len(t, prefix, 12) // Prefix is 12 chars + assert.Equal(t, byte('-'), prefixAndHash[12]) // Separator + hash := prefixAndHash[13:] + assert.Len(t, hash, 64) // Hash is 64 chars + + return keyStr.Key + }, + wantFindErr: false, + wantValidateErr: false, + }, + { + name: "invalid_bcrypt_hash", + setupKey: func() string { + // Create valid key + key, err := db.CreatePreAuthKey( + user.TypedID(), + true, false, nil, nil, + ) + require.NoError(t, err) + + keyStr := key.Key + + // Return key with tampered hash using fixed-length parsing + _, prefixAndHash, _ := strings.Cut(keyStr, "hskey-auth-") + prefix := prefixAndHash[:12] + + return "hskey-auth-" + prefix + "-" + "wrong_hash_here_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "empty_key", + setupKey: func() string { + return "" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "key_too_short", + setupKey: func() string { + return "hskey-auth-short" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "missing_separator", + setupKey: func() string { + return "hskey-auth-ABCDEFGHIJKLabcdefghijklmnopqrstuvwxyz1234567890ABCDEFGHIJKLMNOPQRSTUVWXYZ" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "hash_too_short", + setupKey: func() string { + return "hskey-auth-ABCDEFGHIJKL-short" + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "prefix_with_invalid_chars", + setupKey: func() string { + return "hskey-auth-ABC$EF@HIJKL-" + strings.Repeat("a", 64) + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "hash_with_invalid_chars", + setupKey: func() string { + return "hskey-auth-ABCDEFGHIJKL-" + "invalid$chars" + strings.Repeat("a", 54) + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "prefix_not_found_in_db", + setupKey: func() string { + // Create a validly formatted key but with a prefix that doesn't exist + return "hskey-auth-NotInDB12345-" + strings.Repeat("a", 64) + }, + wantFindErr: true, + wantValidateErr: false, + }, + { + name: "expired_legacy_key", + setupKey: func() string { + legacyKey := "expired_legacy_key_123456789012345678901234" + now := time.Now() + expiration := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago + + // Use raw SQL to avoid UNIQUE constraint on empty prefix + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at, expiration) + VALUES (?, ?, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, true, false, false, now, expiration).Error + require.NoError(t, err) + + return legacyKey + }, + wantFindErr: false, + wantValidateErr: true, + }, + { + name: "used_single_use_legacy_key", + setupKey: func() string { + legacyKey := "used_legacy_key_123456789012345678901234567" + now := time.Now() + + // Use raw SQL to avoid UNIQUE constraint on empty prefix + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, user_id, reusable, ephemeral, used, created_at) + VALUES (?, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, false, false, true, now).Error + require.NoError(t, err) + + return legacyKey + }, + wantFindErr: false, + wantValidateErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + keyStr := tt.setupKey() + + pak, err := db.GetPreAuthKey(keyStr) + + if tt.wantFindErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + require.NotNil(t, pak) + + // Check validation if needed + if tt.wantValidateErr { + err := pak.Validate() + assert.Error(t, err) + + return + } + + if tt.validateResult != nil { + tt.validateResult(t, pak) + } + }) + } +} + +func TestMultipleLegacyKeysAllowed(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user, err := db.CreateUser(types.User{Name: "test-legacy"}) + require.NoError(t, err) + + // Create multiple legacy keys by directly inserting with empty prefix + // This simulates the migration scenario where existing databases have multiple + // plaintext keys without prefix/hash fields now := time.Now() - pak, err := db.CreatePreAuthKey(user.Name, true, false, &now, nil) - c.Assert(err, check.IsNil) - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrPreAuthKeyExpired) - c.Assert(key, check.IsNil) -} + for i := range 5 { + legacyKey := fmt.Sprintf("legacy_key_%d_%s", i, strings.Repeat("x", 40)) -func (*Suite) TestPreAuthKeyDoesNotExist(c *check.C) { - key, err := db.ValidatePreAuthKey("potatoKey") - c.Assert(err, check.Equals, ErrPreAuthKeyNotFound) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestValidateKeyOk(c *check.C) { - user, err := db.CreateUser("test3") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestAlreadyUsedKey(c *check.C) { - user, err := db.CreateUser("test4") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + err := db.DB.Exec(` + INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at) + VALUES (?, '', NULL, ?, ?, ?, ?, ?) + `, legacyKey, user.ID, true, false, false, now).Error + require.NoError(t, err, "should allow multiple legacy keys with empty prefix") } - db.db.Save(&node) - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestReusableBeingUsedKey(c *check.C) { - user, err := db.CreateUser("test5") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 1, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestNotReusableNotBeingUsedKey(c *check.C) { - user, err := db.CreateUser("test6") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.IsNil) - c.Assert(key.ID, check.Equals, pak.ID) -} - -func (*Suite) TestEphemeralKey(c *check.C) { - user, err := db.CreateUser("test7") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, true, nil, nil) - c.Assert(err, check.IsNil) - - now := time.Now().Add(-time.Second * 30) - node := types.Node{ - ID: 0, - Hostname: "testest", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - LastSeen: &now, - AuthKeyID: uint(pak.ID), - } - db.db.Save(&node) - - _, err = db.ValidatePreAuthKey(pak.Key) - // Ephemeral keys are by definition reusable - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test7", "testest") - c.Assert(err, check.IsNil) - - db.ExpireEphemeralNodes(time.Second * 20) - - // The machine record should have been deleted - _, err = db.GetNode("test7", "testest") - c.Assert(err, check.NotNil) -} - -func (*Suite) TestExpirePreauthKey(c *check.C) { - user, err := db.CreateUser("test3") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, true, false, nil, nil) - c.Assert(err, check.IsNil) - c.Assert(pak.Expiration, check.IsNil) - - err = db.ExpirePreAuthKey(pak) - c.Assert(err, check.IsNil) - c.Assert(pak.Expiration, check.NotNil) - - key, err := db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrPreAuthKeyExpired) - c.Assert(key, check.IsNil) -} - -func (*Suite) TestNotReusableMarkedAsUsed(c *check.C) { - user, err := db.CreateUser("test6") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - pak.Used = true - db.db.Save(&pak) - - _, err = db.ValidatePreAuthKey(pak.Key) - c.Assert(err, check.Equals, ErrSingleUseAuthKeyHasBeenUsed) -} - -func (*Suite) TestPreAuthKeyACLTags(c *check.C) { - user, err := db.CreateUser("test8") - c.Assert(err, check.IsNil) - - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, []string{"badtag"}) - c.Assert(err, check.NotNil) // Confirm that malformed tags are rejected - - tags := []string{"tag:test1", "tag:test2"} - tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"} - _, err = db.CreatePreAuthKey(user.Name, false, false, nil, tagsWithDuplicate) - c.Assert(err, check.IsNil) - - listedPaks, err := db.ListPreAuthKeys("test8") - c.Assert(err, check.IsNil) - c.Assert(listedPaks[0].Proto().GetAclTags(), check.DeepEquals, tags) + // Verify all legacy keys can be retrieved + var legacyKeys []types.PreAuthKey + + err = db.DB.Where("prefix = '' OR prefix IS NULL").Find(&legacyKeys).Error + require.NoError(t, err) + assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys") + + // Now create new bcrypt-based keys - these should have unique prefixes + key1, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, key1.Key) + + key2, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) + require.NoError(t, err) + assert.NotEmpty(t, key2.Key) + + // Verify the new keys have different prefixes + pak1, err := db.GetPreAuthKey(key1.Key) + require.NoError(t, err) + assert.NotEmpty(t, pak1.Prefix) + + pak2, err := db.GetPreAuthKey(key2.Key) + require.NoError(t, err) + assert.NotEmpty(t, pak2.Prefix) + + assert.NotEqual(t, pak1.Prefix, pak2.Prefix, "new keys should have unique prefixes") + + // Verify we cannot manually insert duplicate non-empty prefixes + duplicatePrefix := "test_prefix1" + hash1 := []byte("hash1") + hash2 := []byte("hash2") + + // First insert should succeed + err = db.DB.Exec(` + INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at) + VALUES ('', ?, ?, ?, ?, ?, ?, ?) + `, duplicatePrefix, hash1, user.ID, true, false, false, now).Error + require.NoError(t, err, "first key with prefix should succeed") + + // Second insert with same prefix should fail + err = db.DB.Exec(` + INSERT INTO pre_auth_keys (key, prefix, hash, user_id, reusable, ephemeral, used, created_at) + VALUES ('', ?, ?, ?, ?, ?, ?, ?) + `, duplicatePrefix, hash2, user.ID, true, false, false, now).Error + require.Error(t, err, "duplicate non-empty prefix should be rejected") + assert.Contains(t, err.Error(), "UNIQUE constraint failed", "should fail with UNIQUE constraint error") } diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go deleted file mode 100644 index 51c7f3bc..00000000 --- a/hscontrol/db/routes.go +++ /dev/null @@ -1,714 +0,0 @@ -package db - -import ( - "errors" - "net/netip" - - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/rs/zerolog/log" - "github.com/samber/lo" - "gorm.io/gorm" - "tailscale.com/types/key" -) - -var ErrRouteIsNotAvailable = errors.New("route is not available") - -func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoutes() -} - -func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Preload("Node.User"). - Find(&routes).Error - if err != nil { - return nil, err - } - - return routes, nil -} - -func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Preload("Node.User"). - Where("advertised = ? AND enabled = ?", true, true). - Find(&routes).Error - if err != nil { - return nil, err - } - - return routes, nil -} - -func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) { - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Preload("Node.User"). - Where("prefix = ?", types.IPPrefix(pref)). - Find(&routes).Error - if err != nil { - return nil, err - } - - return routes, nil -} - -func (hsdb *HSDatabase) GetNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeAdvertisedRoutes(node) -} - -func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, error) { - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Preload("Node.User"). - Where("node_id = ? AND advertised = true", node.ID). - Find(&routes).Error - if err != nil { - return nil, err - } - - return routes, nil -} - -func (hsdb *HSDatabase) GetNodeRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getNodeRoutes(node) -} - -func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Preload("Node.User"). - Where("node_id = ?", node.ID). - Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, err - } - - return routes, nil -} - -func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getRoute(id) -} - -func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { - var route types.Route - err := hsdb.db. - Preload("Node"). - Preload("Node.User"). - First(&route, id).Error - if err != nil { - return nil, err - } - - return &route, nil -} - -func (hsdb *HSDatabase) EnableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.enableRoute(id) -} - -func (hsdb *HSDatabase) enableRoute(id uint64) error { - route, err := hsdb.getRoute(id) - if err != nil { - return err - } - - // Tailscale requires both IPv4 and IPv6 exit routes to - // be enabled at the same time, as per - // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if route.IsExitRoute() { - return hsdb.enableRoutes( - &route.Node, - types.ExitRouteV4.String(), - types.ExitRouteV6.String(), - ) - } - - return hsdb.enableRoutes(&route.Node, netip.Prefix(route.Prefix).String()) -} - -func (hsdb *HSDatabase) DisableRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - route, err := hsdb.getRoute(id) - if err != nil { - return err - } - - var routes types.Routes - node := route.Node - - // Tailscale requires both IPv4 and IPv6 exit routes to - // be enabled at the same time, as per - // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if !route.IsExitRoute() { - err = hsdb.failoverRouteWithNotify(route) - if err != nil { - return err - } - - route.Enabled = false - route.IsPrimary = false - err = hsdb.db.Save(route).Error - if err != nil { - return err - } - } else { - routes, err = hsdb.getNodeRoutes(&node) - if err != nil { - return err - } - - for i := range routes { - if routes[i].IsExitRoute() { - routes[i].Enabled = false - routes[i].IsPrimary = false - err = hsdb.db.Save(&routes[i]).Error - if err != nil { - return err - } - } - } - } - - if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) - if err != nil { - return err - } - } - - node.Routes = routes - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DisableRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - - return nil -} - -func (hsdb *HSDatabase) DeleteRoute(id uint64) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - route, err := hsdb.getRoute(id) - if err != nil { - return err - } - - var routes types.Routes - node := route.Node - - // Tailscale requires both IPv4 and IPv6 exit routes to - // be enabled at the same time, as per - // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 - if !route.IsExitRoute() { - err := hsdb.failoverRouteWithNotify(route) - if err != nil { - return nil - } - - if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { - return err - } - } else { - routes, err := hsdb.getNodeRoutes(&node) - if err != nil { - return err - } - - routesToDelete := types.Routes{} - for _, r := range routes { - if r.IsExitRoute() { - routesToDelete = append(routesToDelete, r) - } - } - - if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { - return err - } - } - - if routes == nil { - routes, err = hsdb.getNodeRoutes(&node) - if err != nil { - return err - } - } - - node.Routes = routes - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{&node}, - Message: "called from db.DeleteRoute", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - - return nil -} - -func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) - if err != nil { - return err - } - - for i := range routes { - if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { - return err - } - - // TODO(kradalby): This is a bit too aggressive, we could probably - // figure out which routes needs to be failed over rather than all. - hsdb.failoverRouteWithNotify(&routes[i]) - } - - return nil -} - -// isUniquePrefix returns if there is another node providing the same route already. -func (hsdb *HSDatabase) isUniquePrefix(route types.Route) bool { - var count int64 - hsdb.db. - Model(&types.Route{}). - Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", - route.Prefix, - route.NodeID, - true, true).Count(&count) - - return count == 0 -} - -func (hsdb *HSDatabase) getPrimaryRoute(prefix netip.Prefix) (*types.Route, error) { - var route types.Route - err := hsdb.db. - Preload("Node"). - Where("prefix = ? AND advertised = ? AND enabled = ? AND is_primary = ?", types.IPPrefix(prefix), true, true, true). - First(&route).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - return nil, err - } - - if errors.Is(err, gorm.ErrRecordNotFound) { - return nil, gorm.ErrRecordNotFound - } - - return &route, nil -} - -// getNodePrimaryRoutes returns the routes that are enabled and marked as primary (for subnet failover) -// Exit nodes are not considered for this, as they are never marked as Primary. -func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Where("node_id = ? AND advertised = ? AND enabled = ? AND is_primary = ?", node.ID, true, true, true). - Find(&routes).Error - if err != nil { - return nil, err - } - - return routes, nil -} - -// SaveNodeRoutes takes a node and updates the database with -// the new routes. -// It returns a bool wheter an update should be sent as the -// saved route impacts nodes. -func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.saveNodeRoutes(node) -} - -func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { - sendUpdate := false - - currentRoutes := types.Routes{} - err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error - if err != nil { - return sendUpdate, err - } - - advertisedRoutes := map[netip.Prefix]bool{} - for _, prefix := range node.Hostinfo.RoutableIPs { - advertisedRoutes[prefix] = false - } - - log.Trace(). - Str("node", node.Hostname). - Interface("advertisedRoutes", advertisedRoutes). - Interface("currentRoutes", currentRoutes). - Msg("updating routes") - - for pos, route := range currentRoutes { - if _, ok := advertisedRoutes[netip.Prefix(route.Prefix)]; ok { - if !route.Advertised { - currentRoutes[pos].Advertised = true - err := hsdb.db.Save(¤tRoutes[pos]).Error - if err != nil { - return sendUpdate, err - } - - // If a route that is newly "saved" is already - // enabled, set sendUpdate to true as it is now - // available. - if route.Enabled { - sendUpdate = true - } - } - advertisedRoutes[netip.Prefix(route.Prefix)] = true - } else if route.Advertised { - currentRoutes[pos].Advertised = false - currentRoutes[pos].Enabled = false - err := hsdb.db.Save(¤tRoutes[pos]).Error - if err != nil { - return sendUpdate, err - } - } - } - - for prefix, exists := range advertisedRoutes { - if !exists { - route := types.Route{ - NodeID: node.ID, - Prefix: types.IPPrefix(prefix), - Advertised: true, - Enabled: false, - } - err := hsdb.db.Create(&route).Error - if err != nil { - return sendUpdate, err - } - } - } - - return sendUpdate, nil -} - -// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route -// currently have a functioning host that exposes the network. -func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error { - nodeRoutes, err := hsdb.getNodeRoutes(node) - if err != nil { - return nil - } - - for _, nodeRoute := range nodeRoutes { - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix)) - if err != nil { - return err - } - - for _, route := range routes { - if route.IsPrimary { - // if we have a primary route, and the node is connected - // nothing needs to be done. - if hsdb.notifier.IsConnected(route.Node.MachineKey) { - continue - } - - // if not, we need to failover the route - err := hsdb.failoverRouteWithNotify(&route) - if err != nil { - return err - } - } - } - } - - return nil -} - -func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error { - routes, err := hsdb.getNodeRoutes(node) - if err != nil { - return nil - } - - var changedKeys []key.MachinePublic - - for _, route := range routes { - changed, err := hsdb.failoverRoute(&route) - if err != nil { - return err - } - - changedKeys = append(changedKeys, changed...) - } - - changedKeys = lo.Uniq(changedKeys) - - var nodes types.Nodes - - for _, key := range changedKeys { - node, err := hsdb.GetNodeByMachineKey(key) - if err != nil { - return err - } - - nodes = append(nodes, node) - } - - if nodes != nil { - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.FailoverNodeRoutesWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - } - - return nil -} - -func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { - changedKeys, err := hsdb.failoverRoute(r) - if err != nil { - return err - } - - if len(changedKeys) == 0 { - return nil - } - - var nodes types.Nodes - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("loading machines with new primary routes from db") - - for _, key := range changedKeys { - node, err := hsdb.getNodeByMachineKey(key) - if err != nil { - return err - } - - nodes = append(nodes, node) - } - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notifying peers about primary route change") - - if nodes != nil { - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: nodes, - Message: "called from db.failoverRouteWithNotify", - } - if stateUpdate.Valid() { - hsdb.notifier.NotifyAll(stateUpdate) - } - } - - log.Trace(). - Str("hostname", r.Node.Hostname). - Msg("notified peers about primary route change") - - return nil -} - -// failoverRoute takes a route that is no longer available, -// this can be either from: -// - being disabled -// - being deleted -// - host going offline -// -// and tries to find a new route to take over its place. -// If the given route was not primary, it returns early. -func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) { - if r == nil { - return nil, nil - } - - // This route is not a primary route, and it isnt - // being served to nodes. - if !r.IsPrimary { - return nil, nil - } - - // We do not have to failover exit nodes - if r.IsExitRoute() { - return nil, nil - } - - routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix)) - if err != nil { - return nil, err - } - - var newPrimary *types.Route - - // Find a new suitable route - for idx, route := range routes { - if r.ID == route.ID { - continue - } - - if hsdb.notifier.IsConnected(route.Node.MachineKey) { - newPrimary = &routes[idx] - break - } - } - - // If a new route was not found/available, - // return with an error. - // We do not want to update the database as - // the one currently marked as primary is the - // best we got. - if newPrimary == nil { - return nil, nil - } - - log.Trace(). - Str("hostname", newPrimary.Node.Hostname). - Msg("found new primary, updating db") - - // Remove primary from the old route - r.IsPrimary = false - err = hsdb.db.Save(&r).Error - if err != nil { - log.Error().Err(err).Msg("error disabling new primary route") - - return nil, err - } - - log.Trace(). - Str("hostname", newPrimary.Node.Hostname). - Msg("removed primary from old route") - - // Set primary for the new primary - newPrimary.IsPrimary = true - err = hsdb.db.Save(&newPrimary).Error - if err != nil { - log.Error().Err(err).Msg("error enabling new primary route") - - return nil, err - } - - log.Trace(). - Str("hostname", newPrimary.Node.Hostname). - Msg("set primary to new route") - - // Return a list of the machinekeys of the changed nodes. - return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil -} - -// EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. -func (hsdb *HSDatabase) EnableAutoApprovedRoutes( - aclPolicy *policy.ACLPolicy, - node *types.Node, -) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - if len(node.IPAddresses) == 0 { - return nil // This node has no IPAddresses, so can't possibly match any autoApprovers ACLs - } - - routes, err := hsdb.getNodeAdvertisedRoutes(node) - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error(). - Caller(). - Err(err). - Str("node", node.Hostname). - Msg("Could not get advertised routes for node") - - return err - } - - approvedRoutes := types.Routes{} - - for _, advertisedRoute := range routes { - if advertisedRoute.Enabled { - continue - } - - routeApprovers, err := aclPolicy.AutoApprovers.GetRouteApprovers( - netip.Prefix(advertisedRoute.Prefix), - ) - if err != nil { - log.Err(err). - Str("advertisedRoute", advertisedRoute.String()). - Uint64("nodeId", node.ID). - Msg("Failed to resolve autoApprovers for advertised route") - - return err - } - - for _, approvedAlias := range routeApprovers { - if approvedAlias == node.User.Name { - approvedRoutes = append(approvedRoutes, advertisedRoute) - } else { - // TODO(kradalby): figure out how to get this to depend on less stuff - approvedIps, err := aclPolicy.ExpandAlias(types.Nodes{node}, approvedAlias) - if err != nil { - log.Err(err). - Str("alias", approvedAlias). - Msg("Failed to expand alias when processing autoApprovers policy") - - return err - } - - // approvedIPs should contain all of node's IPs if it matches the rule, so check for first - if approvedIps.Contains(node.IPAddresses[0]) { - approvedRoutes = append(approvedRoutes, advertisedRoute) - } - } - } - } - - for _, approvedRoute := range approvedRoutes { - err := hsdb.enableRoute(uint64(approvedRoute.ID)) - if err != nil { - log.Err(err). - Str("approvedRoute", approvedRoute.String()). - Uint64("nodeId", node.ID). - Msg("Failed to enable approved route") - - return err - } - } - - return nil -} diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go deleted file mode 100644 index d491b6a3..00000000 --- a/hscontrol/db/routes_test.go +++ /dev/null @@ -1,629 +0,0 @@ -package db - -import ( - "net/netip" - "os" - "testing" - "time" - - "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/notifier" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/stretchr/testify/assert" - "gopkg.in/check.v1" - "gorm.io/gorm" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -func (s *Suite) TestGetRoutes(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "test_get_route_node") - c.Assert(err, check.NotNil) - - route, err := netip.ParsePrefix("10.0.0.0/24") - c.Assert(err, check.IsNil) - - hostInfo := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{route}, - } - - node := types.Node{ - ID: 0, - Hostname: "test_get_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo, - } - db.db.Save(&node) - - su, err := db.SaveNodeRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(su, check.Equals, false) - - advertisedRoutes, err := db.GetAdvertisedRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(len(advertisedRoutes), check.Equals, 1) - - err = db.enableRoutes(&node, "192.168.0.0/24") - c.Assert(err, check.NotNil) - - err = db.enableRoutes(&node, "10.0.0.0/24") - c.Assert(err, check.IsNil) -} - -func (s *Suite) TestGetEnableRoutes(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "test_enable_route_node") - c.Assert(err, check.NotNil) - - route, err := netip.ParsePrefix( - "10.0.0.0/24", - ) - c.Assert(err, check.IsNil) - - route2, err := netip.ParsePrefix( - "150.0.10.0/25", - ) - c.Assert(err, check.IsNil) - - hostInfo := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{route, route2}, - } - - node := types.Node{ - ID: 0, - Hostname: "test_enable_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo, - } - db.db.Save(&node) - - sendUpdate, err := db.SaveNodeRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(sendUpdate, check.Equals, false) - - availableRoutes, err := db.GetAdvertisedRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(err, check.IsNil) - c.Assert(len(availableRoutes), check.Equals, 2) - - noEnabledRoutes, err := db.GetEnabledRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(len(noEnabledRoutes), check.Equals, 0) - - err = db.enableRoutes(&node, "192.168.0.0/24") - c.Assert(err, check.NotNil) - - err = db.enableRoutes(&node, "10.0.0.0/24") - c.Assert(err, check.IsNil) - - enabledRoutes, err := db.GetEnabledRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes), check.Equals, 1) - - // Adding it twice will just let it pass through - err = db.enableRoutes(&node, "10.0.0.0/24") - c.Assert(err, check.IsNil) - - enableRoutesAfterDoubleApply, err := db.GetEnabledRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(len(enableRoutesAfterDoubleApply), check.Equals, 1) - - err = db.enableRoutes(&node, "150.0.10.0/25") - c.Assert(err, check.IsNil) - - enabledRoutesWithAdditionalRoute, err := db.GetEnabledRoutes(&node) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutesWithAdditionalRoute), check.Equals, 2) -} - -func (s *Suite) TestIsUniquePrefix(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "test_enable_route_node") - c.Assert(err, check.NotNil) - - route, err := netip.ParsePrefix( - "10.0.0.0/24", - ) - c.Assert(err, check.IsNil) - - route2, err := netip.ParsePrefix( - "150.0.10.0/25", - ) - c.Assert(err, check.IsNil) - - hostInfo1 := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{route, route2}, - } - node1 := types.Node{ - ID: 1, - Hostname: "test_enable_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo1, - } - db.db.Save(&node1) - - sendUpdate, err := db.SaveNodeRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(sendUpdate, check.Equals, false) - - err = db.enableRoutes(&node1, route.String()) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&node1, route2.String()) - c.Assert(err, check.IsNil) - - hostInfo2 := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{route2}, - } - node2 := types.Node{ - ID: 2, - Hostname: "test_enable_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo2, - } - db.db.Save(&node2) - - sendUpdate, err = db.SaveNodeRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(sendUpdate, check.Equals, false) - - err = db.enableRoutes(&node2, route2.String()) - c.Assert(err, check.IsNil) - - enabledRoutes1, err := db.GetEnabledRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 2) - - enabledRoutes2, err := db.GetEnabledRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes2), check.Equals, 1) - - routes, err := db.GetNodePrimaryRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 2) - - routes, err = db.GetNodePrimaryRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 0) -} - -func (s *Suite) TestDeleteRoutes(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "test_enable_route_node") - c.Assert(err, check.NotNil) - - prefix, err := netip.ParsePrefix( - "10.0.0.0/24", - ) - c.Assert(err, check.IsNil) - - prefix2, err := netip.ParsePrefix( - "150.0.10.0/25", - ) - c.Assert(err, check.IsNil) - - hostInfo1 := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{prefix, prefix2}, - } - - now := time.Now() - node1 := types.Node{ - ID: 1, - Hostname: "test_enable_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo1, - LastSeen: &now, - } - db.db.Save(&node1) - - sendUpdate, err := db.SaveNodeRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(sendUpdate, check.Equals, false) - - err = db.enableRoutes(&node1, prefix.String()) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&node1, prefix2.String()) - c.Assert(err, check.IsNil) - - routes, err := db.GetNodeRoutes(&node1) - c.Assert(err, check.IsNil) - - err = db.DeleteRoute(uint64(routes[0].ID)) - c.Assert(err, check.IsNil) - - enabledRoutes1, err := db.GetEnabledRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 1) -} - -func TestFailoverRoute(t *testing.T) { - ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } - - // TODO(kradalby): Count/verify updates - var sink chan types.StateUpdate - - go func() { - for range sink { - } - }() - - machineKeys := []key.MachinePublic{ - key.NewMachine().Public(), - key.NewMachine().Public(), - key.NewMachine().Public(), - key.NewMachine().Public(), - } - - tests := []struct { - name string - failingRoute types.Route - routes types.Routes - want []key.MachinePublic - wantErr bool - }{ - { - name: "no-route", - failingRoute: types.Route{}, - routes: types.Routes{}, - want: nil, - wantErr: false, - }, - { - name: "no-prime", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: false, - }, - routes: types.Routes{}, - want: nil, - wantErr: false, - }, - { - name: "exit-node", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("0.0.0.0/0"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - routes: types.Routes{}, - want: nil, - wantErr: false, - }, - { - name: "no-failover-single-route", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - routes: types.Routes{ - types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "failover-primary", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - routes: types.Routes{ - types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - types.Route{ - Model: gorm.Model{ - ID: 2, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[1], - }, - IsPrimary: false, - }, - }, - want: []key.MachinePublic{ - machineKeys[0], - machineKeys[1], - }, - wantErr: false, - }, - { - name: "failover-none-primary", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: false, - }, - routes: types.Routes{ - types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - types.Route{ - Model: gorm.Model{ - ID: 2, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[1], - }, - IsPrimary: false, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "failover-primary-multi-route", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 2, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[1], - }, - IsPrimary: true, - }, - routes: types.Routes{ - types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: false, - }, - types.Route{ - Model: gorm.Model{ - ID: 2, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[1], - }, - IsPrimary: true, - }, - types.Route{ - Model: gorm.Model{ - ID: 3, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[2], - }, - IsPrimary: false, - }, - }, - want: []key.MachinePublic{ - machineKeys[1], - machineKeys[0], - }, - wantErr: false, - }, - { - name: "failover-primary-no-online", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - routes: types.Routes{ - types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - // Offline - types.Route{ - Model: gorm.Model{ - ID: 2, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[3], - }, - IsPrimary: false, - }, - }, - want: nil, - wantErr: false, - }, - { - name: "failover-primary-one-not-online", - failingRoute: types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - routes: types.Routes{ - types.Route{ - Model: gorm.Model{ - ID: 1, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[0], - }, - IsPrimary: true, - }, - // Offline - types.Route{ - Model: gorm.Model{ - ID: 2, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[3], - }, - IsPrimary: false, - }, - types.Route{ - Model: gorm.Model{ - ID: 3, - }, - Prefix: ipp("10.0.0.0/24"), - Node: types.Node{ - MachineKey: machineKeys[1], - }, - IsPrimary: true, - }, - }, - want: []key.MachinePublic{ - machineKeys[0], - machineKeys[1], - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - tmpDir, err := os.MkdirTemp("", "failover-db-test") - assert.NoError(t, err) - - notif := notifier.NewNotifier() - - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notif, - []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), - }, - "", - ) - assert.NoError(t, err) - - // Pretend that all the nodes are connected to control - for idx, key := range machineKeys { - // Pretend one node is offline - if idx == 3 { - continue - } - - notif.AddNode(key, sink) - } - - for _, route := range tt.routes { - if err := db.db.Save(&route).Error; err != nil { - t.Fatalf("failed to create route: %s", err) - } - } - - got, err := db.failoverRoute(&tt.failingRoute) - - if (err != nil) != tt.wantErr { - t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { - t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff) - } - }) - } -} diff --git a/hscontrol/db/schema.sql b/hscontrol/db/schema.sql new file mode 100644 index 00000000..ef0a2a0e --- /dev/null +++ b/hscontrol/db/schema.sql @@ -0,0 +1,106 @@ +-- This file is the representation of the SQLite schema of Headscale. +-- It is the "source of truth" and is used to validate any migrations +-- that are run against the database to ensure it ends in the expected state. + +CREATE TABLE migrations(id text,PRIMARY KEY(id)); + +CREATE TABLE users( + id integer PRIMARY KEY AUTOINCREMENT, + name text, + display_name text, + email text, + provider_identifier text, + provider text, + profile_pic_url text, + + created_at datetime, + updated_at datetime, + deleted_at datetime +); +CREATE INDEX idx_users_deleted_at ON users(deleted_at); + + +-- The following three UNIQUE indexes work together to enforce the user identity model: +-- +-- 1. Users can be either local (provider_identifier is NULL) or from external providers (provider_identifier set) +-- 2. Each external provider identifier must be unique across the system +-- 3. Local usernames must be unique among local users +-- 4. The same username can exist across different providers with different identifiers +-- +-- Examples: +-- - Can create local user "alice" (provider_identifier=NULL) +-- - Can create external user "alice" with GitHub (name="alice", provider_identifier="alice_github") +-- - Can create external user "alice" with Google (name="alice", provider_identifier="alice_google") +-- - Cannot create another local user "alice" (blocked by idx_name_no_provider_identifier) +-- - Cannot create another user with provider_identifier="alice_github" (blocked by idx_provider_identifier) +-- - Cannot create user "bob" with provider_identifier="alice_github" (blocked by idx_name_provider_identifier) +CREATE UNIQUE INDEX idx_provider_identifier ON users(provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users(name, provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users(name) WHERE provider_identifier IS NULL; + +CREATE TABLE pre_auth_keys( + id integer PRIMARY KEY AUTOINCREMENT, + key text, + prefix text, + hash blob, + user_id integer, + reusable numeric, + ephemeral numeric DEFAULT false, + used numeric DEFAULT false, + tags text, + expiration datetime, + + created_at datetime, + + CONSTRAINT fk_pre_auth_keys_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE SET NULL +); +CREATE UNIQUE INDEX idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''; + +CREATE TABLE api_keys( + id integer PRIMARY KEY AUTOINCREMENT, + prefix text, + hash blob, + expiration datetime, + last_seen datetime, + + created_at datetime +); +CREATE UNIQUE INDEX idx_api_keys_prefix ON api_keys(prefix); + +CREATE TABLE nodes( + id integer PRIMARY KEY AUTOINCREMENT, + machine_key text, + node_key text, + disco_key text, + + endpoints text, + host_info text, + ipv4 text, + ipv6 text, + hostname text, + given_name varchar(63), + user_id integer, + register_method text, + tags text, + auth_key_id integer, + last_seen datetime, + expiry datetime, + approved_routes text, + + created_at datetime, + updated_at datetime, + deleted_at datetime, + + CONSTRAINT fk_nodes_user FOREIGN KEY(user_id) REFERENCES users(id) ON DELETE CASCADE, + CONSTRAINT fk_nodes_auth_key FOREIGN KEY(auth_key_id) REFERENCES pre_auth_keys(id) +); + +CREATE TABLE policies( + id integer PRIMARY KEY AUTOINCREMENT, + data text, + + created_at datetime, + updated_at datetime, + deleted_at datetime +); +CREATE INDEX idx_policies_deleted_at ON policies(deleted_at); diff --git a/hscontrol/db/sqliteconfig/config.go b/hscontrol/db/sqliteconfig/config.go new file mode 100644 index 00000000..d27977a4 --- /dev/null +++ b/hscontrol/db/sqliteconfig/config.go @@ -0,0 +1,417 @@ +// Package sqliteconfig provides type-safe configuration for SQLite databases +// with proper enum validation and URL generation for modernc.org/sqlite driver. +package sqliteconfig + +import ( + "errors" + "fmt" + "strings" +) + +// Errors returned by config validation. +var ( + ErrPathEmpty = errors.New("path cannot be empty") + ErrBusyTimeoutNegative = errors.New("busy_timeout must be >= 0") + ErrInvalidJournalMode = errors.New("invalid journal_mode") + ErrInvalidAutoVacuum = errors.New("invalid auto_vacuum") + ErrWALAutocheckpoint = errors.New("wal_autocheckpoint must be >= -1") + ErrInvalidSynchronous = errors.New("invalid synchronous") + ErrInvalidTxLock = errors.New("invalid txlock") +) + +const ( + // DefaultBusyTimeout is the default busy timeout in milliseconds. + DefaultBusyTimeout = 10000 +) + +// JournalMode represents SQLite journal_mode pragma values. +// Journal modes control how SQLite handles write transactions and crash recovery. +// +// Performance vs Durability Tradeoffs: +// +// WAL (Write-Ahead Logging) - Recommended for production: +// - Best performance for concurrent reads/writes +// - Readers don't block writers, writers don't block readers +// - Excellent crash recovery with minimal data loss risk +// - Uses additional .wal and .shm files +// - Default choice for Headscale production deployments +// +// DELETE - Traditional rollback journal: +// - Good performance for single-threaded access +// - Readers block writers and vice versa +// - Reliable crash recovery but with exclusive locking +// - Creates temporary journal files during transactions +// - Suitable for low-concurrency scenarios +// +// TRUNCATE - Similar to DELETE but faster cleanup: +// - Slightly better performance than DELETE +// - Same concurrency limitations as DELETE +// - Faster transaction commit by truncating instead of deleting journal +// +// PERSIST - Journal file remains between transactions: +// - Avoids file creation/deletion overhead +// - Same concurrency limitations as DELETE +// - Good for frequent small transactions +// +// MEMORY - Journal kept in memory: +// - Fastest performance but NO crash recovery +// - Data loss risk on power failure or crash +// - Only suitable for temporary or non-critical data +// +// OFF - No journaling: +// - Maximum performance but NO transaction safety +// - High risk of database corruption on crash +// - Should only be used for read-only or disposable databases +type JournalMode string + +const ( + // JournalModeWAL enables Write-Ahead Logging (RECOMMENDED for production). + // Best concurrent performance + crash recovery. Uses additional .wal/.shm files. + JournalModeWAL JournalMode = "WAL" + + // JournalModeDelete uses traditional rollback journaling. + // Good single-threaded performance, readers block writers. Creates temp journal files. + JournalModeDelete JournalMode = "DELETE" + + // JournalModeTruncate is like DELETE but with faster cleanup. + // Slightly better performance than DELETE, same safety with exclusive locking. + JournalModeTruncate JournalMode = "TRUNCATE" + + // JournalModePersist keeps journal file between transactions. + // Good for frequent transactions, avoids file creation/deletion overhead. + JournalModePersist JournalMode = "PERSIST" + + // JournalModeMemory keeps journal in memory (DANGEROUS). + // Fastest performance but NO crash recovery - data loss on power failure. + JournalModeMemory JournalMode = "MEMORY" + + // JournalModeOff disables journaling entirely (EXTREMELY DANGEROUS). + // Maximum performance but high corruption risk. Only for disposable databases. + JournalModeOff JournalMode = "OFF" +) + +// IsValid returns true if the JournalMode is valid. +func (j JournalMode) IsValid() bool { + switch j { + case JournalModeWAL, JournalModeDelete, JournalModeTruncate, + JournalModePersist, JournalModeMemory, JournalModeOff: + return true + default: + return false + } +} + +// String returns the string representation. +func (j JournalMode) String() string { + return string(j) +} + +// AutoVacuum represents SQLite auto_vacuum pragma values. +// Auto-vacuum controls how SQLite reclaims space from deleted data. +// +// Performance vs Storage Tradeoffs: +// +// INCREMENTAL - Recommended for production: +// - Reclaims space gradually during normal operations +// - Minimal performance impact on writes +// - Database size shrinks automatically over time +// - Can manually trigger with PRAGMA incremental_vacuum +// - Good balance of space efficiency and performance +// +// FULL - Automatic space reclamation: +// - Immediately reclaims space on every DELETE/DROP +// - Higher write overhead due to page reorganization +// - Keeps database file size minimal +// - Can cause significant slowdowns on large deletions +// - Best for applications with frequent deletes and limited storage +// +// NONE - No automatic space reclamation: +// - Fastest write performance (no vacuum overhead) +// - Database file only grows, never shrinks +// - Deleted space is reused but file size remains large +// - Requires manual VACUUM to reclaim space +// - Best for write-heavy workloads where storage isn't constrained +type AutoVacuum string + +const ( + // AutoVacuumNone disables automatic space reclamation. + // Fastest writes, file only grows. Requires manual VACUUM to reclaim space. + AutoVacuumNone AutoVacuum = "NONE" + + // AutoVacuumFull immediately reclaims space on every DELETE/DROP. + // Minimal file size but slower writes. Can impact performance on large deletions. + AutoVacuumFull AutoVacuum = "FULL" + + // AutoVacuumIncremental reclaims space gradually (RECOMMENDED for production). + // Good balance: minimal write impact, automatic space management over time. + AutoVacuumIncremental AutoVacuum = "INCREMENTAL" +) + +// IsValid returns true if the AutoVacuum is valid. +func (a AutoVacuum) IsValid() bool { + switch a { + case AutoVacuumNone, AutoVacuumFull, AutoVacuumIncremental: + return true + default: + return false + } +} + +// String returns the string representation. +func (a AutoVacuum) String() string { + return string(a) +} + +// Synchronous represents SQLite synchronous pragma values. +// Synchronous mode controls how aggressively SQLite flushes data to disk. +// +// Performance vs Durability Tradeoffs: +// +// NORMAL - Recommended for production: +// - Good balance of performance and safety +// - Syncs at critical moments (transaction commits in WAL mode) +// - Very low risk of corruption, minimal performance impact +// - Safe with WAL mode even with power loss +// - Default choice for most production applications +// +// FULL - Maximum durability: +// - Syncs to disk after every write operation +// - Highest data safety, virtually no corruption risk +// - Significant performance penalty (up to 50% slower) +// - Recommended for critical data where corruption is unacceptable +// +// EXTRA - Paranoid mode: +// - Even more aggressive syncing than FULL +// - Maximum possible data safety +// - Severe performance impact +// - Only for extremely critical scenarios +// +// OFF - Maximum performance, minimum safety: +// - No syncing, relies on OS to flush data +// - Fastest possible performance +// - High risk of corruption on power failure or crash +// - Only suitable for non-critical or easily recreatable data +type Synchronous string + +const ( + // SynchronousOff disables syncing (DANGEROUS). + // Fastest performance but high corruption risk on power failure. Avoid in production. + SynchronousOff Synchronous = "OFF" + + // SynchronousNormal provides balanced performance and safety (RECOMMENDED). + // Good performance with low corruption risk. Safe with WAL mode on power loss. + SynchronousNormal Synchronous = "NORMAL" + + // SynchronousFull provides maximum durability with performance cost. + // Syncs after every write. Up to 50% slower but virtually no corruption risk. + SynchronousFull Synchronous = "FULL" + + // SynchronousExtra provides paranoid-level data safety (EXTREME). + // Maximum safety with severe performance impact. Rarely needed in practice. + SynchronousExtra Synchronous = "EXTRA" +) + +// IsValid returns true if the Synchronous is valid. +func (s Synchronous) IsValid() bool { + switch s { + case SynchronousOff, SynchronousNormal, SynchronousFull, SynchronousExtra: + return true + default: + return false + } +} + +// String returns the string representation. +func (s Synchronous) String() string { + return string(s) +} + +// TxLock represents SQLite transaction lock mode. +// Transaction lock mode determines when write locks are acquired during transactions. +// +// Lock Acquisition Behavior: +// +// DEFERRED - SQLite default, acquire lock lazily: +// - Transaction starts without any lock +// - First read acquires SHARED lock +// - First write attempts to upgrade to RESERVED lock +// - If another transaction holds RESERVED: SQLITE_BUSY (potential deadlock) +// - Can cause deadlocks when multiple connections attempt concurrent writes +// +// IMMEDIATE - Recommended for write-heavy workloads: +// - Transaction immediately acquires RESERVED lock at BEGIN +// - If lock unavailable, waits up to busy_timeout before failing +// - Other writers queue orderly instead of deadlocking +// - Prevents the upgrade-lock deadlock scenario +// - Slight overhead for read-only transactions that don't need locks +// +// EXCLUSIVE - Maximum isolation: +// - Transaction immediately acquires EXCLUSIVE lock at BEGIN +// - No other connections can read or write +// - Highest isolation but lowest concurrency +// - Rarely needed in practice +type TxLock string + +const ( + // TxLockDeferred acquires locks lazily (SQLite default). + // Risk of SQLITE_BUSY deadlocks with concurrent writers. Use for read-heavy workloads. + TxLockDeferred TxLock = "deferred" + + // TxLockImmediate acquires write lock immediately (RECOMMENDED for production). + // Prevents deadlocks by acquiring RESERVED lock at transaction start. + // Writers queue orderly, respecting busy_timeout. + TxLockImmediate TxLock = "immediate" + + // TxLockExclusive acquires exclusive lock immediately. + // Maximum isolation, no concurrent reads or writes. Rarely needed. + TxLockExclusive TxLock = "exclusive" +) + +// IsValid returns true if the TxLock is valid. +func (t TxLock) IsValid() bool { + switch t { + case TxLockDeferred, TxLockImmediate, TxLockExclusive, "": + return true + default: + return false + } +} + +// String returns the string representation. +func (t TxLock) String() string { + return string(t) +} + +// Config holds SQLite database configuration with type-safe enums. +// This configuration balances performance, durability, and operational requirements +// for Headscale's SQLite database usage patterns. +type Config struct { + Path string // file path or ":memory:" + BusyTimeout int // milliseconds (0 = default/disabled) + JournalMode JournalMode // journal mode (affects concurrency and crash recovery) + AutoVacuum AutoVacuum // auto vacuum mode (affects storage efficiency) + WALAutocheckpoint int // pages (-1 = default/not set, 0 = disabled, >0 = enabled) + Synchronous Synchronous // synchronous mode (affects durability vs performance) + ForeignKeys bool // enable foreign key constraints (data integrity) + TxLock TxLock // transaction lock mode (affects write concurrency) +} + +// Default returns the production configuration optimized for Headscale's usage patterns. +// This configuration prioritizes: +// - Concurrent access (WAL mode for multiple readers/writers) +// - Data durability with good performance (NORMAL synchronous) +// - Automatic space management (INCREMENTAL auto-vacuum) +// - Data integrity (foreign key constraints enabled) +// - Safe concurrent writes (IMMEDIATE transaction lock) +// - Reasonable timeout for busy database scenarios (10s) +func Default(path string) *Config { + return &Config{ + Path: path, + BusyTimeout: DefaultBusyTimeout, + JournalMode: JournalModeWAL, + AutoVacuum: AutoVacuumIncremental, + WALAutocheckpoint: 1000, + Synchronous: SynchronousNormal, + ForeignKeys: true, + TxLock: TxLockImmediate, + } +} + +// Memory returns a configuration for in-memory databases. +func Memory() *Config { + return &Config{ + Path: ":memory:", + WALAutocheckpoint: -1, // not set, use driver default + ForeignKeys: true, + } +} + +// Validate checks if all configuration values are valid. +func (c *Config) Validate() error { + if c.Path == "" { + return ErrPathEmpty + } + + if c.BusyTimeout < 0 { + return fmt.Errorf("%w, got %d", ErrBusyTimeoutNegative, c.BusyTimeout) + } + + if c.JournalMode != "" && !c.JournalMode.IsValid() { + return fmt.Errorf("%w: %s", ErrInvalidJournalMode, c.JournalMode) + } + + if c.AutoVacuum != "" && !c.AutoVacuum.IsValid() { + return fmt.Errorf("%w: %s", ErrInvalidAutoVacuum, c.AutoVacuum) + } + + if c.WALAutocheckpoint < -1 { + return fmt.Errorf("%w, got %d", ErrWALAutocheckpoint, c.WALAutocheckpoint) + } + + if c.Synchronous != "" && !c.Synchronous.IsValid() { + return fmt.Errorf("%w: %s", ErrInvalidSynchronous, c.Synchronous) + } + + if c.TxLock != "" && !c.TxLock.IsValid() { + return fmt.Errorf("%w: %s", ErrInvalidTxLock, c.TxLock) + } + + return nil +} + +// ToURL builds a properly encoded SQLite connection string using _pragma parameters +// compatible with modernc.org/sqlite driver. +func (c *Config) ToURL() (string, error) { + if err := c.Validate(); err != nil { + return "", fmt.Errorf("invalid config: %w", err) + } + + var pragmas []string + + // Add pragma parameters only if they're set (non-zero/non-empty) + if c.BusyTimeout > 0 { + pragmas = append(pragmas, fmt.Sprintf("busy_timeout=%d", c.BusyTimeout)) + } + if c.JournalMode != "" { + pragmas = append(pragmas, fmt.Sprintf("journal_mode=%s", c.JournalMode)) + } + if c.AutoVacuum != "" { + pragmas = append(pragmas, fmt.Sprintf("auto_vacuum=%s", c.AutoVacuum)) + } + if c.WALAutocheckpoint >= 0 { + pragmas = append(pragmas, fmt.Sprintf("wal_autocheckpoint=%d", c.WALAutocheckpoint)) + } + if c.Synchronous != "" { + pragmas = append(pragmas, fmt.Sprintf("synchronous=%s", c.Synchronous)) + } + if c.ForeignKeys { + pragmas = append(pragmas, "foreign_keys=ON") + } + + // Handle different database types + var baseURL string + if c.Path == ":memory:" { + baseURL = ":memory:" + } else { + baseURL = "file:" + c.Path + } + + // Build query parameters + queryParts := make([]string, 0, 1+len(pragmas)) + + // Add _txlock first (it's a connection parameter, not a pragma) + if c.TxLock != "" { + queryParts = append(queryParts, "_txlock="+string(c.TxLock)) + } + + // Add pragma parameters + for _, pragma := range pragmas { + queryParts = append(queryParts, "_pragma="+pragma) + } + + if len(queryParts) > 0 { + baseURL += "?" + strings.Join(queryParts, "&") + } + + return baseURL, nil +} diff --git a/hscontrol/db/sqliteconfig/config_test.go b/hscontrol/db/sqliteconfig/config_test.go new file mode 100644 index 00000000..66955bb9 --- /dev/null +++ b/hscontrol/db/sqliteconfig/config_test.go @@ -0,0 +1,320 @@ +package sqliteconfig + +import ( + "testing" +) + +func TestJournalMode(t *testing.T) { + tests := []struct { + mode JournalMode + valid bool + }{ + {JournalModeWAL, true}, + {JournalModeDelete, true}, + {JournalModeTruncate, true}, + {JournalModePersist, true}, + {JournalModeMemory, true}, + {JournalModeOff, true}, + {JournalMode("INVALID"), false}, + {JournalMode(""), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + if got := tt.mode.IsValid(); got != tt.valid { + t.Errorf("JournalMode(%q).IsValid() = %v, want %v", tt.mode, got, tt.valid) + } + }) + } +} + +func TestAutoVacuum(t *testing.T) { + tests := []struct { + mode AutoVacuum + valid bool + }{ + {AutoVacuumNone, true}, + {AutoVacuumFull, true}, + {AutoVacuumIncremental, true}, + {AutoVacuum("INVALID"), false}, + {AutoVacuum(""), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + if got := tt.mode.IsValid(); got != tt.valid { + t.Errorf("AutoVacuum(%q).IsValid() = %v, want %v", tt.mode, got, tt.valid) + } + }) + } +} + +func TestSynchronous(t *testing.T) { + tests := []struct { + mode Synchronous + valid bool + }{ + {SynchronousOff, true}, + {SynchronousNormal, true}, + {SynchronousFull, true}, + {SynchronousExtra, true}, + {Synchronous("INVALID"), false}, + {Synchronous(""), false}, + } + + for _, tt := range tests { + t.Run(string(tt.mode), func(t *testing.T) { + if got := tt.mode.IsValid(); got != tt.valid { + t.Errorf("Synchronous(%q).IsValid() = %v, want %v", tt.mode, got, tt.valid) + } + }) + } +} + +func TestTxLock(t *testing.T) { + tests := []struct { + mode TxLock + valid bool + }{ + {TxLockDeferred, true}, + {TxLockImmediate, true}, + {TxLockExclusive, true}, + {TxLock(""), true}, // empty is valid (uses driver default) + {TxLock("IMMEDIATE"), false}, // uppercase is invalid + {TxLock("INVALID"), false}, + } + + for _, tt := range tests { + name := string(tt.mode) + if name == "" { + name = "empty" + } + + t.Run(name, func(t *testing.T) { + if got := tt.mode.IsValid(); got != tt.valid { + t.Errorf("TxLock(%q).IsValid() = %v, want %v", tt.mode, got, tt.valid) + } + }) + } +} + +func TestTxLockString(t *testing.T) { + tests := []struct { + mode TxLock + want string + }{ + {TxLockDeferred, "deferred"}, + {TxLockImmediate, "immediate"}, + {TxLockExclusive, "exclusive"}, + } + + for _, tt := range tests { + t.Run(tt.want, func(t *testing.T) { + if got := tt.mode.String(); got != tt.want { + t.Errorf("TxLock.String() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestConfigValidate(t *testing.T) { + tests := []struct { + name string + config *Config + wantErr bool + }{ + { + name: "valid default config", + config: Default("/path/to/db.sqlite"), + }, + { + name: "empty path", + config: &Config{ + Path: "", + }, + wantErr: true, + }, + { + name: "negative busy timeout", + config: &Config{ + Path: "/path/to/db.sqlite", + BusyTimeout: -1, + }, + wantErr: true, + }, + { + name: "invalid journal mode", + config: &Config{ + Path: "/path/to/db.sqlite", + JournalMode: JournalMode("INVALID"), + }, + wantErr: true, + }, + { + name: "invalid txlock", + config: &Config{ + Path: "/path/to/db.sqlite", + TxLock: TxLock("INVALID"), + }, + wantErr: true, + }, + { + name: "valid txlock immediate", + config: &Config{ + Path: "/path/to/db.sqlite", + TxLock: TxLockImmediate, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.config.Validate() + if (err != nil) != tt.wantErr { + t.Errorf("Config.Validate() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestConfigToURL(t *testing.T) { + tests := []struct { + name string + config *Config + want string + }{ + { + name: "default config includes txlock immediate", + config: Default("/path/to/db.sqlite"), + want: "file:/path/to/db.sqlite?_txlock=immediate&_pragma=busy_timeout=10000&_pragma=journal_mode=WAL&_pragma=auto_vacuum=INCREMENTAL&_pragma=wal_autocheckpoint=1000&_pragma=synchronous=NORMAL&_pragma=foreign_keys=ON", + }, + { + name: "memory config", + config: Memory(), + want: ":memory:?_pragma=foreign_keys=ON", + }, + { + name: "minimal config", + config: &Config{ + Path: "/simple/db.sqlite", + WALAutocheckpoint: -1, // not set + }, + want: "file:/simple/db.sqlite", + }, + { + name: "custom config", + config: &Config{ + Path: "/custom/db.sqlite", + BusyTimeout: 5000, + JournalMode: JournalModeDelete, + WALAutocheckpoint: -1, // not set + Synchronous: SynchronousFull, + ForeignKeys: true, + }, + want: "file:/custom/db.sqlite?_pragma=busy_timeout=5000&_pragma=journal_mode=DELETE&_pragma=synchronous=FULL&_pragma=foreign_keys=ON", + }, + { + name: "memory with custom timeout", + config: &Config{ + Path: ":memory:", + BusyTimeout: 2000, + WALAutocheckpoint: -1, // not set + ForeignKeys: true, + }, + want: ":memory:?_pragma=busy_timeout=2000&_pragma=foreign_keys=ON", + }, + { + name: "wal autocheckpoint zero", + config: &Config{ + Path: "/test.db", + WALAutocheckpoint: 0, + }, + want: "file:/test.db?_pragma=wal_autocheckpoint=0", + }, + { + name: "all options", + config: &Config{ + Path: "/full.db", + BusyTimeout: 15000, + JournalMode: JournalModeWAL, + AutoVacuum: AutoVacuumFull, + WALAutocheckpoint: 1000, + Synchronous: SynchronousExtra, + ForeignKeys: true, + }, + want: "file:/full.db?_pragma=busy_timeout=15000&_pragma=journal_mode=WAL&_pragma=auto_vacuum=FULL&_pragma=wal_autocheckpoint=1000&_pragma=synchronous=EXTRA&_pragma=foreign_keys=ON", + }, + { + name: "with txlock immediate", + config: &Config{ + Path: "/test.db", + BusyTimeout: 5000, + TxLock: TxLockImmediate, + WALAutocheckpoint: -1, + ForeignKeys: true, + }, + want: "file:/test.db?_txlock=immediate&_pragma=busy_timeout=5000&_pragma=foreign_keys=ON", + }, + { + name: "with txlock deferred", + config: &Config{ + Path: "/test.db", + TxLock: TxLockDeferred, + WALAutocheckpoint: -1, + ForeignKeys: true, + }, + want: "file:/test.db?_txlock=deferred&_pragma=foreign_keys=ON", + }, + { + name: "with txlock exclusive", + config: &Config{ + Path: "/test.db", + TxLock: TxLockExclusive, + WALAutocheckpoint: -1, + }, + want: "file:/test.db?_txlock=exclusive", + }, + { + name: "empty txlock omitted from URL", + config: &Config{ + Path: "/test.db", + TxLock: "", + BusyTimeout: 1000, + WALAutocheckpoint: -1, + ForeignKeys: true, + }, + want: "file:/test.db?_pragma=busy_timeout=1000&_pragma=foreign_keys=ON", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := tt.config.ToURL() + if err != nil { + t.Errorf("Config.ToURL() error = %v", err) + return + } + if got != tt.want { + t.Errorf("Config.ToURL() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestConfigToURLInvalid(t *testing.T) { + config := &Config{ + Path: "", + BusyTimeout: -1, + } + _, err := config.ToURL() + if err == nil { + t.Error("Config.ToURL() with invalid config should return error") + } +} + +func TestDefaultConfigHasTxLockImmediate(t *testing.T) { + config := Default("/test.db") + if config.TxLock != TxLockImmediate { + t.Errorf("Default().TxLock = %q, want %q", config.TxLock, TxLockImmediate) + } +} diff --git a/hscontrol/db/sqliteconfig/integration_test.go b/hscontrol/db/sqliteconfig/integration_test.go new file mode 100644 index 00000000..bb54ea1e --- /dev/null +++ b/hscontrol/db/sqliteconfig/integration_test.go @@ -0,0 +1,269 @@ +package sqliteconfig + +import ( + "database/sql" + "path/filepath" + "strings" + "testing" + + _ "modernc.org/sqlite" +) + +const memoryDBPath = ":memory:" + +// TestSQLiteDriverPragmaIntegration verifies that the modernc.org/sqlite driver +// correctly applies all pragma settings from URL parameters, ensuring they work +// the same as the old SQL PRAGMA statements approach. +func TestSQLiteDriverPragmaIntegration(t *testing.T) { + tests := []struct { + name string + config *Config + expected map[string]any + }{ + { + name: "default configuration", + config: Default("/tmp/test.db"), + expected: map[string]any{ + "busy_timeout": 10000, + "journal_mode": "wal", + "auto_vacuum": 2, // INCREMENTAL = 2 + "wal_autocheckpoint": 1000, + "synchronous": 1, // NORMAL = 1 + "foreign_keys": 1, // ON = 1 + }, + }, + { + name: "memory database with foreign keys", + config: Memory(), + expected: map[string]any{ + "foreign_keys": 1, // ON = 1 + }, + }, + { + name: "custom configuration", + config: &Config{ + Path: "/tmp/custom.db", + BusyTimeout: 5000, + JournalMode: JournalModeDelete, + AutoVacuum: AutoVacuumFull, + WALAutocheckpoint: 1000, + Synchronous: SynchronousFull, + ForeignKeys: true, + }, + expected: map[string]any{ + "busy_timeout": 5000, + "journal_mode": "delete", + "auto_vacuum": 1, // FULL = 1 + "wal_autocheckpoint": 1000, + "synchronous": 2, // FULL = 2 + "foreign_keys": 1, // ON = 1 + }, + }, + { + name: "foreign keys disabled", + config: &Config{ + Path: "/tmp/no_fk.db", + ForeignKeys: false, + }, + expected: map[string]any{ + // foreign_keys should not be set (defaults to 0/OFF) + "foreign_keys": 0, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temporary database file if not memory + if tt.config.Path == memoryDBPath { + // For memory databases, no changes needed + } else { + tempDir := t.TempDir() + dbPath := filepath.Join(tempDir, "test.db") + // Update config with actual temp path + configCopy := *tt.config + configCopy.Path = dbPath + tt.config = &configCopy + } + + // Generate URL and open database + url, err := tt.config.ToURL() + if err != nil { + t.Fatalf("Failed to generate URL: %v", err) + } + + t.Logf("Opening database with URL: %s", url) + + db, err := sql.Open("sqlite", url) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Test connection + if err := db.Ping(); err != nil { + t.Fatalf("Failed to ping database: %v", err) + } + + // Verify each expected pragma setting + for pragma, expectedValue := range tt.expected { + t.Run("pragma_"+pragma, func(t *testing.T) { + var actualValue any + query := "PRAGMA " + pragma + err := db.QueryRow(query).Scan(&actualValue) + if err != nil { + t.Fatalf("Failed to query %s: %v", query, err) + } + + t.Logf("%s: expected=%v, actual=%v", pragma, expectedValue, actualValue) + + // Handle type conversion for comparison + switch expected := expectedValue.(type) { + case int: + if actual, ok := actualValue.(int64); ok { + if int64(expected) != actual { + t.Errorf("%s: expected %d, got %d", pragma, expected, actual) + } + } else { + t.Errorf("%s: expected int %d, got %T %v", pragma, expected, actualValue, actualValue) + } + case string: + if actual, ok := actualValue.(string); ok { + if expected != actual { + t.Errorf("%s: expected %q, got %q", pragma, expected, actual) + } + } else { + t.Errorf("%s: expected string %q, got %T %v", pragma, expected, actualValue, actualValue) + } + default: + t.Errorf("Unsupported expected type for %s: %T", pragma, expectedValue) + } + }) + } + }) + } +} + +// TestForeignKeyConstraintEnforcement verifies that foreign key constraints +// are actually enforced when enabled via URL parameters. +func TestForeignKeyConstraintEnforcement(t *testing.T) { + tempDir := t.TempDir() + + dbPath := filepath.Join(tempDir, "fk_test.db") + config := Default(dbPath) + + url, err := config.ToURL() + if err != nil { + t.Fatalf("Failed to generate URL: %v", err) + } + + db, err := sql.Open("sqlite", url) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + // Create test tables with foreign key relationship + schema := ` + CREATE TABLE parent ( + id INTEGER PRIMARY KEY, + name TEXT NOT NULL + ); + + CREATE TABLE child ( + id INTEGER PRIMARY KEY, + parent_id INTEGER NOT NULL, + name TEXT NOT NULL, + FOREIGN KEY (parent_id) REFERENCES parent(id) + ); + ` + + if _, err := db.Exec(schema); err != nil { + t.Fatalf("Failed to create schema: %v", err) + } + + // Insert parent record + if _, err := db.Exec("INSERT INTO parent (id, name) VALUES (1, 'Parent 1')"); err != nil { + t.Fatalf("Failed to insert parent: %v", err) + } + + // Test 1: Valid foreign key should work + _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (1, 1, 'Child 1')") + if err != nil { + t.Fatalf("Valid foreign key insert failed: %v", err) + } + + // Test 2: Invalid foreign key should fail + _, err = db.Exec("INSERT INTO child (id, parent_id, name) VALUES (2, 999, 'Child 2')") + if err == nil { + t.Error("Expected foreign key constraint violation, but insert succeeded") + } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { + t.Errorf("Expected foreign key constraint error, got: %v", err) + } else { + t.Logf("✓ Foreign key constraint correctly enforced: %v", err) + } + + // Test 3: Deleting referenced parent should fail + _, err = db.Exec("DELETE FROM parent WHERE id = 1") + if err == nil { + t.Error("Expected foreign key constraint violation when deleting referenced parent") + } else if !contains(err.Error(), "FOREIGN KEY constraint failed") { + t.Errorf("Expected foreign key constraint error on delete, got: %v", err) + } else { + t.Logf("✓ Foreign key constraint correctly prevented parent deletion: %v", err) + } +} + +// TestJournalModeValidation verifies that the journal_mode setting is applied correctly. +func TestJournalModeValidation(t *testing.T) { + modes := []struct { + mode JournalMode + expected string + }{ + {JournalModeWAL, "wal"}, + {JournalModeDelete, "delete"}, + {JournalModeTruncate, "truncate"}, + {JournalModeMemory, "memory"}, + } + + for _, tt := range modes { + t.Run(string(tt.mode), func(t *testing.T) { + tempDir := t.TempDir() + + dbPath := filepath.Join(tempDir, "journal_test.db") + config := &Config{ + Path: dbPath, + JournalMode: tt.mode, + ForeignKeys: true, + } + + url, err := config.ToURL() + if err != nil { + t.Fatalf("Failed to generate URL: %v", err) + } + + db, err := sql.Open("sqlite", url) + if err != nil { + t.Fatalf("Failed to open database: %v", err) + } + defer db.Close() + + var actualMode string + err = db.QueryRow("PRAGMA journal_mode").Scan(&actualMode) + if err != nil { + t.Fatalf("Failed to query journal_mode: %v", err) + } + + if actualMode != tt.expected { + t.Errorf("journal_mode: expected %q, got %q", tt.expected, actualMode) + } else { + t.Logf("✓ journal_mode correctly set to: %s", actualMode) + } + }) + } +} + +// contains checks if a string contains a substring (helper function). +func contains(str, substr string) bool { + return strings.Contains(str, substr) +} diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 1c384918..15a85cf8 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -2,59 +2,101 @@ package db import ( "log" - "net/netip" + "net/url" "os" + "strconv" + "strings" "testing" - "github.com/juanfont/headscale/hscontrol/notifier" - "gopkg.in/check.v1" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog" + "zombiezen.com/go/postgrestest" ) -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&Suite{}) - -type Suite struct{} - -var ( - tmpDir string - db *HSDatabase -) - -func (s *Suite) SetUpTest(c *check.C) { - s.ResetDB(c) -} - -func (s *Suite) TearDownTest(c *check.C) { - // os.RemoveAll(tmpDir) -} - -func (s *Suite) ResetDB(c *check.C) { - // if len(tmpDir) != 0 { - // os.RemoveAll(tmpDir) - // } - - var err error - tmpDir, err = os.MkdirTemp("", "headscale-db-test-*") +func newSQLiteTestDB() (*HSDatabase, error) { + tmpDir, err := os.MkdirTemp("", "headscale-db-test-*") if err != nil { - c.Fatal(err) + return nil, err } log.Printf("database path: %s", tmpDir+"/headscale_test.db") + zerolog.SetGlobalLevel(zerolog.Disabled) - db, err = NewHeadscaleDatabase( - "sqlite3", - tmpDir+"/headscale_test.db", - false, - notifier.NewNotifier(), - []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), + db, err := NewHeadscaleDatabase( + &types.Config{ + Database: types.DatabaseConfig{ + Type: types.DatabaseSqlite, + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, }, - "", + emptyCache(), ) if err != nil { - c.Fatal(err) + return nil, err } + + return db, nil +} + +func newPostgresTestDB(t *testing.T) *HSDatabase { + t.Helper() + + return newHeadscaleDBFromPostgresURL(t, newPostgresDBForTest(t)) +} + +func newPostgresDBForTest(t *testing.T) *url.URL { + t.Helper() + + ctx := t.Context() + srv, err := postgrestest.Start(ctx) + if err != nil { + t.Fatal(err) + } + t.Cleanup(srv.Cleanup) + + u, err := srv.CreateDatabase(ctx) + if err != nil { + t.Fatal(err) + } + t.Logf("created local postgres: %s", u) + pu, _ := url.Parse(u) + + return pu +} + +func newHeadscaleDBFromPostgresURL(t *testing.T, pu *url.URL) *HSDatabase { + t.Helper() + + pass, _ := pu.User.Password() + port, _ := strconv.Atoi(pu.Port()) + + db, err := NewHeadscaleDatabase( + &types.Config{ + Database: types.DatabaseConfig{ + Type: types.DatabasePostgres, + Postgres: types.PostgresConfig{ + Host: pu.Hostname(), + User: pu.User.Username(), + Name: strings.TrimLeft(pu.Path, "/"), + Pass: pass, + Port: port, + Ssl: "disable", + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + }, + emptyCache(), + ) + if err != nil { + t.Fatal(err) + } + + return db } diff --git a/hscontrol/db/testdata/sqlite/failing-node-preauth-constraint_dump.sql b/hscontrol/db/testdata/sqlite/failing-node-preauth-constraint_dump.sql new file mode 100644 index 00000000..68069064 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/failing-node-preauth-constraint_dump.sql @@ -0,0 +1,34 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE IF NOT EXISTS "api_keys" (`id` integer,`prefix` text UNIQUE,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime,PRIMARY KEY (`id`)); +INSERT INTO api_keys VALUES(1,'hFKcRjLyfw',X'243261243130242e68554a6739332e6658333061326457723637464f2e6146424c74726e4542474c6c746437597a4253534d6f3677326d3944664d61','2023-04-09 22:34:28.624250346+00:00','2023-07-08 22:34:28.559681279+00:00',NULL); +INSERT INTO api_keys VALUES(2,'88Wbitubag',X'243261243130246f7932506d53375033334b733861376e7745434f3665674e776e517659374b5474326a30686958446c6c55696c3568513948307665','2024-07-28 21:59:38.786936789+00:00','2024-10-26 21:59:38.724189498+00:00',NULL); +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +CREATE TABLE IF NOT EXISTS "pre_auth_keys" (`id` integer,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`created_at` datetime,`expiration` datetime,`tags` text,PRIMARY KEY (`id`),CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer,`machine_key` text,`node_key` text,`disco_key` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`last_seen` datetime,`expiry` datetime,`host_info` text,`endpoints` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`ipv4` text,`ipv6` text,PRIMARY KEY (`id`),CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +INSERT INTO nodes VALUES(1,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e63','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c160554f','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57759','hostname_1','given_name1',1,'cli','["tag:sshclient","tag:ssh"]',0,'2025-02-05 16:46:13.960213431+00:00','0001-01-01 00:00:00+00:00','{}','[]','2023-03-30 23:18:17.612740902+00:00','2025-02-05 16:46:13.960284003+00:00',NULL,'100.64.0.1','fd7a:115c:a1e0::1'); +INSERT INTO nodes VALUES(2,'mkey:f63dda7495db68077080364ba4109f48dee7a59310b9ed4968beb40d038eb622','nodekey:8186817337049e092e6ea02507091d8e9686924d46ad0e74a90370ec0113c440','discokey:28a2df7e73b8196c6859c94329443a28f9605b2b83541b685c1db666bd835775','hostname_2','given_name2',1,'cli','["tag:sshclient"]',0,'2024-07-30 17:37:24.266006395+00:00','0001-01-01 00:00:00+00:00','{}','[]','2023-03-30 23:20:01.05202704+00:00','2024-07-30 17:37:24.266082813+00:00',NULL,'100.64.0.2','fd7a:115c:a1e0::2'); +INSERT INTO nodes VALUES(3,'mkey:0af53661fedf5143af3ea79e596928302e51c9fc9f0ea9ed1f2bb7d54778b80e','nodekey:8defd8272fd2851601158b2444fc8d1ab12b6187ec5db154b7a83bb75b2ce952','discokey:ba9d1ffac1997acbd8d281b8711699daa77ed91691772683ebbfdaafa2518a52','hostname_3','given_name3',1,'cli','["tag:ssh"]',0,'2025-02-05 16:48:00.460606473+00:00','0001-01-01 00:00:00+00:00','{}','[]','2023-03-30 23:36:04.930844845+00:00','2025-02-05 16:48:00.460679869+00:00',NULL,'100.64.0.3','fd7a:115c:a1e0::3'); +INSERT INTO nodes VALUES(4,'mkey:365e2055485de89e65e63c13e426b1ec5d5606327d63955b38be1d3f8cbbac6c','nodekey:996b9814e405f572fc0338f91b0c53f3a3a9a5b1ae0d2846d179195778d50909','discokey:ed72cb545b46b3e2ed0332f9cb4d7f4e774ea5834e2cbadc43c9bf7918ef2503','hostname_4','given_name4',1,'cli','["tag:ssh"]',0,'2025-02-05 16:48:00.460607206+00:00','0001-01-01 00:00:00+00:00','{}','[]','2023-03-31 15:51:56.149734121+00:00','2025-02-05 16:48:00.46092239+00:00',NULL,'100.64.0.4','fd7a:115c:a1e0::4'); +INSERT INTO nodes VALUES(5,'mkey:1d04be488182a66cd7df4596ac59a40613eac6465a331af9ac6c91bb70754a25','nodekey:9b617f3e7941ac70b76f0e40c55543173e0432d4a9bb8bcb8b25d93b60a5da0e','discokey:15834557115cb889e8362e7f2cae1cfd7e78e754cb7310cff6b5c5b5d3027e35','hostname_5','given_name5',1,'cli','["tag:sshclient","tag:ssh"]',0,'2023-04-21 15:07:38.796218079+00:00','0001-01-01 00:00:00+00:00','{}','[]','2023-04-21 13:16:19.148836255+00:00','2024-04-17 15:39:21.339518261+00:00',NULL,'100.64.0.5','fd7a:115c:a1e0::5'); +INSERT INTO nodes VALUES(6,'mkey:ed649503734e31eafad7f884ac8ee36ba0922c57cda8b6946cb439b1ed645676','nodekey:200484e66b43012eca81ec8850e4b5d1dd8fa538dfebdaac718f202cd2f1f955','discokey:600651ed2436ce5a49e71b3980f93070d888e6d65d608a64be29fdeed9f7bd6b','hostname_6','given_name6',1,'cli','["tag:ssh"]',0,'2023-07-09 16:56:18.876491583+00:00','0001-01-01 00:00:00+00:00','{}','[]','2023-05-07 10:30:54.520661376+00:00','2024-04-17 15:39:23.182648721+00:00',NULL,'100.64.0.6','fd7a:115c:a1e0::6'); +CREATE TABLE IF NOT EXISTS "routes" (`id` integer,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`node_id` integer NOT NULL,`prefix` text,`advertised` numeric,`enabled` numeric,`is_primary` numeric,PRIMARY KEY (`id`),CONSTRAINT `fk_nodes_routes` FOREIGN KEY (`node_id`) REFERENCES `nodes`(`id`) ON DELETE CASCADE); +CREATE TABLE IF NOT EXISTS "users" (`id` integer,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text UNIQUE,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text,PRIMARY KEY (`id`)); +INSERT INTO users VALUES(1,'2023-03-30 23:08:54.151102578+00:00','2023-03-30 23:08:54.151102578+00:00',NULL,'username_1','display_name_1','email_1@example.com',NULL,NULL,NULL); +DELETE FROM sqlite_sequence; +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE INDEX `idx_routes_deleted_at` ON `routes`(`deleted_at`); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.0-beta.1_dump.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.0-beta.1_dump.sql new file mode 100644 index 00000000..62384198 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.0-beta.1_dump.sql @@ -0,0 +1,30 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`),CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.0-beta.2_dump.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.0-beta.2_dump.sql new file mode 100644 index 00000000..284a4c4f --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.0-beta.2_dump.sql @@ -0,0 +1,31 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.0_dump.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.0_dump.sql new file mode 100644 index 00000000..d91e38c9 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.0_dump.sql @@ -0,0 +1,32 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +INSERT INTO migrations VALUES('202505141324'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump-litestream.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump-litestream.sql new file mode 100644 index 00000000..c8c05755 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump-litestream.sql @@ -0,0 +1,34 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +INSERT INTO migrations VALUES('202505141324'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; +CREATE TABLE _litestream_seq (id INTEGER PRIMARY KEY, seq INTEGER); +CREATE TABLE _litestream_lock (id INTEGER); +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump.sql new file mode 100644 index 00000000..d91e38c9 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump.sql @@ -0,0 +1,32 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +INSERT INTO migrations VALUES('202505141324'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump_schema-to-0.27.0-old-table-cleanup.sql b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump_schema-to-0.27.0-old-table-cleanup.sql new file mode 100644 index 00000000..d911e960 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/headscale_0.26.1_dump_schema-to-0.27.0-old-table-cleanup.sql @@ -0,0 +1,45 @@ +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +INSERT INTO migrations VALUES('202505141324'); +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('nodes',0); +CREATE INDEX `idx_users_deleted_at` ON `users`(`deleted_at`); +CREATE UNIQUE INDEX `idx_api_keys_prefix` ON `api_keys`(`prefix`); +CREATE INDEX `idx_policies_deleted_at` ON `policies`(`deleted_at`); +CREATE UNIQUE INDEX idx_provider_identifier ON users (provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users (name,provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users (name) WHERE provider_identifier IS NULL; + +-- Create all the old tables we have had and ensure they are clean up. +CREATE TABLE `namespaces` (`id` text,`deleted_at` datetime,PRIMARY KEY (`id`)); +CREATE TABLE `machines` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `kvs` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `shared_machines` (`id` text,`deleted_at` datetime,PRIMARY KEY (`id`)); +CREATE TABLE `pre_auth_key_acl_tags` (`id` text,PRIMARY KEY (`id`)); +CREATE TABLE `routes` (`id` text,`deleted_at` datetime,PRIMARY KEY (`id`)); + +CREATE INDEX `idx_routes_deleted_at` ON `routes`(`deleted_at`); +CREATE INDEX `idx_namespaces_deleted_at` ON `namespaces`(`deleted_at`); +CREATE INDEX `idx_shared_machines_deleted_at` ON `shared_machines`(`deleted_at`); + +COMMIT; diff --git a/hscontrol/db/testdata/sqlite/request_tags_migration_test.sql b/hscontrol/db/testdata/sqlite/request_tags_migration_test.sql new file mode 100644 index 00000000..6a6c1568 --- /dev/null +++ b/hscontrol/db/testdata/sqlite/request_tags_migration_test.sql @@ -0,0 +1,119 @@ +-- Test SQL dump for RequestTags migration (202601121700-migrate-hostinfo-request-tags) +-- and forced_tags->tags rename migration (202511131445-node-forced-tags-to-tags) +-- +-- This dump simulates a 0.27.x database where: +-- - Tags from --advertise-tags were stored only in host_info.RequestTags +-- - The tags column is still named forced_tags +-- +-- Test scenarios: +-- 1. Node with RequestTags that user is authorized for (should be migrated) +-- 2. Node with RequestTags that user is NOT authorized for (should be rejected) +-- 3. Node with existing forced_tags that should be preserved +-- 4. Node with RequestTags that overlap with existing tags (no duplicates) +-- 5. Node without RequestTags (should be unchanged) +-- 6. Node with RequestTags via group membership (should be migrated) + +PRAGMA foreign_keys=OFF; +BEGIN TRANSACTION; + +-- Migrations table - includes all migrations BEFORE the two tag migrations +CREATE TABLE `migrations` (`id` text,PRIMARY KEY (`id`)); +INSERT INTO migrations VALUES('202312101416'); +INSERT INTO migrations VALUES('202312101430'); +INSERT INTO migrations VALUES('202402151347'); +INSERT INTO migrations VALUES('2024041121742'); +INSERT INTO migrations VALUES('202406021630'); +INSERT INTO migrations VALUES('202409271400'); +INSERT INTO migrations VALUES('202407191627'); +INSERT INTO migrations VALUES('202408181235'); +INSERT INTO migrations VALUES('202501221827'); +INSERT INTO migrations VALUES('202501311657'); +INSERT INTO migrations VALUES('202502070949'); +INSERT INTO migrations VALUES('202502131714'); +INSERT INTO migrations VALUES('202502171819'); +INSERT INTO migrations VALUES('202505091439'); +INSERT INTO migrations VALUES('202505141324'); +INSERT INTO migrations VALUES('202507021200'); +INSERT INTO migrations VALUES('202510311551'); +INSERT INTO migrations VALUES('202511101554-drop-old-idx'); +INSERT INTO migrations VALUES('202511011637-preauthkey-bcrypt'); +INSERT INTO migrations VALUES('202511122344-remove-newline-index'); +-- Note: 202511131445-node-forced-tags-to-tags is NOT included - it will run +-- Note: 202601121700-migrate-hostinfo-request-tags is NOT included - it will run + +-- Users table +-- Note: User names must match the usernames in the policy (with @) +CREATE TABLE `users` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`name` text,`display_name` text,`email` text,`provider_identifier` text,`provider` text,`profile_pic_url` text); +INSERT INTO users VALUES(1,'2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL,'user1@example.com','User One','user1@example.com',NULL,NULL,NULL); +INSERT INTO users VALUES(2,'2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL,'user2@example.com','User Two','user2@example.com',NULL,NULL,NULL); +INSERT INTO users VALUES(3,'2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL,'admin1@example.com','Admin One','admin1@example.com',NULL,NULL,NULL); + +-- Pre-auth keys table +CREATE TABLE `pre_auth_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`key` text,`user_id` integer,`reusable` numeric,`ephemeral` numeric DEFAULT false,`used` numeric DEFAULT false,`tags` text,`created_at` datetime,`expiration` datetime,`prefix` text,`hash` blob,CONSTRAINT `fk_pre_auth_keys_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE SET NULL); + +-- API keys table +CREATE TABLE `api_keys` (`id` integer PRIMARY KEY AUTOINCREMENT,`prefix` text,`hash` blob,`created_at` datetime,`expiration` datetime,`last_seen` datetime); + +-- Nodes table - using OLD schema with forced_tags (not tags) +CREATE TABLE IF NOT EXISTS "nodes" (`id` integer PRIMARY KEY AUTOINCREMENT,`machine_key` text,`node_key` text,`disco_key` text,`endpoints` text,`host_info` text,`ipv4` text,`ipv6` text,`hostname` text,`given_name` varchar(63),`user_id` integer,`register_method` text,`forced_tags` text,`auth_key_id` integer,`expiry` datetime,`last_seen` datetime,`approved_routes` text,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,CONSTRAINT `fk_nodes_user` FOREIGN KEY (`user_id`) REFERENCES `users`(`id`) ON DELETE CASCADE,CONSTRAINT `fk_nodes_auth_key` FOREIGN KEY (`auth_key_id`) REFERENCES `pre_auth_keys`(`id`)); + +-- Node 1: user1 owns it, has RequestTags for tag:server (user1 is authorized for this tag) +-- Expected: tag:server should be added to tags +INSERT INTO nodes VALUES(1,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e01','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605501','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57701','[]','{"RequestTags":["tag:server"]}','100.64.0.1','fd7a:115c:a1e0::1','node1','node1',1,'oidc','[]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Node 2: user1 owns it, has RequestTags for tag:unauthorized (user1 is NOT authorized for this tag) +-- Expected: tag:unauthorized should be rejected, tags stays empty +INSERT INTO nodes VALUES(2,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e02','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605502','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57702','[]','{"RequestTags":["tag:unauthorized"]}','100.64.0.2','fd7a:115c:a1e0::2','node2','node2',1,'oidc','[]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Node 3: user2 owns it, has RequestTags for tag:client (user2 is authorized) +-- Also has existing forced_tags that should be preserved +-- Expected: tag:client added, tag:existing preserved +INSERT INTO nodes VALUES(3,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e03','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605503','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57703','[]','{"RequestTags":["tag:client"]}','100.64.0.3','fd7a:115c:a1e0::3','node3','node3',2,'oidc','["tag:existing"]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Node 4: user1 owns it, has RequestTags for tag:server which already exists in forced_tags +-- Expected: no duplicates, tags should be ["tag:server"] +INSERT INTO nodes VALUES(4,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e04','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605504','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57704','[]','{"RequestTags":["tag:server"]}','100.64.0.4','fd7a:115c:a1e0::4','node4','node4',1,'oidc','["tag:server"]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Node 5: user2 owns it, no RequestTags in host_info +-- Expected: tags unchanged (empty) +INSERT INTO nodes VALUES(5,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e05','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605505','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57705','[]','{}','100.64.0.5','fd7a:115c:a1e0::5','node5','node5',2,'oidc','[]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Node 6: admin1 owns it, has RequestTags for tag:admin (admin1 is in group:admins which owns tag:admin) +-- Expected: tag:admin should be added via group membership +INSERT INTO nodes VALUES(6,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e06','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605506','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57706','[]','{"RequestTags":["tag:admin"]}','100.64.0.6','fd7a:115c:a1e0::6','node6','node6',3,'oidc','[]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Node 7: user1 owns it, has multiple RequestTags (tag:server authorized, tag:forbidden not authorized) +-- Expected: tag:server added, tag:forbidden rejected +INSERT INTO nodes VALUES(7,'mkey:a0ab77456320823945ae0331823e3c0d516fae9585bd42698dfa1ac3d7679e07','nodekey:7c84167ab68f494942de14deb83587fd841843de2bac105b6c670048c1605507','discokey:53075b3c6cad3b62a2a29caea61beeb93f66b8c75cb89dac465236a5bbf57707','[]','{"RequestTags":["tag:server","tag:forbidden"]}','100.64.0.7','fd7a:115c:a1e0::7','node7','node7',1,'oidc','[]',NULL,'0001-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00','[]','2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL); + +-- Policies table with tagOwners defining who can use which tags +-- Note: Usernames in policy must contain @ (e.g., user1@example.com or just user1@) +CREATE TABLE `policies` (`id` integer PRIMARY KEY AUTOINCREMENT,`created_at` datetime,`updated_at` datetime,`deleted_at` datetime,`data` text); +INSERT INTO policies VALUES(1,'2024-01-01 00:00:00+00:00','2024-01-01 00:00:00+00:00',NULL,'{ + "groups": { + "group:admins": ["admin1@example.com"] + }, + "tagOwners": { + "tag:server": ["user1@example.com"], + "tag:client": ["user1@example.com", "user2@example.com"], + "tag:admin": ["group:admins"] + }, + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ] +}'); + +-- Indexes (using exact format expected by schema validation) +DELETE FROM sqlite_sequence; +INSERT INTO sqlite_sequence VALUES('users',3); +INSERT INTO sqlite_sequence VALUES('nodes',7); +INSERT INTO sqlite_sequence VALUES('policies',1); +CREATE INDEX idx_users_deleted_at ON users(deleted_at); +CREATE UNIQUE INDEX idx_api_keys_prefix ON api_keys(prefix); +CREATE INDEX idx_policies_deleted_at ON policies(deleted_at); +CREATE UNIQUE INDEX idx_provider_identifier ON users(provider_identifier) WHERE provider_identifier IS NOT NULL; +CREATE UNIQUE INDEX idx_name_provider_identifier ON users(name, provider_identifier); +CREATE UNIQUE INDEX idx_name_no_provider_identifier ON users(name) WHERE provider_identifier IS NULL; +CREATE UNIQUE INDEX IF NOT EXISTS idx_pre_auth_keys_prefix ON pre_auth_keys(prefix) WHERE prefix IS NOT NULL AND prefix != ''; + +COMMIT; diff --git a/hscontrol/db/text_serialiser.go b/hscontrol/db/text_serialiser.go new file mode 100644 index 00000000..6172e7e0 --- /dev/null +++ b/hscontrol/db/text_serialiser.go @@ -0,0 +1,101 @@ +package db + +import ( + "context" + "encoding" + "fmt" + "reflect" + + "gorm.io/gorm/schema" +) + +// Got from https://github.com/xdg-go/strum/blob/main/types.go +var textUnmarshalerType = reflect.TypeFor[encoding.TextUnmarshaler]() + +func isTextUnmarshaler(rv reflect.Value) bool { + return rv.Type().Implements(textUnmarshalerType) +} + +func maybeInstantiatePtr(rv reflect.Value) { + if rv.Kind() == reflect.Ptr && rv.IsNil() { + np := reflect.New(rv.Type().Elem()) + rv.Set(np) + } +} + +func decodingError(name string, err error) error { + return fmt.Errorf("error decoding to %s: %w", name, err) +} + +// TextSerialiser implements the Serialiser interface for fields that +// have a type that implements encoding.TextUnmarshaler. +type TextSerialiser struct{} + +func (TextSerialiser) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue any) error { + fieldValue := reflect.New(field.FieldType) + + // If the field is a pointer, we need to dereference it to get the actual type + // so we do not end with a second pointer. + if fieldValue.Elem().Kind() == reflect.Ptr { + fieldValue = fieldValue.Elem() + } + + if dbValue != nil { + var bytes []byte + switch v := dbValue.(type) { + case []byte: + bytes = v + case string: + bytes = []byte(v) + default: + return fmt.Errorf("failed to unmarshal text value: %#v", dbValue) + } + + if isTextUnmarshaler(fieldValue) { + maybeInstantiatePtr(fieldValue) + f := fieldValue.MethodByName("UnmarshalText") + args := []reflect.Value{reflect.ValueOf(bytes)} + ret := f.Call(args) + if !ret[0].IsNil() { + return decodingError(field.Name, ret[0].Interface().(error)) + } + + // If the underlying field is to a pointer type, we need to + // assign the value as a pointer to it. + // If it is not a pointer, we need to assign the value to the + // field. + dstField := field.ReflectValueOf(ctx, dst) + if dstField.Kind() == reflect.Ptr { + dstField.Set(fieldValue) + } else { + dstField.Set(fieldValue.Elem()) + } + + return nil + } else { + return fmt.Errorf("unsupported type: %T", fieldValue.Interface()) + } + } + + return nil +} + +func (TextSerialiser) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue any) (any, error) { + switch v := fieldValue.(type) { + case encoding.TextMarshaler: + // If the value is nil, we return nil, however, go nil values are not + // always comparable, particularly when reflection is involved: + // https://dev.to/arxeiss/in-go-nil-is-not-equal-to-nil-sometimes-jn8 + if v == nil || (reflect.ValueOf(v).Kind() == reflect.Ptr && reflect.ValueOf(v).IsNil()) { + return nil, nil + } + b, err := v.MarshalText() + if err != nil { + return nil, err + } + + return string(b), nil + default: + return nil, fmt.Errorf("only encoding.TextMarshaler is supported, got %t", v) + } +} diff --git a/hscontrol/db/user_update_test.go b/hscontrol/db/user_update_test.go new file mode 100644 index 00000000..180481e7 --- /dev/null +++ b/hscontrol/db/user_update_test.go @@ -0,0 +1,134 @@ +package db + +import ( + "database/sql" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" +) + +// TestUserUpdatePreservesUnchangedFields verifies that updating a user +// preserves fields that aren't modified. This test validates the fix +// for using Updates() instead of Save() in UpdateUser-like operations. +func TestUserUpdatePreservesUnchangedFields(t *testing.T) { + database := dbForTest(t) + + // Create a user with all fields set + initialUser := types.User{ + Name: "testuser", + DisplayName: "Test User Display", + Email: "test@example.com", + ProviderIdentifier: sql.NullString{ + String: "provider-123", + Valid: true, + }, + } + + createdUser, err := database.CreateUser(initialUser) + require.NoError(t, err) + require.NotNil(t, createdUser) + + // Verify initial state + assert.Equal(t, "testuser", createdUser.Name) + assert.Equal(t, "Test User Display", createdUser.DisplayName) + assert.Equal(t, "test@example.com", createdUser.Email) + assert.True(t, createdUser.ProviderIdentifier.Valid) + assert.Equal(t, "provider-123", createdUser.ProviderIdentifier.String) + + // Simulate what UpdateUser does: load user, modify one field, save + _, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := GetUserByID(tx, types.UserID(createdUser.ID)) + if err != nil { + return nil, err + } + + // Modify ONLY DisplayName + user.DisplayName = "Updated Display Name" + + // This is the line being tested - currently uses Save() which writes ALL fields, potentially overwriting unchanged ones + err = tx.Save(user).Error + if err != nil { + return nil, err + } + + return user, nil + }) + require.NoError(t, err) + + // Read user back from database + updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByID(rx, types.UserID(createdUser.ID)) + }) + require.NoError(t, err) + + // Verify that DisplayName was updated + assert.Equal(t, "Updated Display Name", updatedUser.DisplayName) + + // CRITICAL: Verify that other fields were NOT overwritten + // With Save(), these assertions should pass because the user object + // was loaded from DB and has all fields populated. + // But if Updates() is used, these will also pass (and it's safer). + assert.Equal(t, "testuser", updatedUser.Name, "Name should be preserved") + assert.Equal(t, "test@example.com", updatedUser.Email, "Email should be preserved") + assert.True(t, updatedUser.ProviderIdentifier.Valid, "ProviderIdentifier should be preserved") + assert.Equal(t, "provider-123", updatedUser.ProviderIdentifier.String, "ProviderIdentifier value should be preserved") +} + +// TestUserUpdateWithUpdatesMethod tests that using Updates() instead of Save() +// works correctly and only updates modified fields. +func TestUserUpdateWithUpdatesMethod(t *testing.T) { + database := dbForTest(t) + + // Create a user + initialUser := types.User{ + Name: "testuser", + DisplayName: "Original Display", + Email: "original@example.com", + ProviderIdentifier: sql.NullString{ + String: "provider-abc", + Valid: true, + }, + } + + createdUser, err := database.CreateUser(initialUser) + require.NoError(t, err) + + // Update using Updates() method + _, err = Write(database.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := GetUserByID(tx, types.UserID(createdUser.ID)) + if err != nil { + return nil, err + } + + // Modify multiple fields + user.DisplayName = "New Display" + user.Email = "new@example.com" + + // Use Updates() instead of Save() + err = tx.Updates(user).Error + if err != nil { + return nil, err + } + + return user, nil + }) + require.NoError(t, err) + + // Verify changes + updatedUser, err := Read(database.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByID(rx, types.UserID(createdUser.ID)) + }) + require.NoError(t, err) + + // Verify updated fields + assert.Equal(t, "New Display", updatedUser.DisplayName) + assert.Equal(t, "new@example.com", updatedUser.Email) + + // Verify preserved fields + assert.Equal(t, "testuser", updatedUser.Name) + assert.True(t, updatedUser.ProviderIdentifier.Valid) + assert.Equal(t, "provider-abc", updatedUser.ProviderIdentifier.String) +} diff --git a/hscontrol/db/users.go b/hscontrol/db/users.go index 27a1406b..6aff9ed1 100644 --- a/hscontrol/db/users.go +++ b/hscontrol/db/users.go @@ -2,10 +2,12 @@ package db import ( "errors" + "fmt" + "strconv" + "testing" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" "gorm.io/gorm" ) @@ -15,45 +17,40 @@ var ( ErrUserStillHasNodes = errors.New("user not empty: node(s) found") ) +func (hsdb *HSDatabase) CreateUser(user types.User) (*types.User, error) { + return Write(hsdb.DB, func(tx *gorm.DB) (*types.User, error) { + return CreateUser(tx, user) + }) +} + // CreateUser creates a new User. Returns error if could not be created // or another user already exists. -func (hsdb *HSDatabase) CreateUser(name string) (*types.User, error) { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - err := util.CheckForFQDNRules(name) - if err != nil { +func CreateUser(tx *gorm.DB, user types.User) (*types.User, error) { + if err := util.ValidateHostname(user.Name); err != nil { return nil, err } - user := types.User{} - if err := hsdb.db.Where("name = ?", name).First(&user).Error; err == nil { - return nil, ErrUserExists - } - user.Name = name - if err := hsdb.db.Create(&user).Error; err != nil { - log.Error(). - Str("func", "CreateUser"). - Err(err). - Msg("Could not create row") - - return nil, err + if err := tx.Create(&user).Error; err != nil { + return nil, fmt.Errorf("creating user: %w", err) } return &user, nil } +func (hsdb *HSDatabase) DestroyUser(uid types.UserID) error { + return hsdb.Write(func(tx *gorm.DB) error { + return DestroyUser(tx, uid) + }) +} + // DestroyUser destroys a User. Returns error if the User does // not exist or if there are nodes associated with it. -func (hsdb *HSDatabase) DestroyUser(name string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - user, err := hsdb.getUser(name) +func DestroyUser(tx *gorm.DB, uid types.UserID) error { + user, err := GetUserByID(tx, uid) if err != nil { - return ErrUserNotFound + return err } - nodes, err := hsdb.listNodesByUser(name) + nodes, err := ListNodesByUser(tx, uid) if err != nil { return err } @@ -61,67 +58,65 @@ func (hsdb *HSDatabase) DestroyUser(name string) error { return ErrUserStillHasNodes } - keys, err := hsdb.listPreAuthKeys(name) + keys, err := ListPreAuthKeys(tx) if err != nil { return err } for _, key := range keys { - err = hsdb.destroyPreAuthKey(key) + err = DestroyPreAuthKey(tx, key.ID) if err != nil { return err } } - if result := hsdb.db.Unscoped().Delete(&user); result.Error != nil { + if result := tx.Unscoped().Delete(&user); result.Error != nil { return result.Error } return nil } +func (hsdb *HSDatabase) RenameUser(uid types.UserID, newName string) error { + return hsdb.Write(func(tx *gorm.DB) error { + return RenameUser(tx, uid, newName) + }) +} + +var ErrCannotChangeOIDCUser = errors.New("cannot edit OIDC user") + // RenameUser renames a User. Returns error if the User does // not exist or if another User exists with the new name. -func (hsdb *HSDatabase) RenameUser(oldName, newName string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - +func RenameUser(tx *gorm.DB, uid types.UserID, newName string) error { var err error - oldUser, err := hsdb.getUser(oldName) + oldUser, err := GetUserByID(tx, uid) if err != nil { return err } - err = util.CheckForFQDNRules(newName) - if err != nil { + if err = util.ValidateHostname(newName); err != nil { return err } - _, err = hsdb.getUser(newName) - if err == nil { - return ErrUserExists - } - if !errors.Is(err, ErrUserNotFound) { - return err + + if oldUser.Provider == util.RegisterMethodOIDC { + return ErrCannotChangeOIDCUser } oldUser.Name = newName - if result := hsdb.db.Save(&oldUser); result.Error != nil { - return result.Error + err = tx.Updates(&oldUser).Error + if err != nil { + return err } return nil } -// GetUser fetches a user by name. -func (hsdb *HSDatabase) GetUser(name string) (*types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.getUser(name) +func (hsdb *HSDatabase) GetUserByID(uid types.UserID) (*types.User, error) { + return GetUserByID(hsdb.DB, uid) } -func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { +func GetUserByID(tx *gorm.DB, uid types.UserID) (*types.User, error) { user := types.User{} - if result := hsdb.db.First(&user, "name = ?", name); errors.Is( + if result := tx.First(&user, "id = ?", uid); errors.Is( result.Error, gorm.ErrRecordNotFound, ) { @@ -131,66 +126,113 @@ func (hsdb *HSDatabase) getUser(name string) (*types.User, error) { return &user, nil } -// ListUsers gets all the existing users. -func (hsdb *HSDatabase) ListUsers() ([]types.User, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - return hsdb.listUsers() +func (hsdb *HSDatabase) GetUserByOIDCIdentifier(id string) (*types.User, error) { + return Read(hsdb.DB, func(rx *gorm.DB) (*types.User, error) { + return GetUserByOIDCIdentifier(rx, id) + }) } -func (hsdb *HSDatabase) listUsers() ([]types.User, error) { +func GetUserByOIDCIdentifier(tx *gorm.DB, id string) (*types.User, error) { + user := types.User{} + if result := tx.First(&user, "provider_identifier = ?", id); errors.Is( + result.Error, + gorm.ErrRecordNotFound, + ) { + return nil, ErrUserNotFound + } + + return &user, nil +} + +func (hsdb *HSDatabase) ListUsers(where ...*types.User) ([]types.User, error) { + return ListUsers(hsdb.DB, where...) +} + +// ListUsers gets all the existing users. +func ListUsers(tx *gorm.DB, where ...*types.User) ([]types.User, error) { + if len(where) > 1 { + return nil, fmt.Errorf("expect 0 or 1 where User structs, got %d", len(where)) + } + + var user *types.User + if len(where) == 1 { + user = where[0] + } + users := []types.User{} - if err := hsdb.db.Find(&users).Error; err != nil { + if err := tx.Where(user).Find(&users).Error; err != nil { return nil, err } return users, nil } -// ListNodesByUser gets all the nodes in a given user. -func (hsdb *HSDatabase) ListNodesByUser(name string) (types.Nodes, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() +// GetUserByName returns a user if the provided username is +// unique, and otherwise an error. +func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) { + users, err := hsdb.ListUsers(&types.User{Name: name}) + if err != nil { + return nil, err + } - return hsdb.listNodesByUser(name) + if len(users) == 0 { + return nil, ErrUserNotFound + } + + if len(users) != 1 { + return nil, fmt.Errorf("expected exactly one user, found %d", len(users)) + } + + return &users[0], nil } -func (hsdb *HSDatabase) listNodesByUser(name string) (types.Nodes, error) { - err := util.CheckForFQDNRules(name) - if err != nil { - return nil, err - } - user, err := hsdb.getUser(name) - if err != nil { - return nil, err - } - +// ListNodesByUser gets all the nodes in a given user. +func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) { nodes := types.Nodes{} - if err := hsdb.db.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: user.ID}).Find(&nodes).Error; err != nil { + + uidPtr := uint(uid) + + err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: &uidPtr}).Find(&nodes).Error + if err != nil { return nil, err } return nodes, nil } -// AssignNodeToUser assigns a Node to a user. -func (hsdb *HSDatabase) AssignNodeToUser(node *types.Node, username string) error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - err := util.CheckForFQDNRules(username) - if err != nil { - return err - } - user, err := hsdb.getUser(username) - if err != nil { - return err - } - node.User = *user - if result := hsdb.db.Save(&node); result.Error != nil { - return result.Error +func (hsdb *HSDatabase) CreateUserForTest(name ...string) *types.User { + if !testing.Testing() { + panic("CreateUserForTest can only be called during tests") } - return nil + userName := "testuser" + if len(name) > 0 && name[0] != "" { + userName = name[0] + } + + user, err := hsdb.CreateUser(types.User{Name: userName}) + if err != nil { + panic(fmt.Sprintf("failed to create test user: %v", err)) + } + + return user +} + +func (hsdb *HSDatabase) CreateUsersForTest(count int, namePrefix ...string) []*types.User { + if !testing.Testing() { + panic("CreateUsersForTest can only be called during tests") + } + + prefix := "testuser" + if len(namePrefix) > 0 && namePrefix[0] != "" { + prefix = namePrefix[0] + } + + users := make([]*types.User, count) + for i := range count { + name := prefix + "-" + strconv.Itoa(i) + users[i] = hsdb.CreateUserForTest(name) + } + + return users } diff --git a/hscontrol/db/users_test.go b/hscontrol/db/users_test.go index 1ca3b49f..a3fd49b3 100644 --- a/hscontrol/db/users_test.go +++ b/hscontrol/db/users_test.go @@ -1,123 +1,167 @@ package db import ( + "testing" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" - "gopkg.in/check.v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gorm.io/gorm" + "tailscale.com/types/ptr" ) -func (s *Suite) TestCreateAndDestroyUser(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - c.Assert(user.Name, check.Equals, "test") +func TestCreateAndDestroyUser(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + user := db.CreateUserForTest("test") + assert.Equal(t, "test", user.Name) users, err := db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) + require.NoError(t, err) + assert.Len(t, users, 1) - err = db.DestroyUser("test") - c.Assert(err, check.IsNil) + err = db.DestroyUser(types.UserID(user.ID)) + require.NoError(t, err) - _, err = db.GetUser("test") - c.Assert(err, check.NotNil) + _, err = db.GetUserByID(types.UserID(user.ID)) + assert.Error(t, err) } -func (s *Suite) TestDestroyUserErrors(c *check.C) { - err := db.DestroyUser("test") - c.Assert(err, check.Equals, ErrUserNotFound) +func TestDestroyUserErrors(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "error_user_not_found", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) + err := db.DestroyUser(9998) + assert.ErrorIs(t, err, ErrUserNotFound) + }, + }, + { + name: "success_deletes_preauthkeys", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) + user := db.CreateUserForTest("test") - err = db.DestroyUser("test") - c.Assert(err, check.IsNil) + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) - result := db.db.Preload("User").First(&pak, "key = ?", pak.Key) - // destroying a user also deletes all associated preauthkeys - c.Assert(result.Error, check.Equals, gorm.ErrRecordNotFound) + err = db.DestroyUser(types.UserID(user.ID)) + require.NoError(t, err) - user, err = db.CreateUser("test") - c.Assert(err, check.IsNil) + // Verify preauth key was deleted (need to search by prefix for new keys) + var foundPak types.PreAuthKey - pak, err = db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) + result := db.DB.First(&foundPak, "id = ?", pak.ID) + assert.ErrorIs(t, result.Error, gorm.ErrRecordNotFound) + }, + }, + { + name: "error_user_has_nodes", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() - node := types.Node{ - ID: 0, - Hostname: "testnode", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + user, err := db.CreateUser(types.User{Name: "test"}) + require.NoError(t, err) + + pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + node := types.Node{ + ID: 0, + Hostname: "testnode", + UserID: &user.ID, + RegisterMethod: util.RegisterMethodAuthKey, + AuthKeyID: ptr.To(pak.ID), + } + trx := db.DB.Save(&node) + require.NoError(t, trx.Error) + + err = db.DestroyUser(types.UserID(user.ID)) + assert.ErrorIs(t, err, ErrUserStillHasNodes) + }, + }, } - db.db.Save(&node) - err = db.DestroyUser("test") - c.Assert(err, check.Equals, ErrUserStillHasNodes) -} + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) -func (s *Suite) TestRenameUser(c *check.C) { - userTest, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - c.Assert(userTest.Name, check.Equals, "test") - - users, err := db.ListUsers() - c.Assert(err, check.IsNil) - c.Assert(len(users), check.Equals, 1) - - err = db.RenameUser("test", "test-renamed") - c.Assert(err, check.IsNil) - - _, err = db.GetUser("test") - c.Assert(err, check.Equals, ErrUserNotFound) - - _, err = db.GetUser("test-renamed") - c.Assert(err, check.IsNil) - - err = db.RenameUser("test-does-not-exit", "test") - c.Assert(err, check.Equals, ErrUserNotFound) - - userTest2, err := db.CreateUser("test2") - c.Assert(err, check.IsNil) - c.Assert(userTest2.Name, check.Equals, "test2") - - err = db.RenameUser("test2", "test-renamed") - c.Assert(err, check.Equals, ErrUserExists) -} - -func (s *Suite) TestSetMachineUser(c *check.C) { - oldUser, err := db.CreateUser("old") - c.Assert(err, check.IsNil) - - newUser, err := db.CreateUser("new") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(oldUser.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - node := types.Node{ - ID: 0, - Hostname: "testnode", - UserID: oldUser.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), + tt.test(t, db) + }) + } +} + +func TestRenameUser(t *testing.T) { + tests := []struct { + name string + test func(*testing.T, *HSDatabase) + }{ + { + name: "success_rename", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + userTest := db.CreateUserForTest("test") + assert.Equal(t, "test", userTest.Name) + + users, err := db.ListUsers() + require.NoError(t, err) + assert.Len(t, users, 1) + + err = db.RenameUser(types.UserID(userTest.ID), "test-renamed") + require.NoError(t, err) + + users, err = db.ListUsers(&types.User{Name: "test"}) + require.NoError(t, err) + assert.Empty(t, users) + + users, err = db.ListUsers(&types.User{Name: "test-renamed"}) + require.NoError(t, err) + assert.Len(t, users, 1) + }, + }, + { + name: "error_user_not_found", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + err := db.RenameUser(99988, "test") + assert.ErrorIs(t, err, ErrUserNotFound) + }, + }, + { + name: "error_duplicate_name", + test: func(t *testing.T, db *HSDatabase) { + t.Helper() + + userTest := db.CreateUserForTest("test") + userTest2 := db.CreateUserForTest("test2") + + assert.Equal(t, "test", userTest.Name) + assert.Equal(t, "test2", userTest2.Name) + + err := db.RenameUser(types.UserID(userTest2.ID), "test") + require.Error(t, err) + assert.Contains(t, err.Error(), "UNIQUE constraint failed") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + db, err := newSQLiteTestDB() + require.NoError(t, err) + + tt.test(t, db) + }) } - db.db.Save(&node) - c.Assert(node.UserID, check.Equals, oldUser.ID) - - err = db.AssignNodeToUser(&node, newUser.Name) - c.Assert(err, check.IsNil) - c.Assert(node.UserID, check.Equals, newUser.ID) - c.Assert(node.User.Name, check.Equals, newUser.Name) - - err = db.AssignNodeToUser(&node, "non-existing-user") - c.Assert(err, check.Equals, ErrUserNotFound) - - err = db.AssignNodeToUser(&node, newUser.Name) - c.Assert(err, check.IsNil) - c.Assert(node.UserID, check.Equals, newUser.ID) - c.Assert(node.User.Name, check.Equals, newUser.Name) } diff --git a/hscontrol/debug.go b/hscontrol/debug.go new file mode 100644 index 00000000..629b7be1 --- /dev/null +++ b/hscontrol/debug.go @@ -0,0 +1,408 @@ +package hscontrol + +import ( + "encoding/json" + "fmt" + "net/http" + "strings" + + "github.com/arl/statsviz" + "github.com/juanfont/headscale/hscontrol/mapper" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/prometheus/client_golang/prometheus/promhttp" + "tailscale.com/tsweb" +) + +func (h *Headscale) debugHTTPServer() *http.Server { + debugMux := http.NewServeMux() + debug := tsweb.Debugger(debugMux) + + // State overview endpoint + debug.Handle("overview", "State overview", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + overview := h.state.DebugOverviewJSON() + overviewJSON, err := json.MarshalIndent(overview, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(overviewJSON) + } else { + // Default to text/plain for backward compatibility + overview := h.state.DebugOverview() + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(overview)) + } + })) + + // Configuration endpoint + debug.Handle("config", "Current configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + config := h.state.DebugConfig() + configJSON, err := json.MarshalIndent(config, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(configJSON) + })) + + // Policy endpoint + debug.Handle("policy", "Current policy", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + policy, err := h.state.DebugPolicy() + if err != nil { + httpError(w, err) + return + } + // Policy data is HuJSON, which is a superset of JSON + // Set content type based on Accept header preference + acceptHeader := r.Header.Get("Accept") + if strings.Contains(acceptHeader, "application/json") { + w.Header().Set("Content-Type", "application/json") + } else { + w.Header().Set("Content-Type", "text/plain") + } + w.WriteHeader(http.StatusOK) + w.Write([]byte(policy)) + })) + + // Filter rules endpoint + debug.Handle("filter", "Current filter rules", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + filter, err := h.state.DebugFilter() + if err != nil { + httpError(w, err) + return + } + filterJSON, err := json.MarshalIndent(filter, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(filterJSON) + })) + + // SSH policies endpoint + debug.Handle("ssh", "SSH policies per node", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sshPolicies := h.state.DebugSSHPolicies() + sshJSON, err := json.MarshalIndent(sshPolicies, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(sshJSON) + })) + + // DERP map endpoint + debug.Handle("derp", "DERP map configuration", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + derpInfo := h.state.DebugDERPJSON() + derpJSON, err := json.MarshalIndent(derpInfo, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(derpJSON) + } else { + // Default to text/plain for backward compatibility + derpInfo := h.state.DebugDERPMap() + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(derpInfo)) + } + })) + + // NodeStore endpoint + debug.Handle("nodestore", "NodeStore information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + nodeStoreNodes := h.state.DebugNodeStoreJSON() + nodeStoreJSON, err := json.MarshalIndent(nodeStoreNodes, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(nodeStoreJSON) + } else { + // Default to text/plain for backward compatibility + nodeStoreInfo := h.state.DebugNodeStore() + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(nodeStoreInfo)) + } + })) + + // Registration cache endpoint + debug.Handle("registration-cache", "Registration cache information", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cacheInfo := h.state.DebugRegistrationCache() + cacheJSON, err := json.MarshalIndent(cacheInfo, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(cacheJSON) + })) + + // Routes endpoint + debug.Handle("routes", "Primary routes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + routes := h.state.DebugRoutes() + routesJSON, err := json.MarshalIndent(routes, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(routesJSON) + } else { + // Default to text/plain for backward compatibility + routes := h.state.DebugRoutesString() + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(routes)) + } + })) + + // Policy manager endpoint + debug.Handle("policy-manager", "Policy manager state", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + policyManagerInfo := h.state.DebugPolicyManagerJSON() + policyManagerJSON, err := json.MarshalIndent(policyManagerInfo, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(policyManagerJSON) + } else { + // Default to text/plain for backward compatibility + policyManagerInfo := h.state.DebugPolicyManager() + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(policyManagerInfo)) + } + })) + + debug.Handle("mapresponses", "Map responses for all nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + res, err := h.mapBatcher.DebugMapResponses() + if err != nil { + httpError(w, err) + return + } + + if res == nil { + w.WriteHeader(http.StatusOK) + w.Write([]byte("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH not set")) + return + } + + resJSON, err := json.MarshalIndent(res, "", " ") + if err != nil { + httpError(w, err) + return + } + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(resJSON) + })) + + // Batcher endpoint + debug.Handle("batcher", "Batcher connected nodes", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Check Accept header to determine response format + acceptHeader := r.Header.Get("Accept") + wantsJSON := strings.Contains(acceptHeader, "application/json") + + if wantsJSON { + batcherInfo := h.debugBatcherJSON() + + batcherJSON, err := json.MarshalIndent(batcherInfo, "", " ") + if err != nil { + httpError(w, err) + return + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + w.Write(batcherJSON) + } else { + // Default to text/plain for backward compatibility + batcherInfo := h.debugBatcher() + + w.Header().Set("Content-Type", "text/plain") + w.WriteHeader(http.StatusOK) + w.Write([]byte(batcherInfo)) + } + })) + + err := statsviz.Register(debugMux) + if err == nil { + debug.URL("/debug/statsviz", "Statsviz (visualise go metrics)") + } + + debug.URL("/metrics", "Prometheus metrics") + debugMux.Handle("/metrics", promhttp.Handler()) + + debugHTTPServer := &http.Server{ + Addr: h.cfg.MetricsAddr, + Handler: debugMux, + ReadTimeout: types.HTTPTimeout, + WriteTimeout: 0, + } + + return debugHTTPServer +} + +// debugBatcher returns debug information about the batcher's connected nodes. +func (h *Headscale) debugBatcher() string { + var sb strings.Builder + sb.WriteString("=== Batcher Connected Nodes ===\n\n") + + totalNodes := 0 + connectedCount := 0 + + // Collect nodes and sort them by ID + type nodeStatus struct { + id types.NodeID + connected bool + activeConnections int + } + + var nodes []nodeStatus + + // Try to get detailed debug info if we have a LockFreeBatcher + if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok { + debugInfo := batcher.Debug() + for nodeID, info := range debugInfo { + nodes = append(nodes, nodeStatus{ + id: nodeID, + connected: info.Connected, + activeConnections: info.ActiveConnections, + }) + totalNodes++ + if info.Connected { + connectedCount++ + } + } + } else { + // Fallback to basic connection info + connectedMap := h.mapBatcher.ConnectedMap() + connectedMap.Range(func(nodeID types.NodeID, connected bool) bool { + nodes = append(nodes, nodeStatus{ + id: nodeID, + connected: connected, + activeConnections: 0, + }) + totalNodes++ + if connected { + connectedCount++ + } + return true + }) + } + + // Sort by node ID + for i := 0; i < len(nodes); i++ { + for j := i + 1; j < len(nodes); j++ { + if nodes[i].id > nodes[j].id { + nodes[i], nodes[j] = nodes[j], nodes[i] + } + } + } + + // Output sorted nodes + for _, node := range nodes { + status := "disconnected" + if node.connected { + status = "connected" + } + + if node.activeConnections > 0 { + sb.WriteString(fmt.Sprintf("Node %d:\t%s (%d connections)\n", node.id, status, node.activeConnections)) + } else { + sb.WriteString(fmt.Sprintf("Node %d:\t%s\n", node.id, status)) + } + } + + sb.WriteString(fmt.Sprintf("\nSummary: %d connected, %d total\n", connectedCount, totalNodes)) + + return sb.String() +} + +// DebugBatcherInfo represents batcher connection information in a structured format. +type DebugBatcherInfo struct { + ConnectedNodes map[string]DebugBatcherNodeInfo `json:"connected_nodes"` // NodeID -> node connection info + TotalNodes int `json:"total_nodes"` +} + +// DebugBatcherNodeInfo represents connection information for a single node. +type DebugBatcherNodeInfo struct { + Connected bool `json:"connected"` + ActiveConnections int `json:"active_connections"` +} + +// debugBatcherJSON returns structured debug information about the batcher's connected nodes. +func (h *Headscale) debugBatcherJSON() DebugBatcherInfo { + info := DebugBatcherInfo{ + ConnectedNodes: make(map[string]DebugBatcherNodeInfo), + TotalNodes: 0, + } + + // Try to get detailed debug info if we have a LockFreeBatcher + if batcher, ok := h.mapBatcher.(*mapper.LockFreeBatcher); ok { + debugInfo := batcher.Debug() + for nodeID, debugData := range debugInfo { + info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ + Connected: debugData.Connected, + ActiveConnections: debugData.ActiveConnections, + } + info.TotalNodes++ + } + } else { + // Fallback to basic connection info + connectedMap := h.mapBatcher.ConnectedMap() + connectedMap.Range(func(nodeID types.NodeID, connected bool) bool { + info.ConnectedNodes[fmt.Sprintf("%d", nodeID)] = DebugBatcherNodeInfo{ + Connected: connected, + ActiveConnections: 0, + } + info.TotalNodes++ + return true + }) + } + + return info +} diff --git a/hscontrol/derp/derp.go b/hscontrol/derp/derp.go index 83c200a2..42d74abe 100644 --- a/hscontrol/derp/derp.go +++ b/hscontrol/derp/derp.go @@ -1,15 +1,23 @@ package derp import ( + "cmp" "context" "encoding/json" + "hash/crc64" "io" + "maps" + "math/rand" "net/http" "net/url" "os" + "reflect" + "slices" + "sync" + "time" "github.com/juanfont/headscale/hscontrol/types" - "github.com/rs/zerolog/log" + "github.com/spf13/viper" "gopkg.in/yaml.v3" "tailscale.com/tailcfg" ) @@ -31,7 +39,7 @@ func loadDERPMapFromPath(path string) (*tailcfg.DERPMap, error) { } func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { - ctx, cancel := context.WithTimeout(context.Background(), types.HTTPReadTimeout) + ctx, cancel := context.WithTimeout(context.Background(), types.HTTPTimeout) defer cancel() req, err := http.NewRequestWithContext(ctx, http.MethodGet, addr.String(), nil) @@ -40,7 +48,7 @@ func loadDERPMapFromURL(addr url.URL) (*tailcfg.DERPMap, error) { } client := http.Client{ - Timeout: types.HTTPReadTimeout, + Timeout: types.HTTPTimeout, } resp, err := client.Do(req) @@ -72,63 +80,101 @@ func mergeDERPMaps(derpMaps []*tailcfg.DERPMap) *tailcfg.DERPMap { } for _, derpMap := range derpMaps { - for id, region := range derpMap.Regions { - result.Regions[id] = region + maps.Copy(result.Regions, derpMap.Regions) + } + + for id, region := range result.Regions { + if region == nil { + delete(result.Regions, id) } } return &result } -func GetDERPMap(cfg types.DERPConfig) *tailcfg.DERPMap { - derpMaps := make([]*tailcfg.DERPMap, 0) +func GetDERPMap(cfg types.DERPConfig) (*tailcfg.DERPMap, error) { + var derpMaps []*tailcfg.DERPMap + if cfg.DERPMap != nil { + derpMaps = append(derpMaps, cfg.DERPMap) + } - for _, path := range cfg.Paths { - log.Debug(). - Str("func", "GetDERPMap"). - Str("path", path). - Msg("Loading DERPMap from path") - derpMap, err := loadDERPMapFromPath(path) + for _, addr := range cfg.URLs { + derpMap, err := loadDERPMapFromURL(addr) if err != nil { - log.Error(). - Str("func", "GetDERPMap"). - Str("path", path). - Err(err). - Msg("Could not load DERP map from path") - - break + return nil, err } derpMaps = append(derpMaps, derpMap) } - for _, addr := range cfg.URLs { - derpMap, err := loadDERPMapFromURL(addr) - log.Debug(). - Str("func", "GetDERPMap"). - Str("url", addr.String()). - Msg("Loading DERPMap from path") + for _, path := range cfg.Paths { + derpMap, err := loadDERPMapFromPath(path) if err != nil { - log.Error(). - Str("func", "GetDERPMap"). - Str("url", addr.String()). - Err(err). - Msg("Could not load DERP map from path") - - break + return nil, err } derpMaps = append(derpMaps, derpMap) } derpMap := mergeDERPMaps(derpMaps) + shuffleDERPMap(derpMap) - log.Trace().Interface("derpMap", derpMap).Msg("DERPMap loaded") + return derpMap, nil +} - if len(derpMap.Regions) == 0 { - log.Warn(). - Msg("DERP map is empty, not a single DERP map datasource was loaded correctly or contained a region") +func shuffleDERPMap(dm *tailcfg.DERPMap) { + if dm == nil || len(dm.Regions) == 0 { + return } - return derpMap + // Collect region IDs and sort them to ensure deterministic iteration order. + // Map iteration order is non-deterministic in Go, which would cause the + // shuffle to be non-deterministic even with a fixed seed. + ids := make([]int, 0, len(dm.Regions)) + for id := range dm.Regions { + ids = append(ids, id) + } + slices.Sort(ids) + + for _, id := range ids { + region := dm.Regions[id] + if len(region.Nodes) == 0 { + continue + } + + dm.Regions[id] = shuffleRegionNoClone(region) + } +} + +var crc64Table = crc64.MakeTable(crc64.ISO) + +var ( + derpRandomOnce sync.Once + derpRandomInst *rand.Rand + derpRandomMu sync.Mutex +) + +func derpRandom() *rand.Rand { + derpRandomMu.Lock() + defer derpRandomMu.Unlock() + + derpRandomOnce.Do(func() { + seed := cmp.Or(viper.GetString("dns.base_domain"), time.Now().String()) + rnd := rand.New(rand.NewSource(0)) + rnd.Seed(int64(crc64.Checksum([]byte(seed), crc64Table))) + derpRandomInst = rnd + }) + return derpRandomInst +} + +func resetDerpRandomForTesting() { + derpRandomMu.Lock() + defer derpRandomMu.Unlock() + derpRandomOnce = sync.Once{} + derpRandomInst = nil +} + +func shuffleRegionNoClone(r *tailcfg.DERPRegion) *tailcfg.DERPRegion { + derpRandom().Shuffle(len(r.Nodes), reflect.Swapper(r.Nodes)) + return r } diff --git a/hscontrol/derp/derp_test.go b/hscontrol/derp/derp_test.go new file mode 100644 index 00000000..91d605a6 --- /dev/null +++ b/hscontrol/derp/derp_test.go @@ -0,0 +1,350 @@ +package derp + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/spf13/viper" + "tailscale.com/tailcfg" +) + +func TestShuffleDERPMapDeterministic(t *testing.T) { + tests := []struct { + name string + baseDomain string + derpMap *tailcfg.DERPMap + expected *tailcfg.DERPMap + }{ + { + name: "single region with 4 nodes", + baseDomain: "test1.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "nyc", + RegionName: "New York City", + Nodes: []*tailcfg.DERPNode{ + {Name: "1f", RegionID: 1, HostName: "derp1f.tailscale.com"}, + {Name: "1g", RegionID: 1, HostName: "derp1g.tailscale.com"}, + {Name: "1h", RegionID: 1, HostName: "derp1h.tailscale.com"}, + {Name: "1i", RegionID: 1, HostName: "derp1i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "nyc", + RegionName: "New York City", + Nodes: []*tailcfg.DERPNode{ + {Name: "1g", RegionID: 1, HostName: "derp1g.tailscale.com"}, + {Name: "1f", RegionID: 1, HostName: "derp1f.tailscale.com"}, + {Name: "1i", RegionID: 1, HostName: "derp1i.tailscale.com"}, + {Name: "1h", RegionID: 1, HostName: "derp1h.tailscale.com"}, + }, + }, + }, + }, + }, + { + name: "multiple regions with nodes", + baseDomain: "test2.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 10: { + RegionID: 10, + RegionCode: "sea", + RegionName: "Seattle", + Nodes: []*tailcfg.DERPNode{ + {Name: "10b", RegionID: 10, HostName: "derp10b.tailscale.com"}, + {Name: "10c", RegionID: 10, HostName: "derp10c.tailscale.com"}, + {Name: "10d", RegionID: 10, HostName: "derp10d.tailscale.com"}, + }, + }, + 2: { + RegionID: 2, + RegionCode: "sfo", + RegionName: "San Francisco", + Nodes: []*tailcfg.DERPNode{ + {Name: "2d", RegionID: 2, HostName: "derp2d.tailscale.com"}, + {Name: "2e", RegionID: 2, HostName: "derp2e.tailscale.com"}, + {Name: "2f", RegionID: 2, HostName: "derp2f.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 10: { + RegionID: 10, + RegionCode: "sea", + RegionName: "Seattle", + Nodes: []*tailcfg.DERPNode{ + {Name: "10d", RegionID: 10, HostName: "derp10d.tailscale.com"}, + {Name: "10c", RegionID: 10, HostName: "derp10c.tailscale.com"}, + {Name: "10b", RegionID: 10, HostName: "derp10b.tailscale.com"}, + }, + }, + 2: { + RegionID: 2, + RegionCode: "sfo", + RegionName: "San Francisco", + Nodes: []*tailcfg.DERPNode{ + {Name: "2d", RegionID: 2, HostName: "derp2d.tailscale.com"}, + {Name: "2e", RegionID: 2, HostName: "derp2e.tailscale.com"}, + {Name: "2f", RegionID: 2, HostName: "derp2f.tailscale.com"}, + }, + }, + }, + }, + }, + { + name: "large region with many nodes", + baseDomain: "test3.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + }, + { + name: "same region different base domain", + baseDomain: "different.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + }, + }, + }, + }, + }, + { + name: "same dataset with another base domain", + baseDomain: "another.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + }, + { + name: "same dataset with yet another base domain", + baseDomain: "yetanother.example.com", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + }, + }, + }, + }, + expected: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 4: { + RegionID: 4, + RegionCode: "fra", + RegionName: "Frankfurt", + Nodes: []*tailcfg.DERPNode{ + {Name: "4i", RegionID: 4, HostName: "derp4i.tailscale.com"}, + {Name: "4h", RegionID: 4, HostName: "derp4h.tailscale.com"}, + {Name: "4f", RegionID: 4, HostName: "derp4f.tailscale.com"}, + {Name: "4g", RegionID: 4, HostName: "derp4g.tailscale.com"}, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + viper.Set("dns.base_domain", tt.baseDomain) + defer viper.Reset() + resetDerpRandomForTesting() + + testMap := tt.derpMap.View().AsStruct() + shuffleDERPMap(testMap) + + if diff := cmp.Diff(tt.expected, testMap); diff != "" { + t.Errorf("Shuffled DERP map doesn't match expected (-expected +actual):\n%s", diff) + } + }) + } +} + +func TestShuffleDERPMapEdgeCases(t *testing.T) { + tests := []struct { + name string + derpMap *tailcfg.DERPMap + }{ + { + name: "nil derp map", + derpMap: nil, + }, + { + name: "empty derp map", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{}, + }, + }, + { + name: "region with no nodes", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "empty", + RegionName: "Empty Region", + Nodes: []*tailcfg.DERPNode{}, + }, + }, + }, + }, + { + name: "region with single node", + derpMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "single", + RegionName: "Single Node Region", + Nodes: []*tailcfg.DERPNode{ + {Name: "1a", RegionID: 1, HostName: "derp1a.tailscale.com"}, + }, + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + shuffleDERPMap(tt.derpMap) + }) + } +} + +func TestShuffleDERPMapWithoutBaseDomain(t *testing.T) { + viper.Reset() + resetDerpRandomForTesting() + + derpMap := &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 1: { + RegionID: 1, + RegionCode: "test", + RegionName: "Test Region", + Nodes: []*tailcfg.DERPNode{ + {Name: "1a", RegionID: 1, HostName: "derp1a.test.com"}, + {Name: "1b", RegionID: 1, HostName: "derp1b.test.com"}, + {Name: "1c", RegionID: 1, HostName: "derp1c.test.com"}, + {Name: "1d", RegionID: 1, HostName: "derp1d.test.com"}, + }, + }, + }, + } + + original := derpMap.View().AsStruct() + shuffleDERPMap(derpMap) + + if len(derpMap.Regions) != 1 || len(derpMap.Regions[1].Nodes) != 4 { + t.Error("Shuffle corrupted DERP map structure") + } + + originalNodes := make(map[string]bool) + for _, node := range original.Regions[1].Nodes { + originalNodes[node.Name] = true + } + + shuffledNodes := make(map[string]bool) + for _, node := range derpMap.Regions[1].Nodes { + shuffledNodes[node.Name] = true + } + + if diff := cmp.Diff(originalNodes, shuffledNodes); diff != "" { + t.Errorf("Shuffle changed node set (-original +shuffled):\n%s", diff) + } +} diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index c92595d0..474306e5 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -1,9 +1,12 @@ package server import ( + "bufio" + "bytes" "context" "encoding/json" "fmt" + "io" "net" "net/http" "net/netip" @@ -12,32 +15,38 @@ import ( "strings" "time" + "github.com/coder/websocket" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "tailscale.com/derp" + "tailscale.com/derp/derpserver" + "tailscale.com/envknob" "tailscale.com/net/stun" + "tailscale.com/net/wsconn" "tailscale.com/tailcfg" "tailscale.com/types/key" - "tailscale.com/types/logger" ) // fastStartHeader is the header (with value "1") that signals to the HTTP // server that the DERP HTTP client does not want the HTTP 101 response // headers and it will begin writing & reading the DERP protocol immediately // following its HTTP request. -const fastStartHeader = "Derp-Fast-Start" +const ( + fastStartHeader = "Derp-Fast-Start" + DerpVerifyScheme = "headscale-derp-verify" +) + +// debugUseDERPIP is a debug-only flag that causes the DERP server to resolve +// hostnames to IP addresses when generating the DERP region configuration. +// This is useful for integration testing where DNS resolution may be unreliable. +var debugUseDERPIP = envknob.Bool("HEADSCALE_DEBUG_DERP_USE_IP") type DERPServer struct { serverURL string key key.NodePrivate cfg *types.DERPConfig - tailscaleDERP *derp.Server -} - -func derpLogf() logger.Logf { - return func(format string, args ...any) { - log.Debug().Caller().Msgf(format, args...) - } + tailscaleDERP *derpserver.Server } func NewDERPServer( @@ -46,7 +55,12 @@ func NewDERPServer( cfg *types.DERPConfig, ) (*DERPServer, error) { log.Trace().Caller().Msg("Creating new embedded DERP server") - server := derp.NewServer(derpKey, derpLogf()) // nolint // zerolinter complains + server := derpserver.New(derpKey, util.TSLogfWrapper()) // nolint // zerolinter complains + + if cfg.ServerVerifyClients { + server.SetVerifyClientURL(DerpVerifyScheme + "://verify") + server.SetVerifyClientURLFailOpen(false) + } return &DERPServer{ serverURL: serverURL, @@ -63,7 +77,10 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { } var host string var port int - host, portStr, err := net.SplitHostPort(serverURL.Host) + var portStr string + + // Extract hostname and port from URL + host, portStr, err = net.SplitHostPort(serverURL.Host) if err != nil { if serverURL.Scheme == "https" { host = serverURL.Host @@ -79,6 +96,19 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { } } + // If debug flag is set, resolve hostname to IP address + if debugUseDERPIP { + ips, err := net.LookupIP(host) + if err != nil { + log.Error().Caller().Err(err).Msgf("Failed to resolve DERP hostname %s to IP, using hostname", host) + } else if len(ips) > 0 { + // Use the first IP address + ipStr := ips[0].String() + log.Info().Caller().Msgf("HEADSCALE_DEBUG_DERP_USE_IP: Resolved %s to %s", host, ipStr) + host = ipStr + } + } + localDERPregion := tailcfg.DERPRegion{ RegionID: d.cfg.ServerRegionID, RegionCode: d.cfg.ServerRegionCode, @@ -86,10 +116,12 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { Avoid: false, Nodes: []*tailcfg.DERPNode{ { - Name: fmt.Sprintf("%d", d.cfg.ServerRegionID), + Name: strconv.Itoa(d.cfg.ServerRegionID), RegionID: d.cfg.ServerRegionID, HostName: host, DERPPort: port, + IPv4: d.cfg.IPv4, + IPv6: d.cfg.IPv6, }, }, } @@ -105,6 +137,7 @@ func (d *DERPServer) GenerateRegion() (tailcfg.DERPRegion, error) { localDERPregion.Nodes[0].STUNPort = portSTUN log.Info().Caller().Msgf("DERP region: %+v", localDERPregion) + log.Info().Caller().Msgf("DERP Nodes[0]: %+v", localDERPregion.Nodes[0]) return localDERPregion, nil } @@ -129,12 +162,62 @@ func (d *DERPServer) DERPHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return } + if strings.Contains(req.Header.Get("Sec-Websocket-Protocol"), "derp") { + d.serveWebsocket(writer, req) + } else { + d.servePlain(writer, req) + } +} + +func (d *DERPServer) serveWebsocket(writer http.ResponseWriter, req *http.Request) { + websocketConn, err := websocket.Accept(writer, req, &websocket.AcceptOptions{ + Subprotocols: []string{"derp"}, + OriginPatterns: []string{"*"}, + // Disable compression because DERP transmits WireGuard messages that + // are not compressible. + // Additionally, Safari has a broken implementation of compression + // (see https://github.com/nhooyr/websocket/issues/218) that makes + // enabling it actively harmful. + CompressionMode: websocket.CompressionDisabled, + }) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to upgrade websocket request") + + writer.Header().Set("Content-Type", "text/plain") + writer.WriteHeader(http.StatusInternalServerError) + + _, err = writer.Write([]byte("Failed to upgrade websocket request")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write HTTP response") + } + + return + } + defer websocketConn.Close(websocket.StatusInternalError, "closing") + if websocketConn.Subprotocol() != "derp" { + websocketConn.Close(websocket.StatusPolicyViolation, "client must speak the derp subprotocol") + + return + } + + wc := wsconn.NetConn(req.Context(), websocketConn, websocket.MessageBinary, req.RemoteAddr) + brw := bufio.NewReadWriter(bufio.NewReader(wc), bufio.NewWriter(wc)) + d.tailscaleDERP.Accept(req.Context(), wc, brw, req.RemoteAddr) +} + +func (d *DERPServer) servePlain(writer http.ResponseWriter, req *http.Request) { fastStart := req.Header.Get(fastStartHeader) == "1" hijacker, ok := writer.(http.Hijacker) @@ -147,7 +230,7 @@ func (d *DERPServer) DERPHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -163,7 +246,7 @@ func (d *DERPServer) DERPHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } return @@ -202,20 +285,21 @@ func DERPProbeHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } } } -// DERPBootstrapDNSHandler implements the /bootsrap-dns endpoint +// DERPBootstrapDNSHandler implements the /bootstrap-dns endpoint // Described in https://github.com/tailscale/tailscale/issues/1405, // this endpoint provides a way to help a client when it fails to start up // because its DNS are broken. // The initial implementation is here https://github.com/tailscale/tailscale/pull/1406 // They have a cache, but not clear if that is really necessary at Headscale, uh, scale. // An example implementation is found here https://derp.tailscale.com/bootstrap-dns +// Coordination server is included automatically, since local DERP is using the same DNS Name in d.serverURL. func DERPBootstrapDNSHandler( - derpMap *tailcfg.DERPMap, + derpMap tailcfg.DERPMapView, ) func(http.ResponseWriter, *http.Request) { return func( writer http.ResponseWriter, @@ -226,18 +310,18 @@ func DERPBootstrapDNSHandler( resolvCtx, cancel := context.WithTimeout(req.Context(), time.Minute) defer cancel() var resolver net.Resolver - for _, region := range derpMap.Regions { - for _, node := range region.Nodes { // we don't care if we override some nodes - addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName) + for _, region := range derpMap.Regions().All() { + for _, node := range region.Nodes().All() { // we don't care if we override some nodes + addrs, err := resolver.LookupIP(resolvCtx, "ip", node.HostName()) if err != nil { log.Trace(). Caller(). Err(err). - Msgf("bootstrap DNS lookup failed %q", node.HostName) + Msgf("bootstrap DNS lookup failed %q", node.HostName()) continue } - dnsEntries[node.HostName] = addrs + dnsEntries[node.HostName()] = addrs } } writer.Header().Set("Content-Type", "application/json") @@ -247,7 +331,7 @@ func DERPBootstrapDNSHandler( log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } } } @@ -281,7 +365,13 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { return } log.Error().Caller().Err(err).Msgf("STUN ReadFrom") - time.Sleep(time.Second) + + // Rate limit error logging - wait before retrying, but respect context cancellation + select { + case <-ctx.Done(): + return + case <-time.After(time.Second): + } continue } @@ -309,3 +399,29 @@ func serverSTUNListener(ctx context.Context, packetConn *net.UDPConn) { } } } + +func NewDERPVerifyTransport(handleVerifyRequest func(*http.Request, io.Writer) error) *DERPVerifyTransport { + return &DERPVerifyTransport{ + handleVerifyRequest: handleVerifyRequest, + } +} + +type DERPVerifyTransport struct { + handleVerifyRequest func(*http.Request, io.Writer) error +} + +func (t *DERPVerifyTransport) RoundTrip(req *http.Request) (*http.Response, error) { + buf := new(bytes.Buffer) + if err := t.handleVerifyRequest(req, buf); err != nil { + log.Error().Caller().Err(err).Msg("Failed to handle client verify request: ") + + return nil, err + } + + resp := &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(buf), + } + + return resp, nil +} diff --git a/hscontrol/dns/extrarecords.go b/hscontrol/dns/extrarecords.go new file mode 100644 index 00000000..82b3078b --- /dev/null +++ b/hscontrol/dns/extrarecords.go @@ -0,0 +1,194 @@ +package dns + +import ( + "context" + "crypto/sha256" + "encoding/json" + "fmt" + "os" + "sync" + + "github.com/cenkalti/backoff/v5" + "github.com/fsnotify/fsnotify" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" + "tailscale.com/util/set" +) + +type ExtraRecordsMan struct { + mu sync.RWMutex + records set.Set[tailcfg.DNSRecord] + watcher *fsnotify.Watcher + path string + + updateCh chan []tailcfg.DNSRecord + closeCh chan struct{} + hashes map[string][32]byte +} + +// NewExtraRecordsManager creates a new ExtraRecordsMan and starts watching the file at the given path. +func NewExtraRecordsManager(path string) (*ExtraRecordsMan, error) { + watcher, err := fsnotify.NewWatcher() + if err != nil { + return nil, fmt.Errorf("creating watcher: %w", err) + } + + fi, err := os.Stat(path) + if err != nil { + return nil, fmt.Errorf("getting file info: %w", err) + } + + if fi.IsDir() { + return nil, fmt.Errorf("path is a directory, only file is supported: %s", path) + } + + records, hash, err := readExtraRecordsFromPath(path) + if err != nil { + return nil, fmt.Errorf("reading extra records from path: %w", err) + } + + er := &ExtraRecordsMan{ + watcher: watcher, + path: path, + records: set.SetOf(records), + hashes: map[string][32]byte{ + path: hash, + }, + closeCh: make(chan struct{}), + updateCh: make(chan []tailcfg.DNSRecord), + } + + err = watcher.Add(path) + if err != nil { + return nil, fmt.Errorf("adding path to watcher: %w", err) + } + + log.Trace().Caller().Strs("watching", watcher.WatchList()).Msg("started filewatcher") + + return er, nil +} + +func (e *ExtraRecordsMan) Records() []tailcfg.DNSRecord { + e.mu.RLock() + defer e.mu.RUnlock() + + return e.records.Slice() +} + +func (e *ExtraRecordsMan) Run() { + for { + select { + case <-e.closeCh: + return + case event, ok := <-e.watcher.Events: + if !ok { + log.Error().Caller().Msgf("file watcher event channel closing") + return + } + switch event.Op { + case fsnotify.Create, fsnotify.Write, fsnotify.Chmod: + log.Trace().Caller().Str("path", event.Name).Str("op", event.Op.String()).Msg("extra records received filewatch event") + if event.Name != e.path { + continue + } + e.updateRecords() + + // If a file is removed or renamed, fsnotify will loose track of it + // and not watch it. We will therefore attempt to re-add it with a backoff. + case fsnotify.Remove, fsnotify.Rename: + _, err := backoff.Retry(context.Background(), func() (struct{}, error) { + if _, err := os.Stat(e.path); err != nil { + return struct{}{}, err + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff())) + if err != nil { + log.Error().Caller().Err(err).Msgf("extra records filewatcher retrying to find file after delete") + continue + } + + err = e.watcher.Add(e.path) + if err != nil { + log.Error().Caller().Err(err).Msgf("extra records filewatcher re-adding file after delete failed, giving up.") + return + } else { + log.Trace().Caller().Str("path", e.path).Msg("extra records file re-added after delete") + e.updateRecords() + } + } + + case err, ok := <-e.watcher.Errors: + if !ok { + log.Error().Caller().Msgf("file watcher error channel closing") + return + } + log.Error().Caller().Err(err).Msgf("extra records filewatcher returned error: %q", err) + } + } +} + +func (e *ExtraRecordsMan) Close() { + e.watcher.Close() + close(e.closeCh) +} + +func (e *ExtraRecordsMan) UpdateCh() <-chan []tailcfg.DNSRecord { + return e.updateCh +} + +func (e *ExtraRecordsMan) updateRecords() { + records, newHash, err := readExtraRecordsFromPath(e.path) + if err != nil { + log.Error().Caller().Err(err).Msgf("reading extra records from path: %s", e.path) + return + } + + // If there are no records, ignore the update. + if records == nil { + return + } + + e.mu.Lock() + defer e.mu.Unlock() + + // If there has not been any change, ignore the update. + if oldHash, ok := e.hashes[e.path]; ok { + if newHash == oldHash { + return + } + } + + oldCount := e.records.Len() + + e.records = set.SetOf(records) + e.hashes[e.path] = newHash + + log.Trace().Caller().Interface("records", e.records).Msgf("extra records updated from path, count old: %d, new: %d", oldCount, e.records.Len()) + e.updateCh <- e.records.Slice() +} + +// readExtraRecordsFromPath reads a JSON file of tailcfg.DNSRecord +// and returns the records and the hash of the file. +func readExtraRecordsFromPath(path string) ([]tailcfg.DNSRecord, [32]byte, error) { + b, err := os.ReadFile(path) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("reading path: %s, err: %w", path, err) + } + + // If the read was triggered too fast, and the file is not complete, ignore the update + // if the file is empty. A consecutive update will be triggered when the file is complete. + if len(b) == 0 { + return nil, [32]byte{}, nil + } + + var records []tailcfg.DNSRecord + err = json.Unmarshal(b, &records) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("unmarshalling records, content: %q: %w", string(b), err) + } + + hash := sha256.Sum256(b) + + return records, hash, nil +} diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index ffd3a576..a35a73af 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -1,20 +1,34 @@ +//go:generate buf generate --template ../buf.gen.yaml -o .. ../proto + // nolint package hscontrol import ( "context" + "errors" "fmt" + "io" + "net/netip" + "os" + "slices" + "sort" "strings" "time" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/views" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" ) type headscaleV1APIServer struct { // v1.HeadscaleServiceServer @@ -28,26 +42,24 @@ func newHeadscaleV1APIServer(h *Headscale) v1.HeadscaleServiceServer { } } -func (api headscaleV1APIServer) GetUser( - ctx context.Context, - request *v1.GetUserRequest, -) (*v1.GetUserResponse, error) { - user, err := api.h.db.GetUser(request.GetName()) - if err != nil { - return nil, err - } - - return &v1.GetUserResponse{User: user.Proto()}, nil -} - func (api headscaleV1APIServer) CreateUser( ctx context.Context, request *v1.CreateUserRequest, ) (*v1.CreateUserResponse, error) { - user, err := api.h.db.CreateUser(request.GetName()) - if err != nil { - return nil, err + newUser := types.User{ + Name: request.GetName(), + DisplayName: request.GetDisplayName(), + Email: request.GetEmail(), + ProfilePicURL: request.GetPictureUrl(), } + user, policyChanged, err := api.h.state.CreateUser(newUser) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create user: %s", err) + } + + // CreateUser returns a policy change response if the user creation affected policy. + // This triggers a full policy re-evaluation for all connected nodes. + api.h.Change(policyChanged) return &v1.CreateUserResponse{User: user.Proto()}, nil } @@ -56,28 +68,44 @@ func (api headscaleV1APIServer) RenameUser( ctx context.Context, request *v1.RenameUserRequest, ) (*v1.RenameUserResponse, error) { - err := api.h.db.RenameUser(request.GetOldName(), request.GetNewName()) + oldUser, err := api.h.state.GetUserByID(types.UserID(request.GetOldId())) if err != nil { return nil, err } - user, err := api.h.db.GetUser(request.GetNewName()) + _, c, err := api.h.state.RenameUser(types.UserID(oldUser.ID), request.GetNewName()) if err != nil { return nil, err } - return &v1.RenameUserResponse{User: user.Proto()}, nil + // Send policy update notifications if needed + api.h.Change(c) + + newUser, err := api.h.state.GetUserByName(request.GetNewName()) + if err != nil { + return nil, err + } + + return &v1.RenameUserResponse{User: newUser.Proto()}, nil } func (api headscaleV1APIServer) DeleteUser( ctx context.Context, request *v1.DeleteUserRequest, ) (*v1.DeleteUserResponse, error) { - err := api.h.db.DestroyUser(request.GetName()) + user, err := api.h.state.GetUserByID(types.UserID(request.GetId())) if err != nil { return nil, err } + policyChanged, err := api.h.state.DeleteUser(types.UserID(user.ID)) + if err != nil { + return nil, err + } + + // Use the change returned from DeleteUser which includes proper policy updates + api.h.Change(policyChanged) + return &v1.DeleteUserResponse{}, nil } @@ -85,7 +113,19 @@ func (api headscaleV1APIServer) ListUsers( ctx context.Context, request *v1.ListUsersRequest, ) (*v1.ListUsersResponse, error) { - users, err := api.h.db.ListUsers() + var err error + var users []types.User + + switch { + case request.GetName() != "": + users, err = api.h.state.ListUsersWithFilter(&types.User{Name: request.GetName()}) + case request.GetEmail() != "": + users, err = api.h.state.ListUsersWithFilter(&types.User{Email: request.GetEmail()}) + case request.GetId() != 0: + users, err = api.h.state.ListUsersWithFilter(&types.User{Model: gorm.Model{ID: uint(request.GetId())}}) + default: + users, err = api.h.state.ListAllUsers() + } if err != nil { return nil, err } @@ -95,7 +135,9 @@ func (api headscaleV1APIServer) ListUsers( response[index] = user.Proto() } - log.Trace().Caller().Interface("users", response).Msg("") + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) return &v1.ListUsersResponse{Users: response}, nil } @@ -118,8 +160,17 @@ func (api headscaleV1APIServer) CreatePreAuthKey( } } - preAuthKey, err := api.h.db.CreatePreAuthKey( - request.GetUser(), + var userID *types.UserID + if request.GetUser() != 0 { + user, err := api.h.state.GetUserByID(types.UserID(request.GetUser())) + if err != nil { + return nil, err + } + userID = user.TypedID() + } + + preAuthKey, err := api.h.state.CreatePreAuthKey( + userID, request.GetReusable(), request.GetEphemeral(), &expiration, @@ -136,12 +187,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - preAuthKey, err := api.h.db.GetPreAuthKey(request.GetUser(), request.Key) - if err != nil { - return nil, err - } - - err = api.h.db.ExpirePreAuthKey(preAuthKey) + err := api.h.state.ExpirePreAuthKey(request.GetId()) if err != nil { return nil, err } @@ -149,11 +195,23 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( return &v1.ExpirePreAuthKeyResponse{}, nil } +func (api headscaleV1APIServer) DeletePreAuthKey( + ctx context.Context, + request *v1.DeletePreAuthKeyRequest, +) (*v1.DeletePreAuthKeyResponse, error) { + err := api.h.state.DeletePreAuthKey(request.GetId()) + if err != nil { + return nil, err + } + + return &v1.DeletePreAuthKeyResponse{}, nil +} + func (api headscaleV1APIServer) ListPreAuthKeys( ctx context.Context, request *v1.ListPreAuthKeysRequest, ) (*v1.ListPreAuthKeysResponse, error) { - preAuthKeys, err := api.h.db.ListPreAuthKeys(request.GetUser()) + preAuthKeys, err := api.h.state.ListPreAuthKeys() if err != nil { return nil, err } @@ -163,6 +221,10 @@ func (api headscaleV1APIServer) ListPreAuthKeys( response[index] = key.Proto() } + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) + return &v1.ListPreAuthKeysResponse{PreAuthKeys: response}, nil } @@ -170,28 +232,69 @@ func (api headscaleV1APIServer) RegisterNode( ctx context.Context, request *v1.RegisterNodeRequest, ) (*v1.RegisterNodeResponse, error) { + // Generate ephemeral registration key for tracking this registration flow in logs + registrationKey, err := util.GenerateRegistrationKey() + if err != nil { + log.Warn().Err(err).Msg("Failed to generate registration key") + registrationKey = "" // Continue without key if generation fails + } + log.Trace(). + Caller(). Str("user", request.GetUser()). - Str("machine_key", request.GetKey()). + Str("registration_id", request.GetKey()). + Str("registration_key", registrationKey). Msg("Registering node") - var mkey key.MachinePublic - err := mkey.UnmarshalText([]byte(request.GetKey())) + registrationId, err := types.RegistrationIDFromString(request.GetKey()) if err != nil { return nil, err } - node, err := api.h.db.RegisterNodeFromAuthCallback( - api.h.registrationCache, - mkey, - request.GetUser(), + user, err := api.h.state.GetUserByName(request.GetUser()) + if err != nil { + return nil, fmt.Errorf("looking up user: %w", err) + } + + node, nodeChange, err := api.h.state.HandleNodeFromAuthPath( + registrationId, + types.UserID(user.ID), nil, util.RegisterMethodCLI, ) if err != nil { + log.Error(). + Str("registration_key", registrationKey). + Err(err). + Msg("Failed to register node") return nil, err } + log.Info(). + Str("registration_key", registrationKey). + Str("node_id", fmt.Sprintf("%d", node.ID())). + Str("hostname", node.Hostname()). + Msg("Node registered successfully") + + // This is a bit of a back and forth, but we have a bit of a chicken and egg + // dependency here. + // Because the way the policy manager works, we need to have the node + // in the database, then add it to the policy manager and then we can + // approve the route. This means we get this dance where the node is + // first added to the database, then we add it to the policy manager via + // SaveNode (which automatically updates the policy manager) and then we can auto approve the routes. + // As that only approves the struct object, we need to save it again and + // ensure we send an update. + // This works, but might be another good candidate for doing some sort of + // eventbus. + routeChange, err := api.h.state.AutoApproveRoutes(node) + if err != nil { + return nil, fmt.Errorf("auto approving routes: %w", err) + } + + // Send both changes. Empty changes are ignored by Change(). + api.h.Change(nodeChange, routeChange) + return &v1.RegisterNodeResponse{Node: node.Proto()}, nil } @@ -199,17 +302,13 @@ func (api headscaleV1APIServer) GetNode( ctx context.Context, request *v1.GetNodeRequest, ) (*v1.GetNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") } resp := node.Proto() - // Populate the online field based on - // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) - return &v1.GetNodeResponse{Node: resp}, nil } @@ -217,44 +316,114 @@ func (api headscaleV1APIServer) SetTags( ctx context.Context, request *v1.SetTagsRequest, ) (*v1.SetTagsResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err + // Validate tags not empty - tagged nodes must have at least one tag + if len(request.GetTags()) == 0 { + return &v1.SetTagsResponse{ + Node: nil, + }, status.Error( + codes.InvalidArgument, + "cannot remove all tags from a node - tagged nodes must have at least one tag", + ) } + // Validate tag format for _, tag := range request.GetTags() { err := validateTag(tag) if err != nil { - return &v1.SetTagsResponse{ - Node: nil, - }, status.Error(codes.InvalidArgument, err.Error()) + return nil, err } } - err = api.h.db.SetTags(node, request.GetTags()) + // User XOR Tags: nodes are either tagged or user-owned, never both. + // Setting tags on a user-owned node converts it to a tagged node. + // Once tagged, a node cannot be converted back to user-owned. + _, found := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !found { + return &v1.SetTagsResponse{ + Node: nil, + }, status.Error(codes.NotFound, "node not found") + } + + node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags()) if err != nil { return &v1.SetTagsResponse{ Node: nil, - }, status.Error(codes.Internal, err.Error()) + }, status.Error(codes.InvalidArgument, err.Error()) } + api.h.Change(nodeChange) + log.Trace(). - Str("node", node.Hostname). + Caller(). + Str("node", node.Hostname()). Strs("tags", request.GetTags()). Msg("Changing tags of node") return &v1.SetTagsResponse{Node: node.Proto()}, nil } +func (api headscaleV1APIServer) SetApprovedRoutes( + ctx context.Context, + request *v1.SetApprovedRoutesRequest, +) (*v1.SetApprovedRoutesResponse, error) { + log.Debug(). + Caller(). + Uint64("node.id", request.GetNodeId()). + Strs("requestedRoutes", request.GetRoutes()). + Msg("gRPC SetApprovedRoutes called") + + var newApproved []netip.Prefix + for _, route := range request.GetRoutes() { + prefix, err := netip.ParsePrefix(route) + if err != nil { + return nil, fmt.Errorf("parsing route: %w", err) + } + + // If the prefix is an exit route, add both. The client expect both + // to annotate the node as an exit node. + if prefix == tsaddr.AllIPv4() || prefix == tsaddr.AllIPv6() { + newApproved = append(newApproved, tsaddr.AllIPv4(), tsaddr.AllIPv6()) + } else { + newApproved = append(newApproved, prefix) + } + } + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) + + node, nodeChange, err := api.h.state.SetApprovedRoutes(types.NodeID(request.GetNodeId()), newApproved) + if err != nil { + return nil, status.Error(codes.InvalidArgument, err.Error()) + } + + // Always propagate node changes from SetApprovedRoutes + api.h.Change(nodeChange) + + proto := node.Proto() + // Populate SubnetRoutes with PrimaryRoutes to ensure it includes only the + // routes that are actively served from the node (per architectural requirement in types/node.go) + primaryRoutes := api.h.state.GetNodePrimaryRoutes(node.ID()) + proto.SubnetRoutes = util.PrefixesToString(primaryRoutes) + + log.Debug(). + Caller(). + Uint64("node.id", node.ID().Uint64()). + Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). + Strs("primaryRoutes", util.PrefixesToString(primaryRoutes)). + Strs("finalSubnetRoutes", proto.SubnetRoutes). + Msg("gRPC SetApprovedRoutes completed") + + return &v1.SetApprovedRoutesResponse{Node: proto}, nil +} + func validateTag(tag string) error { if strings.Index(tag, "tag:") != 0 { - return fmt.Errorf("tag must start with the string 'tag:'") + return errors.New("tag must start with the string 'tag:'") } if strings.ToLower(tag) != tag { - return fmt.Errorf("tag should be lowercase") + return errors.New("tag should be lowercase") } if len(strings.Fields(tag)) > 1 { - return fmt.Errorf("tag should not contains space") + return errors.New("tag should not contains space") } return nil } @@ -263,17 +432,17 @@ func (api headscaleV1APIServer) DeleteNode( ctx context.Context, request *v1.DeleteNodeRequest, ) (*v1.DeleteNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, ok := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId())) + if !ok { + return nil, status.Errorf(codes.NotFound, "node not found") + } + + nodeChange, err := api.h.state.DeleteNode(node) if err != nil { return nil, err } - err = api.h.db.DeleteNode( - node, - ) - if err != nil { - return nil, err - } + api.h.Change(nodeChange) return &v1.DeleteNodeResponse{}, nil } @@ -282,21 +451,23 @@ func (api headscaleV1APIServer) ExpireNode( ctx context.Context, request *v1.ExpireNodeRequest, ) (*v1.ExpireNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + expiry := time.Now() + if request.GetExpiry() != nil { + expiry = request.GetExpiry().AsTime() + } + + node, nodeChange, err := api.h.state.SetNodeExpiry(types.NodeID(request.GetNodeId()), expiry) if err != nil { return nil, err } - now := time.Now() - - api.h.db.NodeSetExpiry( - node, - now, - ) + // TODO(kradalby): Ensure that both the selfupdate and peer updates are sent + api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). - Time("expiry", *node.Expiry). + Caller(). + Str("node", node.Hostname()). + Time("expiry", *node.AsStruct().Expiry). Msg("node expired") return &v1.ExpireNodeResponse{Node: node.Proto()}, nil @@ -306,21 +477,17 @@ func (api headscaleV1APIServer) RenameNode( ctx context.Context, request *v1.RenameNodeRequest, ) (*v1.RenameNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) + node, nodeChange, err := api.h.state.RenameNode(types.NodeID(request.GetNodeId()), request.GetNewName()) if err != nil { return nil, err } - err = api.h.db.RenameNode( - node, - request.GetNewName(), - ) - if err != nil { - return nil, err - } + // TODO(kradalby): investigate if we need selfupdate + api.h.Change(nodeChange) log.Trace(). - Str("node", node.Hostname). + Caller(). + Str("node", node.Hostname()). Str("new_name", request.GetNewName()). Msg("node renamed") @@ -331,134 +498,66 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { + // TODO(kradalby): it looks like this can be simplified a lot, + // the filtering of nodes by user, vs nodes as a whole can + // probably be done once. + // TODO(kradalby): This should be done in one tx. if request.GetUser() != "" { - nodes, err := api.h.db.ListNodesByUser(request.GetUser()) + user, err := api.h.state.GetUserByName(request.GetUser()) if err != nil { return nil, err } - response := make([]*v1.Node, len(nodes)) - for index, node := range nodes { - resp := node.Proto() - - // Populate the online field based on - // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) - - response[index] = resp - } + nodes := api.h.state.ListNodesByUser(types.UserID(user.ID)) + response := nodesToProto(api.h.state, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } - nodes, err := api.h.db.ListNodes() - if err != nil { - return nil, err - } - - response := make([]*v1.Node, len(nodes)) - for index, node := range nodes { - resp := node.Proto() - - // Populate the online field based on - // currently connected nodes. - resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) - - validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( - &node, - ) - resp.InvalidTags = invalidTags - resp.ValidTags = validTags - response[index] = resp - } + nodes := api.h.state.ListNodes() + response := nodesToProto(api.h.state, nodes) return &v1.ListNodesResponse{Nodes: response}, nil } -func (api headscaleV1APIServer) MoveNode( - ctx context.Context, - request *v1.MoveNodeRequest, -) (*v1.MoveNodeResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err +func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.Node { + response := make([]*v1.Node, nodes.Len()) + for index, node := range nodes.All() { + resp := node.Proto() + + // Tags-as-identity: tagged nodes show as TaggedDevices user in API responses + // (UserID may be set internally for "created by" tracking) + if node.IsTagged() { + resp.User = types.TaggedDevices.Proto() + } + + resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...)) + response[index] = resp } - err = api.h.db.AssignNodeToUser(node, request.GetUser()) - if err != nil { - return nil, err - } + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) - return &v1.MoveNodeResponse{Node: node.Proto()}, nil + return response } -func (api headscaleV1APIServer) GetRoutes( +func (api headscaleV1APIServer) BackfillNodeIPs( ctx context.Context, - request *v1.GetRoutesRequest, -) (*v1.GetRoutesResponse, error) { - routes, err := api.h.db.GetRoutes() + request *v1.BackfillNodeIPsRequest, +) (*v1.BackfillNodeIPsResponse, error) { + log.Trace().Caller().Msg("Backfill called") + + if !request.Confirmed { + return nil, errors.New("not confirmed, aborting") + } + + changes, err := api.h.state.BackfillNodeIPs() if err != nil { return nil, err } - return &v1.GetRoutesResponse{ - Routes: types.Routes(routes).Proto(), - }, nil -} - -func (api headscaleV1APIServer) EnableRoute( - ctx context.Context, - request *v1.EnableRouteRequest, -) (*v1.EnableRouteResponse, error) { - err := api.h.db.EnableRoute(request.GetRouteId()) - if err != nil { - return nil, err - } - - return &v1.EnableRouteResponse{}, nil -} - -func (api headscaleV1APIServer) DisableRoute( - ctx context.Context, - request *v1.DisableRouteRequest, -) (*v1.DisableRouteResponse, error) { - err := api.h.db.DisableRoute(request.GetRouteId()) - if err != nil { - return nil, err - } - - return &v1.DisableRouteResponse{}, nil -} - -func (api headscaleV1APIServer) GetNodeRoutes( - ctx context.Context, - request *v1.GetNodeRoutesRequest, -) (*v1.GetNodeRoutesResponse, error) { - node, err := api.h.db.GetNodeByID(request.GetNodeId()) - if err != nil { - return nil, err - } - - routes, err := api.h.db.GetNodeRoutes(node) - if err != nil { - return nil, err - } - - return &v1.GetNodeRoutesResponse{ - Routes: types.Routes(routes).Proto(), - }, nil -} - -func (api headscaleV1APIServer) DeleteRoute( - ctx context.Context, - request *v1.DeleteRouteRequest, -) (*v1.DeleteRouteResponse, error) { - err := api.h.db.DeleteRoute(request.GetRouteId()) - if err != nil { - return nil, err - } - - return &v1.DeleteRouteResponse{}, nil + return &v1.BackfillNodeIPsResponse{Changes: changes}, nil } func (api headscaleV1APIServer) CreateApiKey( @@ -470,9 +569,7 @@ func (api headscaleV1APIServer) CreateApiKey( expiration = request.GetExpiration().AsTime() } - apiKey, _, err := api.h.db.CreateAPIKey( - &expiration, - ) + apiKey, _, err := api.h.state.CreateAPIKey(&expiration) if err != nil { return nil, err } @@ -480,19 +577,40 @@ func (api headscaleV1APIServer) CreateApiKey( return &v1.CreateApiKeyResponse{ApiKey: apiKey}, nil } +// apiKeyIdentifier is implemented by requests that identify an API key. +type apiKeyIdentifier interface { + GetId() uint64 + GetPrefix() string +} + +// getAPIKey retrieves an API key by ID or prefix from the request. +// Returns InvalidArgument if neither or both are provided. +func (api headscaleV1APIServer) getAPIKey(req apiKeyIdentifier) (*types.APIKey, error) { + hasID := req.GetId() != 0 + hasPrefix := req.GetPrefix() != "" + + switch { + case hasID && hasPrefix: + return nil, status.Error(codes.InvalidArgument, "provide either id or prefix, not both") + case hasID: + return api.h.state.GetAPIKeyByID(req.GetId()) + case hasPrefix: + return api.h.state.GetAPIKey(req.GetPrefix()) + default: + return nil, status.Error(codes.InvalidArgument, "must provide id or prefix") + } +} + func (api headscaleV1APIServer) ExpireApiKey( ctx context.Context, request *v1.ExpireApiKeyRequest, ) (*v1.ExpireApiKeyResponse, error) { - var apiKey *types.APIKey - var err error - - apiKey, err = api.h.db.GetAPIKey(request.Prefix) + apiKey, err := api.getAPIKey(request) if err != nil { return nil, err } - err = api.h.db.ExpireAPIKey(apiKey) + err = api.h.state.ExpireAPIKey(apiKey) if err != nil { return nil, err } @@ -504,7 +622,7 @@ func (api headscaleV1APIServer) ListApiKeys( ctx context.Context, request *v1.ListApiKeysRequest, ) (*v1.ListApiKeysResponse, error) { - apiKeys, err := api.h.db.ListAPIKeys() + apiKeys, err := api.h.state.ListAPIKeys() if err != nil { return nil, err } @@ -514,15 +632,133 @@ func (api headscaleV1APIServer) ListApiKeys( response[index] = key.Proto() } + sort.Slice(response, func(i, j int) bool { + return response[i].Id < response[j].Id + }) + return &v1.ListApiKeysResponse{ApiKeys: response}, nil } +func (api headscaleV1APIServer) DeleteApiKey( + ctx context.Context, + request *v1.DeleteApiKeyRequest, +) (*v1.DeleteApiKeyResponse, error) { + apiKey, err := api.getAPIKey(request) + if err != nil { + return nil, err + } + + if err := api.h.state.DestroyAPIKey(*apiKey); err != nil { + return nil, err + } + + return &v1.DeleteApiKeyResponse{}, nil +} + +func (api headscaleV1APIServer) GetPolicy( + _ context.Context, + _ *v1.GetPolicyRequest, +) (*v1.GetPolicyResponse, error) { + switch api.h.cfg.Policy.Mode { + case types.PolicyModeDB: + p, err := api.h.state.GetPolicy() + if err != nil { + return nil, fmt.Errorf("loading ACL from database: %w", err) + } + + return &v1.GetPolicyResponse{ + Policy: p.Data, + UpdatedAt: timestamppb.New(p.UpdatedAt), + }, nil + case types.PolicyModeFile: + // Read the file and return the contents as-is. + absPath := util.AbsolutePathFromConfigPath(api.h.cfg.Policy.Path) + f, err := os.Open(absPath) + if err != nil { + return nil, fmt.Errorf("reading policy from path %q: %w", absPath, err) + } + + defer f.Close() + + b, err := io.ReadAll(f) + if err != nil { + return nil, fmt.Errorf("reading policy from file: %w", err) + } + + return &v1.GetPolicyResponse{Policy: string(b)}, nil + } + + return nil, fmt.Errorf("no supported policy mode found in configuration, policy.mode: %q", api.h.cfg.Policy.Mode) +} + +func (api headscaleV1APIServer) SetPolicy( + _ context.Context, + request *v1.SetPolicyRequest, +) (*v1.SetPolicyResponse, error) { + if api.h.cfg.Policy.Mode != types.PolicyModeDB { + return nil, types.ErrPolicyUpdateIsDisabled + } + + p := request.GetPolicy() + + // Validate and reject configuration that would error when applied + // when creating a map response. This requires nodes, so there is still + // a scenario where they might be allowed if the server has no nodes + // yet, but it should help for the general case and for hot reloading + // configurations. + nodes := api.h.state.ListNodes() + + _, err := api.h.state.SetPolicy([]byte(p)) + if err != nil { + return nil, fmt.Errorf("setting policy: %w", err) + } + + if nodes.Len() > 0 { + _, err = api.h.state.SSHPolicy(nodes.At(0)) + if err != nil { + return nil, fmt.Errorf("verifying SSH rules: %w", err) + } + } + + updated, err := api.h.state.SetPolicyInDB(p) + if err != nil { + return nil, err + } + + // Always reload policy to ensure route re-evaluation, even if policy content hasn't changed. + // This ensures that routes are re-evaluated for auto-approval in cases where routes + // were manually disabled but could now be auto-approved with the current policy. + cs, err := api.h.state.ReloadPolicy() + if err != nil { + return nil, fmt.Errorf("reloading policy: %w", err) + } + + if len(cs) > 0 { + api.h.Change(cs...) + } else { + log.Debug(). + Caller(). + Msg("No policy changes to distribute because ReloadPolicy returned empty changeset") + } + + response := &v1.SetPolicyResponse{ + Policy: updated.Data, + UpdatedAt: timestamppb.New(updated.UpdatedAt), + } + + log.Debug(). + Caller(). + Msg("gRPC SetPolicy completed successfully because response prepared") + + return response, nil +} + // The following service calls are for testing and debugging func (api headscaleV1APIServer) DebugCreateNode( ctx context.Context, request *v1.DebugCreateNodeRequest, ) (*v1.DebugCreateNodeResponse, error) { - user, err := api.h.db.GetUser(request.GetUser()) + user, err := api.h.state.GetUserByName(request.GetUser()) if err != nil { return nil, err } @@ -536,51 +772,61 @@ func (api headscaleV1APIServer) DebugCreateNode( Caller(). Interface("route-prefix", routes). Interface("route-str", request.GetRoutes()). - Msg("") + Msg("Creating routes for node") hostinfo := tailcfg.Hostinfo{ RoutableIPs: routes, OS: "TestOS", - Hostname: "DebugTestNode", + Hostname: request.GetName(), } - var mkey key.MachinePublic - err = mkey.UnmarshalText([]byte(request.GetKey())) + registrationId, err := types.RegistrationIDFromString(request.GetKey()) if err != nil { return nil, err } - givenName, err := api.h.db.GenerateGivenName(mkey, request.GetName()) - if err != nil { - return nil, err - } + newNode := types.NewRegisterNode( + types.Node{ + NodeKey: key.NewNode().Public(), + MachineKey: key.NewMachine().Public(), + Hostname: request.GetName(), + User: user, - nodeKey := key.NewNode() + Expiry: &time.Time{}, + LastSeen: &time.Time{}, - newNode := types.Node{ - MachineKey: mkey, - NodeKey: nodeKey.Public(), - Hostname: request.GetName(), - GivenName: givenName, - User: *user, - - Expiry: &time.Time{}, - LastSeen: &time.Time{}, - - Hostinfo: &hostinfo, - } - - log.Debug(). - Str("machine_key", mkey.ShortString()). - Msg("adding debug machine via CLI, appending to registration cache") - - api.h.registrationCache.Set( - mkey.String(), - newNode, - registerCacheExpiration, + Hostinfo: &hostinfo, + }, ) - return &v1.DebugCreateNodeResponse{Node: newNode.Proto()}, nil + log.Debug(). + Caller(). + Str("registration_id", registrationId.String()). + Msg("adding debug machine via CLI, appending to registration cache") + + api.h.state.SetRegistrationCacheEntry(registrationId, newNode) + + return &v1.DebugCreateNodeResponse{Node: newNode.Node.Proto()}, nil +} + +func (api headscaleV1APIServer) Health( + ctx context.Context, + request *v1.HealthRequest, +) (*v1.HealthResponse, error) { + var healthErr error + response := &v1.HealthResponse{} + + if err := api.h.state.PingDB(ctx); err != nil { + healthErr = fmt.Errorf("database ping failed: %w", err) + } else { + response.DatabaseConnectivity = true + } + + if healthErr != nil { + log.Error().Err(healthErr).Msg("Health check failed") + } + + return response, healthErr } func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/grpcv1_test.go b/hscontrol/grpcv1_test.go index 1d87bfe0..4cf5b7d4 100644 --- a/hscontrol/grpcv1_test.go +++ b/hscontrol/grpcv1_test.go @@ -1,6 +1,17 @@ package hscontrol -import "testing" +import ( + "context" + "testing" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) func Test_validateTag(t *testing.T) { type args struct { @@ -40,3 +51,418 @@ func Test_validateTag(t *testing.T) { }) } } + +// TestSetTags_Conversion tests the conversion of user-owned nodes to tagged nodes. +// The tags-as-identity model allows one-way conversion from user-owned to tagged. +// Tag authorization is checked via the policy manager - unauthorized tags are rejected. +func TestSetTags_Conversion(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and nodes + user := app.state.CreateUserForTest("test-user") + + // Create a pre-auth key WITHOUT tags for user-owned node + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) + require.NoError(t, err) + + machineKey1 := key.NewMachine() + nodeKey1 := key.NewNode() + + // Register a user-owned node (via untagged PreAuthKey) + userOwnedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey1.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "user-owned-node", + }, + } + _, err = app.handleRegisterWithAuthKey(userOwnedReq, machineKey1.Public()) + require.NoError(t, err) + + // Get the created node + userOwnedNode, found := app.state.GetNodeByNodeKey(nodeKey1.Public()) + require.True(t, found) + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + tests := []struct { + name string + nodeID uint64 + tags []string + wantErr bool + wantCode codes.Code + wantErrMessage string + }{ + { + // Conversion is allowed, but tag authorization fails without tagOwners + name: "reject unauthorized tags on user-owned node", + nodeID: uint64(userOwnedNode.ID()), + tags: []string{"tag:server"}, + wantErr: true, + wantCode: codes.InvalidArgument, + wantErrMessage: "requested tags", + }, + { + // Conversion is allowed, but tag authorization fails without tagOwners + name: "reject multiple unauthorized tags", + nodeID: uint64(userOwnedNode.ID()), + tags: []string{"tag:server", "tag:database"}, + wantErr: true, + wantCode: codes.InvalidArgument, + wantErrMessage: "requested tags", + }, + { + name: "reject non-existent node", + nodeID: 99999, + tags: []string{"tag:server"}, + wantErr: true, + wantCode: codes.NotFound, + wantErrMessage: "node not found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: tt.nodeID, + Tags: tt.tags, + }) + + if tt.wantErr { + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, tt.wantCode, st.Code()) + assert.Contains(t, st.Message(), tt.wantErrMessage) + assert.Nil(t, resp.GetNode()) + } else { + require.NoError(t, err) + assert.NotNil(t, resp) + assert.NotNil(t, resp.GetNode()) + } + }) + } +} + +// TestSetTags_TaggedNode tests that SetTags correctly identifies tagged nodes +// and doesn't reject them with the "user-owned nodes" error. +// Note: This test doesn't validate ACL tag authorization - that's tested elsewhere. +func TestSetTags_TaggedNode(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and tagged pre-auth key + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:initial"}) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register a tagged node (via tagged PreAuthKey) + taggedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + } + _, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public()) + require.NoError(t, err) + + // Get the created node + taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, taggedNode.IsTagged(), "Node should be tagged") + assert.True(t, taggedNode.UserID().Valid(), "Tagged node should have UserID for tracking") + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + // Test: SetTags should NOT reject tagged nodes with "user-owned" error + // (Even though they have UserID set, IsTagged() identifies them correctly) + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: uint64(taggedNode.ID()), + Tags: []string{"tag:initial"}, // Keep existing tag to avoid ACL validation issues + }) + + // The call should NOT fail with "cannot set tags on user-owned nodes" + if err != nil { + st, ok := status.FromError(err) + require.True(t, ok) + // If error is about unauthorized tags, that's fine - ACL validation is working + // If error is about user-owned nodes, that's the bug we're testing for + assert.NotContains(t, st.Message(), "user-owned nodes", "Should not reject tagged nodes as user-owned") + } else { + // Success is also fine + assert.NotNil(t, resp) + } +} + +// TestSetTags_CannotRemoveAllTags tests that SetTags rejects attempts to remove +// all tags from a tagged node, enforcing Tailscale's requirement that tagged +// nodes must have at least one tag. +func TestSetTags_CannotRemoveAllTags(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create test user and tagged pre-auth key + user := app.state.CreateUserForTest("test-user") + pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:server"}) + require.NoError(t, err) + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + // Register a tagged node + taggedReq := tailcfg.RegisterRequest{ + Auth: &tailcfg.RegisterResponseAuth{ + AuthKey: pak.Key, + }, + NodeKey: nodeKey.Public(), + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "tagged-node", + }, + } + _, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public()) + require.NoError(t, err) + + // Get the created node + taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public()) + require.True(t, found) + assert.True(t, taggedNode.IsTagged()) + + // Create API server instance + apiServer := newHeadscaleV1APIServer(app) + + // Attempt to remove all tags (empty array) + resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{ + NodeId: uint64(taggedNode.ID()), + Tags: []string{}, // Empty - attempting to remove all tags + }) + + // Should fail with InvalidArgument error + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "cannot remove all tags") + assert.Nil(t, resp.GetNode()) +} + +// TestDeleteUser_ReturnsProperChangeSignal tests issue #2967 fix: +// When a user is deleted, the state should return a non-empty change signal +// to ensure policy manager is updated and clients are notified immediately. +func TestDeleteUser_ReturnsProperChangeSignal(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + + // Create a user + user := app.state.CreateUserForTest("test-user-to-delete") + require.NotNil(t, user) + + // Delete the user and verify a non-empty change is returned + // Issue #2967: Without the fix, DeleteUser returned an empty change, + // causing stale policy state until another user operation triggered an update. + changeSignal, err := app.state.DeleteUser(*user.TypedID()) + require.NoError(t, err, "DeleteUser should succeed") + assert.False(t, changeSignal.IsEmpty(), "DeleteUser should return a non-empty change signal (issue #2967)") +} + +// TestExpireApiKey_ByID tests that API keys can be expired by ID. +func TestExpireApiKey_ByID(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the ID + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyID := listResp.GetApiKeys()[0].GetId() + + // Expire by ID + _, err = apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Id: keyID, + }) + require.NoError(t, err) + + // Verify key is expired (expiration is set to now or in the past) + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + assert.NotNil(t, listResp.GetApiKeys()[0].GetExpiration(), "expiration should be set") +} + +// TestExpireApiKey_ByPrefix tests that API keys can still be expired by prefix. +func TestExpireApiKey_ByPrefix(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the prefix + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyPrefix := listResp.GetApiKeys()[0].GetPrefix() + + // Expire by prefix + _, err = apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Prefix: keyPrefix, + }) + require.NoError(t, err) +} + +// TestDeleteApiKey_ByID tests that API keys can be deleted by ID. +func TestDeleteApiKey_ByID(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the ID + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyID := listResp.GetApiKeys()[0].GetId() + + // Delete by ID + _, err = apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Id: keyID, + }) + require.NoError(t, err) + + // Verify key is deleted + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + assert.Empty(t, listResp.GetApiKeys()) +} + +// TestDeleteApiKey_ByPrefix tests that API keys can still be deleted by prefix. +func TestDeleteApiKey_ByPrefix(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + // Create an API key + createResp, err := apiServer.CreateApiKey(context.Background(), &v1.CreateApiKeyRequest{}) + require.NoError(t, err) + require.NotEmpty(t, createResp.GetApiKey()) + + // List keys to get the prefix + listResp, err := apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + require.Len(t, listResp.GetApiKeys(), 1) + + keyPrefix := listResp.GetApiKeys()[0].GetPrefix() + + // Delete by prefix + _, err = apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Prefix: keyPrefix, + }) + require.NoError(t, err) + + // Verify key is deleted + listResp, err = apiServer.ListApiKeys(context.Background(), &v1.ListApiKeysRequest{}) + require.NoError(t, err) + assert.Empty(t, listResp.GetApiKeys()) +} + +// TestExpireApiKey_NoIdentifier tests that an error is returned when neither ID nor prefix is provided. +func TestExpireApiKey_NoIdentifier(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{}) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "must provide id or prefix") +} + +// TestDeleteApiKey_NoIdentifier tests that an error is returned when neither ID nor prefix is provided. +func TestDeleteApiKey_NoIdentifier(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{}) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "must provide id or prefix") +} + +// TestExpireApiKey_BothIdentifiers tests that an error is returned when both ID and prefix are provided. +func TestExpireApiKey_BothIdentifiers(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.ExpireApiKey(context.Background(), &v1.ExpireApiKeyRequest{ + Id: 1, + Prefix: "test", + }) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "provide either id or prefix, not both") +} + +// TestDeleteApiKey_BothIdentifiers tests that an error is returned when both ID and prefix are provided. +func TestDeleteApiKey_BothIdentifiers(t *testing.T) { + t.Parallel() + + app := createTestApp(t) + apiServer := newHeadscaleV1APIServer(app) + + _, err := apiServer.DeleteApiKey(context.Background(), &v1.DeleteApiKeyRequest{ + Id: 1, + Prefix: "test", + }) + require.Error(t, err) + st, ok := status.FromError(err) + require.True(t, ok, "error should be a gRPC status error") + assert.Equal(t, codes.InvalidArgument, st.Code()) + assert.Contains(t, st.Message(), "provide either id or prefix, not both") +} diff --git a/hscontrol/handlers.go b/hscontrol/handlers.go index ee670733..dc693dae 100644 --- a/hscontrol/handlers.go +++ b/hscontrol/handlers.go @@ -5,15 +5,18 @@ import ( "encoding/json" "errors" "fmt" - "html/template" + "io" "net/http" "strconv" + "strings" "time" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/assets" + "github.com/juanfont/headscale/hscontrol/templates" + "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "tailscale.com/tailcfg" - "tailscale.com/types/key" ) const ( @@ -28,31 +31,110 @@ const ( // See also https://github.com/tailscale/tailscale/blob/main/tailcfg/tailcfg.go NoiseCapabilityVersion = 39 - // TODO(juan): remove this once https://github.com/juanfont/headscale/issues/727 is fixed. - registrationHoldoff = time.Second * 5 reservedResponseHeaderSize = 4 ) +// httpError logs an error and sends an HTTP error response with the given. +func httpError(w http.ResponseWriter, err error) { + var herr HTTPError + if errors.As(err, &herr) { + http.Error(w, herr.Msg, herr.Code) + log.Error().Err(herr.Err).Int("code", herr.Code).Msgf("user msg: %s", herr.Msg) + } else { + http.Error(w, "internal server error", http.StatusInternalServerError) + log.Error().Err(err).Int("code", http.StatusInternalServerError).Msg("http internal server error") + } +} + +// HTTPError represents an error that is surfaced to the user via web. +type HTTPError struct { + Code int // HTTP response code to send to client; 0 means 500 + Msg string // Response body to send to client + Err error // Detailed error to log on the server +} + +func (e HTTPError) Error() string { return fmt.Sprintf("http error[%d]: %s, %s", e.Code, e.Msg, e.Err) } +func (e HTTPError) Unwrap() error { return e.Err } + +// Error returns an HTTPError containing the given information. +func NewHTTPError(code int, msg string, err error) HTTPError { + return HTTPError{Code: code, Msg: msg, Err: err} +} + +var errMethodNotAllowed = NewHTTPError(http.StatusMethodNotAllowed, "method not allowed", nil) + var ErrRegisterMethodCLIDoesNotSupportExpire = errors.New( "machines registered with CLI does not support expire", ) -var ErrNoCapabilityVersion = errors.New("no capability version set") -func parseCabailityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) { +func parseCapabilityVersion(req *http.Request) (tailcfg.CapabilityVersion, error) { clientCapabilityStr := req.URL.Query().Get("v") if clientCapabilityStr == "" { - return 0, ErrNoCapabilityVersion + return 0, NewHTTPError(http.StatusBadRequest, "capability version must be set", nil) } clientCapabilityVersion, err := strconv.Atoi(clientCapabilityStr) if err != nil { - return 0, fmt.Errorf("failed to parse capability version: %w", err) + return 0, NewHTTPError(http.StatusBadRequest, "invalid capability version", fmt.Errorf("failed to parse capability version: %w", err)) } return tailcfg.CapabilityVersion(clientCapabilityVersion), nil } +func (h *Headscale) handleVerifyRequest( + req *http.Request, + writer io.Writer, +) error { + body, err := io.ReadAll(req.Body) + if err != nil { + return fmt.Errorf("cannot read request body: %w", err) + } + + var derpAdmitClientRequest tailcfg.DERPAdmitClientRequest + if err := json.Unmarshal(body, &derpAdmitClientRequest); err != nil { + return NewHTTPError(http.StatusBadRequest, "Bad Request: invalid JSON", fmt.Errorf("cannot parse derpAdmitClientRequest: %w", err)) + } + + nodes := h.state.ListNodes() + + // Check if any node has the requested NodeKey + var nodeKeyFound bool + + for _, node := range nodes.All() { + if node.NodeKey() == derpAdmitClientRequest.NodePublic { + nodeKeyFound = true + break + } + } + + resp := &tailcfg.DERPAdmitClientResponse{ + Allow: nodeKeyFound, + } + + return json.NewEncoder(writer).Encode(resp) +} + +// VerifyHandler see https://github.com/tailscale/tailscale/blob/964282d34f06ecc06ce644769c66b0b31d118340/derp/derp_server.go#L1159 +// DERP use verifyClientsURL to verify whether a client is allowed to connect to the DERP server. +func (h *Headscale) VerifyHandler( + writer http.ResponseWriter, + req *http.Request, +) { + if req.Method != http.MethodPost { + httpError(writer, errMethodNotAllowed) + return + } + + err := h.handleVerifyRequest(req, writer) + if err != nil { + httpError(writer, err) + return + } + + writer.Header().Set("Content-Type", "application/json") +} + // KeyHandler provides the Headscale pub key // Listens in /key. func (h *Headscale) KeyHandler( @@ -60,39 +142,9 @@ func (h *Headscale) KeyHandler( req *http.Request, ) { // New Tailscale clients send a 'v' parameter to indicate the CurrentCapabilityVersion - capVer, err := parseCabailityVersion(req) + capVer, err := parseCapabilityVersion(req) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("could not get capability version") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - - log.Debug(). - Str("handler", "/key"). - Int("cap_ver", int(capVer)). - Msg("New noise client") - if err != nil { - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Wrong params")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, err) return } @@ -101,15 +153,9 @@ func (h *Headscale) KeyHandler( resp := tailcfg.OverTLSPublicKeyResponse{ PublicKey: h.noisePrivateKey.Public(), } + writer.Header().Set("Content-Type", "application/json") - writer.WriteHeader(http.StatusOK) - err = json.NewEncoder(writer).Encode(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } + json.NewEncoder(writer).Encode(resp) return } @@ -130,21 +176,14 @@ func (h *Headscale) HealthHandler( if err != nil { writer.WriteHeader(http.StatusInternalServerError) - log.Error().Caller().Err(err).Msg("health check failed") + res.Status = "fail" } - buf, err := json.Marshal(res) - if err != nil { - log.Error().Caller().Err(err).Msg("marshal failed") - } - _, err = writer.Write(buf) - if err != nil { - log.Error().Caller().Err(err).Msg("write failed") - } + json.NewEncoder(writer).Encode(res) } - - if err := h.db.PingDB(req.Context()); err != nil { + err := h.state.PingDB(req.Context()) + if err != nil { respond(err) return @@ -153,90 +192,99 @@ func (h *Headscale) HealthHandler( respond(nil) } -type registerWebAPITemplateConfig struct { - Key string +func (h *Headscale) RobotsHandler( + writer http.ResponseWriter, + req *http.Request, +) { + writer.Header().Set("Content-Type", "text/plain") + writer.WriteHeader(http.StatusOK) + + _, err := writer.Write([]byte("User-agent: *\nDisallow: /")) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write HTTP response") + } } -var registerWebAPITemplate = template.Must( - template.New("registerweb").Parse(` - - - Registration - Headscale - - -

headscale

-

Machine registration

-

- Run the command below in the headscale server to add this machine to your network: -

-
headscale nodes register --user USERNAME --key {{.Key}}
- - -`)) +// VersionHandler returns version information about the Headscale server +// Listens in /version. +func (h *Headscale) VersionHandler( + writer http.ResponseWriter, + req *http.Request, +) { + writer.Header().Set("Content-Type", "application/json") + writer.WriteHeader(http.StatusOK) + + versionInfo := types.GetVersionInfo() + err := json.NewEncoder(writer).Encode(versionInfo) + if err != nil { + log.Error(). + Caller(). + Err(err). + Msg("Failed to write version response") + } +} + +type AuthProviderWeb struct { + serverURL string +} + +func NewAuthProviderWeb(serverURL string) *AuthProviderWeb { + return &AuthProviderWeb{ + serverURL: serverURL, + } +} + +func (a *AuthProviderWeb) AuthURL(registrationId types.RegistrationID) string { + return fmt.Sprintf( + "%s/register/%s", + strings.TrimSuffix(a.serverURL, "/"), + registrationId.String()) +} // RegisterWebAPI shows a simple message in the browser to point to the CLI -// Listens in /register/:nkey. +// Listens in /register/:registration_id. // // This is not part of the Tailscale control API, as we could send whatever URL // in the RegisterResponse.AuthURL field. -func (h *Headscale) RegisterWebAPI( +func (a *AuthProviderWeb) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { vars := mux.Vars(req) - machineKeyStr := vars["mkey"] + registrationIdStr := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - var machineKey key.MachinePublic - err := machineKey.UnmarshalText( - []byte(machineKeyStr), - ) + registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - log.Warn().Err(err).Msg("Failed to parse incoming nodekey") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Wrong params")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - - var content bytes.Buffer - if err := registerWebAPITemplate.Execute(&content, registerWebAPITemplateConfig{ - Key: machineKey.String(), - }); err != nil { - log.Error(). - Str("func", "RegisterWebAPI"). - Err(err). - Msg("Could not render register web API template") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err = writer.Write([]byte("Could not render register web API template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) return } writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err = writer.Write(content.Bytes()) + writer.Write([]byte(templates.RegisterWeb(registrationId).Render())) +} + +func FaviconHandler(writer http.ResponseWriter, req *http.Request) { + writer.Header().Set("Content-Type", "image/png") + http.ServeContent(writer, req, "favicon.ico", time.Unix(0, 0), bytes.NewReader(assets.Favicon)) +} + +// BlankHandler returns a blank page with favicon linked. +func BlankHandler(writer http.ResponseWriter, res *http.Request) { + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + _, err := writer.Write([]byte(templates.BlankPage().Render())) if err != nil { log.Error(). Caller(). Err(err). - Msg("Failed to write response") + Msg("Failed to write HTTP response") } } diff --git a/hscontrol/mapper/batcher.go b/hscontrol/mapper/batcher.go new file mode 100644 index 00000000..0a1e30d0 --- /dev/null +++ b/hscontrol/mapper/batcher.go @@ -0,0 +1,178 @@ +package mapper + +import ( + "errors" + "fmt" + "time" + + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" +) + +var mapResponseGenerated = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: "headscale", + Name: "mapresponse_generated_total", + Help: "total count of mapresponses generated by response type", +}, []string{"response_type"}) + +type batcherFunc func(cfg *types.Config, state *state.State) Batcher + +// Batcher defines the common interface for all batcher implementations. +type Batcher interface { + Start() + Close() + AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error + RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool + IsConnected(id types.NodeID) bool + ConnectedMap() *xsync.Map[types.NodeID, bool] + AddWork(r ...change.Change) + MapResponseFromChange(id types.NodeID, r change.Change) (*tailcfg.MapResponse, error) + DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) +} + +func NewBatcher(batchTime time.Duration, workers int, mapper *mapper) *LockFreeBatcher { + return &LockFreeBatcher{ + mapper: mapper, + workers: workers, + tick: time.NewTicker(batchTime), + + // The size of this channel is arbitrary chosen, the sizing should be revisited. + workCh: make(chan work, workers*200), + nodes: xsync.NewMap[types.NodeID, *multiChannelNodeConn](), + connected: xsync.NewMap[types.NodeID, *time.Time](), + pendingChanges: xsync.NewMap[types.NodeID, []change.Change](), + } +} + +// NewBatcherAndMapper creates a Batcher implementation. +func NewBatcherAndMapper(cfg *types.Config, state *state.State) Batcher { + m := newMapper(cfg, state) + b := NewBatcher(cfg.Tuning.BatchChangeDelay, cfg.Tuning.BatcherWorkers, m) + m.batcher = b + + return b +} + +// nodeConnection interface for different connection implementations. +type nodeConnection interface { + nodeID() types.NodeID + version() tailcfg.CapabilityVersion + send(data *tailcfg.MapResponse) error + // computePeerDiff returns peers that were previously sent but are no longer in the current list. + computePeerDiff(currentPeers []tailcfg.NodeID) (removed []tailcfg.NodeID) + // updateSentPeers updates the tracking of which peers have been sent to this node. + updateSentPeers(resp *tailcfg.MapResponse) +} + +// generateMapResponse generates a [tailcfg.MapResponse] for the given NodeID based on the provided [change.Change]. +func generateMapResponse(nc nodeConnection, mapper *mapper, r change.Change) (*tailcfg.MapResponse, error) { + nodeID := nc.nodeID() + version := nc.version() + + if r.IsEmpty() { + return nil, nil //nolint:nilnil // Empty response means nothing to send + } + + if nodeID == 0 { + return nil, fmt.Errorf("invalid nodeID: %d", nodeID) + } + + if mapper == nil { + return nil, fmt.Errorf("mapper is nil for nodeID %d", nodeID) + } + + // Handle self-only responses + if r.IsSelfOnly() && r.TargetNode != nodeID { + return nil, nil //nolint:nilnil // No response needed for other nodes when self-only + } + + // Check if this is a self-update (the changed node is the receiving node). + // When true, ensure the response includes the node's self info so it sees + // its own attribute changes (e.g., tags changed via admin API). + isSelfUpdate := r.OriginNode != 0 && r.OriginNode == nodeID + + var ( + mapResp *tailcfg.MapResponse + err error + ) + + // Track metric using categorized type, not free-form reason + mapResponseGenerated.WithLabelValues(r.Type()).Inc() + + // Check if this requires runtime peer visibility computation (e.g., policy changes) + if r.RequiresRuntimePeerComputation { + currentPeers := mapper.state.ListPeers(nodeID) + + currentPeerIDs := make([]tailcfg.NodeID, 0, currentPeers.Len()) + for _, peer := range currentPeers.All() { + currentPeerIDs = append(currentPeerIDs, peer.ID().NodeID()) + } + + removedPeers := nc.computePeerDiff(currentPeerIDs) + // Include self node when this is a self-update (e.g., node's own tags changed) + // so the node sees its updated self info along with new packet filters. + mapResp, err = mapper.policyChangeResponse(nodeID, version, removedPeers, currentPeers, isSelfUpdate) + } else if isSelfUpdate { + // Non-policy self-update: just send the self node info + mapResp, err = mapper.selfMapResponse(nodeID, version) + } else { + mapResp, err = mapper.buildFromChange(nodeID, version, &r) + } + + if err != nil { + return nil, fmt.Errorf("generating map response for nodeID %d: %w", nodeID, err) + } + + return mapResp, nil +} + +// handleNodeChange generates and sends a [tailcfg.MapResponse] for a given node and [change.Change]. +func handleNodeChange(nc nodeConnection, mapper *mapper, r change.Change) error { + if nc == nil { + return errors.New("nodeConnection is nil") + } + + nodeID := nc.nodeID() + + log.Debug().Caller().Uint64("node.id", nodeID.Uint64()).Str("reason", r.Reason).Msg("Node change processing started because change notification received") + + data, err := generateMapResponse(nc, mapper, r) + if err != nil { + return fmt.Errorf("generating map response for node %d: %w", nodeID, err) + } + + if data == nil { + // No data to send is valid for some response types + return nil + } + + // Send the map response + err = nc.send(data) + if err != nil { + return fmt.Errorf("sending map response to node %d: %w", nodeID, err) + } + + // Update peer tracking after successful send + nc.updateSentPeers(data) + + return nil +} + +// workResult represents the result of processing a change. +type workResult struct { + mapResponse *tailcfg.MapResponse + err error +} + +// work represents a unit of work to be processed by workers. +type work struct { + c change.Change + nodeID types.NodeID + resultCh chan<- workResult // optional channel for synchronous operations +} diff --git a/hscontrol/mapper/batcher_lockfree.go b/hscontrol/mapper/batcher_lockfree.go new file mode 100644 index 00000000..e00512b6 --- /dev/null +++ b/hscontrol/mapper/batcher_lockfree.go @@ -0,0 +1,829 @@ +package mapper + +import ( + "crypto/rand" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/puzpuzpuz/xsync/v4" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +var errConnectionClosed = errors.New("connection channel already closed") + +// LockFreeBatcher uses atomic operations and concurrent maps to eliminate mutex contention. +type LockFreeBatcher struct { + tick *time.Ticker + mapper *mapper + workers int + + nodes *xsync.Map[types.NodeID, *multiChannelNodeConn] + connected *xsync.Map[types.NodeID, *time.Time] + + // Work queue channel + workCh chan work + workChOnce sync.Once // Ensures workCh is only closed once + done chan struct{} + doneOnce sync.Once // Ensures done is only closed once + + // Batching state + pendingChanges *xsync.Map[types.NodeID, []change.Change] + + // Metrics + totalNodes atomic.Int64 + workQueuedCount atomic.Int64 + workProcessed atomic.Int64 + workErrors atomic.Int64 +} + +// AddNode registers a new node connection with the batcher and sends an initial map response. +// It creates or updates the node's connection data, validates the initial map generation, +// and notifies other nodes that this node has come online. +func (b *LockFreeBatcher) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + addNodeStart := time.Now() + + // Generate connection ID + connID := generateConnectionID() + + // Create new connection entry + now := time.Now() + newEntry := &connectionEntry{ + id: connID, + c: c, + version: version, + created: now, + } + // Initialize last used timestamp + newEntry.lastUsed.Store(now.Unix()) + + // Get or create multiChannelNodeConn - this reuses existing offline nodes for rapid reconnection + nodeConn, loaded := b.nodes.LoadOrStore(id, newMultiChannelNodeConn(id, b.mapper)) + + if !loaded { + b.totalNodes.Add(1) + } + + // Add connection to the list (lock-free) + nodeConn.addConnection(newEntry) + + // Use the worker pool for controlled concurrency instead of direct generation + initialMap, err := b.MapResponseFromChange(id, change.FullSelf(id)) + if err != nil { + log.Error().Uint64("node.id", id.Uint64()).Err(err).Msg("Initial map generation failed") + nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to generate initial map for node %d: %w", id, err) + } + + // Use a blocking send with timeout for initial map since the channel should be ready + // and we want to avoid the race condition where the receiver isn't ready yet + select { + case c <- initialMap: + // Success + case <-time.After(5 * time.Second): + log.Error().Uint64("node.id", id.Uint64()).Err(fmt.Errorf("timeout")).Msg("Initial map send timeout") + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("timeout.duration", 5*time.Second). + Msg("Initial map send timed out because channel was blocked or receiver not ready") + nodeConn.removeConnectionByChannel(c) + return fmt.Errorf("failed to send initial map to node %d: timeout", id) + } + + // Update connection status + b.connected.Store(id, nil) // nil = connected + + // Node will automatically receive updates through the normal flow + // The initial full map already contains all current state + + log.Debug().Caller().Uint64("node.id", id.Uint64()).Dur("total.duration", time.Since(addNodeStart)). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("Node connection established in batcher because AddNode completed successfully") + + return nil +} + +// RemoveNode disconnects a node from the batcher, marking it as offline and cleaning up its state. +// It validates the connection channel matches one of the current connections, closes that specific connection, +// and keeps the node entry alive for rapid reconnections instead of aggressive deletion. +// Reports if the node still has active connections after removal. +func (b *LockFreeBatcher) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + nodeConn, exists := b.nodes.Load(id) + if !exists { + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode called for non-existent node because node not found in batcher") + return false + } + + // Remove specific connection + removed := nodeConn.removeConnectionByChannel(c) + if !removed { + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("RemoveNode: channel not found because connection already removed or invalid") + return false + } + + // Check if node has any remaining active connections + if nodeConn.hasActiveConnections() { + log.Debug().Caller().Uint64("node.id", id.Uint64()). + Int("active.connections", nodeConn.getActiveConnectionCount()). + Msg("Node connection removed but keeping online because other connections remain") + return true // Node still has active connections + } + + // No active connections - keep the node entry alive for rapid reconnections + // The node will get a fresh full map when it reconnects + log.Debug().Caller().Uint64("node.id", id.Uint64()).Msg("Node disconnected from batcher because all connections removed, keeping entry for rapid reconnection") + b.connected.Store(id, ptr.To(time.Now())) + + return false +} + +// AddWork queues a change to be processed by the batcher. +func (b *LockFreeBatcher) AddWork(r ...change.Change) { + b.addWork(r...) +} + +func (b *LockFreeBatcher) Start() { + b.done = make(chan struct{}) + go b.doWork() +} + +func (b *LockFreeBatcher) Close() { + // Signal shutdown to all goroutines, only once + b.doneOnce.Do(func() { + if b.done != nil { + close(b.done) + } + }) + + // Only close workCh once using sync.Once to prevent races + b.workChOnce.Do(func() { + close(b.workCh) + }) + + // Close the underlying channels supplying the data to the clients. + b.nodes.Range(func(nodeID types.NodeID, conn *multiChannelNodeConn) bool { + conn.close() + return true + }) +} + +func (b *LockFreeBatcher) doWork() { + for i := range b.workers { + go b.worker(i + 1) + } + + // Create a cleanup ticker for removing truly disconnected nodes + cleanupTicker := time.NewTicker(5 * time.Minute) + defer cleanupTicker.Stop() + + for { + select { + case <-b.tick.C: + // Process batched changes + b.processBatchedChanges() + case <-cleanupTicker.C: + // Clean up nodes that have been offline for too long + b.cleanupOfflineNodes() + case <-b.done: + log.Info().Msg("batcher done channel closed, stopping to feed workers") + return + } + } +} + +func (b *LockFreeBatcher) worker(workerID int) { + for { + select { + case w, ok := <-b.workCh: + if !ok { + log.Debug().Int("worker.id", workerID).Msgf("worker channel closing, shutting down worker %d", workerID) + return + } + + b.workProcessed.Add(1) + + // If the resultCh is set, it means that this is a work request + // where there is a blocking function waiting for the map that + // is being generated. + // This is used for synchronous map generation. + if w.resultCh != nil { + var result workResult + if nc, exists := b.nodes.Load(w.nodeID); exists { + var err error + + result.mapResponse, err = generateMapResponse(nc, b.mapper, w.c) + result.err = err + if result.err != nil { + b.workErrors.Add(1) + log.Error().Err(result.err). + Int("worker.id", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("reason", w.c.Reason). + Msg("failed to generate map response for synchronous work") + } else if result.mapResponse != nil { + // Update peer tracking for synchronous responses too + nc.updateSentPeers(result.mapResponse) + } + } else { + result.err = fmt.Errorf("node %d not found", w.nodeID) + + b.workErrors.Add(1) + log.Error().Err(result.err). + Int("worker.id", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Msg("node not found for synchronous work") + } + + // Send result + select { + case w.resultCh <- result: + case <-b.done: + return + } + + continue + } + + // If resultCh is nil, this is an asynchronous work request + // that should be processed and sent to the node instead of + // returned to the caller. + if nc, exists := b.nodes.Load(w.nodeID); exists { + // Apply change to node - this will handle offline nodes gracefully + // and queue work for when they reconnect + err := nc.change(w.c) + if err != nil { + b.workErrors.Add(1) + log.Error().Err(err). + Int("worker.id", workerID). + Uint64("node.id", w.nodeID.Uint64()). + Str("reason", w.c.Reason). + Msg("failed to apply change") + } + } + case <-b.done: + log.Debug().Int("worker.id", workerID).Msg("batcher shutting down, exiting worker") + return + } + } +} + +func (b *LockFreeBatcher) addWork(r ...change.Change) { + b.addToBatch(r...) +} + +// queueWork safely queues work. +func (b *LockFreeBatcher) queueWork(w work) { + b.workQueuedCount.Add(1) + + select { + case b.workCh <- w: + // Successfully queued + case <-b.done: + // Batcher is shutting down + return + } +} + +// addToBatch adds changes to the pending batch. +func (b *LockFreeBatcher) addToBatch(changes ...change.Change) { + // Clean up any nodes being permanently removed from the system. + // + // This handles the case where a node is deleted from state but the batcher + // still has it registered. By cleaning up here, we prevent "node not found" + // errors when workers try to generate map responses for deleted nodes. + // + // Safety: change.Change.PeersRemoved is ONLY populated when nodes are actually + // deleted from the system (via change.NodeRemoved in state.DeleteNode). Policy + // changes that affect peer visibility do NOT use this field - they set + // RequiresRuntimePeerComputation=true and compute removed peers at runtime, + // putting them in tailcfg.MapResponse.PeersRemoved (a different struct). + // Therefore, this cleanup only removes nodes that are truly being deleted, + // not nodes that are still connected but have lost visibility of certain peers. + // + // See: https://github.com/juanfont/headscale/issues/2924 + for _, ch := range changes { + for _, removedID := range ch.PeersRemoved { + if _, existed := b.nodes.LoadAndDelete(removedID); existed { + b.totalNodes.Add(-1) + log.Debug(). + Uint64("node.id", removedID.Uint64()). + Msg("Removed deleted node from batcher") + } + + b.connected.Delete(removedID) + b.pendingChanges.Delete(removedID) + } + } + + // Short circuit if any of the changes is a full update, which + // means we can skip sending individual changes. + if change.HasFull(changes) { + b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + b.pendingChanges.Store(nodeID, []change.Change{change.FullUpdate()}) + + return true + }) + + return + } + + broadcast, targeted := change.SplitTargetedAndBroadcast(changes) + + // Handle targeted changes - send only to the specific node + for _, ch := range targeted { + pending, _ := b.pendingChanges.LoadOrStore(ch.TargetNode, []change.Change{}) + pending = append(pending, ch) + b.pendingChanges.Store(ch.TargetNode, pending) + } + + // Handle broadcast changes - send to all nodes, filtering as needed + if len(broadcast) > 0 { + b.nodes.Range(func(nodeID types.NodeID, _ *multiChannelNodeConn) bool { + filtered := change.FilterForNode(nodeID, broadcast) + + if len(filtered) > 0 { + pending, _ := b.pendingChanges.LoadOrStore(nodeID, []change.Change{}) + pending = append(pending, filtered...) + b.pendingChanges.Store(nodeID, pending) + } + + return true + }) + } +} + +// processBatchedChanges processes all pending batched changes. +func (b *LockFreeBatcher) processBatchedChanges() { + if b.pendingChanges == nil { + return + } + + // Process all pending changes + b.pendingChanges.Range(func(nodeID types.NodeID, pending []change.Change) bool { + if len(pending) == 0 { + return true + } + + // Send all batched changes for this node + for _, ch := range pending { + b.queueWork(work{c: ch, nodeID: nodeID, resultCh: nil}) + } + + // Clear the pending changes for this node + b.pendingChanges.Delete(nodeID) + + return true + }) +} + +// cleanupOfflineNodes removes nodes that have been offline for too long to prevent memory leaks. +// TODO(kradalby): reevaluate if we want to keep this. +func (b *LockFreeBatcher) cleanupOfflineNodes() { + cleanupThreshold := 15 * time.Minute + now := time.Now() + + var nodesToCleanup []types.NodeID + + // Find nodes that have been offline for too long + b.connected.Range(func(nodeID types.NodeID, disconnectTime *time.Time) bool { + if disconnectTime != nil && now.Sub(*disconnectTime) > cleanupThreshold { + // Double-check the node doesn't have active connections + if nodeConn, exists := b.nodes.Load(nodeID); exists { + if !nodeConn.hasActiveConnections() { + nodesToCleanup = append(nodesToCleanup, nodeID) + } + } + } + return true + }) + + // Clean up the identified nodes + for _, nodeID := range nodesToCleanup { + log.Info().Uint64("node.id", nodeID.Uint64()). + Dur("offline_duration", cleanupThreshold). + Msg("Cleaning up node that has been offline for too long") + + b.nodes.Delete(nodeID) + b.connected.Delete(nodeID) + b.totalNodes.Add(-1) + } + + if len(nodesToCleanup) > 0 { + log.Info().Int("cleaned_nodes", len(nodesToCleanup)). + Msg("Completed cleanup of long-offline nodes") + } +} + +// IsConnected is lock-free read that checks if a node has any active connections. +func (b *LockFreeBatcher) IsConnected(id types.NodeID) bool { + // First check if we have active connections for this node + if nodeConn, exists := b.nodes.Load(id); exists { + if nodeConn.hasActiveConnections() { + return true + } + } + + // Check disconnected timestamp with grace period + val, ok := b.connected.Load(id) + if !ok { + return false + } + + // nil means connected + if val == nil { + return true + } + + return false +} + +// ConnectedMap returns a lock-free map of all connected nodes. +func (b *LockFreeBatcher) ConnectedMap() *xsync.Map[types.NodeID, bool] { + ret := xsync.NewMap[types.NodeID, bool]() + + // First, add all nodes with active connections + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + if nodeConn.hasActiveConnections() { + ret.Store(id, true) + } + return true + }) + + // Then add all entries from the connected map + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // Only add if not already added as connected above + if _, exists := ret.Load(id); !exists { + if val == nil { + // nil means connected + ret.Store(id, true) + } else { + // timestamp means disconnected + ret.Store(id, false) + } + } + return true + }) + + return ret +} + +// MapResponseFromChange queues work to generate a map response and waits for the result. +// This allows synchronous map generation using the same worker pool. +func (b *LockFreeBatcher) MapResponseFromChange(id types.NodeID, ch change.Change) (*tailcfg.MapResponse, error) { + resultCh := make(chan workResult, 1) + + // Queue the work with a result channel using the safe queueing method + b.queueWork(work{c: ch, nodeID: id, resultCh: resultCh}) + + // Wait for the result + select { + case result := <-resultCh: + return result.mapResponse, result.err + case <-b.done: + return nil, fmt.Errorf("batcher shutting down while generating map response for node %d", id) + } +} + +// connectionEntry represents a single connection to a node. +type connectionEntry struct { + id string // unique connection ID + c chan<- *tailcfg.MapResponse + version tailcfg.CapabilityVersion + created time.Time + lastUsed atomic.Int64 // Unix timestamp of last successful send + closed atomic.Bool // Indicates if this connection has been closed +} + +// multiChannelNodeConn manages multiple concurrent connections for a single node. +type multiChannelNodeConn struct { + id types.NodeID + mapper *mapper + + mutex sync.RWMutex + connections []*connectionEntry + + updateCount atomic.Int64 + + // lastSentPeers tracks which peers were last sent to this node. + // This enables computing diffs for policy changes instead of sending + // full peer lists (which clients interpret as "no change" when empty). + // Using xsync.Map for lock-free concurrent access. + lastSentPeers *xsync.Map[tailcfg.NodeID, struct{}] +} + +// generateConnectionID generates a unique connection identifier. +func generateConnectionID() string { + bytes := make([]byte, 8) + rand.Read(bytes) + return fmt.Sprintf("%x", bytes) +} + +// newMultiChannelNodeConn creates a new multi-channel node connection. +func newMultiChannelNodeConn(id types.NodeID, mapper *mapper) *multiChannelNodeConn { + return &multiChannelNodeConn{ + id: id, + mapper: mapper, + lastSentPeers: xsync.NewMap[tailcfg.NodeID, struct{}](), + } +} + +func (mc *multiChannelNodeConn) close() { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for _, conn := range mc.connections { + // Mark as closed before closing the channel to prevent + // send on closed channel panics from concurrent workers + conn.closed.Store(true) + close(conn.c) + } +} + +// addConnection adds a new connection. +func (mc *multiChannelNodeConn) addConnection(entry *connectionEntry) { + mutexWaitStart := time.Now() + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). + Msg("addConnection: waiting for mutex - POTENTIAL CONTENTION POINT") + + mc.mutex.Lock() + mutexWaitDur := time.Since(mutexWaitStart) + defer mc.mutex.Unlock() + + mc.connections = append(mc.connections, entry) + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", entry.c)).Str("conn.id", entry.id). + Int("total_connections", len(mc.connections)). + Dur("mutex_wait_time", mutexWaitDur). + Msg("Successfully added connection after mutex wait") +} + +// removeConnectionByChannel removes a connection by matching channel pointer. +func (mc *multiChannelNodeConn) removeConnectionByChannel(c chan<- *tailcfg.MapResponse) bool { + mc.mutex.Lock() + defer mc.mutex.Unlock() + + for i, entry := range mc.connections { + if entry.c == c { + // Remove this connection + mc.connections = append(mc.connections[:i], mc.connections[i+1:]...) + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", c)). + Int("remaining_connections", len(mc.connections)). + Msg("Successfully removed connection") + return true + } + } + return false +} + +// hasActiveConnections checks if the node has any active connections. +func (mc *multiChannelNodeConn) hasActiveConnections() bool { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + return len(mc.connections) > 0 +} + +// getActiveConnectionCount returns the number of active connections. +func (mc *multiChannelNodeConn) getActiveConnectionCount() int { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + return len(mc.connections) +} + +// send broadcasts data to all active connections for the node. +func (mc *multiChannelNodeConn) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + + mc.mutex.Lock() + defer mc.mutex.Unlock() + + if len(mc.connections) == 0 { + // During rapid reconnection, nodes may temporarily have no active connections + // This is not an error - the node will receive a full map when it reconnects + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). + Msg("send: skipping send to node with no active connections (likely rapid reconnection)") + return nil // Return success instead of error + } + + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). + Int("total_connections", len(mc.connections)). + Msg("send: broadcasting to all connections") + + var lastErr error + successCount := 0 + var failedConnections []int // Track failed connections for removal + + // Send to all connections + for i, conn := range mc.connections { + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). + Str("conn.id", conn.id).Int("connection_index", i). + Msg("send: attempting to send to connection") + + if err := conn.send(data); err != nil { + lastErr = err + failedConnections = append(failedConnections, i) + log.Warn().Err(err). + Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). + Str("conn.id", conn.id).Int("connection_index", i). + Msg("send: connection send failed") + } else { + successCount++ + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()).Str("chan", fmt.Sprintf("%p", conn.c)). + Str("conn.id", conn.id).Int("connection_index", i). + Msg("send: successfully sent to connection") + } + } + + // Remove failed connections (in reverse order to maintain indices) + for i := len(failedConnections) - 1; i >= 0; i-- { + idx := failedConnections[i] + log.Debug().Caller().Uint64("node.id", mc.id.Uint64()). + Str("conn.id", mc.connections[idx].id). + Msg("send: removing failed connection") + mc.connections = append(mc.connections[:idx], mc.connections[idx+1:]...) + } + + mc.updateCount.Add(1) + + log.Debug().Uint64("node.id", mc.id.Uint64()). + Int("successful_sends", successCount). + Int("failed_connections", len(failedConnections)). + Int("remaining_connections", len(mc.connections)). + Msg("send: completed broadcast") + + // Success if at least one send succeeded + if successCount > 0 { + return nil + } + + return fmt.Errorf("node %d: all connections failed, last error: %w", mc.id, lastErr) +} + +// send sends data to a single connection entry with timeout-based stale connection detection. +func (entry *connectionEntry) send(data *tailcfg.MapResponse) error { + if data == nil { + return nil + } + + // Check if the connection has been closed to prevent send on closed channel panic. + // This can happen during shutdown when Close() is called while workers are still processing. + if entry.closed.Load() { + return fmt.Errorf("connection %s: %w", entry.id, errConnectionClosed) + } + + // Use a short timeout to detect stale connections where the client isn't reading the channel. + // This is critical for detecting Docker containers that are forcefully terminated + // but still have channels that appear open. + select { + case entry.c <- data: + // Update last used timestamp on successful send + entry.lastUsed.Store(time.Now().Unix()) + return nil + case <-time.After(50 * time.Millisecond): + // Connection is likely stale - client isn't reading from channel + // This catches the case where Docker containers are killed but channels remain open + return fmt.Errorf("connection %s: timeout sending to channel (likely stale connection)", entry.id) + } +} + +// nodeID returns the node ID. +func (mc *multiChannelNodeConn) nodeID() types.NodeID { + return mc.id +} + +// version returns the capability version from the first active connection. +// All connections for a node should have the same version in practice. +func (mc *multiChannelNodeConn) version() tailcfg.CapabilityVersion { + mc.mutex.RLock() + defer mc.mutex.RUnlock() + + if len(mc.connections) == 0 { + return 0 + } + + return mc.connections[0].version +} + +// updateSentPeers updates the tracked peer state based on a sent MapResponse. +// This must be called after successfully sending a response to keep track of +// what the client knows about, enabling accurate diffs for future updates. +func (mc *multiChannelNodeConn) updateSentPeers(resp *tailcfg.MapResponse) { + if resp == nil { + return + } + + // Full peer list replaces tracked state entirely + if resp.Peers != nil { + mc.lastSentPeers.Clear() + + for _, peer := range resp.Peers { + mc.lastSentPeers.Store(peer.ID, struct{}{}) + } + } + + // Incremental additions + for _, peer := range resp.PeersChanged { + mc.lastSentPeers.Store(peer.ID, struct{}{}) + } + + // Incremental removals + for _, id := range resp.PeersRemoved { + mc.lastSentPeers.Delete(id) + } +} + +// computePeerDiff compares the current peer list against what was last sent +// and returns the peers that were removed (in lastSentPeers but not in current). +func (mc *multiChannelNodeConn) computePeerDiff(currentPeers []tailcfg.NodeID) []tailcfg.NodeID { + currentSet := make(map[tailcfg.NodeID]struct{}, len(currentPeers)) + for _, id := range currentPeers { + currentSet[id] = struct{}{} + } + + var removed []tailcfg.NodeID + + // Find removed: in lastSentPeers but not in current + mc.lastSentPeers.Range(func(id tailcfg.NodeID, _ struct{}) bool { + if _, exists := currentSet[id]; !exists { + removed = append(removed, id) + } + + return true + }) + + return removed +} + +// change applies a change to all active connections for the node. +func (mc *multiChannelNodeConn) change(r change.Change) error { + return handleNodeChange(mc, mc.mapper, r) +} + +// DebugNodeInfo contains debug information about a node's connections. +type DebugNodeInfo struct { + Connected bool `json:"connected"` + ActiveConnections int `json:"active_connections"` +} + +// Debug returns a pre-baked map of node debug information for the debug interface. +func (b *LockFreeBatcher) Debug() map[types.NodeID]DebugNodeInfo { + result := make(map[types.NodeID]DebugNodeInfo) + + // Get all nodes with their connection status using immediate connection logic + // (no grace period) for debug purposes + b.nodes.Range(func(id types.NodeID, nodeConn *multiChannelNodeConn) bool { + nodeConn.mutex.RLock() + activeConnCount := len(nodeConn.connections) + nodeConn.mutex.RUnlock() + + // Use immediate connection status: if active connections exist, node is connected + // If not, check the connected map for nil (connected) vs timestamp (disconnected) + connected := false + if activeConnCount > 0 { + connected = true + } else { + // Check connected map for immediate status + if val, ok := b.connected.Load(id); ok && val == nil { + connected = true + } + } + + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: activeConnCount, + } + return true + }) + + // Add all entries from the connected map to capture both connected and disconnected nodes + b.connected.Range(func(id types.NodeID, val *time.Time) bool { + // Only add if not already processed above + if _, exists := result[id]; !exists { + // Use immediate connection status for debug (no grace period) + connected := (val == nil) // nil means connected, timestamp means disconnected + result[id] = DebugNodeInfo{ + Connected: connected, + ActiveConnections: 0, + } + } + return true + }) + + return result +} + +func (b *LockFreeBatcher) DebugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + return b.mapper.debugMapResponses() +} + +// WorkErrors returns the count of work errors encountered. +// This is primarily useful for testing and debugging. +func (b *LockFreeBatcher) WorkErrors() int64 { + return b.workErrors.Load() +} diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go new file mode 100644 index 00000000..70d5e377 --- /dev/null +++ b/hscontrol/mapper/batcher_test.go @@ -0,0 +1,2773 @@ +package mapper + +import ( + "errors" + "fmt" + "net/netip" + "runtime" + "strings" + "sync" + "sync/atomic" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/derp" + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/rs/zerolog" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "zgo.at/zcache/v2" +) + +var errNodeNotFoundAfterAdd = errors.New("node not found after adding to batcher") + +// batcherTestCase defines a batcher function with a descriptive name for testing. +type batcherTestCase struct { + name string + fn batcherFunc +} + +// testBatcherWrapper wraps a real batcher to add online/offline notifications +// that would normally be sent by poll.go in production. +type testBatcherWrapper struct { + Batcher + state *state.State +} + +func (t *testBatcherWrapper) AddNode(id types.NodeID, c chan<- *tailcfg.MapResponse, version tailcfg.CapabilityVersion) error { + // Mark node as online in state before AddNode to match production behavior + // This ensures the NodeStore has correct online status for change processing + if t.state != nil { + // Use Connect to properly mark node online in NodeStore but don't send its changes + _ = t.state.Connect(id) + } + + // First add the node to the real batcher + err := t.Batcher.AddNode(id, c, version) + if err != nil { + return err + } + + // Send the online notification that poll.go would normally send + // This ensures other nodes get notified about this node coming online + node, ok := t.state.GetNodeByID(id) + if !ok { + return fmt.Errorf("%w: %d", errNodeNotFoundAfterAdd, id) + } + + t.AddWork(change.NodeOnlineFor(node)) + + return nil +} + +func (t *testBatcherWrapper) RemoveNode(id types.NodeID, c chan<- *tailcfg.MapResponse) bool { + // Mark node as offline in state BEFORE removing from batcher + // This ensures the NodeStore has correct offline status when the change is processed + if t.state != nil { + // Use Disconnect to properly mark node offline in NodeStore but don't send its changes + _, _ = t.state.Disconnect(id) + } + + // Send the offline notification that poll.go would normally send + // Do this BEFORE removing from batcher so the change can be processed + node, ok := t.state.GetNodeByID(id) + if ok { + t.AddWork(change.NodeOfflineFor(node)) + } + + // Finally remove from the real batcher + removed := t.Batcher.RemoveNode(id, c) + if !removed { + return false + } + + return true +} + +// wrapBatcherForTest wraps a batcher with test-specific behavior. +func wrapBatcherForTest(b Batcher, state *state.State) Batcher { + return &testBatcherWrapper{Batcher: b, state: state} +} + +// allBatcherFunctions contains all batcher implementations to test. +var allBatcherFunctions = []batcherTestCase{ + {"LockFree", NewBatcherAndMapper}, +} + +// emptyCache creates an empty registration cache for testing. +func emptyCache() *zcache.Cache[types.RegistrationID, types.RegisterNode] { + return zcache.New[types.RegistrationID, types.RegisterNode](time.Minute, time.Hour) +} + +// Test configuration constants. +const ( + // Test data configuration. + TEST_USER_COUNT = 3 + TEST_NODES_PER_USER = 2 + + // Load testing configuration. + HIGH_LOAD_NODES = 25 // Increased from 9 + HIGH_LOAD_CYCLES = 100 // Increased from 20 + HIGH_LOAD_UPDATES = 50 // Increased from 20 + + // Extreme load testing configuration. + EXTREME_LOAD_NODES = 50 + EXTREME_LOAD_CYCLES = 200 + EXTREME_LOAD_UPDATES = 100 + + // Timing configuration. + TEST_TIMEOUT = 120 * time.Second // Increased for more intensive tests + UPDATE_TIMEOUT = 5 * time.Second + DEADLOCK_TIMEOUT = 30 * time.Second + + // Channel configuration. + NORMAL_BUFFER_SIZE = 50 + SMALL_BUFFER_SIZE = 3 + TINY_BUFFER_SIZE = 1 // For maximum contention + LARGE_BUFFER_SIZE = 200 + + reservedResponseHeaderSize = 4 +) + +// TestData contains all test entities created for a test scenario. +type TestData struct { + Database *db.HSDatabase + Users []*types.User + Nodes []node + State *state.State + Config *types.Config + Batcher Batcher +} + +type node struct { + n *types.Node + ch chan *tailcfg.MapResponse + + // Update tracking (all accessed atomically for thread safety) + updateCount int64 + patchCount int64 + fullCount int64 + maxPeersCount atomic.Int64 + lastPeerCount atomic.Int64 + stop chan struct{} + stopped chan struct{} +} + +// setupBatcherWithTestData creates a comprehensive test environment with real +// database test data including users and registered nodes. +// +// This helper creates a database, populates it with test data, then creates +// a state and batcher using the SAME database for testing. This provides real +// node data for testing full map responses and comprehensive update scenarios. +// +// Returns TestData struct containing all created entities and a cleanup function. +func setupBatcherWithTestData( + t *testing.T, + bf batcherFunc, + userCount, nodesPerUser, bufferSize int, +) (*TestData, func()) { + t.Helper() + + // Create database and populate with test data first + tmpDir := t.TempDir() + dbPath := tmpDir + "/headscale_test.db" + + prefixV4 := netip.MustParsePrefix("100.64.0.0/10") + prefixV6 := netip.MustParsePrefix("fd7a:115c:a1e0::/48") + + cfg := &types.Config{ + Database: types.DatabaseConfig{ + Type: types.DatabaseSqlite, + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + PrefixV4: &prefixV4, + PrefixV6: &prefixV6, + IPAllocation: types.IPAllocationStrategySequential, + BaseDomain: "headscale.test", + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, + }, + DERP: types.DERPConfig{ + ServerEnabled: false, + DERPMap: &tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 999: { + RegionID: 999, + }, + }, + }, + }, + Tuning: types.Tuning{ + BatchChangeDelay: 10 * time.Millisecond, + BatcherWorkers: types.DefaultBatcherWorkers(), // Use same logic as config.go + NodeStoreBatchSize: state.TestBatchSize, + NodeStoreBatchTimeout: state.TestBatchTimeout, + }, + } + + // Create database and populate it with test data + database, err := db.NewHeadscaleDatabase( + cfg, + emptyCache(), + ) + if err != nil { + t.Fatalf("setting up database: %s", err) + } + + // Create test users and nodes in the database + users := database.CreateUsersForTest(userCount, "testuser") + + allNodes := make([]node, 0, userCount*nodesPerUser) + for _, user := range users { + dbNodes := database.CreateRegisteredNodesForTest(user, nodesPerUser, "node") + for i := range dbNodes { + allNodes = append(allNodes, node{ + n: dbNodes[i], + ch: make(chan *tailcfg.MapResponse, bufferSize), + }) + } + } + + // Now create state using the same database + state, err := state.NewState(cfg) + if err != nil { + t.Fatalf("Failed to create state: %v", err) + } + + derpMap, err := derp.GetDERPMap(cfg.DERP) + assert.NoError(t, err) + assert.NotNil(t, derpMap) + + state.SetDERPMap(derpMap) + + // Set up a permissive policy that allows all communication for testing + allowAllPolicy := `{ + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + _, err = state.SetPolicy([]byte(allowAllPolicy)) + if err != nil { + t.Fatalf("Failed to set allow-all policy: %v", err) + } + + // Create batcher with the state and wrap it for testing + batcher := wrapBatcherForTest(bf(cfg, state), state) + batcher.Start() + + testData := &TestData{ + Database: database, + Users: users, + Nodes: allNodes, + State: state, + Config: cfg, + Batcher: batcher, + } + + cleanup := func() { + batcher.Close() + state.Close() + database.Close() + } + + return testData, cleanup +} + +type UpdateStats struct { + TotalUpdates int + UpdateSizes []int + LastUpdate time.Time +} + +// updateTracker provides thread-safe tracking of updates per node. +type updateTracker struct { + mu sync.RWMutex + stats map[types.NodeID]*UpdateStats +} + +// newUpdateTracker creates a new update tracker. +func newUpdateTracker() *updateTracker { + return &updateTracker{ + stats: make(map[types.NodeID]*UpdateStats), + } +} + +// recordUpdate records an update for a specific node. +func (ut *updateTracker) recordUpdate(nodeID types.NodeID, updateSize int) { + ut.mu.Lock() + defer ut.mu.Unlock() + + if ut.stats[nodeID] == nil { + ut.stats[nodeID] = &UpdateStats{} + } + + stats := ut.stats[nodeID] + stats.TotalUpdates++ + stats.UpdateSizes = append(stats.UpdateSizes, updateSize) + stats.LastUpdate = time.Now() +} + +// getStats returns a copy of the statistics for a node. +func (ut *updateTracker) getStats(nodeID types.NodeID) UpdateStats { + ut.mu.RLock() + defer ut.mu.RUnlock() + + if stats, exists := ut.stats[nodeID]; exists { + // Return a copy to avoid race conditions + return UpdateStats{ + TotalUpdates: stats.TotalUpdates, + UpdateSizes: append([]int{}, stats.UpdateSizes...), + LastUpdate: stats.LastUpdate, + } + } + + return UpdateStats{} +} + +// getAllStats returns a copy of all statistics. +func (ut *updateTracker) getAllStats() map[types.NodeID]UpdateStats { + ut.mu.RLock() + defer ut.mu.RUnlock() + + result := make(map[types.NodeID]UpdateStats) + for nodeID, stats := range ut.stats { + result[nodeID] = UpdateStats{ + TotalUpdates: stats.TotalUpdates, + UpdateSizes: append([]int{}, stats.UpdateSizes...), + LastUpdate: stats.LastUpdate, + } + } + + return result +} + +func assertDERPMapResponse(t *testing.T, resp *tailcfg.MapResponse) { + t.Helper() + + assert.NotNil(t, resp.DERPMap, "DERPMap should not be nil in response") + assert.Len(t, resp.DERPMap.Regions, 1, "Expected exactly one DERP region in response") + assert.Equal(t, 999, resp.DERPMap.Regions[999].RegionID, "Expected DERP region ID to be 1337") +} + +func assertOnlineMapResponse(t *testing.T, resp *tailcfg.MapResponse, expected bool) { + t.Helper() + + // Check for peer changes patch (new online/offline notifications use patches) + if len(resp.PeersChangedPatch) > 0 { + require.Len(t, resp.PeersChangedPatch, 1) + assert.Equal(t, expected, *resp.PeersChangedPatch[0].Online) + + return + } + + // Fallback to old format for backwards compatibility + require.Len(t, resp.Peers, 1) + assert.Equal(t, expected, resp.Peers[0].Online) +} + +// UpdateInfo contains parsed information about an update. +type UpdateInfo struct { + IsFull bool + IsPatch bool + IsDERP bool + PeerCount int + PatchCount int +} + +// parseUpdateAndAnalyze parses an update and returns detailed information. +func parseUpdateAndAnalyze(resp *tailcfg.MapResponse) (UpdateInfo, error) { + info := UpdateInfo{ + PeerCount: len(resp.Peers), + PatchCount: len(resp.PeersChangedPatch), + IsFull: len(resp.Peers) > 0, + IsPatch: len(resp.PeersChangedPatch) > 0, + IsDERP: resp.DERPMap != nil, + } + + return info, nil +} + +// start begins consuming updates from the node's channel and tracking stats. +func (n *node) start() { + // Prevent multiple starts on the same node + if n.stop != nil { + return // Already started + } + + n.stop = make(chan struct{}) + n.stopped = make(chan struct{}) + + go func() { + defer close(n.stopped) + + for { + select { + case data := <-n.ch: + atomic.AddInt64(&n.updateCount, 1) + + // Parse update and track detailed stats + if info, err := parseUpdateAndAnalyze(data); err == nil { + // Track update types + if info.IsFull { + atomic.AddInt64(&n.fullCount, 1) + n.lastPeerCount.Store(int64(info.PeerCount)) + // Update max peers seen using compare-and-swap for thread safety + for { + current := n.maxPeersCount.Load() + if int64(info.PeerCount) <= current { + break + } + + if n.maxPeersCount.CompareAndSwap(current, int64(info.PeerCount)) { + break + } + } + } + + if info.IsPatch { + atomic.AddInt64(&n.patchCount, 1) + // For patches, we track how many patch items using compare-and-swap + for { + current := n.maxPeersCount.Load() + if int64(info.PatchCount) <= current { + break + } + + if n.maxPeersCount.CompareAndSwap(current, int64(info.PatchCount)) { + break + } + } + } + } + + case <-n.stop: + return + } + } + }() +} + +// NodeStats contains final statistics for a node. +type NodeStats struct { + TotalUpdates int64 + PatchUpdates int64 + FullUpdates int64 + MaxPeersSeen int + LastPeerCount int +} + +// cleanup stops the update consumer and returns final stats. +func (n *node) cleanup() NodeStats { + if n.stop != nil { + close(n.stop) + <-n.stopped // Wait for goroutine to finish + } + + return NodeStats{ + TotalUpdates: atomic.LoadInt64(&n.updateCount), + PatchUpdates: atomic.LoadInt64(&n.patchCount), + FullUpdates: atomic.LoadInt64(&n.fullCount), + MaxPeersSeen: int(n.maxPeersCount.Load()), + LastPeerCount: int(n.lastPeerCount.Load()), + } +} + +// validateUpdateContent validates that the update data contains a proper MapResponse. +func validateUpdateContent(resp *tailcfg.MapResponse) (bool, string) { + if resp == nil { + return false, "nil MapResponse" + } + + // Simple validation - just check if it's a valid MapResponse + return true, "valid" +} + +// TestEnhancedNodeTracking verifies that the enhanced node tracking works correctly. +func TestEnhancedNodeTracking(t *testing.T) { + // Create a simple test node + testNode := node{ + n: &types.Node{ID: 1}, + ch: make(chan *tailcfg.MapResponse, 10), + } + + // Start the enhanced tracking + testNode.start() + + // Create a simple MapResponse that should be parsed correctly + resp := tailcfg.MapResponse{ + KeepAlive: false, + Peers: []*tailcfg.Node{ + {ID: 2}, + {ID: 3}, + }, + } + + // Send the data to the node's channel + testNode.ch <- &resp + + // Wait for tracking goroutine to process the update + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.GreaterOrEqual(c, atomic.LoadInt64(&testNode.updateCount), int64(1), "should have processed the update") + }, time.Second, 10*time.Millisecond, "waiting for update to be processed") + + // Check stats + stats := testNode.cleanup() + t.Logf("Enhanced tracking stats: Total=%d, Full=%d, Patch=%d, MaxPeers=%d", + stats.TotalUpdates, stats.FullUpdates, stats.PatchUpdates, stats.MaxPeersSeen) + + require.Equal(t, int64(1), stats.TotalUpdates, "Expected 1 total update") + require.Equal(t, int64(1), stats.FullUpdates, "Expected 1 full update") + require.Equal(t, 2, stats.MaxPeersSeen, "Expected 2 max peers seen") +} + +// TestEnhancedTrackingWithBatcher verifies enhanced tracking works with a real batcher. +func TestEnhancedTrackingWithBatcher(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with 1 node + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, 10) + defer cleanup() + + batcher := testData.Batcher + testNode := &testData.Nodes[0] + + t.Logf("Testing enhanced tracking with node ID %d", testNode.n.ID) + + // Start enhanced tracking for the node + testNode.start() + + // Connect the node to the batcher + batcher.AddNode(testNode.n.ID, testNode.ch, tailcfg.CapabilityVersion(100)) + + // Wait for connection to be established + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.True(c, batcher.IsConnected(testNode.n.ID), "node should be connected") + }, time.Second, 10*time.Millisecond, "waiting for node connection") + + // Generate work and wait for updates to be processed + batcher.AddWork(change.FullUpdate()) + batcher.AddWork(change.PolicyChange()) + batcher.AddWork(change.DERPMap()) + + // Wait for updates to be processed (at least 1 update received) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.GreaterOrEqual(c, atomic.LoadInt64(&testNode.updateCount), int64(1), "should have received updates") + }, time.Second, 10*time.Millisecond, "waiting for updates to be processed") + + // Check stats + stats := testNode.cleanup() + t.Logf("Enhanced tracking with batcher: Total=%d, Full=%d, Patch=%d, MaxPeers=%d", + stats.TotalUpdates, stats.FullUpdates, stats.PatchUpdates, stats.MaxPeersSeen) + + if stats.TotalUpdates == 0 { + t.Error( + "Enhanced tracking with batcher received 0 updates - batcher may not be working", + ) + } + }) + } +} + +// TestBatcherScalabilityAllToAll tests the batcher's ability to handle rapid node joins +// and ensure all nodes can see all other nodes. This is a critical test for mesh network +// functionality where every node must be able to communicate with every other node. +func TestBatcherScalabilityAllToAll(t *testing.T) { + // Reduce verbose application logging for cleaner test output + originalLevel := zerolog.GlobalLevel() + defer zerolog.SetGlobalLevel(originalLevel) + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + + // Test cases: different node counts to stress test the all-to-all connectivity + testCases := []struct { + name string + nodeCount int + }{ + {"10_nodes", 10}, // Quick baseline test + {"100_nodes", 100}, // Full scalability test ~2 minutes + // Large-scale tests commented out - uncomment for scalability testing + // {"1000_nodes", 1000}, // ~12 minutes + // {"2000_nodes", 2000}, // ~60+ minutes + // {"5000_nodes", 5000}, // Not recommended - database bottleneck + } + + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Logf( + "ALL-TO-ALL TEST: %d nodes with %s batcher", + tc.nodeCount, + batcherFunc.name, + ) + + // Create test environment - all nodes from same user so they can be peers + // We need enough users to support the node count (max 1000 nodes per user) + usersNeeded := max(1, (tc.nodeCount+999)/1000) + nodesPerUser := (tc.nodeCount + usersNeeded - 1) / usersNeeded + + // Use large buffer to avoid blocking during rapid joins + // Buffer needs to handle nodeCount * average_updates_per_node + // Estimate: each node receives ~2*nodeCount updates during all-to-all + // For very large tests (>1000 nodes), limit buffer to avoid excessive memory + bufferSize := max(1000, min(tc.nodeCount*2, 10000)) + + testData, cleanup := setupBatcherWithTestData( + t, + batcherFunc.fn, + usersNeeded, + nodesPerUser, + bufferSize, + ) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes[:tc.nodeCount] // Limit to requested count + + t.Logf( + "Created %d nodes across %d users, buffer size: %d", + len(allNodes), + usersNeeded, + bufferSize, + ) + + // Start enhanced tracking for all nodes + for i := range allNodes { + allNodes[i].start() + } + + // Yield to allow tracking goroutines to start + runtime.Gosched() + + startTime := time.Now() + + // Join all nodes as fast as possible + t.Logf("Joining %d nodes as fast as possible...", len(allNodes)) + + for i := range allNodes { + node := &allNodes[i] + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + + // Issue full update after each join to ensure connectivity + batcher.AddWork(change.FullUpdate()) + + // Yield to scheduler for large node counts to prevent overwhelming the work queue + if tc.nodeCount > 100 && i%50 == 49 { + runtime.Gosched() + } + } + + joinTime := time.Since(startTime) + t.Logf("All nodes joined in %v, waiting for full connectivity...", joinTime) + + // Wait for all updates to propagate until all nodes achieve connectivity + expectedPeers := tc.nodeCount - 1 // Each node should see all others except itself + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + connectedCount := 0 + for i := range allNodes { + node := &allNodes[i] + + currentMaxPeers := int(node.maxPeersCount.Load()) + if currentMaxPeers >= expectedPeers { + connectedCount++ + } + } + + progress := float64(connectedCount) / float64(len(allNodes)) * 100 + t.Logf("Progress: %d/%d nodes (%.1f%%) have seen %d+ peers", + connectedCount, len(allNodes), progress, expectedPeers) + + assert.Equal(c, len(allNodes), connectedCount, "all nodes should achieve full connectivity") + }, 5*time.Minute, 5*time.Second, "waiting for full connectivity") + + t.Logf("✅ All nodes achieved full connectivity!") + totalTime := time.Since(startTime) + + // Disconnect all nodes + for i := range allNodes { + node := &allNodes[i] + batcher.RemoveNode(node.n.ID, node.ch) + } + + // Wait for all nodes to be disconnected + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range allNodes { + assert.False(c, batcher.IsConnected(allNodes[i].n.ID), "node should be disconnected") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for nodes to disconnect") + + // Collect final statistics + totalUpdates := int64(0) + totalFull := int64(0) + maxPeersGlobal := 0 + minPeersSeen := tc.nodeCount + successfulNodes := 0 + + nodeDetails := make([]string, 0, min(10, len(allNodes))) + + for i := range allNodes { + node := &allNodes[i] + stats := node.cleanup() + + totalUpdates += stats.TotalUpdates + totalFull += stats.FullUpdates + + if stats.MaxPeersSeen > maxPeersGlobal { + maxPeersGlobal = stats.MaxPeersSeen + } + + if stats.MaxPeersSeen < minPeersSeen { + minPeersSeen = stats.MaxPeersSeen + } + + if stats.MaxPeersSeen >= expectedPeers { + successfulNodes++ + } + + // Collect details for first few nodes or failing nodes + if len(nodeDetails) < 10 || stats.MaxPeersSeen < expectedPeers { + nodeDetails = append(nodeDetails, + fmt.Sprintf( + "Node %d: %d updates (%d full), max %d peers", + node.n.ID, + stats.TotalUpdates, + stats.FullUpdates, + stats.MaxPeersSeen, + )) + } + } + + // Final results + t.Logf("ALL-TO-ALL RESULTS: %d nodes, %d total updates (%d full)", + len(allNodes), totalUpdates, totalFull) + t.Logf( + " Connectivity: %d/%d nodes successful (%.1f%%)", + successfulNodes, + len(allNodes), + float64(successfulNodes)/float64(len(allNodes))*100, + ) + t.Logf(" Peers seen: min=%d, max=%d, expected=%d", + minPeersSeen, maxPeersGlobal, expectedPeers) + t.Logf(" Timing: join=%v, total=%v", joinTime, totalTime) + + // Show sample of node details + if len(nodeDetails) > 0 { + t.Logf(" Node sample:") + + for _, detail := range nodeDetails[:min(5, len(nodeDetails))] { + t.Logf(" %s", detail) + } + + if len(nodeDetails) > 5 { + t.Logf(" ... (%d more nodes)", len(nodeDetails)-5) + } + } + + // Final verification: Since we waited until all nodes achieved connectivity, + // this should always pass, but we verify the final state for completeness + if successfulNodes == len(allNodes) { + t.Logf( + "✅ PASS: All-to-all connectivity achieved for %d nodes", + len(allNodes), + ) + } else { + // This should not happen since we loop until success, but handle it just in case + failedNodes := len(allNodes) - successfulNodes + t.Errorf("❌ UNEXPECTED: %d/%d nodes still failed after waiting for connectivity (expected %d, some saw %d-%d)", + failedNodes, len(allNodes), expectedPeers, minPeersSeen, maxPeersGlobal) + + // Show details of failed nodes for debugging + if len(nodeDetails) > 5 { + t.Logf("Failed nodes details:") + + for _, detail := range nodeDetails[5:] { + if !strings.Contains(detail, fmt.Sprintf("max %d peers", expectedPeers)) { + t.Logf(" %s", detail) + } + } + } + } + }) + } + }) + } +} + +// TestBatcherBasicOperations verifies core batcher functionality by testing +// the basic lifecycle of adding nodes, processing updates, and removing nodes. +// +// Enhanced with real database test data, this test creates a registered node +// and tests both DERP updates and full node updates. It validates the fundamental +// add/remove operations and basic work processing pipeline with actual update +// content validation instead of just byte count checks. +func TestBatcherBasicOperations(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 8) + defer cleanup() + + batcher := testData.Batcher + tn := testData.Nodes[0] + tn2 := testData.Nodes[1] + + // Test AddNode with real node ID + batcher.AddNode(tn.n.ID, tn.ch, 100) + + if !batcher.IsConnected(tn.n.ID) { + t.Error("Node should be connected after AddNode") + } + + // Test work processing with DERP change + batcher.AddWork(change.DERPMap()) + + // Wait for update and validate content + select { + case data := <-tn.ch: + assertDERPMapResponse(t, data) + case <-time.After(200 * time.Millisecond): + t.Error("Did not receive expected DERP update") + } + + // Drain any initial messages from first node + drainChannelTimeout(tn.ch, "first node before second", 100*time.Millisecond) + + // Add the second node and verify update message + batcher.AddNode(tn2.n.ID, tn2.ch, 100) + assert.True(t, batcher.IsConnected(tn2.n.ID)) + + // First node should get an update that second node has connected. + select { + case data := <-tn.ch: + assertOnlineMapResponse(t, data, true) + case <-time.After(500 * time.Millisecond): + t.Error("Did not receive expected Online response update") + } + + // Second node should receive its initial full map + select { + case data := <-tn2.ch: + // Verify it's a full map response + assert.NotNil(t, data) + assert.True( + t, + len(data.Peers) >= 1 || data.Node != nil, + "Should receive initial full map", + ) + case <-time.After(500 * time.Millisecond): + t.Error("Second node should receive its initial full map") + } + + // Disconnect the second node + batcher.RemoveNode(tn2.n.ID, tn2.ch) + // Note: IsConnected may return true during grace period for DNS resolution + + // First node should get update that second has disconnected. + select { + case data := <-tn.ch: + assertOnlineMapResponse(t, data, false) + case <-time.After(500 * time.Millisecond): + t.Error("Did not receive expected Online response update") + } + + // // Test node-specific update with real node data + // batcher.AddWork(change.NodeKeyChanged(tn.n.ID)) + + // // Wait for node update (may be empty for certain node changes) + // select { + // case data := <-tn.ch: + // t.Logf("Received node update: %d bytes", len(data)) + // if len(data) == 0 { + // t.Logf("Empty node update (expected for some node changes in test environment)") + // } else { + // if valid, updateType := validateUpdateContent(data); !valid { + // t.Errorf("Invalid node update content: %s", updateType) + // } else { + // t.Logf("Valid node update type: %s", updateType) + // } + // } + // case <-time.After(200 * time.Millisecond): + // // Node changes might not always generate updates in test environment + // t.Logf("No node update received (may be expected in test environment)") + // } + + // Test RemoveNode + batcher.RemoveNode(tn.n.ID, tn.ch) + // Note: IsConnected may return true during grace period for DNS resolution + // The node is actually removed from active connections but grace period allows DNS lookups + }) + } +} + +func drainChannelTimeout(ch <-chan *tailcfg.MapResponse, name string, timeout time.Duration) { + count := 0 + + timer := time.NewTimer(timeout) + defer timer.Stop() + + for { + select { + case data := <-ch: + count++ + // Optional: add debug output if needed + _ = data + case <-timer.C: + return + } + } +} + +// TestBatcherUpdateTypes tests different types of updates and verifies +// that the batcher correctly processes them based on their content. +// +// Enhanced with real database test data, this test creates registered nodes +// and tests various update types including DERP changes, node-specific changes, +// and full updates. This validates the change classification logic and ensures +// different update types are handled appropriately with actual node data. +// func TestBatcherUpdateTypes(t *testing.T) { +// for _, batcherFunc := range allBatcherFunctions { +// t.Run(batcherFunc.name, func(t *testing.T) { +// // Create test environment with real database and nodes +// testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 8) +// defer cleanup() + +// batcher := testData.Batcher +// testNodes := testData.Nodes + +// ch := make(chan *tailcfg.MapResponse, 10) +// // Use real node ID from test data +// batcher.AddNode(testNodes[0].n.ID, ch, false, "zstd", tailcfg.CapabilityVersion(100)) + +// tests := []struct { +// name string +// changeSet change.ChangeSet +// expectData bool // whether we expect to receive data +// description string +// }{ +// { +// name: "DERP change", +// changeSet: change.DERPMapResponse(), +// expectData: true, +// description: "DERP changes should generate map updates", +// }, +// { +// name: "Node key expiry", +// changeSet: change.KeyExpiryFor(testNodes[1].n.ID), +// expectData: true, +// description: "Node key expiry with real node data", +// }, +// { +// name: "Node new registration", +// changeSet: change.NodeAddedResponse(testNodes[1].n.ID), +// expectData: true, +// description: "New node registration with real data", +// }, +// { +// name: "Full update", +// changeSet: change.FullUpdateResponse(), +// expectData: true, +// description: "Full updates with real node data", +// }, +// { +// name: "Policy change", +// changeSet: change.PolicyChangeResponse(), +// expectData: true, +// description: "Policy updates with real node data", +// }, +// } + +// for _, tt := range tests { +// t.Run(tt.name, func(t *testing.T) { +// t.Logf("Testing: %s", tt.description) + +// // Clear any existing updates +// select { +// case <-ch: +// default: +// } + +// batcher.AddWork(tt.changeSet) + +// select { +// case data := <-ch: +// if !tt.expectData { +// t.Errorf("Unexpected update for %s: %d bytes", tt.name, len(data)) +// } else { +// t.Logf("%s: received %d bytes", tt.name, len(data)) + +// // Validate update content when we have data +// if len(data) > 0 { +// if valid, updateType := validateUpdateContent(data); !valid { +// t.Errorf("Invalid update content for %s: %s", tt.name, updateType) +// } else { +// t.Logf("%s: valid update type: %s", tt.name, updateType) +// } +// } else { +// t.Logf("%s: empty update (may be expected for some node changes)", tt.name) +// } +// } +// case <-time.After(100 * time.Millisecond): +// if tt.expectData { +// t.Errorf("Expected update for %s (%s) but none received", tt.name, tt.description) +// } else { +// t.Logf("%s: no update (expected)", tt.name) +// } +// } +// }) +// } +// }) +// } +// } + +// TestBatcherWorkQueueBatching tests that multiple changes get batched +// together and sent as a single update to reduce network overhead. +// +// Enhanced with real database test data, this test creates registered nodes +// and rapidly submits multiple types of changes including DERP updates and +// node changes. Due to the batching mechanism with BatchChangeDelay, these +// should be combined into fewer updates. This validates that the batching +// system works correctly with real node data and mixed change types. +func TestBatcherWorkQueueBatching(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 8) + defer cleanup() + + batcher := testData.Batcher + testNodes := testData.Nodes + + ch := make(chan *tailcfg.MapResponse, 10) + batcher.AddNode(testNodes[0].n.ID, ch, tailcfg.CapabilityVersion(100)) + + // Track update content for validation + var receivedUpdates []*tailcfg.MapResponse + + // Add multiple changes rapidly to test batching + batcher.AddWork(change.DERPMap()) + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiryFor(testNodes[1].n.ID, testExpiry)) + batcher.AddWork(change.DERPMap()) + batcher.AddWork(change.NodeAdded(testNodes[1].n.ID)) + batcher.AddWork(change.DERPMap()) + + // Collect updates with timeout + updateCount := 0 + timeout := time.After(200 * time.Millisecond) + + for { + select { + case data := <-ch: + updateCount++ + + receivedUpdates = append(receivedUpdates, data) + + // Validate update content + if data != nil { + if valid, reason := validateUpdateContent(data); valid { + t.Logf("Update %d: valid", updateCount) + } else { + t.Logf("Update %d: invalid: %s", updateCount, reason) + } + } else { + t.Logf("Update %d: nil update", updateCount) + } + case <-timeout: + // Expected: 5 explicit changes + 1 initial from AddNode + 1 NodeOnline from wrapper = 7 updates + expectedUpdates := 7 + t.Logf("Received %d updates from %d changes (expected %d)", + updateCount, 5, expectedUpdates) + + if updateCount != expectedUpdates { + t.Errorf( + "Expected %d updates but received %d", + expectedUpdates, + updateCount, + ) + } + + // Validate that all updates have valid content + validUpdates := 0 + + for _, data := range receivedUpdates { + if data != nil { + if valid, _ := validateUpdateContent(data); valid { + validUpdates++ + } + } + } + + if validUpdates != updateCount { + t.Errorf("Expected all %d updates to be valid, but only %d were valid", + updateCount, validUpdates) + } + + return + } + } + }) + } +} + +// TestBatcherChannelClosingRace tests the fix for the async channel closing +// race condition that previously caused panics and data races. +// +// Enhanced with real database test data, this test simulates rapid node +// reconnections using real registered nodes while processing actual updates. +// The test verifies that channels are closed synchronously and deterministically +// even when real node updates are being processed, ensuring no race conditions +// occur during channel replacement with actual workload. +func XTestBatcherChannelClosingRace(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, 8) + defer cleanup() + + batcher := testData.Batcher + testNode := testData.Nodes[0] + + var ( + channelIssues int + mutex sync.Mutex + ) + + // Run rapid connect/disconnect cycles with real updates to test channel closing + + for i := range 100 { + var wg sync.WaitGroup + + // First connection + ch1 := make(chan *tailcfg.MapResponse, 1) + + wg.Go(func() { + batcher.AddNode(testNode.n.ID, ch1, tailcfg.CapabilityVersion(100)) + }) + + // Add real work during connection chaos + if i%10 == 0 { + batcher.AddWork(change.DERPMap()) + } + + // Rapid second connection - should replace ch1 + ch2 := make(chan *tailcfg.MapResponse, 1) + + wg.Go(func() { + runtime.Gosched() // Yield to introduce timing variability + batcher.AddNode(testNode.n.ID, ch2, tailcfg.CapabilityVersion(100)) + }) + + // Remove second connection + + wg.Go(func() { + runtime.Gosched() // Yield to introduce timing variability + runtime.Gosched() // Extra yield to offset from AddNode + batcher.RemoveNode(testNode.n.ID, ch2) + }) + + wg.Wait() + + // Verify ch1 behavior when replaced by ch2 + // The test is checking if ch1 gets closed/replaced properly + select { + case <-ch1: + // Channel received data or was closed, which is expected + case <-time.After(1 * time.Millisecond): + // If no data received, increment issues counter + mutex.Lock() + + channelIssues++ + + mutex.Unlock() + } + + // Clean up ch2 + select { + case <-ch2: + default: + } + } + + mutex.Lock() + defer mutex.Unlock() + + t.Logf("Channel closing issues: %d out of 100 iterations", channelIssues) + + // The main fix prevents panics and race conditions. Some timing variations + // are acceptable as long as there are no crashes or deadlocks. + if channelIssues > 50 { // Allow some timing variations + t.Errorf("Excessive channel closing issues: %d iterations", channelIssues) + } + }) + } +} + +// TestBatcherWorkerChannelSafety tests that worker goroutines handle closed +// channels safely without panicking when processing work items. +// +// Enhanced with real database test data, this test creates rapid connect/disconnect +// cycles using registered nodes while simultaneously queuing real work items. +// This creates a race where workers might try to send to channels that have been +// closed by node removal. The test validates that the safeSend() method properly +// handles closed channels with real update workloads. +func TestBatcherWorkerChannelSafety(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with real database and nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 1, 8) + defer cleanup() + + batcher := testData.Batcher + testNode := testData.Nodes[0] + + var ( + panics int + channelErrors int + invalidData int + mutex sync.Mutex + ) + + // Test rapid connect/disconnect with work generation + + for i := range 50 { + func() { + defer func() { + if r := recover(); r != nil { + mutex.Lock() + + panics++ + + mutex.Unlock() + t.Logf("Panic caught: %v", r) + } + }() + + ch := make(chan *tailcfg.MapResponse, 5) + + // Add node and immediately queue real work + batcher.AddNode(testNode.n.ID, ch, tailcfg.CapabilityVersion(100)) + batcher.AddWork(change.DERPMap()) + + // Consumer goroutine to validate data and detect channel issues + go func() { + defer func() { + if r := recover(); r != nil { + mutex.Lock() + + channelErrors++ + + mutex.Unlock() + t.Logf("Channel consumer panic: %v", r) + } + }() + + for { + select { + case data, ok := <-ch: + if !ok { + // Channel was closed, which is expected + return + } + // Validate the data we received + if valid, reason := validateUpdateContent(data); !valid { + mutex.Lock() + + invalidData++ + + mutex.Unlock() + t.Logf("Invalid data received: %s", reason) + } + case <-time.After(10 * time.Millisecond): + // Timeout waiting for data + return + } + } + }() + + // Add node-specific work occasionally + if i%10 == 0 { + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiryFor(testNode.n.ID, testExpiry)) + } + + // Rapid removal creates race between worker and removal + for range i % 3 { + runtime.Gosched() // Introduce timing variability + } + batcher.RemoveNode(testNode.n.ID, ch) + + // Yield to allow workers to process and close channels + runtime.Gosched() + }() + } + + mutex.Lock() + defer mutex.Unlock() + + t.Logf( + "Worker safety test results: %d panics, %d channel errors, %d invalid data packets", + panics, + channelErrors, + invalidData, + ) + + // Test failure conditions + if panics > 0 { + t.Errorf("Worker channel safety failed with %d panics", panics) + } + + if channelErrors > 0 { + t.Errorf("Channel handling failed with %d channel errors", channelErrors) + } + + if invalidData > 0 { + t.Errorf("Data validation failed with %d invalid data packets", invalidData) + } + }) + } +} + +// TestBatcherConcurrentClients tests that concurrent connection lifecycle changes +// don't affect other stable clients' ability to receive updates. +// +// The test sets up real test data with multiple users and registered nodes, +// then creates stable clients and churning clients that rapidly connect and +// disconnect. Work is generated continuously during these connection churn cycles using +// real node data. The test validates that stable clients continue to function +// normally and receive proper updates despite the connection churn from other clients, +// ensuring system stability under concurrent load. +func TestBatcherConcurrentClients(t *testing.T) { + if testing.Short() { + t.Skip("Skipping concurrent client test in short mode") + } + + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create comprehensive test environment with real data + testData, cleanup := setupBatcherWithTestData( + t, + batcherFunc.fn, + TEST_USER_COUNT, + TEST_NODES_PER_USER, + 8, + ) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + // Create update tracker for monitoring all updates + tracker := newUpdateTracker() + + // Set up stable clients using real node IDs + stableNodes := allNodes[:len(allNodes)/2] // Use first half as stable + stableChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) + + for _, node := range stableNodes { + ch := make(chan *tailcfg.MapResponse, NORMAL_BUFFER_SIZE) + stableChannels[node.n.ID] = ch + batcher.AddNode(node.n.ID, ch, tailcfg.CapabilityVersion(100)) + + // Monitor updates for each stable client + go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { + for { + select { + case data, ok := <-channel: + if !ok { + // Channel was closed, exit gracefully + return + } + if valid, reason := validateUpdateContent(data); valid { + tracker.recordUpdate( + nodeID, + 1, + ) // Use 1 as update size since we have MapResponse + } else { + t.Errorf("Invalid update received for stable node %d: %s", nodeID, reason) + } + case <-time.After(TEST_TIMEOUT): + return + } + } + }(node.n.ID, ch) + } + + // Use remaining nodes for connection churn testing + churningNodes := allNodes[len(allNodes)/2:] + churningChannels := make(map[types.NodeID]chan *tailcfg.MapResponse) + + var churningChannelsMutex sync.Mutex // Protect concurrent map access + + var wg sync.WaitGroup + + numCycles := 10 // Reduced for simpler test + panicCount := 0 + + var panicMutex sync.Mutex + + // Track deadlock with timeout + done := make(chan struct{}) + + go func() { + defer close(done) + + // Connection churn cycles - rapidly connect/disconnect to test concurrency safety + for i := range numCycles { + for _, node := range churningNodes { + wg.Add(2) + + // Connect churning node + go func(nodeID types.NodeID) { + defer func() { + if r := recover(); r != nil { + panicMutex.Lock() + + panicCount++ + + panicMutex.Unlock() + t.Logf("Panic in churning connect: %v", r) + } + + wg.Done() + }() + + ch := make(chan *tailcfg.MapResponse, SMALL_BUFFER_SIZE) + + churningChannelsMutex.Lock() + churningChannels[nodeID] = ch + churningChannelsMutex.Unlock() + + batcher.AddNode(nodeID, ch, tailcfg.CapabilityVersion(100)) + + // Consume updates to prevent blocking + go func() { + for { + select { + case data, ok := <-ch: + if !ok { + // Channel was closed, exit gracefully + return + } + if valid, _ := validateUpdateContent(data); valid { + tracker.recordUpdate( + nodeID, + 1, + ) // Use 1 as update size since we have MapResponse + } + case <-time.After(500 * time.Millisecond): + // Longer timeout to prevent premature exit during heavy load + return + } + } + }() + }(node.n.ID) + + // Disconnect churning node + go func(nodeID types.NodeID) { + defer func() { + if r := recover(); r != nil { + panicMutex.Lock() + + panicCount++ + + panicMutex.Unlock() + t.Logf("Panic in churning disconnect: %v", r) + } + + wg.Done() + }() + + for range i % 5 { + runtime.Gosched() // Introduce timing variability + } + churningChannelsMutex.Lock() + + ch, exists := churningChannels[nodeID] + + churningChannelsMutex.Unlock() + + if exists { + batcher.RemoveNode(nodeID, ch) + } + }(node.n.ID) + } + + // Generate various types of work during racing + if i%3 == 0 { + // DERP changes + batcher.AddWork(change.DERPMap()) + } + + if i%5 == 0 { + // Full updates using real node data + batcher.AddWork(change.FullUpdate()) + } + + if i%7 == 0 && len(allNodes) > 0 { + // Node-specific changes using real nodes + node := allNodes[i%len(allNodes)] + // Use a valid expiry time for testing since test nodes don't have expiry set + testExpiry := time.Now().Add(24 * time.Hour) + batcher.AddWork(change.KeyExpiryFor(node.n.ID, testExpiry)) + } + + // Yield to allow some batching + runtime.Gosched() + } + + wg.Wait() + }() + + // Deadlock detection + select { + case <-done: + t.Logf("Connection churn cycles completed successfully") + case <-time.After(DEADLOCK_TIMEOUT): + t.Error("Test timed out - possible deadlock detected") + return + } + + // Yield to allow any in-flight updates to complete + runtime.Gosched() + + // Validate results + panicMutex.Lock() + + finalPanicCount := panicCount + + panicMutex.Unlock() + + allStats := tracker.getAllStats() + + // Calculate expected vs actual updates + stableUpdateCount := 0 + churningUpdateCount := 0 + + // Count actual update sources to understand the pattern + // Let's track what we observe rather than trying to predict + expectedDerpUpdates := (numCycles + 2) / 3 + expectedFullUpdates := (numCycles + 4) / 5 + expectedKeyUpdates := (numCycles + 6) / 7 + totalGeneratedWork := expectedDerpUpdates + expectedFullUpdates + expectedKeyUpdates + + t.Logf("Work generated: %d DERP + %d Full + %d KeyExpiry = %d total AddWork calls", + expectedDerpUpdates, expectedFullUpdates, expectedKeyUpdates, totalGeneratedWork) + + for _, node := range stableNodes { + if stats, exists := allStats[node.n.ID]; exists { + stableUpdateCount += stats.TotalUpdates + t.Logf("Stable node %d: %d updates", + node.n.ID, stats.TotalUpdates) + } + + // Verify stable clients are still connected + if !batcher.IsConnected(node.n.ID) { + t.Errorf("Stable node %d should still be connected", node.n.ID) + } + } + + for _, node := range churningNodes { + if stats, exists := allStats[node.n.ID]; exists { + churningUpdateCount += stats.TotalUpdates + } + } + + t.Logf("Total updates - Stable clients: %d, Churning clients: %d", + stableUpdateCount, churningUpdateCount) + t.Logf( + "Average per stable client: %.1f updates", + float64(stableUpdateCount)/float64(len(stableNodes)), + ) + t.Logf("Panics during test: %d", finalPanicCount) + + // Validate test success criteria + if finalPanicCount > 0 { + t.Errorf("Test failed with %d panics", finalPanicCount) + } + + // Basic sanity check - stable clients should receive some updates + if stableUpdateCount == 0 { + t.Error("Stable clients received no updates - batcher may not be working") + } + + // Verify all stable clients are still functional + for _, node := range stableNodes { + if !batcher.IsConnected(node.n.ID) { + t.Errorf("Stable node %d lost connection during racing", node.n.ID) + } + } + }) + } +} + +// TestBatcherHighLoadStability tests batcher behavior under high concurrent load +// scenarios with multiple nodes rapidly connecting and disconnecting while +// continuous updates are generated. +// +// This test creates a high-stress environment with many nodes connecting and +// disconnecting rapidly while various types of updates are generated continuously. +// It validates that the system remains stable with no deadlocks, panics, or +// missed updates under sustained high load. The test uses real node data to +// generate authentic update scenarios and tracks comprehensive statistics. +func XTestBatcherScalability(t *testing.T) { + if testing.Short() { + t.Skip("Skipping scalability test in short mode") + } + + // Reduce verbose application logging for cleaner test output + originalLevel := zerolog.GlobalLevel() + defer zerolog.SetGlobalLevel(originalLevel) + + zerolog.SetGlobalLevel(zerolog.ErrorLevel) + + // Full test matrix for scalability testing + nodes := []int{25, 50, 100} // 250, 500, 1000, + + cycles := []int{10, 100} // 500 + bufferSizes := []int{1, 200, 1000} + chaosTypes := []string{"connection", "processing", "mixed"} + + type testCase struct { + name string + nodeCount int + cycles int + bufferSize int + chaosType string + expectBreak bool + description string + } + + var testCases []testCase + + // Generate all combinations of the test matrix + for _, nodeCount := range nodes { + for _, cycleCount := range cycles { + for _, bufferSize := range bufferSizes { + for _, chaosType := range chaosTypes { + expectBreak := false + // resourceIntensity := float64(nodeCount*cycleCount) / float64(bufferSize) + + // switch chaosType { + // case "processing": + // resourceIntensity *= 1.1 + // case "mixed": + // resourceIntensity *= 1.15 + // } + + // if resourceIntensity > 500000 { + // expectBreak = true + // } else if nodeCount >= 1000 && cycleCount >= 500 && bufferSize <= 1 { + // expectBreak = true + // } else if nodeCount >= 500 && cycleCount >= 500 && bufferSize <= 1 && chaosType == "mixed" { + // expectBreak = true + // } + + name := fmt.Sprintf( + "%s_%dn_%dc_%db", + chaosType, + nodeCount, + cycleCount, + bufferSize, + ) + description := fmt.Sprintf("%s chaos: %d nodes, %d cycles, %d buffers", + chaosType, nodeCount, cycleCount, bufferSize) + + testCases = append(testCases, testCase{ + name: name, + nodeCount: nodeCount, + cycles: cycleCount, + bufferSize: bufferSize, + chaosType: chaosType, + expectBreak: expectBreak, + description: description, + }) + } + } + } + } + + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + for i, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Create comprehensive test environment with real data using the specific buffer size for this test case + // Need 1000 nodes for largest test case, all from same user so they can be peers + usersNeeded := max(1, tc.nodeCount/1000) // 1 user per 1000 nodes, minimum 1 + nodesPerUser := tc.nodeCount / usersNeeded + + testData, cleanup := setupBatcherWithTestData( + t, + batcherFunc.fn, + usersNeeded, + nodesPerUser, + tc.bufferSize, + ) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + t.Logf("[%d/%d] SCALABILITY TEST: %s", i+1, len(testCases), tc.description) + t.Logf( + " Cycles: %d, Buffer Size: %d, Chaos Type: %s", + tc.cycles, + tc.bufferSize, + tc.chaosType, + ) + + // Use provided nodes, limit to requested count + testNodes := allNodes[:min(len(allNodes), tc.nodeCount)] + + tracker := newUpdateTracker() + panicCount := int64(0) + deadlockDetected := false + + startTime := time.Now() + setupTime := time.Since(startTime) + t.Logf( + "Starting scalability test with %d nodes (setup took: %v)", + len(testNodes), + setupTime, + ) + + // Comprehensive stress test + done := make(chan struct{}) + + // Start update consumers for all nodes + for i := range testNodes { + testNodes[i].start() + } + + // Yield to allow tracking goroutines to start + runtime.Gosched() + + // Connect all nodes first so they can see each other as peers + connectedNodes := make(map[types.NodeID]bool) + + var connectedNodesMutex sync.RWMutex + + for i := range testNodes { + node := &testNodes[i] + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + connectedNodesMutex.Lock() + + connectedNodes[node.n.ID] = true + + connectedNodesMutex.Unlock() + } + + // Wait for all connections to be established + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range testNodes { + assert.True(c, batcher.IsConnected(testNodes[i].n.ID), "node should be connected") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect") + + batcher.AddWork(change.FullUpdate()) + + // Wait for initial update to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range testNodes { + assert.GreaterOrEqual(c, atomic.LoadInt64(&testNodes[i].updateCount), int64(1), "should have received initial update") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for initial update") + + go func() { + defer close(done) + + var wg sync.WaitGroup + + t.Logf( + "Starting load generation: %d cycles with %d nodes", + tc.cycles, + len(testNodes), + ) + + // Main load generation - varies by chaos type + for cycle := range tc.cycles { + if cycle%10 == 0 { + t.Logf("Cycle %d/%d completed", cycle, tc.cycles) + } + // Yield for mixed chaos to introduce timing variability + if tc.chaosType == "mixed" && cycle%10 == 0 { + runtime.Gosched() + } + + // For chaos testing, only disconnect/reconnect a subset of nodes + // This ensures some nodes stay connected to continue receiving updates + startIdx := cycle % len(testNodes) + + endIdx := min(startIdx+len(testNodes)/4, len(testNodes)) + + if startIdx >= endIdx { + startIdx = 0 + endIdx = min(len(testNodes)/4, len(testNodes)) + } + + chaosNodes := testNodes[startIdx:endIdx] + if len(chaosNodes) == 0 { + chaosNodes = testNodes[:min(1, len(testNodes))] // At least one node for chaos + } + + // Connection/disconnection cycles for subset of nodes + for i, node := range chaosNodes { + // Only add work if this is connection chaos or mixed + if tc.chaosType == "connection" || tc.chaosType == "mixed" { + wg.Add(2) + + // Disconnection first + go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse) { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + + wg.Done() + }() + + connectedNodesMutex.RLock() + + isConnected := connectedNodes[nodeID] + + connectedNodesMutex.RUnlock() + + if isConnected { + batcher.RemoveNode(nodeID, channel) + connectedNodesMutex.Lock() + + connectedNodes[nodeID] = false + + connectedNodesMutex.Unlock() + } + }( + node.n.ID, + node.ch, + ) + + // Then reconnection + go func(nodeID types.NodeID, channel chan *tailcfg.MapResponse, index int) { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + + wg.Done() + }() + + // Yield before reconnecting to introduce timing variability + for range index % 3 { + runtime.Gosched() + } + + _ = batcher.AddNode( + nodeID, + channel, + tailcfg.CapabilityVersion(100), + ) + connectedNodesMutex.Lock() + + connectedNodes[nodeID] = true + + connectedNodesMutex.Unlock() + + // Add work to create load + if index%5 == 0 { + batcher.AddWork(change.FullUpdate()) + } + }( + node.n.ID, + node.ch, + i, + ) + } + } + + // Concurrent work generation - scales with load + updateCount := min(tc.nodeCount/5, 20) // Scale updates with node count + for i := range updateCount { + wg.Add(1) + + go func(index int) { + defer func() { + if r := recover(); r != nil { + atomic.AddInt64(&panicCount, 1) + } + + wg.Done() + }() + + // Generate different types of work to ensure updates are sent + switch index % 4 { + case 0: + batcher.AddWork(change.FullUpdate()) + case 1: + batcher.AddWork(change.PolicyChange()) + case 2: + batcher.AddWork(change.DERPMap()) + default: + // Pick a random node and generate a node change + if len(testNodes) > 0 { + nodeIdx := index % len(testNodes) + batcher.AddWork( + change.NodeAdded(testNodes[nodeIdx].n.ID), + ) + } else { + batcher.AddWork(change.FullUpdate()) + } + } + }(i) + } + } + + t.Logf("Waiting for all goroutines to complete") + wg.Wait() + t.Logf("All goroutines completed") + }() + + // Wait for completion with timeout and progress monitoring + progressTicker := time.NewTicker(10 * time.Second) + defer progressTicker.Stop() + + select { + case <-done: + t.Logf("Test completed successfully") + case <-time.After(TEST_TIMEOUT): + deadlockDetected = true + // Collect diagnostic information + allStats := tracker.getAllStats() + + totalUpdates := 0 + for _, stats := range allStats { + totalUpdates += stats.TotalUpdates + } + + interimPanics := atomic.LoadInt64(&panicCount) + + t.Logf("TIMEOUT DIAGNOSIS: Test timed out after %v", TEST_TIMEOUT) + t.Logf( + " Progress at timeout: %d total updates, %d panics", + totalUpdates, + interimPanics, + ) + t.Logf( + " Possible causes: deadlock, excessive load, or performance bottleneck", + ) + + // Try to detect if workers are still active + if totalUpdates > 0 { + t.Logf( + " System was processing updates - likely performance bottleneck", + ) + } else { + t.Logf(" No updates processed - likely deadlock or startup issue") + } + } + + // Wait for batcher workers to process all work and send updates + // before disconnecting nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Check that at least some updates were processed + var totalUpdates int64 + for i := range testNodes { + totalUpdates += atomic.LoadInt64(&testNodes[i].updateCount) + } + + assert.Positive(c, totalUpdates, "should have processed some updates") + }, 5*time.Second, 50*time.Millisecond, "waiting for updates to be processed") + + // Now disconnect all nodes from batcher to stop new updates + for i := range testNodes { + node := &testNodes[i] + batcher.RemoveNode(node.n.ID, node.ch) + } + + // Wait for nodes to be disconnected + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range testNodes { + assert.False(c, batcher.IsConnected(testNodes[i].n.ID), "node should be disconnected") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for nodes to disconnect") + + // Cleanup nodes and get their final stats + totalUpdates := int64(0) + totalPatches := int64(0) + totalFull := int64(0) + maxPeersGlobal := 0 + nodeStatsReport := make([]string, 0, len(testNodes)) + + for i := range testNodes { + node := &testNodes[i] + stats := node.cleanup() + totalUpdates += stats.TotalUpdates + totalPatches += stats.PatchUpdates + + totalFull += stats.FullUpdates + if stats.MaxPeersSeen > maxPeersGlobal { + maxPeersGlobal = stats.MaxPeersSeen + } + + if stats.TotalUpdates > 0 { + nodeStatsReport = append(nodeStatsReport, + fmt.Sprintf( + "Node %d: %d total (%d patch, %d full), max %d peers", + node.n.ID, + stats.TotalUpdates, + stats.PatchUpdates, + stats.FullUpdates, + stats.MaxPeersSeen, + )) + } + } + + // Comprehensive final summary + t.Logf( + "FINAL RESULTS: %d total updates (%d patch, %d full), max peers seen: %d", + totalUpdates, + totalPatches, + totalFull, + maxPeersGlobal, + ) + + if len(nodeStatsReport) <= 10 { // Only log details for smaller tests + for _, report := range nodeStatsReport { + t.Logf(" %s", report) + } + } else { + t.Logf(" (%d nodes had activity, details suppressed for large test)", len(nodeStatsReport)) + } + + // Legacy tracker comparison (optional) + allStats := tracker.getAllStats() + + legacyTotalUpdates := 0 + for _, stats := range allStats { + legacyTotalUpdates += stats.TotalUpdates + } + + if legacyTotalUpdates != int(totalUpdates) { + t.Logf( + "Note: Legacy tracker mismatch - legacy: %d, new: %d", + legacyTotalUpdates, + totalUpdates, + ) + } + + finalPanicCount := atomic.LoadInt64(&panicCount) + + // Validation based on expectation + testPassed := true + + if tc.expectBreak { + // For tests expected to break, we're mainly checking that we don't crash + if finalPanicCount > 0 { + t.Errorf( + "System crashed with %d panics (even breaking point tests shouldn't crash)", + finalPanicCount, + ) + + testPassed = false + } + // Timeout/deadlock is acceptable for breaking point tests + if deadlockDetected { + t.Logf( + "Expected breaking point reached: system overloaded at %d nodes", + len(testNodes), + ) + } + } else { + // For tests expected to pass, validate proper operation + if finalPanicCount > 0 { + t.Errorf("Scalability test failed with %d panics", finalPanicCount) + + testPassed = false + } + + if deadlockDetected { + t.Errorf("Deadlock detected at %d nodes (should handle this load)", len(testNodes)) + + testPassed = false + } + + if totalUpdates == 0 { + t.Error("No updates received - system may be completely stalled") + + testPassed = false + } + } + + // Clear success/failure indication + if testPassed { + t.Logf("✅ PASS: %s | %d nodes, %d updates, 0 panics, no deadlock", + tc.name, len(testNodes), totalUpdates) + } else { + t.Logf("❌ FAIL: %s | %d nodes, %d updates, %d panics, deadlock: %v", + tc.name, len(testNodes), totalUpdates, finalPanicCount, deadlockDetected) + } + }) + } + }) + } +} + +// TestBatcherFullPeerUpdates verifies that when multiple nodes are connected +// and we send a FullSet update, nodes receive the complete peer list. +func TestBatcherFullPeerUpdates(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with 3 nodes from same user (so they can be peers) + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + t.Logf("Created %d nodes in database", len(allNodes)) + + // Connect nodes one at a time and wait for each to be connected + for i, node := range allNodes { + batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + t.Logf("Connected node %d (ID: %d)", i, node.n.ID) + + // Wait for node to be connected + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.True(c, batcher.IsConnected(node.n.ID), "node should be connected") + }, time.Second, 10*time.Millisecond, "waiting for node connection") + } + + // Wait for all NodeCameOnline events to be processed + t.Logf("Waiting for NodeCameOnline events to settle...") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range allNodes { + assert.True(c, batcher.IsConnected(allNodes[i].n.ID), "all nodes should be connected") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for all nodes to connect") + + // Check how many peers each node should see + for i, node := range allNodes { + peers := testData.State.ListPeers(node.n.ID) + t.Logf("Node %d should see %d peers from state", i, peers.Len()) + } + + // Send a full update - this should generate full peer lists + t.Logf("Sending FullSet update...") + batcher.AddWork(change.FullUpdate()) + + // Wait for FullSet work items to be processed + t.Logf("Waiting for FullSet to be processed...") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Check that some data is available in at least one channel + found := false + + for i := range allNodes { + if len(allNodes[i].ch) > 0 { + found = true + break + } + } + + assert.True(c, found, "no updates received yet") + }, 5*time.Second, 50*time.Millisecond, "waiting for FullSet updates") + + // Check what each node receives - read multiple updates + totalUpdates := 0 + foundFullUpdate := false + + // Read all available updates for each node + for i := range allNodes { + nodeUpdates := 0 + + t.Logf("Reading updates for node %d:", i) + + // Read up to 10 updates per node or until timeout/no more data + for updateNum := range 10 { + select { + case data := <-allNodes[i].ch: + nodeUpdates++ + totalUpdates++ + + // Parse and examine the update - data is already a MapResponse + if data == nil { + t.Errorf("Node %d update %d: nil MapResponse", i, updateNum) + continue + } + + updateType := "unknown" + if len(data.Peers) > 0 { + updateType = "FULL" + foundFullUpdate = true + } else if len(data.PeersChangedPatch) > 0 { + updateType = "PATCH" + } else if data.DERPMap != nil { + updateType = "DERP" + } + + t.Logf( + " Update %d: %s - Peers=%d, PeersChangedPatch=%d, DERPMap=%v", + updateNum, + updateType, + len(data.Peers), + len(data.PeersChangedPatch), + data.DERPMap != nil, + ) + + if len(data.Peers) > 0 { + t.Logf(" Full peer list with %d peers", len(data.Peers)) + + for j, peer := range data.Peers[:min(3, len(data.Peers))] { + t.Logf( + " Peer %d: NodeID=%d, Online=%v", + j, + peer.ID, + peer.Online, + ) + } + } + + if len(data.PeersChangedPatch) > 0 { + t.Logf(" Patch update with %d changes", len(data.PeersChangedPatch)) + + for j, patch := range data.PeersChangedPatch[:min(3, len(data.PeersChangedPatch))] { + t.Logf( + " Patch %d: NodeID=%d, Online=%v", + j, + patch.NodeID, + patch.Online, + ) + } + } + + case <-time.After(500 * time.Millisecond): + } + } + + t.Logf("Node %d received %d updates", i, nodeUpdates) + } + + t.Logf("Total updates received across all nodes: %d", totalUpdates) + + if !foundFullUpdate { + t.Errorf("CRITICAL: No FULL updates received despite sending change.FullUpdateResponse()!") + t.Errorf( + "This confirms the bug - FullSet updates are not generating full peer responses", + ) + } + }) + } +} + +// TestBatcherRapidReconnection reproduces the issue where nodes connecting with the same ID +// at the same time cause /debug/batcher to show nodes as disconnected when they should be connected. +// This specifically tests the multi-channel batcher implementation issue. +func TestBatcherRapidReconnection(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, 10) + defer cleanup() + + batcher := testData.Batcher + allNodes := testData.Nodes + + t.Logf("=== RAPID RECONNECTION TEST ===") + t.Logf("Testing rapid connect/disconnect with %d nodes", len(allNodes)) + + // Phase 1: Connect all nodes initially + t.Logf("Phase 1: Connecting all nodes...") + for i, node := range allNodes { + err := batcher.AddNode(node.n.ID, node.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node %d: %v", i, err) + } + } + + // Wait for all connections to settle + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range allNodes { + assert.True(c, batcher.IsConnected(allNodes[i].n.ID), "node should be connected") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for connections to settle") + + // Phase 2: Rapid disconnect ALL nodes (simulating nodes going down) + t.Logf("Phase 2: Rapid disconnect all nodes...") + for i, node := range allNodes { + removed := batcher.RemoveNode(node.n.ID, node.ch) + t.Logf("Node %d RemoveNode result: %t", i, removed) + } + + // Phase 3: Rapid reconnect with NEW channels (simulating nodes coming back up) + t.Logf("Phase 3: Rapid reconnect with new channels...") + newChannels := make([]chan *tailcfg.MapResponse, len(allNodes)) + for i, node := range allNodes { + newChannels[i] = make(chan *tailcfg.MapResponse, 10) + err := batcher.AddNode(node.n.ID, newChannels[i], tailcfg.CapabilityVersion(100)) + if err != nil { + t.Errorf("Failed to reconnect node %d: %v", i, err) + } + } + + // Wait for all reconnections to settle + assert.EventuallyWithT(t, func(c *assert.CollectT) { + for i := range allNodes { + assert.True(c, batcher.IsConnected(allNodes[i].n.ID), "node should be reconnected") + } + }, 5*time.Second, 50*time.Millisecond, "waiting for reconnections to settle") + + // Phase 4: Check debug status - THIS IS WHERE THE BUG SHOULD APPEAR + t.Logf("Phase 4: Checking debug status...") + + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + disconnectedCount := 0 + + for i, node := range allNodes { + if info, exists := debugInfo[node.n.ID]; exists { + t.Logf("Node %d (ID %d): debug info = %+v", i, node.n.ID, info) + + // Check if the debug info shows the node as connected + if infoMap, ok := info.(map[string]any); ok { + if connected, ok := infoMap["connected"].(bool); ok && !connected { + disconnectedCount++ + t.Logf("BUG REPRODUCED: Node %d shows as disconnected in debug but should be connected", i) + } + } + } else { + disconnectedCount++ + t.Logf("Node %d missing from debug info entirely", i) + } + + // Also check IsConnected method + if !batcher.IsConnected(node.n.ID) { + t.Logf("Node %d IsConnected() returns false", i) + } + } + + if disconnectedCount > 0 { + t.Logf("ISSUE REPRODUCED: %d/%d nodes show as disconnected in debug", disconnectedCount, len(allNodes)) + // This is expected behavior for multi-channel batcher according to user + // "it has never worked with the multi" + } else { + t.Logf("All nodes show as connected - working correctly") + } + } else { + t.Logf("Batcher does not implement Debug() method") + } + + // Phase 5: Test if "disconnected" nodes can actually receive updates + t.Logf("Phase 5: Testing if nodes can receive updates despite debug status...") + + // Send a change that should reach all nodes + batcher.AddWork(change.DERPMap()) + + receivedCount := 0 + timeout := time.After(500 * time.Millisecond) + + for i := range allNodes { + select { + case update := <-newChannels[i]: + if update != nil { + receivedCount++ + t.Logf("Node %d received update successfully", i) + } + case <-timeout: + t.Logf("Node %d timed out waiting for update", i) + goto done + } + } + + done: + t.Logf("Update delivery test: %d/%d nodes received updates", receivedCount, len(allNodes)) + + if receivedCount < len(allNodes) { + t.Logf("Some nodes failed to receive updates - confirming the issue") + } + }) + } +} + +func TestBatcherMultiConnection(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 2, 10) + defer cleanup() + + batcher := testData.Batcher + node1 := testData.Nodes[0] + node2 := testData.Nodes[1] + + t.Logf("=== MULTI-CONNECTION TEST ===") + + // Phase 1: Connect first node with initial connection + t.Logf("Phase 1: Connecting node 1 with first connection...") + err := batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node1: %v", err) + } + + // Connect second node for comparison + err = batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add node2: %v", err) + } + + // Wait for initial connections + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.True(c, batcher.IsConnected(node1.n.ID), "node1 should be connected") + assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected") + }, time.Second, 10*time.Millisecond, "waiting for initial connections") + + // Phase 2: Add second connection for node1 (multi-connection scenario) + t.Logf("Phase 2: Adding second connection for node 1...") + secondChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, secondChannel, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add second connection for node1: %v", err) + } + + // Yield to allow connection to be processed + runtime.Gosched() + + // Phase 3: Add third connection for node1 + t.Logf("Phase 3: Adding third connection for node 1...") + thirdChannel := make(chan *tailcfg.MapResponse, 10) + err = batcher.AddNode(node1.n.ID, thirdChannel, tailcfg.CapabilityVersion(100)) + if err != nil { + t.Fatalf("Failed to add third connection for node1: %v", err) + } + + // Yield to allow connection to be processed + runtime.Gosched() + + // Phase 4: Verify debug status shows correct connection count + t.Logf("Phase 4: Verifying debug status shows multiple connections...") + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + + if info, exists := debugInfo[node1.n.ID]; exists { + t.Logf("Node1 debug info: %+v", info) + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 3 { + t.Errorf("Node1 should have 3 active connections, got %d", activeConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 3 active connections") + } + } + if connected, ok := infoMap["connected"].(bool); ok && !connected { + t.Errorf("Node1 should show as connected with 3 active connections") + } + } + } + + if info, exists := debugInfo[node2.n.ID]; exists { + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 1 { + t.Errorf("Node2 should have 1 active connection, got %d", activeConnections) + } + } + } + } + } + + // Phase 5: Send update and verify ALL connections receive it + t.Logf("Phase 5: Testing update distribution to all connections...") + + // Clear any existing updates from all channels + clearChannel := func(ch chan *tailcfg.MapResponse) { + for { + select { + case <-ch: + // drain + default: + return + } + } + } + + clearChannel(node1.ch) + clearChannel(secondChannel) + clearChannel(thirdChannel) + clearChannel(node2.ch) + + // Send a change notification from node2 (so node1 should receive it on all connections) + testChangeSet := change.NodeAdded(node2.n.ID) + + batcher.AddWork(testChangeSet) + + // Wait for updates to propagate to at least one channel + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Positive(c, len(node1.ch)+len(secondChannel)+len(thirdChannel), "should have received updates") + }, 5*time.Second, 50*time.Millisecond, "waiting for updates to propagate") + + // Verify all three connections for node1 receive the update + connection1Received := false + connection2Received := false + connection3Received := false + + select { + case mapResp := <-node1.ch: + connection1Received = (mapResp != nil) + t.Logf("Node1 connection 1 received update: %t", connection1Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 1 did not receive update") + } + + select { + case mapResp := <-secondChannel: + connection2Received = (mapResp != nil) + t.Logf("Node1 connection 2 received update: %t", connection2Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 2 did not receive update") + } + + select { + case mapResp := <-thirdChannel: + connection3Received = (mapResp != nil) + t.Logf("Node1 connection 3 received update: %t", connection3Received) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 3 did not receive update") + } + + if connection1Received && connection2Received && connection3Received { + t.Logf("SUCCESS: All three connections for node1 received the update") + } else { + t.Errorf("FAILURE: Multi-connection broadcast failed - conn1: %t, conn2: %t, conn3: %t", + connection1Received, connection2Received, connection3Received) + } + + // Phase 6: Test connection removal and verify remaining connections still work + t.Logf("Phase 6: Testing connection removal...") + + // Remove the second connection + removed := batcher.RemoveNode(node1.n.ID, secondChannel) + if !removed { + t.Errorf("Failed to remove second connection for node1") + } + + // Yield to allow removal to be processed + runtime.Gosched() + + // Verify debug status shows 2 connections now + if debugBatcher, ok := batcher.(interface { + Debug() map[types.NodeID]any + }); ok { + debugInfo := debugBatcher.Debug() + if info, exists := debugInfo[node1.n.ID]; exists { + if infoMap, ok := info.(map[string]any); ok { + if activeConnections, ok := infoMap["active_connections"].(int); ok { + if activeConnections != 2 { + t.Errorf("Node1 should have 2 active connections after removal, got %d", activeConnections) + } else { + t.Logf("SUCCESS: Node1 correctly shows 2 active connections after removal") + } + } + } + } + } + + // Send another update and verify remaining connections still work + clearChannel(node1.ch) + clearChannel(thirdChannel) + + testChangeSet2 := change.NodeAdded(node2.n.ID) + + batcher.AddWork(testChangeSet2) + + // Wait for updates to propagate to remaining channels + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.Positive(c, len(node1.ch)+len(thirdChannel), "should have received updates") + }, 5*time.Second, 50*time.Millisecond, "waiting for updates to propagate") + + // Verify remaining connections still receive updates + remaining1Received := false + remaining3Received := false + + select { + case mapResp := <-node1.ch: + remaining1Received = (mapResp != nil) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 1 did not receive update after removal") + } + + select { + case mapResp := <-thirdChannel: + remaining3Received = (mapResp != nil) + case <-time.After(500 * time.Millisecond): + t.Errorf("Node1 connection 3 did not receive update after removal") + } + + if remaining1Received && remaining3Received { + t.Logf("SUCCESS: Remaining connections still receive updates after removal") + } else { + t.Errorf("FAILURE: Remaining connections failed to receive updates - conn1: %t, conn3: %t", + remaining1Received, remaining3Received) + } + + // Drain secondChannel of any messages received before removal + // (the test wrapper sends NodeOffline before removal, which may have reached this channel) + clearChannel(secondChannel) + + // Verify second channel no longer receives new updates after being removed + select { + case <-secondChannel: + t.Errorf("Removed connection still received update - this should not happen") + case <-time.After(100 * time.Millisecond): + t.Logf("SUCCESS: Removed connection correctly no longer receives updates") + } + }) + } +} + +// TestNodeDeletedWhileChangesPending reproduces issue #2924 where deleting a node +// from state while there are pending changes for that node in the batcher causes +// "node not found" errors. The race condition occurs when: +// 1. Node is connected and changes are queued for it +// 2. Node is deleted from state (NodeStore) but not from batcher +// 3. Batcher worker tries to generate map response for deleted node +// 4. Mapper fails to find node in state, causing repeated "node not found" errors. +func TestNodeDeletedWhileChangesPending(t *testing.T) { + for _, batcherFunc := range allBatcherFunctions { + t.Run(batcherFunc.name, func(t *testing.T) { + // Create test environment with 3 nodes + testData, cleanup := setupBatcherWithTestData(t, batcherFunc.fn, 1, 3, NORMAL_BUFFER_SIZE) + defer cleanup() + + batcher := testData.Batcher + st := testData.State + node1 := &testData.Nodes[0] + node2 := &testData.Nodes[1] + node3 := &testData.Nodes[2] + + t.Logf("Testing issue #2924: Node1=%d, Node2=%d, Node3=%d", + node1.n.ID, node2.n.ID, node3.n.ID) + + // Helper to drain channels + drainCh := func(ch chan *tailcfg.MapResponse) { + for { + select { + case <-ch: + // drain + default: + return + } + } + } + + // Start update consumers for all nodes + node1.start() + node2.start() + node3.start() + + defer node1.cleanup() + defer node2.cleanup() + defer node3.cleanup() + + // Connect all nodes to the batcher + require.NoError(t, batcher.AddNode(node1.n.ID, node1.ch, tailcfg.CapabilityVersion(100))) + require.NoError(t, batcher.AddNode(node2.n.ID, node2.ch, tailcfg.CapabilityVersion(100))) + require.NoError(t, batcher.AddNode(node3.n.ID, node3.ch, tailcfg.CapabilityVersion(100))) + + // Wait for all nodes to be connected + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.True(c, batcher.IsConnected(node1.n.ID), "node1 should be connected") + assert.True(c, batcher.IsConnected(node2.n.ID), "node2 should be connected") + assert.True(c, batcher.IsConnected(node3.n.ID), "node3 should be connected") + }, 5*time.Second, 50*time.Millisecond, "waiting for nodes to connect") + + // Get initial work errors count + var initialWorkErrors int64 + if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok { + initialWorkErrors = lfb.WorkErrors() + t.Logf("Initial work errors: %d", initialWorkErrors) + } + + // Clear channels to prepare for the test + drainCh(node1.ch) + drainCh(node2.ch) + drainCh(node3.ch) + + // Get node view for deletion + nodeToDelete, ok := st.GetNodeByID(node3.n.ID) + require.True(t, ok, "node3 should exist in state") + + // Delete the node from state - this returns a NodeRemoved change + // In production, this change is sent to batcher via app.Change() + nodeChange, err := st.DeleteNode(nodeToDelete) + require.NoError(t, err, "should be able to delete node from state") + t.Logf("Deleted node %d from state, change: %s", node3.n.ID, nodeChange.Reason) + + // Verify node is deleted from state + _, exists := st.GetNodeByID(node3.n.ID) + require.False(t, exists, "node3 should be deleted from state") + + // Send the NodeRemoved change to batcher (this is what app.Change() does) + // With the fix, this should clean up node3 from batcher's internal state + batcher.AddWork(nodeChange) + + // Wait for the batcher to process the removal and clean up the node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assert.False(c, batcher.IsConnected(node3.n.ID), "node3 should be disconnected from batcher") + }, 5*time.Second, 50*time.Millisecond, "waiting for node removal to be processed") + + t.Logf("Node %d connected in batcher after NodeRemoved: %v", node3.n.ID, batcher.IsConnected(node3.n.ID)) + + // Now queue changes that would have caused errors before the fix + // With the fix, these should NOT cause "node not found" errors + // because node3 was cleaned up when NodeRemoved was processed + batcher.AddWork(change.FullUpdate()) + batcher.AddWork(change.PolicyChange()) + + // Wait for work to be processed and verify no errors occurred + // With the fix, no new errors should occur because the deleted node + // was cleaned up from batcher state when NodeRemoved was processed + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var finalWorkErrors int64 + if lfb, ok := unwrapBatcher(batcher).(*LockFreeBatcher); ok { + finalWorkErrors = lfb.WorkErrors() + } + + newErrors := finalWorkErrors - initialWorkErrors + assert.Zero(c, newErrors, "Fix for #2924: should have no work errors after node deletion") + }, 5*time.Second, 100*time.Millisecond, "waiting for work processing to complete without errors") + + // Verify remaining nodes still work correctly + drainCh(node1.ch) + drainCh(node2.ch) + batcher.AddWork(change.NodeAdded(node1.n.ID)) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Node 1 and 2 should receive updates + stats1 := NodeStats{TotalUpdates: atomic.LoadInt64(&node1.updateCount)} + stats2 := NodeStats{TotalUpdates: atomic.LoadInt64(&node2.updateCount)} + assert.Positive(c, stats1.TotalUpdates, "node1 should have received updates") + assert.Positive(c, stats2.TotalUpdates, "node2 should have received updates") + }, 5*time.Second, 100*time.Millisecond, "waiting for remaining nodes to receive updates") + }) + } +} + +// unwrapBatcher extracts the underlying batcher from wrapper types. +func unwrapBatcher(b Batcher) Batcher { + if wrapper, ok := b.(*testBatcherWrapper); ok { + return unwrapBatcher(wrapper.Batcher) + } + + return b +} diff --git a/hscontrol/mapper/builder.go b/hscontrol/mapper/builder.go new file mode 100644 index 00000000..c666ff24 --- /dev/null +++ b/hscontrol/mapper/builder.go @@ -0,0 +1,298 @@ +package mapper + +import ( + "errors" + "net/netip" + "sort" + "time" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" + "tailscale.com/types/views" + "tailscale.com/util/multierr" +) + +// MapResponseBuilder provides a fluent interface for building tailcfg.MapResponse. +type MapResponseBuilder struct { + resp *tailcfg.MapResponse + mapper *mapper + nodeID types.NodeID + capVer tailcfg.CapabilityVersion + errs []error + + debugType debugType +} + +type debugType string + +const ( + fullResponseDebug debugType = "full" + selfResponseDebug debugType = "self" + changeResponseDebug debugType = "change" + policyResponseDebug debugType = "policy" +) + +// NewMapResponseBuilder creates a new builder with basic fields set. +func (m *mapper) NewMapResponseBuilder(nodeID types.NodeID) *MapResponseBuilder { + now := time.Now() + return &MapResponseBuilder{ + resp: &tailcfg.MapResponse{ + KeepAlive: false, + ControlTime: &now, + }, + mapper: m, + nodeID: nodeID, + errs: nil, + } +} + +// addError adds an error to the builder's error list. +func (b *MapResponseBuilder) addError(err error) { + if err != nil { + b.errs = append(b.errs, err) + } +} + +// hasErrors returns true if the builder has accumulated any errors. +func (b *MapResponseBuilder) hasErrors() bool { + return len(b.errs) > 0 +} + +// WithCapabilityVersion sets the capability version for the response. +func (b *MapResponseBuilder) WithCapabilityVersion(capVer tailcfg.CapabilityVersion) *MapResponseBuilder { + b.capVer = capVer + return b +} + +// WithSelfNode adds the requesting node to the response. +func (b *MapResponseBuilder) WithSelfNode() *MapResponseBuilder { + nv, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) + return b + } + + _, matchers := b.mapper.state.Filter() + + tailnode, err := nv.TailNode( + b.capVer, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(nv, b.mapper.state.GetNodePrimaryRoutes(id), matchers) + }, + b.mapper.cfg) + if err != nil { + b.addError(err) + return b + } + + b.resp.Node = tailnode + + return b +} + +func (b *MapResponseBuilder) WithDebugType(t debugType) *MapResponseBuilder { + if debugDumpMapResponsePath != "" { + b.debugType = t + } + + return b +} + +// WithDERPMap adds the DERP map to the response. +func (b *MapResponseBuilder) WithDERPMap() *MapResponseBuilder { + b.resp.DERPMap = b.mapper.state.DERPMap().AsStruct() + return b +} + +// WithDomain adds the domain configuration. +func (b *MapResponseBuilder) WithDomain() *MapResponseBuilder { + b.resp.Domain = b.mapper.cfg.Domain() + return b +} + +// WithCollectServicesDisabled sets the collect services flag to false. +func (b *MapResponseBuilder) WithCollectServicesDisabled() *MapResponseBuilder { + b.resp.CollectServices.Set(false) + return b +} + +// WithDebugConfig adds debug configuration +// It disables log tailing if the mapper's LogTail is not enabled. +func (b *MapResponseBuilder) WithDebugConfig() *MapResponseBuilder { + b.resp.Debug = &tailcfg.Debug{ + DisableLogTail: !b.mapper.cfg.LogTail.Enabled, + } + return b +} + +// WithSSHPolicy adds SSH policy configuration for the requesting node. +func (b *MapResponseBuilder) WithSSHPolicy() *MapResponseBuilder { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) + return b + } + + sshPolicy, err := b.mapper.state.SSHPolicy(node) + if err != nil { + b.addError(err) + return b + } + + b.resp.SSHPolicy = sshPolicy + + return b +} + +// WithDNSConfig adds DNS configuration for the requesting node. +func (b *MapResponseBuilder) WithDNSConfig() *MapResponseBuilder { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) + return b + } + + b.resp.DNSConfig = generateDNSConfig(b.mapper.cfg, node) + + return b +} + +// WithUserProfiles adds user profiles for the requesting node and given peers. +func (b *MapResponseBuilder) WithUserProfiles(peers views.Slice[types.NodeView]) *MapResponseBuilder { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) + return b + } + + b.resp.UserProfiles = generateUserProfiles(node, peers) + + return b +} + +// WithPacketFilters adds packet filter rules based on policy. +func (b *MapResponseBuilder) WithPacketFilters() *MapResponseBuilder { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + b.addError(errors.New("node not found")) + return b + } + + // FilterForNode returns rules already reduced to only those relevant for this node. + // For autogroup:self policies, it returns per-node compiled rules. + // For global policies, it returns the global filter reduced for this node. + filter, err := b.mapper.state.FilterForNode(node) + if err != nil { + b.addError(err) + return b + } + + // CapVer 81: 2023-11-17: MapResponse.PacketFilters (incremental packet filter updates) + // Currently, we do not send incremental package filters, however using the + // new PacketFilters field and "base" allows us to send a full update when we + // have to send an empty list, avoiding the hack in the else block. + b.resp.PacketFilters = map[string][]tailcfg.FilterRule{ + "base": filter, + } + + return b +} + +// WithPeers adds full peer list with policy filtering (for full map response). +func (b *MapResponseBuilder) WithPeers(peers views.Slice[types.NodeView]) *MapResponseBuilder { + tailPeers, err := b.buildTailPeers(peers) + if err != nil { + b.addError(err) + return b + } + + b.resp.Peers = tailPeers + + return b +} + +// WithPeerChanges adds changed peers with policy filtering (for incremental updates). +func (b *MapResponseBuilder) WithPeerChanges(peers views.Slice[types.NodeView]) *MapResponseBuilder { + tailPeers, err := b.buildTailPeers(peers) + if err != nil { + b.addError(err) + return b + } + + b.resp.PeersChanged = tailPeers + + return b +} + +// buildTailPeers converts views.Slice[types.NodeView] to []tailcfg.Node with policy filtering and sorting. +func (b *MapResponseBuilder) buildTailPeers(peers views.Slice[types.NodeView]) ([]*tailcfg.Node, error) { + node, ok := b.mapper.state.GetNodeByID(b.nodeID) + if !ok { + return nil, errors.New("node not found") + } + + // Get unreduced matchers for peer relationship determination. + // MatchersForNode returns unreduced matchers that include all rules where the node + // could be either source or destination. This is different from FilterForNode which + // returns reduced rules for packet filtering (only rules where node is destination). + matchers, err := b.mapper.state.MatchersForNode(node) + if err != nil { + return nil, err + } + + // If there are filter rules present, see if there are any nodes that cannot + // access each-other at all and remove them from the peers. + var changedViews views.Slice[types.NodeView] + if len(matchers) > 0 { + changedViews = policy.ReduceNodes(node, peers, matchers) + } else { + changedViews = peers + } + + tailPeers, err := types.TailNodes( + changedViews, b.capVer, + func(id types.NodeID) []netip.Prefix { + return policy.ReduceRoutes(node, b.mapper.state.GetNodePrimaryRoutes(id), matchers) + }, + b.mapper.cfg) + if err != nil { + return nil, err + } + + // Peers is always returned sorted by Node.ID. + sort.SliceStable(tailPeers, func(x, y int) bool { + return tailPeers[x].ID < tailPeers[y].ID + }) + + return tailPeers, nil +} + +// WithPeerChangedPatch adds peer change patches. +func (b *MapResponseBuilder) WithPeerChangedPatch(changes []*tailcfg.PeerChange) *MapResponseBuilder { + b.resp.PeersChangedPatch = changes + return b +} + +// WithPeersRemoved adds removed peer IDs. +func (b *MapResponseBuilder) WithPeersRemoved(removedIDs ...types.NodeID) *MapResponseBuilder { + var tailscaleIDs []tailcfg.NodeID + for _, id := range removedIDs { + tailscaleIDs = append(tailscaleIDs, id.NodeID()) + } + b.resp.PeersRemoved = tailscaleIDs + + return b +} + +// Build finalizes the response and returns marshaled bytes +func (b *MapResponseBuilder) Build() (*tailcfg.MapResponse, error) { + if len(b.errs) > 0 { + return nil, multierr.New(b.errs...) + } + if debugDumpMapResponsePath != "" { + writeDebugMapResponse(b.resp, b.debugType, b.nodeID) + } + + return b.resp, nil +} diff --git a/hscontrol/mapper/builder_test.go b/hscontrol/mapper/builder_test.go new file mode 100644 index 00000000..978b2c0e --- /dev/null +++ b/hscontrol/mapper/builder_test.go @@ -0,0 +1,347 @@ +package mapper + +import ( + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/state" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +func TestMapResponseBuilder_Basic(t *testing.T) { + cfg := &types.Config{ + BaseDomain: "example.com", + LogTail: types.LogTailConfig{ + Enabled: true, + }, + } + + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID) + + // Test basic builder creation + assert.NotNil(t, builder) + assert.Equal(t, nodeID, builder.nodeID) + assert.NotNil(t, builder.resp) + assert.False(t, builder.resp.KeepAlive) + assert.NotNil(t, builder.resp.ControlTime) + assert.WithinDuration(t, time.Now(), *builder.resp.ControlTime, time.Second) +} + +func TestMapResponseBuilder_WithCapabilityVersion(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + capVer := tailcfg.CapabilityVersion(42) + + builder := m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer) + + assert.Equal(t, capVer, builder.capVer) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithDomain(t *testing.T) { + domain := "test.example.com" + cfg := &types.Config{ + ServerURL: "https://test.example.com", + BaseDomain: domain, + } + + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithDomain() + + assert.Equal(t, domain, builder.resp.Domain) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithCollectServicesDisabled(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithCollectServicesDisabled() + + value, isSet := builder.resp.CollectServices.Get() + assert.True(t, isSet) + assert.False(t, value) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithDebugConfig(t *testing.T) { + tests := []struct { + name string + logTailEnabled bool + expected bool + }{ + { + name: "LogTail enabled", + logTailEnabled: true, + expected: false, // DisableLogTail should be false when LogTail is enabled + }, + { + name: "LogTail disabled", + logTailEnabled: false, + expected: true, // DisableLogTail should be true when LogTail is disabled + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := &types.Config{ + LogTail: types.LogTailConfig{ + Enabled: tt.logTailEnabled, + }, + } + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithDebugConfig() + + require.NotNil(t, builder.resp.Debug) + assert.Equal(t, tt.expected, builder.resp.Debug.DisableLogTail) + assert.False(t, builder.hasErrors()) + }) + } +} + +func TestMapResponseBuilder_WithPeerChangedPatch(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + changes := []*tailcfg.PeerChange{ + { + NodeID: 123, + DERPRegion: 1, + }, + { + NodeID: 456, + DERPRegion: 2, + }, + } + + builder := m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch(changes) + + assert.Equal(t, changes, builder.resp.PeersChangedPatch) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_WithPeersRemoved(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + removedID1 := types.NodeID(123) + removedID2 := types.NodeID(456) + + builder := m.NewMapResponseBuilder(nodeID). + WithPeersRemoved(removedID1, removedID2) + + expected := []tailcfg.NodeID{ + removedID1.NodeID(), + removedID2.NodeID(), + } + assert.Equal(t, expected, builder.resp.PeersRemoved) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_ErrorHandling(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + // Simulate an error in the builder + builder := m.NewMapResponseBuilder(nodeID) + builder.addError(assert.AnError) + + // All subsequent calls should continue to work and accumulate errors + result := builder. + WithDomain(). + WithCollectServicesDisabled(). + WithDebugConfig() + + assert.True(t, result.hasErrors()) + assert.Len(t, result.errs, 1) + assert.Equal(t, assert.AnError, result.errs[0]) + + // Build should return the error + data, err := result.Build() + assert.Nil(t, data) + assert.Error(t, err) +} + +func TestMapResponseBuilder_ChainedCalls(t *testing.T) { + domain := "chained.example.com" + cfg := &types.Config{ + ServerURL: "https://chained.example.com", + BaseDomain: domain, + LogTail: types.LogTailConfig{ + Enabled: false, + }, + } + + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + capVer := tailcfg.CapabilityVersion(99) + + builder := m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer). + WithDomain(). + WithCollectServicesDisabled(). + WithDebugConfig() + + // Verify all fields are set correctly + assert.Equal(t, capVer, builder.capVer) + assert.Equal(t, domain, builder.resp.Domain) + value, isSet := builder.resp.CollectServices.Get() + assert.True(t, isSet) + assert.False(t, value) + assert.NotNil(t, builder.resp.Debug) + assert.True(t, builder.resp.Debug.DisableLogTail) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_MultipleWithPeersRemoved(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + removedID1 := types.NodeID(100) + removedID2 := types.NodeID(200) + + // Test calling WithPeersRemoved multiple times + builder := m.NewMapResponseBuilder(nodeID). + WithPeersRemoved(removedID1). + WithPeersRemoved(removedID2) + + // Second call should overwrite the first + expected := []tailcfg.NodeID{removedID2.NodeID()} + assert.Equal(t, expected, builder.resp.PeersRemoved) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_EmptyPeerChangedPatch(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch([]*tailcfg.PeerChange{}) + + assert.Empty(t, builder.resp.PeersChangedPatch) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_NilPeerChangedPatch(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + builder := m.NewMapResponseBuilder(nodeID). + WithPeerChangedPatch(nil) + + assert.Nil(t, builder.resp.PeersChangedPatch) + assert.False(t, builder.hasErrors()) +} + +func TestMapResponseBuilder_MultipleErrors(t *testing.T) { + cfg := &types.Config{} + mockState := &state.State{} + m := &mapper{ + cfg: cfg, + state: mockState, + } + + nodeID := types.NodeID(1) + + // Create a builder and add multiple errors + builder := m.NewMapResponseBuilder(nodeID) + builder.addError(assert.AnError) + builder.addError(assert.AnError) + builder.addError(nil) // This should be ignored + + // All subsequent calls should continue to work + result := builder. + WithDomain(). + WithCollectServicesDisabled() + + assert.True(t, result.hasErrors()) + assert.Len(t, result.errs, 2) // nil error should be ignored + + // Build should return a multierr + data, err := result.Build() + assert.Nil(t, data) + assert.Error(t, err) + + // The error should contain information about multiple errors + assert.Contains(t, err.Error(), "multiple errors") +} diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 0a848b8d..616d470f 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -1,7 +1,6 @@ package mapper import ( - "encoding/binary" "encoding/json" "fmt" "io/fs" @@ -9,31 +8,24 @@ import ( "os" "path" "slices" - "sort" + "strconv" "strings" - "sync" - "sync/atomic" "time" - mapset "github.com/deckarep/golang-set/v2" - "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/klauspost/compress/zstd" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/rs/zerolog/log" - "github.com/samber/lo" - "golang.org/x/exp/maps" "tailscale.com/envknob" - "tailscale.com/smallzstd" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/types/views" ) const ( - nextDNSDoHPrefix = "https://dns.nextdns.io" - reservedResponseHeaderSize = 4 - mapperIDLength = 8 - debugMapResponsePerm = 0o755 + nextDNSDoHPrefix = "https://dns.nextdns.io" + mapperIDLength = 8 + debugMapResponsePerm = 0o755 ) var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH") @@ -49,24 +41,13 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_ // - Create a "minifier" that removes info not needed for the node // - some sort of batching, wait for 5 or 60 seconds before sending -type Mapper struct { +type mapper struct { // Configuration - // TODO(kradalby): figure out if this is the format we want this in - derpMap *tailcfg.DERPMap - baseDomain string - dnsCfg *tailcfg.DNSConfig - logtail bool - randomClientPort bool + state *state.State + cfg *types.Config + batcher Batcher - uid string created time.Time - seq uint64 - - // Map isnt concurrency safe, so we need to ensure - // only one func is accessing it over time. - mu sync.Mutex - peers map[uint64]*types.Node - patches map[uint64][]patch } type patch struct { @@ -74,107 +55,60 @@ type patch struct { change *tailcfg.PeerChange } -func NewMapper( - node *types.Node, - peers types.Nodes, - derpMap *tailcfg.DERPMap, - baseDomain string, - dnsCfg *tailcfg.DNSConfig, - logtail bool, - randomClientPort bool, -) *Mapper { - log.Debug(). - Caller(). - Str("node", node.Hostname). - Msg("creating new mapper") +func newMapper( + cfg *types.Config, + state *state.State, +) *mapper { + // uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) - uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) + return &mapper{ + state: state, + cfg: cfg, - return &Mapper{ - derpMap: derpMap, - baseDomain: baseDomain, - dnsCfg: dnsCfg, - logtail: logtail, - randomClientPort: randomClientPort, - - uid: uid, created: time.Now(), - seq: 0, - - // TODO: populate - peers: peers.IDMap(), - patches: make(map[uint64][]patch), } } -func (m *Mapper) String() string { - return fmt.Sprintf("Mapper: { seq: %d, uid: %s, created: %s }", m.seq, m.uid, m.created) -} - +// generateUserProfiles creates user profiles for MapResponse. func generateUserProfiles( - node *types.Node, - peers types.Nodes, - baseDomain string, + node types.NodeView, + peers views.Slice[types.NodeView], ) []tailcfg.UserProfile { - userMap := make(map[string]types.User) - userMap[node.User.Name] = node.User - for _, peer := range peers { - userMap[peer.User.Name] = peer.User // not worth checking if already is there + userMap := make(map[uint]*types.UserView) + ids := make([]uint, 0, len(userMap)) + user := node.Owner() + userID := user.Model().ID + userMap[userID] = &user + ids = append(ids, userID) + for _, peer := range peers.All() { + peerUser := peer.Owner() + peerUserID := peerUser.Model().ID + userMap[peerUserID] = &peerUser + ids = append(ids, peerUserID) } - profiles := []tailcfg.UserProfile{} - for _, user := range userMap { - displayName := user.Name - - if baseDomain != "" { - displayName = fmt.Sprintf("%s@%s", user.Name, baseDomain) + slices.Sort(ids) + ids = slices.Compact(ids) + var profiles []tailcfg.UserProfile + for _, id := range ids { + if userMap[id] != nil { + profiles = append(profiles, userMap[id].TailscaleUserProfile()) } - - profiles = append(profiles, - tailcfg.UserProfile{ - ID: tailcfg.UserID(user.ID), - LoginName: user.Name, - DisplayName: displayName, - }) } return profiles } func generateDNSConfig( - base *tailcfg.DNSConfig, - baseDomain string, - node *types.Node, - peers types.Nodes, + cfg *types.Config, + node types.NodeView, ) *tailcfg.DNSConfig { - dnsConfig := base.Clone() - - // if MagicDNS is enabled - if base != nil && base.Proxied { - // Only inject the Search Domain of the current user - // shared nodes should use their full FQDN - dnsConfig.Domains = append( - dnsConfig.Domains, - fmt.Sprintf( - "%s.%s", - node.User.Name, - baseDomain, - ), - ) - - userSet := mapset.NewSet[types.User]() - userSet.Add(node.User) - for _, p := range peers { - userSet.Add(p.User) - } - for _, user := range userSet.ToSlice() { - dnsRoute := fmt.Sprintf("%v.%v", user.Name, baseDomain) - dnsConfig.Routes[dnsRoute] = nil - } - } else { - dnsConfig = base + if cfg.TailcfgDNSConfig == nil { + return nil } + dnsConfig := cfg.TailcfgDNSConfig.Clone() + addNextDNSMetadata(dnsConfig.Resolvers, node) return dnsConfig @@ -187,16 +121,16 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node types.NodeView) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { attrs := url.Values{ - "device_name": []string{node.Hostname}, - "device_model": []string{node.Hostinfo.OS}, + "device_name": []string{node.Hostname()}, + "device_model": []string{node.Hostinfo().OS()}, } - if len(node.IPAddresses) > 0 { - attrs.Add("device_ip", node.IPAddresses[0].String()) + if len(node.IPs()) > 0 { + attrs.Add("device_ip", node.IPs()[0].String()) } resolver.Addr = fmt.Sprintf("%s?%s", resolver.Addr, attrs.Encode()) @@ -204,466 +138,252 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { } } -// fullMapResponse creates a complete MapResponse for a node. -// It is a separate function to make testing easier. -func (m *Mapper) fullMapResponse( - node *types.Node, - pol *policy.ACLPolicy, +// fullMapResponse returns a MapResponse for the given node. +func (m *mapper) fullMapResponse( + nodeID types.NodeID, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - peers := nodeMapToList(m.peers) + peers := m.state.ListPeers(nodeID) - resp, err := m.baseWithConfigMapResponse(node, pol, capVer) + return m.NewMapResponseBuilder(nodeID). + WithDebugType(fullResponseDebug). + WithCapabilityVersion(capVer). + WithSelfNode(). + WithDERPMap(). + WithDomain(). + WithCollectServicesDisabled(). + WithDebugConfig(). + WithSSHPolicy(). + WithDNSConfig(). + WithUserProfiles(peers). + WithPacketFilters(). + WithPeers(peers). + Build() +} + +func (m *mapper) selfMapResponse( + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, +) (*tailcfg.MapResponse, error) { + ma, err := m.NewMapResponseBuilder(nodeID). + WithDebugType(selfResponseDebug). + WithCapabilityVersion(capVer). + WithSelfNode(). + Build() if err != nil { return nil, err } - err = appendPeerChanges( - resp, - pol, - node, - capVer, - peers, - peers, - m.baseDomain, - m.dnsCfg, - m.randomClientPort, - ) - if err != nil { - return nil, err + // Set the peers to nil, to ensure the node does not think + // its getting a new list. + ma.Peers = nil + + return ma, err +} + +// policyChangeResponse creates a MapResponse for policy changes. +// It sends: +// - PeersRemoved for peers that are no longer visible after the policy change +// - PeersChanged for remaining peers (their AllowedIPs may have changed due to policy) +// - Updated PacketFilters +// - Updated SSHPolicy (SSH rules may reference users/groups that changed) +// - Optionally, the node's own self info (when includeSelf is true) +// This avoids the issue where an empty Peers slice is interpreted by Tailscale +// clients as "no change" rather than "no peers". +// When includeSelf is true, the node's self info is included so that a node +// whose own attributes changed (e.g., tags via admin API) sees its updated +// self info along with the new packet filters. +func (m *mapper) policyChangeResponse( + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, + removedPeers []tailcfg.NodeID, + currentPeers views.Slice[types.NodeView], + includeSelf bool, +) (*tailcfg.MapResponse, error) { + builder := m.NewMapResponseBuilder(nodeID). + WithDebugType(policyResponseDebug). + WithCapabilityVersion(capVer). + WithPacketFilters(). + WithSSHPolicy() + + if includeSelf { + builder = builder.WithSelfNode() } - return resp, nil -} - -// FullMapResponse returns a MapResponse for the given node. -func (m *Mapper) FullMapResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, - pol *policy.ACLPolicy, -) ([]byte, error) { - m.mu.Lock() - defer m.mu.Unlock() - - peers := maps.Keys(m.peers) - peersWithPatches := maps.Keys(m.patches) - slices.Sort(peers) - slices.Sort(peersWithPatches) - - if len(peersWithPatches) > 0 { - log.Debug(). - Str("node", node.Hostname). - Uints64("peers", peers). - Uints64("pending_patches", peersWithPatches). - Msgf("node requested full map response, but has pending patches") - } - - resp, err := m.fullMapResponse(node, pol, mapRequest.Version) - if err != nil { - return nil, err - } - - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) -} - -// LiteMapResponse returns a MapResponse for the given node. -// Lite means that the peers has been omitted, this is intended -// to be used to answer MapRequests with OmitPeers set to true. -func (m *Mapper) LiteMapResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, - pol *policy.ACLPolicy, -) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, pol, mapRequest.Version) - if err != nil { - return nil, err - } - - return m.marshalMapResponse(mapRequest, resp, node, mapRequest.Compress) -} - -func (m *Mapper) KeepAliveResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, -) ([]byte, error) { - resp := m.baseMapResponse() - resp.KeepAlive = true - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) -} - -func (m *Mapper) DERPMapResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, - derpMap *tailcfg.DERPMap, -) ([]byte, error) { - m.derpMap = derpMap - - resp := m.baseMapResponse() - resp.DERPMap = derpMap - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) -} - -func (m *Mapper) PeerChangedResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, - changed types.Nodes, - pol *policy.ACLPolicy, - messages ...string, -) ([]byte, error) { - m.mu.Lock() - defer m.mu.Unlock() - - // Update our internal map. - for _, node := range changed { - if patches, ok := m.patches[node.ID]; ok { - // preserve online status in case the patch has an outdated one - online := node.IsOnline - - for _, p := range patches { - // TODO(kradalby): Figure if this needs to be sorted by timestamp - node.ApplyPeerChange(p.change) - } - - // Ensure the patches are not applied again later - delete(m.patches, node.ID) - - node.IsOnline = online + if len(removedPeers) > 0 { + // Convert tailcfg.NodeID to types.NodeID for WithPeersRemoved + removedIDs := make([]types.NodeID, len(removedPeers)) + for i, id := range removedPeers { + removedIDs[i] = types.NodeID(id) //nolint:gosec // NodeID types are equivalent } - m.peers[node.ID] = node + builder.WithPeersRemoved(removedIDs...) } - resp := m.baseMapResponse() - - err := appendPeerChanges( - &resp, - pol, - node, - mapRequest.Version, - nodeMapToList(m.peers), - changed, - m.baseDomain, - m.dnsCfg, - m.randomClientPort, - ) - if err != nil { - return nil, err + // Send remaining peers in PeersChanged - their AllowedIPs may have + // changed due to the policy update (e.g., different routes allowed). + if currentPeers.Len() > 0 { + builder.WithPeerChanges(currentPeers) } - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...) + return builder.Build() } -// PeerChangedPatchResponse creates a patch MapResponse with -// incoming update from a state change. -func (m *Mapper) PeerChangedPatchResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, - changed []*tailcfg.PeerChange, - pol *policy.ACLPolicy, -) ([]byte, error) { - m.mu.Lock() - defer m.mu.Unlock() +// buildFromChange builds a MapResponse from a change.Change specification. +// This provides fine-grained control over what gets included in the response. +func (m *mapper) buildFromChange( + nodeID types.NodeID, + capVer tailcfg.CapabilityVersion, + resp *change.Change, +) (*tailcfg.MapResponse, error) { + if resp.IsEmpty() { + return nil, nil //nolint:nilnil // Empty response means nothing to send, not an error + } - sendUpdate := false - // patch the internal map - for _, change := range changed { - if peer, ok := m.peers[uint64(change.NodeID)]; ok { - peer.ApplyPeerChange(change) - sendUpdate = true - } else { - log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname) + // If this is a self-update (the changed node is the receiving node), + // send a self-update response to ensure the node sees its own changes. + if resp.OriginNode != 0 && resp.OriginNode == nodeID { + return m.selfMapResponse(nodeID, capVer) + } - p := patch{ - timestamp: time.Now(), - change: change, - } + builder := m.NewMapResponseBuilder(nodeID). + WithCapabilityVersion(capVer). + WithDebugType(changeResponseDebug) - if patches, ok := m.patches[uint64(change.NodeID)]; ok { - patches := append(patches, p) + if resp.IncludeSelf { + builder.WithSelfNode() + } - m.patches[uint64(change.NodeID)] = patches - } else { - m.patches[uint64(change.NodeID)] = []patch{p} - } + if resp.IncludeDERPMap { + builder.WithDERPMap() + } + + if resp.IncludeDNS { + builder.WithDNSConfig() + } + + if resp.IncludeDomain { + builder.WithDomain() + } + + if resp.IncludePolicy { + builder.WithPacketFilters() + builder.WithSSHPolicy() + } + + if resp.SendAllPeers { + peers := m.state.ListPeers(nodeID) + builder.WithUserProfiles(peers) + builder.WithPeers(peers) + } else { + if len(resp.PeersChanged) > 0 { + peers := m.state.ListPeers(nodeID, resp.PeersChanged...) + builder.WithUserProfiles(peers) + builder.WithPeerChanges(peers) + } + + if len(resp.PeersRemoved) > 0 { + builder.WithPeersRemoved(resp.PeersRemoved...) } } - if !sendUpdate { + if len(resp.PeerPatches) > 0 { + builder.WithPeerChangedPatch(resp.PeerPatches) + } + + return builder.Build() +} + +func writeDebugMapResponse( + resp *tailcfg.MapResponse, + t debugType, + nodeID types.NodeID, +) { + body, err := json.MarshalIndent(resp, "", " ") + if err != nil { + panic(err) + } + + perms := fs.FileMode(debugMapResponsePerm) + mPath := path.Join(debugDumpMapResponsePath, fmt.Sprintf("%d", nodeID)) + err = os.MkdirAll(mPath, perms) + if err != nil { + panic(err) + } + + now := time.Now().Format("2006-01-02T15-04-05.999999999") + + mapResponsePath := path.Join( + mPath, + fmt.Sprintf("%s-%s.json", now, t), + ) + + log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) + err = os.WriteFile(mapResponsePath, body, perms) + if err != nil { + panic(err) + } +} + +func (m *mapper) debugMapResponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + if debugDumpMapResponsePath == "" { return nil, nil } - resp := m.baseMapResponse() - resp.PeersChangedPatch = changed - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) + return ReadMapResponsesFromDirectory(debugDumpMapResponsePath) } -// TODO(kradalby): We need some integration tests for this. -func (m *Mapper) PeerRemovedResponse( - mapRequest tailcfg.MapRequest, - node *types.Node, - removed []tailcfg.NodeID, -) ([]byte, error) { - m.mu.Lock() - defer m.mu.Unlock() - - // Some nodes might have been removed already - // so we dont want to ask downstream to remove - // twice, than can cause a panic in tailscaled. - notYetRemoved := []tailcfg.NodeID{} - - // remove from our internal map - for _, id := range removed { - if _, ok := m.peers[uint64(id)]; ok { - notYetRemoved = append(notYetRemoved, id) - } - - delete(m.peers, uint64(id)) - delete(m.patches, uint64(id)) - } - - resp := m.baseMapResponse() - resp.PeersRemoved = notYetRemoved - - return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) -} - -func (m *Mapper) marshalMapResponse( - mapRequest tailcfg.MapRequest, - resp *tailcfg.MapResponse, - node *types.Node, - compression string, - messages ...string, -) ([]byte, error) { - atomic.AddUint64(&m.seq, 1) - - jsonBody, err := json.Marshal(resp) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot marshal map response") - } - - if debugDumpMapResponsePath != "" { - data := map[string]interface{}{ - "Messages": messages, - "MapRequest": mapRequest, - "MapResponse": resp, - } - - responseType := "keepalive" - - switch { - case resp.Peers != nil && len(resp.Peers) > 0: - responseType = "full" - case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: - responseType = "changed" - case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0: - responseType = "patch" - case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0: - responseType = "removed" - } - - body, err := json.MarshalIndent(data, "", " ") - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot marshal map response") - } - - perms := fs.FileMode(debugMapResponsePerm) - mPath := path.Join(debugDumpMapResponsePath, node.Hostname) - err = os.MkdirAll(mPath, perms) - if err != nil { - panic(err) - } - - now := time.Now().UnixNano() - - mapResponsePath := path.Join( - mPath, - fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType), - ) - - log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) - err = os.WriteFile(mapResponsePath, body, perms) - if err != nil { - panic(err) - } - } - - var respBody []byte - if compression == util.ZstdCompression { - respBody = zstdEncode(jsonBody) - } else { - respBody = jsonBody - } - - data := make([]byte, reservedResponseHeaderSize) - binary.LittleEndian.PutUint32(data, uint32(len(respBody))) - data = append(data, respBody...) - - return data, nil -} - -func zstdEncode(in []byte) []byte { - encoder, ok := zstdEncoderPool.Get().(*zstd.Encoder) - if !ok { - panic("invalid type in sync pool") - } - out := encoder.EncodeAll(in, nil) - _ = encoder.Close() - zstdEncoderPool.Put(encoder) - - return out -} - -var zstdEncoderPool = &sync.Pool{ - New: func() any { - encoder, err := smallzstd.NewEncoder( - nil, - zstd.WithEncoderLevel(zstd.SpeedFastest)) - if err != nil { - panic(err) - } - - return encoder - }, -} - -// baseMapResponse returns a tailcfg.MapResponse with -// KeepAlive false and ControlTime set to now. -func (m *Mapper) baseMapResponse() tailcfg.MapResponse { - now := time.Now() - - resp := tailcfg.MapResponse{ - KeepAlive: false, - ControlTime: &now, - // TODO(kradalby): Implement PingRequest? - } - - return resp -} - -// baseWithConfigMapResponse returns a tailcfg.MapResponse struct -// with the basic configuration from headscale set. -// It is used in for bigger updates, such as full and lite, not -// incremental. -func (m *Mapper) baseWithConfigMapResponse( - node *types.Node, - pol *policy.ACLPolicy, - capVer tailcfg.CapabilityVersion, -) (*tailcfg.MapResponse, error) { - resp := m.baseMapResponse() - - tailnode, err := tailNode(node, capVer, pol, m.dnsCfg, m.baseDomain, m.randomClientPort) +func ReadMapResponsesFromDirectory(dir string) (map[types.NodeID][]tailcfg.MapResponse, error) { + nodes, err := os.ReadDir(dir) if err != nil { return nil, err } - resp.Node = tailnode - - resp.DERPMap = m.derpMap - - resp.Domain = m.baseDomain - - // Do not instruct clients to collect services we do not - // support or do anything with them - resp.CollectServices = "false" - - resp.KeepAlive = false - - resp.Debug = &tailcfg.Debug{ - DisableLogTail: !m.logtail, - } - - return &resp, nil -} - -func nodeMapToList(nodes map[uint64]*types.Node) types.Nodes { - ret := make(types.Nodes, 0) + result := make(map[types.NodeID][]tailcfg.MapResponse) for _, node := range nodes { - ret = append(ret, node) + if !node.IsDir() { + continue + } + + nodeIDu, err := strconv.ParseUint(node.Name(), 10, 64) + if err != nil { + log.Error().Err(err).Msgf("Parsing node ID from dir %s", node.Name()) + continue + } + + nodeID := types.NodeID(nodeIDu) + + files, err := os.ReadDir(path.Join(dir, node.Name())) + if err != nil { + log.Error().Err(err).Msgf("Reading dir %s", node.Name()) + continue + } + + slices.SortStableFunc(files, func(a, b fs.DirEntry) int { + return strings.Compare(a.Name(), b.Name()) + }) + + for _, file := range files { + if file.IsDir() || !strings.HasSuffix(file.Name(), ".json") { + continue + } + + body, err := os.ReadFile(path.Join(dir, node.Name(), file.Name())) + if err != nil { + log.Error().Err(err).Msgf("Reading file %s", file.Name()) + continue + } + + var resp tailcfg.MapResponse + err = json.Unmarshal(body, &resp) + if err != nil { + log.Error().Err(err).Msgf("Unmarshalling file %s", file.Name()) + continue + } + + result[nodeID] = append(result[nodeID], resp) + } } - return ret -} - -func filterExpiredAndNotReady(peers types.Nodes) types.Nodes { - return lo.Filter(peers, func(item *types.Node, index int) bool { - // Filter out nodes that are expired OR - // nodes that has no endpoints, this typically means they have - // registered, but are not configured. - return !item.IsExpired() || len(item.Endpoints) > 0 - }) -} - -// appendPeerChanges mutates a tailcfg.MapResponse with all the -// necessary changes when peers have changed. -func appendPeerChanges( - resp *tailcfg.MapResponse, - - pol *policy.ACLPolicy, - node *types.Node, - capVer tailcfg.CapabilityVersion, - peers types.Nodes, - changed types.Nodes, - baseDomain string, - dnsCfg *tailcfg.DNSConfig, - randomClientPort bool, -) error { - fullChange := len(peers) == len(changed) - - rules, sshPolicy, err := policy.GenerateFilterAndSSHRules( - pol, - node, - peers, - ) - if err != nil { - return err - } - - // Filter out peers that have expired. - changed = filterExpiredAndNotReady(changed) - - // If there are filter rules present, see if there are any nodes that cannot - // access eachother at all and remove them from the peers. - if len(rules) > 0 { - changed = policy.FilterNodesByACL(node, changed, rules) - } - - profiles := generateUserProfiles(node, changed, baseDomain) - - dnsConfig := generateDNSConfig( - dnsCfg, - baseDomain, - node, - peers, - ) - - tailPeers, err := tailNodes(changed, capVer, pol, dnsCfg, baseDomain, randomClientPort) - if err != nil { - return err - } - - // Peers is always returned sorted by Node.ID. - sort.SliceStable(tailPeers, func(x, y int) bool { - return tailPeers[x].ID < tailPeers[y].ID - }) - - if fullChange { - resp.Peers = tailPeers - } else { - resp.PeersChanged = tailPeers - } - resp.DNSConfig = dnsConfig - resp.PacketFilter = policy.ReduceFilterRules(node, rules) - resp.UserProfiles = profiles - resp.SSHPolicy = sshPolicy - - return nil + return result, nil } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index bcc17dd4..1bafd135 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -3,61 +3,23 @@ package mapper import ( "fmt" "net/netip" + "slices" "testing" - "time" - "github.com/davecgh/go-spew/spew" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" - "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" - "tailscale.com/types/key" + "tailscale.com/types/ptr" ) -func (s *Suite) TestGetMapResponseUserProfiles(c *check.C) { - mach := func(hostname, username string, userid uint) *types.Node { - return &types.Node{ - Hostname: hostname, - UserID: userid, - User: types.User{ - Name: username, - }, - } - } - - nodeInShared1 := mach("test_get_shared_nodes_1", "user1", 1) - nodeInShared2 := mach("test_get_shared_nodes_2", "user2", 2) - nodeInShared3 := mach("test_get_shared_nodes_3", "user3", 3) - node2InShared1 := mach("test_get_shared_nodes_4", "user1", 1) - - userProfiles := generateUserProfiles( - nodeInShared1, - types.Nodes{ - nodeInShared2, nodeInShared3, node2InShared1, - }, - "", - ) - - c.Assert(len(userProfiles), check.Equals, 3) - - users := []string{ - "user1", "user2", "user3", - } - - for _, user := range users { - found := false - for _, userProfile := range userProfiles { - if userProfile.DisplayName == user { - found = true - - break - } - } - c.Assert(found, check.Equals, true) - } +var iap = func(ipStr string) *netip.Addr { + ip := netip.MustParseAddr(ipStr) + return &ip } func TestDNSConfigMapResponse(t *testing.T) { @@ -68,14 +30,9 @@ func TestDNSConfigMapResponse(t *testing.T) { { magicDNS: true, want: &tailcfg.DNSConfig{ - Routes: map[string][]*dnstype.Resolver{ - "shared1.foobar.headscale.net": {}, - "shared2.foobar.headscale.net": {}, - "shared3.foobar.headscale.net": {}, - }, + Routes: map[string][]*dnstype.Resolver{}, Domains: []string{ "foobar.headscale.net", - "shared1.foobar.headscale.net", }, Proxied: true, }, @@ -94,8 +51,8 @@ func TestDNSConfigMapResponse(t *testing.T) { mach := func(hostname, username string, userid uint) *types.Node { return &types.Node{ Hostname: hostname, - UserID: userid, - User: types.User{ + UserID: ptr.To(userid), + User: &types.User{ Name: username, }, } @@ -110,22 +67,12 @@ func TestDNSConfigMapResponse(t *testing.T) { } nodeInShared1 := mach("test_get_shared_nodes_1", "shared1", 1) - nodeInShared2 := mach("test_get_shared_nodes_2", "shared2", 2) - nodeInShared3 := mach("test_get_shared_nodes_3", "shared3", 3) - node2InShared1 := mach("test_get_shared_nodes_4", "shared1", 1) - - peersOfNodeInShared1 := types.Nodes{ - nodeInShared1, - nodeInShared2, - nodeInShared3, - node2InShared1, - } got := generateDNSConfig( - &dnsConfigOrig, - baseDomain, - nodeInShared1, - peersOfNodeInShared1, + &types.Config{ + TailcfgDNSConfig: &dnsConfigOrig, + }, + nodeInShared1.View(), ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { @@ -135,366 +82,89 @@ func TestDNSConfigMapResponse(t *testing.T) { } } -func Test_fullMapResponse(t *testing.T) { - mustNK := func(str string) key.NodePublic { - var k key.NodePublic - _ = k.UnmarshalText([]byte(str)) - - return k - } - - mustDK := func(str string) key.DiscoPublic { - var k key.DiscoPublic - _ = k.UnmarshalText([]byte(str)) - - return k - } - - mustMK := func(str string) key.MachinePublic { - var k key.MachinePublic - _ = k.UnmarshalText([]byte(str)) - - return k - } - - hiview := func(hoin tailcfg.Hostinfo) tailcfg.HostinfoView { - return hoin.View() - } - - created := time.Date(2009, time.November, 10, 23, 0, 0, 0, time.UTC) - lastSeen := time.Date(2009, time.November, 10, 23, 9, 0, 0, time.UTC) - expire := time.Date(2500, time.November, 11, 23, 0, 0, 0, time.UTC) - - mini := &types.Node{ - ID: 0, - MachineKey: mustMK( - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - ), - NodeKey: mustNK( - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - ), - DiscoKey: mustDK( - "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - ), - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - Hostname: "mini", - GivenName: "mini", - UserID: 0, - User: types.User{Name: "mini"}, - ForcedTags: []string{}, - AuthKeyID: 0, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, - Hostinfo: &tailcfg.Hostinfo{}, - Routes: []types.Route{ - { - Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")), - Advertised: true, - Enabled: true, - IsPrimary: false, - }, - { - Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")), - Advertised: true, - Enabled: true, - IsPrimary: true, - }, - { - Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")), - Advertised: true, - Enabled: false, - IsPrimary: true, - }, - }, - CreatedAt: created, - } - - tailMini := &tailcfg.Node{ - ID: 0, - StableID: "0", - Name: "mini", - User: 0, - Key: mustNK( - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - ), - KeyExpiry: expire, - Machine: mustMK( - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - ), - DiscoKey: mustDK( - "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - ), - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, - AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.1/32"), - netip.MustParsePrefix("0.0.0.0/0"), - netip.MustParsePrefix("192.168.0.0/24"), - }, - DERP: "127.3.3.40:0", - Hostinfo: hiview(tailcfg.Hostinfo{}), - Created: created, - Tags: []string{}, - PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, - LastSeen: &lastSeen, - MachineAuthorized: true, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityFileSharing, - tailcfg.CapabilityAdmin, - tailcfg.CapabilitySSH, - tailcfg.NodeAttrDisableUPnP, - }, - } - - peer1 := &types.Node{ - ID: 1, - MachineKey: mustMK( - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - ), - NodeKey: mustNK( - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - ), - DiscoKey: mustDK( - "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - ), - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - Hostname: "peer1", - GivenName: "peer1", - UserID: 0, - User: types.User{Name: "mini"}, - ForcedTags: []string{}, - LastSeen: &lastSeen, - Expiry: &expire, - Hostinfo: &tailcfg.Hostinfo{}, - Routes: []types.Route{}, - CreatedAt: created, - } - - tailPeer1 := &tailcfg.Node{ - ID: 1, - StableID: "1", - Name: "peer1", - Key: mustNK( - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - ), - KeyExpiry: expire, - Machine: mustMK( - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - ), - DiscoKey: mustDK( - "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - ), - Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, - AllowedIPs: []netip.Prefix{netip.MustParsePrefix("100.64.0.2/32")}, - DERP: "127.3.3.40:0", - Hostinfo: hiview(tailcfg.Hostinfo{}), - Created: created, - Tags: []string{}, - PrimaryRoutes: []netip.Prefix{}, - LastSeen: &lastSeen, - MachineAuthorized: true, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityFileSharing, - tailcfg.CapabilityAdmin, - tailcfg.CapabilitySSH, - tailcfg.NodeAttrDisableUPnP, - }, - } - - peer2 := &types.Node{ - ID: 2, - MachineKey: mustMK( - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - ), - NodeKey: mustNK( - "nodekey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - ), - DiscoKey: mustDK( - "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - ), - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - Hostname: "peer2", - GivenName: "peer2", - UserID: 1, - User: types.User{Name: "peer2"}, - ForcedTags: []string{}, - LastSeen: &lastSeen, - Expiry: &expire, - Hostinfo: &tailcfg.Hostinfo{}, - Routes: []types.Route{}, - CreatedAt: created, - } - - tests := []struct { - name string - pol *policy.ACLPolicy - node *types.Node - peers types.Nodes - - baseDomain string - dnsConfig *tailcfg.DNSConfig - derpMap *tailcfg.DERPMap - logtail bool - randomClientPort bool - want *tailcfg.MapResponse - wantErr bool - }{ - // { - // name: "empty-node", - // node: types.Node{}, - // pol: &policy.ACLPolicy{}, - // dnsConfig: &tailcfg.DNSConfig{}, - // baseDomain: "", - // want: nil, - // wantErr: true, - // }, - { - name: "no-pol-no-peers-map-response", - pol: &policy.ACLPolicy{}, - node: mini, - peers: types.Nodes{}, - baseDomain: "", - dnsConfig: &tailcfg.DNSConfig{}, - derpMap: &tailcfg.DERPMap{}, - logtail: false, - randomClientPort: false, - want: &tailcfg.MapResponse{ - Node: tailMini, - KeepAlive: false, - DERPMap: &tailcfg.DERPMap{}, - Peers: []*tailcfg.Node{}, - DNSConfig: &tailcfg.DNSConfig{}, - Domain: "", - CollectServices: "false", - PacketFilter: []tailcfg.FilterRule{}, - UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, - ControlTime: &time.Time{}, - Debug: &tailcfg.Debug{ - DisableLogTail: true, - }, - }, - wantErr: false, - }, - { - name: "no-pol-with-peer-map-response", - pol: &policy.ACLPolicy{}, - node: mini, - peers: types.Nodes{ - peer1, - }, - baseDomain: "", - dnsConfig: &tailcfg.DNSConfig{}, - derpMap: &tailcfg.DERPMap{}, - logtail: false, - randomClientPort: false, - want: &tailcfg.MapResponse{ - KeepAlive: false, - Node: tailMini, - DERPMap: &tailcfg.DERPMap{}, - Peers: []*tailcfg.Node{ - tailPeer1, - }, - DNSConfig: &tailcfg.DNSConfig{}, - Domain: "", - CollectServices: "false", - PacketFilter: []tailcfg.FilterRule{}, - UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, - SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, - ControlTime: &time.Time{}, - Debug: &tailcfg.Debug{ - DisableLogTail: true, - }, - }, - wantErr: false, - }, - { - name: "with-pol-map-response", - pol: &policy.ACLPolicy{ - ACLs: []policy.ACL{ - { - Action: "accept", - Sources: []string{"100.64.0.2"}, - Destinations: []string{"mini:*"}, - }, - }, - }, - node: mini, - peers: types.Nodes{ - peer1, - peer2, - }, - baseDomain: "", - dnsConfig: &tailcfg.DNSConfig{}, - derpMap: &tailcfg.DERPMap{}, - logtail: false, - randomClientPort: false, - want: &tailcfg.MapResponse{ - KeepAlive: false, - Node: tailMini, - DERPMap: &tailcfg.DERPMap{}, - Peers: []*tailcfg.Node{ - tailPeer1, - }, - DNSConfig: &tailcfg.DNSConfig{}, - Domain: "", - CollectServices: "false", - PacketFilter: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.2/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - UserProfiles: []tailcfg.UserProfile{ - {LoginName: "mini", DisplayName: "mini"}, - }, - SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, - ControlTime: &time.Time{}, - Debug: &tailcfg.Debug{ - DisableLogTail: true, - }, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - mappy := NewMapper( - tt.node, - tt.peers, - tt.derpMap, - tt.baseDomain, - tt.dnsConfig, - tt.logtail, - tt.randomClientPort, - ) - - got, err := mappy.fullMapResponse( - tt.node, - tt.pol, - 0, - ) - - if (err != nil) != tt.wantErr { - t.Errorf("fullMapResponse() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - spew.Dump(got) - - if diff := cmp.Diff( - tt.want, - got, - cmpopts.EquateEmpty(), - // Ignore ControlTime, it is set to now and we dont really need to mock it. - cmpopts.IgnoreFields(tailcfg.MapResponse{}, "ControlTime"), - ); diff != "" { - t.Errorf("fullMapResponse() unexpected result (-want +got):\n%s", diff) - } - }) - } +// mockState is a mock implementation that provides the required methods. +type mockState struct { + polMan policy.PolicyManager + derpMap *tailcfg.DERPMap + primary *routes.PrimaryRoutes + nodes types.Nodes + peers types.Nodes +} + +func (m *mockState) DERPMap() *tailcfg.DERPMap { + return m.derpMap +} + +func (m *mockState) Filter() ([]tailcfg.FilterRule, []matcher.Match) { + if m.polMan == nil { + return tailcfg.FilterAllowAll, nil + } + return m.polMan.Filter() +} + +func (m *mockState) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { + if m.polMan == nil { + return nil, nil + } + return m.polMan.SSHPolicy(node) +} + +func (m *mockState) NodeCanHaveTag(node types.NodeView, tag string) bool { + if m.polMan == nil { + return false + } + return m.polMan.NodeCanHaveTag(node, tag) +} + +func (m *mockState) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { + if m.primary == nil { + return nil + } + return m.primary.PrimaryRoutes(nodeID) +} + +func (m *mockState) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) (types.Nodes, error) { + if len(peerIDs) > 0 { + // Filter peers by the provided IDs + var filtered types.Nodes + for _, peer := range m.peers { + if slices.Contains(peerIDs, peer.ID) { + filtered = append(filtered, peer) + } + } + + return filtered, nil + } + // Return all peers except the node itself + var filtered types.Nodes + for _, peer := range m.peers { + if peer.ID != nodeID { + filtered = append(filtered, peer) + } + } + + return filtered, nil +} + +func (m *mockState) ListNodes(nodeIDs ...types.NodeID) (types.Nodes, error) { + if len(nodeIDs) > 0 { + // Filter nodes by the provided IDs + var filtered types.Nodes + for _, node := range m.nodes { + if slices.Contains(nodeIDs, node.ID) { + filtered = append(filtered, node) + } + } + + return filtered, nil + } + + return m.nodes, nil +} + +func Test_fullMapResponse(t *testing.T) { + t.Skip("Test needs to be refactored for new state-based architecture") + // TODO: Refactor this test to work with the new state-based mapper + // The test architecture needs to be updated to work with the state interface + // instead of the old direct dependency injection pattern } diff --git a/hscontrol/mapper/suite_test.go b/hscontrol/mapper/suite_test.go deleted file mode 100644 index c9b1a580..00000000 --- a/hscontrol/mapper/suite_test.go +++ /dev/null @@ -1,15 +0,0 @@ -package mapper - -import ( - "testing" - - "gopkg.in/check.v1" -) - -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&Suite{}) - -type Suite struct{} diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go deleted file mode 100644 index e213a951..00000000 --- a/hscontrol/mapper/tail.go +++ /dev/null @@ -1,163 +0,0 @@ -package mapper - -import ( - "fmt" - "net/netip" - "strconv" - "time" - - "github.com/juanfont/headscale/hscontrol/policy" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/samber/lo" - "tailscale.com/tailcfg" -) - -func tailNodes( - nodes types.Nodes, - capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, - dnsConfig *tailcfg.DNSConfig, - baseDomain string, - randomClientPort bool, -) ([]*tailcfg.Node, error) { - tNodes := make([]*tailcfg.Node, len(nodes)) - - for index, node := range nodes { - node, err := tailNode( - node, - capVer, - pol, - dnsConfig, - baseDomain, - randomClientPort, - ) - if err != nil { - return nil, err - } - - tNodes[index] = node - } - - return tNodes, nil -} - -// tailNode converts a Node into a Tailscale Node. includeRoutes is false for shared nodes -// as per the expected behaviour in the official SaaS. -func tailNode( - node *types.Node, - capVer tailcfg.CapabilityVersion, - pol *policy.ACLPolicy, - dnsConfig *tailcfg.DNSConfig, - baseDomain string, - randomClientPort bool, -) (*tailcfg.Node, error) { - addrs := node.IPAddresses.Prefixes() - - allowedIPs := append( - []netip.Prefix{}, - addrs...) // we append the node own IP, as it is required by the clients - - primaryPrefixes := []netip.Prefix{} - - for _, route := range node.Routes { - if route.Enabled { - if route.IsPrimary { - allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix)) - primaryPrefixes = append(primaryPrefixes, netip.Prefix(route.Prefix)) - } else if route.IsExitRoute() { - allowedIPs = append(allowedIPs, netip.Prefix(route.Prefix)) - } - } - } - - var derp string - if node.Hostinfo.NetInfo != nil { - derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP) - } else { - derp = "127.3.3.40:0" // Zero means disconnected or unknown. - } - - var keyExpiry time.Time - if node.Expiry != nil { - keyExpiry = *node.Expiry - } else { - keyExpiry = time.Time{} - } - - hostname, err := node.GetFQDN(dnsConfig, baseDomain) - if err != nil { - return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) - } - - tags, _ := pol.TagsOfNode(node) - tags = lo.Uniq(append(tags, node.ForcedTags...)) - - tNode := tailcfg.Node{ - ID: tailcfg.NodeID(node.ID), // this is the actual ID - StableID: tailcfg.StableNodeID( - strconv.FormatUint(node.ID, util.Base10), - ), // in headscale, unlike tailcontrol server, IDs are permanent - Name: hostname, - Cap: capVer, - - User: tailcfg.UserID(node.UserID), - - Key: node.NodeKey, - KeyExpiry: keyExpiry, - - Machine: node.MachineKey, - DiscoKey: node.DiscoKey, - Addresses: addrs, - AllowedIPs: allowedIPs, - Endpoints: node.Endpoints, - DERP: derp, - Hostinfo: node.Hostinfo.View(), - Created: node.CreatedAt, - - Online: node.IsOnline, - - Tags: tags, - - PrimaryRoutes: primaryPrefixes, - - MachineAuthorized: !node.IsExpired(), - Expired: node.IsExpired(), - } - - // - 74: 2023-09-18: Client understands NodeCapMap - if capVer >= 74 { - tNode.CapMap = tailcfg.NodeCapMap{ - tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, - tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, - tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, - } - - if randomClientPort { - tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} - } - } else { - tNode.Capabilities = []tailcfg.NodeCapability{ - tailcfg.CapabilityFileSharing, - tailcfg.CapabilityAdmin, - tailcfg.CapabilitySSH, - } - - if randomClientPort { - tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrRandomizeClientPort) - } - } - - // - 72: 2023-08-23: TS-2023-006 UPnP issue fixed; UPnP can now be used again - if capVer < 72 { - tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrDisableUPnP) - } - - if node.IsOnline == nil || !*node.IsOnline { - // LastSeen is only set when node is - // not connected to the control server. - tNode.LastSeen = node.LastSeen - } - - return &tNode, nil -} diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index f6e370c4..5b7030de 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -1,16 +1,19 @@ package mapper import ( + "encoding/json" "net/netip" "testing" "time" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/ptr" ) func TestTailNode(t *testing.T) { @@ -46,7 +49,7 @@ func TestTailNode(t *testing.T) { tests := []struct { name string node *types.Node - pol *policy.ACLPolicy + pol []byte dnsConfig *tailcfg.DNSConfig baseDomain string want *tailcfg.Node @@ -55,23 +58,23 @@ func TestTailNode(t *testing.T) { { name: "empty-node", node: &types.Node{ - Hostinfo: &tailcfg.Hostinfo{}, + GivenName: "empty", + Hostinfo: &tailcfg.Hostinfo{}, }, - pol: &policy.ACLPolicy{}, dnsConfig: &tailcfg.DNSConfig{}, baseDomain: "", want: &tailcfg.Node{ + Name: "empty", StableID: "0", - Addresses: []netip.Prefix{}, - AllowedIPs: []netip.Prefix{}, - DERP: "127.3.3.40:0", + HomeDERP: 0, + LegacyDERPString: "127.3.3.40:0", Hostinfo: hiview(tailcfg.Hostinfo{}), - Tags: []string{}, - PrimaryRoutes: []netip.Prefix{}, MachineAuthorized: true, - Capabilities: []tailcfg.NodeCapability{ - "https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin", - "https://tailscale.com/cap/ssh", "debug-disable-upnp", + + CapMap: tailcfg.NodeCapMap{ + tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, + tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, + tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, }, }, wantErr: false, @@ -89,44 +92,28 @@ func TestTailNode(t *testing.T) { DiscoKey: mustDK( "discokey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", ), - IPAddresses: []netip.Addr{ - netip.MustParseAddr("100.64.0.1"), - }, + IPv4: iap("100.64.0.1"), Hostname: "mini", GivenName: "mini", - UserID: 0, - User: types.User{ + UserID: ptr.To(uint(0)), + User: &types.User{ Name: "mini", }, - ForcedTags: []string{}, - AuthKeyID: 0, - AuthKey: &types.PreAuthKey{}, - LastSeen: &lastSeen, - Expiry: &expire, - Hostinfo: &tailcfg.Hostinfo{}, - Routes: []types.Route{ - { - Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")), - Advertised: true, - Enabled: true, - IsPrimary: false, - }, - { - Prefix: types.IPPrefix(netip.MustParsePrefix("192.168.0.0/24")), - Advertised: true, - Enabled: true, - IsPrimary: true, - }, - { - Prefix: types.IPPrefix(netip.MustParsePrefix("172.0.0.0/10")), - Advertised: true, - Enabled: false, - IsPrimary: true, + Tags: []string{}, + AuthKey: &types.PreAuthKey{}, + LastSeen: &lastSeen, + Expiry: &expire, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + tsaddr.AllIPv4(), + tsaddr.AllIPv6(), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.0.0.0/10"), }, }, - CreatedAt: created, + ApprovedRoutes: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6(), netip.MustParsePrefix("192.168.0.0/24")}, + CreatedAt: created, }, - pol: &policy.ACLPolicy{}, dnsConfig: &tailcfg.DNSConfig{}, baseDomain: "", want: &tailcfg.Node{ @@ -149,28 +136,59 @@ func TestTailNode(t *testing.T) { ), Addresses: []netip.Prefix{netip.MustParsePrefix("100.64.0.1/32")}, AllowedIPs: []netip.Prefix{ - netip.MustParsePrefix("100.64.0.1/32"), - netip.MustParsePrefix("0.0.0.0/0"), + tsaddr.AllIPv4(), netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("100.64.0.1/32"), + tsaddr.AllIPv6(), }, - DERP: "127.3.3.40:0", - Hostinfo: hiview(tailcfg.Hostinfo{}), - Created: created, - - Tags: []string{}, - PrimaryRoutes: []netip.Prefix{ netip.MustParsePrefix("192.168.0.0/24"), }, + HomeDERP: 0, + LegacyDERPString: "127.3.3.40:0", + Hostinfo: hiview(tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + tsaddr.AllIPv4(), + tsaddr.AllIPv6(), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.0.0.0/10"), + }, + }), + Created: created, + + Tags: []string{}, - LastSeen: &lastSeen, MachineAuthorized: true, - Capabilities: []tailcfg.NodeCapability{ - tailcfg.CapabilityFileSharing, - tailcfg.CapabilityAdmin, - tailcfg.CapabilitySSH, - tailcfg.NodeAttrDisableUPnP, + CapMap: tailcfg.NodeCapMap{ + tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, + tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, + tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, + }, + }, + wantErr: false, + }, + { + name: "check-dot-suffix-on-node-name", + node: &types.Node{ + GivenName: "minimal", + Hostinfo: &tailcfg.Hostinfo{}, + }, + dnsConfig: &tailcfg.DNSConfig{}, + baseDomain: "example.com", + want: &tailcfg.Node{ + // a node name should have a dot appended + Name: "minimal.example.com.", + StableID: "0", + HomeDERP: 0, + LegacyDERPString: "127.3.3.40:0", + Hostinfo: hiview(tailcfg.Hostinfo{}), + MachineAuthorized: true, + + CapMap: tailcfg.NodeCapMap{ + tailcfg.CapabilityFileSharing: []tailcfg.RawMessage{}, + tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, + tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, }, }, wantErr: false, @@ -182,23 +200,102 @@ func TestTailNode(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := tailNode( - tt.node, + primary := routes.New() + cfg := &types.Config{ + BaseDomain: tt.baseDomain, + TailcfgDNSConfig: tt.dnsConfig, + RandomizeClientPort: false, + Taildrop: types.TaildropConfig{Enabled: true}, + } + _ = primary.SetRoutes(tt.node.ID, tt.node.SubnetRoutes()...) + + // This is a hack to avoid having a second node to test the primary route. + // This should be baked into the test case proper if it is extended in the future. + _ = primary.SetRoutes(2, netip.MustParsePrefix("192.168.0.0/24")) + got, err := tt.node.View().TailNode( 0, - tt.pol, - tt.dnsConfig, - tt.baseDomain, - false, + func(id types.NodeID) []netip.Prefix { + return primary.PrimaryRoutes(id) + }, + cfg, ) if (err != nil) != tt.wantErr { - t.Errorf("tailNode() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("TailNode() error = %v, wantErr %v", err, tt.wantErr) return } if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { - t.Errorf("tailNode() unexpected result (-want +got):\n%s", diff) + t.Errorf("TailNode() unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestNodeExpiry(t *testing.T) { + tp := func(t time.Time) *time.Time { + return &t + } + tests := []struct { + name string + exp *time.Time + wantTime time.Time + wantTimeZero bool + }{ + { + name: "no-expiry", + exp: nil, + wantTimeZero: true, + }, + { + name: "zero-expiry", + exp: &time.Time{}, + wantTimeZero: true, + }, + { + name: "localtime", + exp: tp(time.Time{}.Local()), + wantTimeZero: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + node := &types.Node{ + ID: 0, + GivenName: "test", + Expiry: tt.exp, + } + + tn, err := node.View().TailNode( + 0, + func(id types.NodeID) []netip.Prefix { + return []netip.Prefix{} + }, + &types.Config{Taildrop: types.TaildropConfig{Enabled: true}}, + ) + if err != nil { + t.Fatalf("nodeExpiry() error = %v", err) + } + + // Round trip the node through JSON to ensure the time is serialized correctly + seri, err := json.Marshal(tn) + if err != nil { + t.Fatalf("nodeExpiry() error = %v", err) + } + var deseri tailcfg.Node + err = json.Unmarshal(seri, &deseri) + if err != nil { + t.Fatalf("nodeExpiry() error = %v", err) + } + + if tt.wantTimeZero { + if !deseri.KeyExpiry.IsZero() { + t.Errorf("nodeExpiry() = %v, want zero", deseri.KeyExpiry) + } + } else if deseri.KeyExpiry != tt.wantTime { + t.Errorf("nodeExpiry() = %v, want %v", deseri.KeyExpiry, tt.wantTime) } }) } diff --git a/hscontrol/metrics.go b/hscontrol/metrics.go index fc56f584..749d651e 100644 --- a/hscontrol/metrics.go +++ b/hscontrol/metrics.go @@ -1,25 +1,101 @@ package hscontrol import ( + "net/http" + "strconv" + + "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" + "tailscale.com/envknob" ) +var debugHighCardinalityMetrics = envknob.Bool("HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS") + +var mapResponseLastSentSeconds *prometheus.GaugeVec + +func init() { + if debugHighCardinalityMetrics { + mapResponseLastSentSeconds = promauto.NewGaugeVec(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_last_sent_seconds", + Help: "last sent metric to node.id", + }, []string{"type", "id"}) + } +} + const prometheusNamespace = "headscale" var ( - // This is a high cardinality metric (user x node), we might want to make this - // configurable/opt-in in the future. - nodeRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{ + mapResponseSent = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, - Name: "node_registrations_total", - Help: "The total amount of registered node attempts", - }, []string{"action", "auth", "status", "user"}) - - updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "mapresponse_sent_total", + Help: "total count of mapresponses sent to clients", + }, []string{"status", "type"}) + mapResponseEndpointUpdates = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, - Name: "update_request_sent_to_node_total", - Help: "The number of calls/messages issued on a specific nodes update channel", - }, []string{"user", "node", "status"}) - // TODO(kradalby): This is very debugging, we might want to remove it. + Name: "mapresponse_endpoint_updates_total", + Help: "total count of endpoint updates received", + }, []string{"status"}) + mapResponseEnded = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_ended_total", + Help: "total count of new mapsessions ended", + }, []string{"reason"}) + httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "http_duration_seconds", + Help: "Duration of HTTP requests.", + }, []string{"path"}) + httpCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "http_requests_total", + Help: "Total number of http requests processed", + }, []string{"code", "method", "path"}, + ) ) + +// prometheusMiddleware implements mux.MiddlewareFunc. +func prometheusMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route := mux.CurrentRoute(r) + path, _ := route.GetPathTemplate() + + // Ignore streaming and noise sessions + // it has its own router further down. + if path == "/ts2021" || path == "/machine/map" || path == "/derp" || path == "/derp/probe" || path == "/derp/latency-check" || path == "/bootstrap-dns" { + next.ServeHTTP(w, r) + return + } + + rw := &respWriterProm{ResponseWriter: w} + + timer := prometheus.NewTimer(httpDuration.WithLabelValues(path)) + next.ServeHTTP(rw, r) + timer.ObserveDuration() + httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc() + }) +} + +type respWriterProm struct { + http.ResponseWriter + status int + written int64 + wroteHeader bool +} + +func (r *respWriterProm) WriteHeader(code int) { + r.status = code + r.wroteHeader = true + r.ResponseWriter.WriteHeader(code) +} + +func (r *respWriterProm) Write(b []byte) (int, error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + n, err := r.ResponseWriter.Write(b) + r.written += int64(n) + + return n, err +} diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 0fa28d19..a667cd1f 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -3,16 +3,18 @@ package hscontrol import ( "encoding/binary" "encoding/json" + "errors" + "fmt" "io" "net/http" "github.com/gorilla/mux" + "github.com/juanfont/headscale/hscontrol/capver" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" "golang.org/x/net/http2" - "golang.org/x/net/http2/h2c" "tailscale.com/control/controlbase" - "tailscale.com/control/controlhttp" + "tailscale.com/control/controlhttp/controlhttpserver" "tailscale.com/tailcfg" "tailscale.com/types/key" ) @@ -27,9 +29,6 @@ const ( // of length. Then that many bytes of JSON-encoded tailcfg.EarlyNoise. // The early payload is optional. Some servers may not send it... But we do! earlyPayloadMagic = "\xff\xff\xffTS" - - // EarlyNoise was added in protocol version 49. - earlyNoiseCapabilityVersion = 49 ) type noiseServer struct { @@ -72,7 +71,7 @@ func (h *Headscale) NoiseUpgradeHandler( challenge: key.NewChallenge(), } - noiseConn, err := controlhttp.AcceptHTTP( + noiseConn, err := controlhttpserver.AcceptHTTP( req.Context(), writer, req, @@ -80,9 +79,7 @@ func (h *Headscale) NoiseUpgradeHandler( noiseServer.earlyNoise, ) if err != nil { - log.Error().Err(err).Msg("noise upgrade failed") - http.Error(writer, err.Error(), http.StatusInternalServerError) - + httpError(writer, fmt.Errorf("noise upgrade failed: %w", err)) return } @@ -95,23 +92,22 @@ func (h *Headscale) NoiseUpgradeHandler( // The HTTP2 server that exposes this router is created for // a single hijacked connection from /ts2021, using netutil.NewOneConnListener router := mux.NewRouter() + router.Use(prometheusMiddleware) router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). Methods(http.MethodPost) - router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) - server := http.Server{ - ReadTimeout: types.HTTPReadTimeout, - } + // Endpoints outside of the register endpoint must use getAndValidateNode to + // get the node to ensure that the MachineKey matches the Node setting up the + // connection. + router.HandleFunc("/machine/map", noiseServer.NoisePollNetMapHandler) noiseServer.httpBaseConfig = &http.Server{ Handler: router, - ReadHeaderTimeout: types.HTTPReadTimeout, + ReadHeaderTimeout: types.HTTPTimeout, } noiseServer.http2Server = &http2.Server{} - server.Handler = h2c.NewHandler(router, noiseServer.http2Server) - noiseServer.http2Server.ServeConn( noiseConn, &http2.ServeConnOpts{ @@ -120,19 +116,13 @@ func (h *Headscale) NoiseUpgradeHandler( ) } +func unsupportedClientError(version tailcfg.CapabilityVersion) error { + return fmt.Errorf("unsupported client version: %s (%d)", capver.TailscaleVersion(version), version) +} + func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { - log.Trace(). - Caller(). - Int("protocol_version", protocolVersion). - Str("challenge", ns.challenge.Public().String()). - Msg("earlyNoise called") - - if protocolVersion < earlyNoiseCapabilityVersion { - log.Trace(). - Caller(). - Msgf("protocol version %d does not support early noise", protocolVersion) - - return nil + if !isSupportedVersion(tailcfg.CapabilityVersion(protocolVersion)) { + return unsupportedClientError(tailcfg.CapabilityVersion(protocolVersion)) } earlyJSON, err := json.Marshal(&tailcfg.EarlyNoise{ @@ -163,3 +153,154 @@ func (ns *noiseServer) earlyNoise(protocolVersion int, writer io.Writer) error { return nil } + +func isSupportedVersion(version tailcfg.CapabilityVersion) bool { + return version >= capver.MinSupportedCapabilityVersion +} + +func rejectUnsupported( + writer http.ResponseWriter, + version tailcfg.CapabilityVersion, + mkey key.MachinePublic, + nkey key.NodePublic, +) bool { + // Reject unsupported versions + if !isSupportedVersion(version) { + log.Error(). + Caller(). + Int("minimum_cap_ver", int(capver.MinSupportedCapabilityVersion)). + Int("client_cap_ver", int(version)). + Str("minimum_version", capver.TailscaleVersion(capver.MinSupportedCapabilityVersion)). + Str("client_version", capver.TailscaleVersion(version)). + Str("node.key", nkey.ShortString()). + Str("machine.key", mkey.ShortString()). + Msg("unsupported client connected") + http.Error(writer, unsupportedClientError(version).Error(), http.StatusBadRequest) + + return true + } + + return false +} + +// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol +// +// This is the busiest endpoint, as it keeps the HTTP long poll that updates +// the clients when something in the network changes. +// +// The clients POST stuff like HostInfo and their Endpoints here, but +// only after their first request (marked with the ReadOnly field). +// +// At this moment the updates are sent in a quite horrendous way, but they kinda work. +func (ns *noiseServer) NoisePollNetMapHandler( + writer http.ResponseWriter, + req *http.Request, +) { + body, _ := io.ReadAll(req.Body) + + var mapRequest tailcfg.MapRequest + if err := json.Unmarshal(body, &mapRequest); err != nil { + httpError(writer, err) + return + } + + // Reject unsupported versions + if rejectUnsupported(writer, mapRequest.Version, ns.machineKey, mapRequest.NodeKey) { + return + } + + nv, err := ns.getAndValidateNode(mapRequest) + if err != nil { + httpError(writer, err) + return + } + + ns.nodeKey = nv.NodeKey() + + sess := ns.headscale.newMapSession(req.Context(), mapRequest, writer, nv.AsStruct()) + sess.tracef("a node sending a MapRequest with Noise protocol") + if !sess.isStreaming() { + sess.serve() + } else { + sess.serveLongPoll() + } +} + +func regErr(err error) *tailcfg.RegisterResponse { + return &tailcfg.RegisterResponse{Error: err.Error()} +} + +// NoiseRegistrationHandler handles the actual registration process of a node. +func (ns *noiseServer) NoiseRegistrationHandler( + writer http.ResponseWriter, + req *http.Request, +) { + if req.Method != http.MethodPost { + httpError(writer, errMethodNotAllowed) + + return + } + + registerRequest, registerResponse := func() (*tailcfg.RegisterRequest, *tailcfg.RegisterResponse) { + var resp *tailcfg.RegisterResponse + body, err := io.ReadAll(req.Body) + if err != nil { + return &tailcfg.RegisterRequest{}, regErr(err) + } + var regReq tailcfg.RegisterRequest + if err := json.Unmarshal(body, ®Req); err != nil { + return ®Req, regErr(err) + } + + ns.nodeKey = regReq.NodeKey + + resp, err = ns.headscale.handleRegister(req.Context(), regReq, ns.conn.Peer()) + if err != nil { + var httpErr HTTPError + if errors.As(err, &httpErr) { + resp = &tailcfg.RegisterResponse{ + Error: httpErr.Msg, + } + return ®Req, resp + } + + return ®Req, regErr(err) + } + + return ®Req, resp + }() + + // Reject unsupported versions + if rejectUnsupported(writer, registerRequest.Version, ns.machineKey, registerRequest.NodeKey) { + return + } + + writer.Header().Set("Content-Type", "application/json; charset=utf-8") + writer.WriteHeader(http.StatusOK) + + if err := json.NewEncoder(writer).Encode(registerResponse); err != nil { + log.Error().Caller().Err(err).Msg("NoiseRegistrationHandler: failed to encode RegisterResponse") + return + } + + // Ensure response is flushed to client + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } +} + +// getAndValidateNode retrieves the node from the database using the NodeKey +// and validates that it matches the MachineKey from the Noise session. +func (ns *noiseServer) getAndValidateNode(mapRequest tailcfg.MapRequest) (types.NodeView, error) { + nv, ok := ns.headscale.state.GetNodeByNodeKey(mapRequest.NodeKey) + if !ok { + return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node not found", nil) + } + + // Validate that the MachineKey in the Noise session matches the one associated with the NodeKey. + if ns.machineKey != nv.MachineKey() { + return types.NodeView{}, NewHTTPError(http.StatusNotFound, "node key in request does not match the one associated with this machine key", nil) + } + + return nv, nil +} diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go deleted file mode 100644 index ae0aad46..00000000 --- a/hscontrol/notifier/notifier.go +++ /dev/null @@ -1,109 +0,0 @@ -package notifier - -import ( - "fmt" - "strings" - "sync" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" - "tailscale.com/types/key" -) - -type Notifier struct { - l sync.RWMutex - nodes map[string]chan<- types.StateUpdate -} - -func NewNotifier() *Notifier { - return &Notifier{} -} - -func (n *Notifier) AddNode(machineKey key.MachinePublic, c chan<- types.StateUpdate) { - log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to add node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to add node") - - n.l.Lock() - defer n.l.Unlock() - - if n.nodes == nil { - n.nodes = make(map[string]chan<- types.StateUpdate) - } - - n.nodes[machineKey.String()] = c - - log.Trace(). - Str("machine_key", machineKey.ShortString()). - Int("open_chans", len(n.nodes)). - Msg("Added new channel") -} - -func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { - log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("acquiring lock to remove node") - defer log.Trace().Caller().Str("key", machineKey.ShortString()).Msg("releasing lock to remove node") - - n.l.Lock() - defer n.l.Unlock() - - if n.nodes == nil { - return - } - - delete(n.nodes, machineKey.String()) - - log.Trace(). - Str("machine_key", machineKey.ShortString()). - Int("open_chans", len(n.nodes)). - Msg("Removed channel") -} - -// IsConnected reports if a node is connected to headscale and has a -// poll session open. -func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool { - n.l.RLock() - defer n.l.RUnlock() - - if _, ok := n.nodes[machineKey.String()]; ok { - return true - } - - return false -} - -func (n *Notifier) NotifyAll(update types.StateUpdate) { - n.NotifyWithIgnore(update) -} - -func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) { - log.Trace().Caller().Interface("type", update.Type).Msg("acquiring lock to notify") - defer log.Trace(). - Caller(). - Interface("type", update.Type). - Msg("releasing lock, finished notifing") - - n.l.RLock() - defer n.l.RUnlock() - - for key, c := range n.nodes { - if util.IsStringInSlice(ignore, key) { - continue - } - - log.Trace().Caller().Str("machine", key).Strs("ignoring", ignore).Msg("sending update") - c <- update - } -} - -func (n *Notifier) String() string { - n.l.RLock() - defer n.l.RUnlock() - - str := []string{"Notifier, in map:\n"} - - for k, v := range n.nodes { - str = append(str, fmt.Sprintf("\t%s: %v\n", k, v)) - } - - return strings.Join(str, "") -} diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 568519fd..7013b8ed 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -2,34 +2,38 @@ package hscontrol import ( "bytes" + "cmp" "context" - "crypto/rand" - _ "embed" - "encoding/hex" "errors" "fmt" - "html/template" "net/http" + "slices" "strings" "time" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/mux" "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/templates" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "golang.org/x/oauth2" - "tailscale.com/types/key" + "zgo.at/zcache/v2" ) const ( - randomByteSize = 16 + randomByteSize = 16 + defaultOAuthOptionsCount = 3 + registerCacheExpiration = time.Minute * 15 + registerCacheCleanup = time.Minute * 20 ) var ( errEmptyOIDCCallbackParams = errors.New("empty OIDC callback params") errNoOIDCIDToken = errors.New("could not extract ID Token for OIDC callback") + errNoOIDCRegistrationInfo = errors.New("could not get registration info from cache") errOIDCAllowedDomains = errors.New( "authenticated principal does not match any allowed domain", ) @@ -37,349 +41,377 @@ var ( errOIDCAllowedUsers = errors.New( "authenticated principal does not match any allowed user", ) - errOIDCInvalidNodeState = errors.New( - "requested node state key expired before authorisation completed", - ) - errOIDCNodeKeyMissing = errors.New("could not get node key from cache") + errOIDCUnverifiedEmail = errors.New("authenticated principal has an unverified email") ) -type IDTokenClaims struct { - Name string `json:"name,omitempty"` - Groups []string `json:"groups,omitempty"` - Email string `json:"email"` - Username string `json:"preferred_username,omitempty"` +// RegistrationInfo contains both machine key and verifier information for OIDC validation. +type RegistrationInfo struct { + RegistrationID types.RegistrationID + Verifier *string } -func (h *Headscale) initOIDC() error { +type AuthProviderOIDC struct { + h *Headscale + serverURL string + cfg *types.OIDCConfig + registrationCache *zcache.Cache[string, RegistrationInfo] + + oidcProvider *oidc.Provider + oauth2Config *oauth2.Config +} + +func NewAuthProviderOIDC( + ctx context.Context, + h *Headscale, + serverURL string, + cfg *types.OIDCConfig, +) (*AuthProviderOIDC, error) { var err error // grab oidc config if it hasn't been already - if h.oauth2Config == nil { - h.oidcProvider, err = oidc.NewProvider(context.Background(), h.cfg.OIDC.Issuer) - - if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("Could not retrieve OIDC Config: %s", err.Error()) - - return err - } - - h.oauth2Config = &oauth2.Config{ - ClientID: h.cfg.OIDC.ClientID, - ClientSecret: h.cfg.OIDC.ClientSecret, - Endpoint: h.oidcProvider.Endpoint(), - RedirectURL: fmt.Sprintf( - "%s/oidc/callback", - strings.TrimSuffix(h.cfg.ServerURL, "/"), - ), - Scopes: h.cfg.OIDC.Scope, - } + oidcProvider, err := oidc.NewProvider(context.Background(), cfg.Issuer) + if err != nil { + return nil, fmt.Errorf("creating OIDC provider from issuer config: %w", err) } - return nil -} - -func (h *Headscale) determineTokenExpiration(idTokenExpiration time.Time) time.Time { - if h.cfg.OIDC.UseExpiryFromToken { - return idTokenExpiration + oauth2Config := &oauth2.Config{ + ClientID: cfg.ClientID, + ClientSecret: cfg.ClientSecret, + Endpoint: oidcProvider.Endpoint(), + RedirectURL: strings.TrimSuffix(serverURL, "/") + "/oidc/callback", + Scopes: cfg.Scope, } - return time.Now().Add(h.cfg.OIDC.Expiry) + registrationCache := zcache.New[string, RegistrationInfo]( + registerCacheExpiration, + registerCacheCleanup, + ) + + return &AuthProviderOIDC{ + h: h, + serverURL: serverURL, + cfg: cfg, + registrationCache: registrationCache, + + oidcProvider: oidcProvider, + oauth2Config: oauth2Config, + }, nil } -// RegisterOIDC redirects to the OIDC provider for authentication -// Puts NodeKey in cache so the callback can retrieve it using the oidc state param -// Listens in /oidc/register/:mKey. -func (h *Headscale) RegisterOIDC( +func (a *AuthProviderOIDC) AuthURL(registrationID types.RegistrationID) string { + return fmt.Sprintf( + "%s/register/%s", + strings.TrimSuffix(a.serverURL, "/"), + registrationID.String()) +} + +// RegisterHandler registers the OIDC callback handler with the given router. +// It puts NodeKey in cache so the callback can retrieve it using the oidc state param. +// Listens in /register/:registration_id. +func (a *AuthProviderOIDC) RegisterHandler( writer http.ResponseWriter, req *http.Request, ) { vars := mux.Vars(req) - machineKeyStr, ok := vars["mkey"] - - log.Debug(). - Caller(). - Str("machine_key", machineKeyStr). - Bool("ok", ok). - Msg("Received oidc register call") + registrationIdStr := vars["registration_id"] // We need to make sure we dont open for XSS style injections, if the parameter that // is passed as a key is not parsable/validated as a NodePublic key, then fail to render // the template and log an error. - var machineKey key.MachinePublic - err := machineKey.UnmarshalText( - []byte(machineKeyStr), - ) + registrationId, err := types.RegistrationIDFromString(registrationIdStr) if err != nil { - log.Warn(). - Err(err). - Msg("Failed to parse incoming nodekey in OIDC registration") + httpError(writer, NewHTTPError(http.StatusBadRequest, "invalid registration id", err)) + return + } + // Set the state and nonce cookies to protect against CSRF attacks + state, err := setCSRFCookie(writer, req, "state") + if err != nil { + httpError(writer, err) + return + } + + // Set the state and nonce cookies to protect against CSRF attacks + nonce, err := setCSRFCookie(writer, req, "nonce") + if err != nil { + httpError(writer, err) + return + } + + // Initialize registration info with machine key + registrationInfo := RegistrationInfo{ + RegistrationID: registrationId, + } + + extras := make([]oauth2.AuthCodeOption, 0, len(a.cfg.ExtraParams)+defaultOAuthOptionsCount) + // Add PKCE verification if enabled + if a.cfg.PKCE.Enabled { + verifier := oauth2.GenerateVerifier() + registrationInfo.Verifier = &verifier + + extras = append(extras, oauth2.AccessTypeOffline) + + switch a.cfg.PKCE.Method { + case types.PKCEMethodS256: + extras = append(extras, oauth2.S256ChallengeOption(verifier)) + case types.PKCEMethodPlain: + // oauth2 does not have a plain challenge option, so we add it manually + extras = append(extras, oauth2.SetAuthURLParam("code_challenge_method", "plain"), oauth2.SetAuthURLParam("code_challenge", verifier)) + } + } + + // Add any extra parameters from configuration + for k, v := range a.cfg.ExtraParams { + extras = append(extras, oauth2.SetAuthURLParam(k, v)) + } + extras = append(extras, oidc.Nonce(nonce)) + + // Cache the registration info + a.registrationCache.Set(state, registrationInfo) + + authURL := a.oauth2Config.AuthCodeURL(state, extras...) + log.Debug().Caller().Msgf("Redirecting to %s for authentication", authURL) + + http.Redirect(writer, req, authURL, http.StatusFound) +} + +// OIDCCallbackHandler handles the callback from the OIDC endpoint +// Retrieves the nkey from the state cache and adds the node to the users email user +// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities +// TODO: Add groups information from OIDC tokens into node HostInfo +// Listens in /oidc/callback. +func (a *AuthProviderOIDC) OIDCCallbackHandler( + writer http.ResponseWriter, + req *http.Request, +) { + code, state, err := extractCodeAndStateParamFromRequest(req) + if err != nil { + httpError(writer, err) + return + } + + stateCookieName := getCookieName("state", state) + cookieState, err := req.Cookie(stateCookieName) + if err != nil { + httpError(writer, NewHTTPError(http.StatusBadRequest, "state not found", err)) + return + } + + if state != cookieState.Value { + httpError(writer, NewHTTPError(http.StatusForbidden, "state did not match", nil)) + return + } + + oauth2Token, err := a.getOauth2Token(req.Context(), code, state) + if err != nil { + httpError(writer, err) + return + } + + idToken, err := a.extractIDToken(req.Context(), oauth2Token) + if err != nil { + httpError(writer, err) + return + } + if idToken.Nonce == "" { + httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found in IDToken", err)) + return + } + + nonceCookieName := getCookieName("nonce", idToken.Nonce) + nonce, err := req.Cookie(nonceCookieName) + if err != nil { + httpError(writer, NewHTTPError(http.StatusBadRequest, "nonce not found", err)) + return + } + if idToken.Nonce != nonce.Value { + httpError(writer, NewHTTPError(http.StatusForbidden, "nonce did not match", nil)) + return + } + + nodeExpiry := a.determineNodeExpiry(idToken.Expiry) + + var claims types.OIDCClaims + if err := idToken.Claims(&claims); err != nil { + httpError(writer, fmt.Errorf("decoding ID token claims: %w", err)) + return + } + + // Fetch user information (email, groups, name, etc) from the userinfo endpoint + // https://openid.net/specs/openid-connect-core-1_0.html#UserInfo + var userinfo *oidc.UserInfo + userinfo, err = a.oidcProvider.UserInfo(req.Context(), oauth2.StaticTokenSource(oauth2Token)) + if err != nil { + util.LogErr(err, "could not get userinfo; only using claims from id token") + } + + // The oidc.UserInfo type only decodes some fields (Subject, Profile, Email, EmailVerified). + // We are interested in other fields too (e.g. groups are required for allowedGroups) so we + // decode into our own OIDCUserInfo type using the underlying claims struct. + var userinfo2 types.OIDCUserInfo + if userinfo != nil && userinfo.Claims(&userinfo2) == nil && userinfo2.Sub == claims.Sub { + // Update the user with the userinfo claims (with id token claims as fallback). + // TODO(kradalby): there might be more interesting fields here that we have not found yet. + claims.Email = cmp.Or(userinfo2.Email, claims.Email) + claims.EmailVerified = cmp.Or(userinfo2.EmailVerified, claims.EmailVerified) + claims.Username = cmp.Or(userinfo2.PreferredUsername, claims.Username) + claims.Name = cmp.Or(userinfo2.Name, claims.Name) + claims.ProfilePictureURL = cmp.Or(userinfo2.Picture, claims.ProfilePictureURL) + if userinfo2.Groups != nil { + claims.Groups = userinfo2.Groups + } + } else { + util.LogErr(err, "could not get userinfo; only using claims from id token") + } + + // The user claims are now updated from the userinfo endpoint so we can verify the user + // against allowed emails, email domains, and groups. + err = doOIDCAuthorization(a.cfg, &claims) + if err != nil { + httpError(writer, err) + return + } + + user, _, err := a.createOrUpdateUserFromClaim(&claims) + if err != nil { + log.Error(). + Err(err). + Caller(). + Msgf("could not create or update user") writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Wrong params")) - if err != nil { - util.LogErr(err, "Failed to write response") + writer.WriteHeader(http.StatusInternalServerError) + _, werr := writer.Write([]byte("Could not create or update user")) + if werr != nil { + log.Error(). + Caller(). + Err(werr). + Msg("Failed to write HTTP response") } return } - randomBlob := make([]byte, randomByteSize) - if _, err := rand.Read(randomBlob); err != nil { - util.LogErr(err, "could not read 16 bytes from rand") + // TODO(kradalby): Is this comment right? + // If the node exists, then the node should be reauthenticated, + // if the node does not exist, and the machine key exists, then + // this is a new node that should be registered. + registrationId := a.getRegistrationIDFromState(state) - http.Error(writer, "Internal server error", http.StatusInternalServerError) + // Register the node if it does not exist. + if registrationId != nil { + verb := "Reauthenticated" + newNode, err := a.handleRegistration(user, *registrationId, nodeExpiry) + if err != nil { + if errors.Is(err, db.ErrNodeNotFoundRegistrationCache) { + log.Debug().Caller().Str("registration_id", registrationId.String()).Msg("registration session expired before authorization completed") + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", err)) + + return + } + httpError(writer, err) + return + } + + if newNode { + verb = "Authenticated" + } + + // TODO(kradalby): replace with go-elem + content, err := renderOIDCCallbackTemplate(user, verb) + if err != nil { + httpError(writer, err) + return + } + + writer.Header().Set("Content-Type", "text/html; charset=utf-8") + writer.WriteHeader(http.StatusOK) + if _, err := writer.Write(content.Bytes()); err != nil { + util.LogErr(err, "Failed to write HTTP response") + } return } - stateStr := hex.EncodeToString(randomBlob)[:32] - - // place the node key into the state cache, so it can be retrieved later - h.registrationCache.Set( - stateStr, - machineKey, - registerCacheExpiration, - ) - - // Add any extra parameter provided in the configuration to the Authorize Endpoint request - extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)) - - for k, v := range h.cfg.OIDC.ExtraParams { - extras = append(extras, oauth2.SetAuthURLParam(k, v)) - } - - authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...) - log.Debug().Msgf("Redirecting to %s for authentication", authURL) - - http.Redirect(writer, req, authURL, http.StatusFound) + // Neither node nor machine key was found in the state cache meaning + // that we could not reauth nor register the node. + httpError(writer, NewHTTPError(http.StatusGone, "login session expired, try again", nil)) } -type oidcCallbackTemplateConfig struct { - User string - Verb string +func (a *AuthProviderOIDC) determineNodeExpiry(idTokenExpiration time.Time) time.Time { + if a.cfg.UseExpiryFromToken { + return idTokenExpiration + } + + return time.Now().Add(a.cfg.Expiry) } -//go:embed assets/oidc_callback_template.html -var oidcCallbackTemplateContent string - -var oidcCallbackTemplate = template.Must( - template.New("oidccallback").Parse(oidcCallbackTemplateContent), -) - -// OIDCCallback handles the callback from the OIDC endpoint -// Retrieves the nkey from the state cache and adds the node to the users email user -// TODO: A confirmation page for new nodes should be added to avoid phishing vulnerabilities -// TODO: Add groups information from OIDC tokens into node HostInfo -// Listens in /oidc/callback. -func (h *Headscale) OIDCCallback( - writer http.ResponseWriter, - req *http.Request, -) { - code, state, err := validateOIDCCallbackParams(writer, req) - if err != nil { - return - } - - rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state) - if err != nil { - return - } - - idToken, err := h.verifyIDTokenForOIDCCallback(req.Context(), writer, rawIDToken) - if err != nil { - return - } - idTokenExpiry := h.determineTokenExpiration(idToken.Expiry) - - // TODO: we can use userinfo at some point to grab additional information about the user (groups membership, etc) - // userInfo, err := oidcProvider.UserInfo(context.Background(), oauth2.StaticTokenSource(oauth2Token)) - // if err != nil { - // c.String(http.StatusBadRequest, fmt.Sprintf("Failed to retrieve userinfo")) - // return - // } - - claims, err := extractIDTokenClaims(writer, idToken) - if err != nil { - return - } - - if err := validateOIDCAllowedDomains(writer, h.cfg.OIDC.AllowedDomains, claims); err != nil { - return - } - - if err := validateOIDCAllowedGroups(writer, h.cfg.OIDC.AllowedGroups, claims); err != nil { - return - } - - if err := validateOIDCAllowedUsers(writer, h.cfg.OIDC.AllowedUsers, claims); err != nil { - return - } - - machineKey, nodeExists, err := h.validateNodeForOIDCCallback( - writer, - state, - claims, - idTokenExpiry, - ) - if err != nil || nodeExists { - return - } - - userName, err := getUserName(writer, claims, h.cfg.OIDC.StripEmaildomain) - if err != nil { - return - } - - // register the node if it's new - log.Debug().Msg("Registering new node after successful callback") - - user, err := h.findOrCreateNewUserForOIDCCallback(writer, userName) - if err != nil { - return - } - - if err := h.registerNodeForOIDCCallback(writer, user, machineKey, idTokenExpiry); err != nil { - return - } - - content, err := renderOIDCCallbackTemplate(writer, claims) - if err != nil { - return - } - - writer.Header().Set("Content-Type", "text/html; charset=utf-8") - writer.WriteHeader(http.StatusOK) - if _, err := writer.Write(content.Bytes()); err != nil { - util.LogErr(err, "Failed to write response") - } -} - -func validateOIDCCallbackParams( - writer http.ResponseWriter, +func extractCodeAndStateParamFromRequest( req *http.Request, ) (string, string, error) { code := req.URL.Query().Get("code") state := req.URL.Query().Get("state") if code == "" || state == "" { - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Wrong params")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return "", "", errEmptyOIDCCallbackParams + return "", "", NewHTTPError(http.StatusBadRequest, "missing code or state parameter", errEmptyOIDCCallbackParams) } return code, state, nil } -func (h *Headscale) getIDTokenForOIDCCallback( +// getOauth2Token exchanges the code from the callback for an oauth2 token. +func (a *AuthProviderOIDC) getOauth2Token( ctx context.Context, - writer http.ResponseWriter, - code, state string, -) (string, error) { - oauth2Token, err := h.oauth2Config.Exchange(ctx, code) + code string, + state string, +) (*oauth2.Token, error) { + var exchangeOpts []oauth2.AuthCodeOption + + if a.cfg.PKCE.Enabled { + regInfo, ok := a.registrationCache.Get(state) + if !ok { + return nil, NewHTTPError(http.StatusNotFound, "registration not found", errNoOIDCRegistrationInfo) + } + if regInfo.Verifier != nil { + exchangeOpts = []oauth2.AuthCodeOption{oauth2.VerifierOption(*regInfo.Verifier)} + } + } + + oauth2Token, err := a.oauth2Config.Exchange(ctx, code, exchangeOpts...) if err != nil { - util.LogErr(err, "Could not exchange code for token") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("Could not exchange code for token")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return "", err + return nil, NewHTTPError(http.StatusForbidden, "invalid code", fmt.Errorf("could not exchange code for token: %w", err)) } - log.Trace(). - Caller(). - Str("code", code). - Str("state", state). - Msg("Got oidc callback") - - rawIDToken, rawIDTokenOK := oauth2Token.Extra("id_token").(string) - if !rawIDTokenOK { - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("Could not extract ID Token")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return "", errNoOIDCIDToken - } - - return rawIDToken, nil + return oauth2Token, err } -func (h *Headscale) verifyIDTokenForOIDCCallback( +// extractIDToken extracts the ID token from the oauth2 token. +func (a *AuthProviderOIDC) extractIDToken( ctx context.Context, - writer http.ResponseWriter, - rawIDToken string, + oauth2Token *oauth2.Token, ) (*oidc.IDToken, error) { - verifier := h.oidcProvider.Verifier(&oidc.Config{ClientID: h.cfg.OIDC.ClientID}) + rawIDToken, ok := oauth2Token.Extra("id_token").(string) + if !ok { + return nil, NewHTTPError(http.StatusBadRequest, "no id_token", errNoOIDCIDToken) + } + + verifier := a.oidcProvider.Verifier(&oidc.Config{ClientID: a.cfg.ClientID}) idToken, err := verifier.Verify(ctx, rawIDToken) if err != nil { - util.LogErr(err, "failed to verify id token") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("Failed to verify id token")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, err + return nil, NewHTTPError(http.StatusForbidden, "failed to verify id_token", fmt.Errorf("failed to verify ID token: %w", err)) } return idToken, nil } -func extractIDTokenClaims( - writer http.ResponseWriter, - idToken *oidc.IDToken, -) (*IDTokenClaims, error) { - var claims IDTokenClaims - if err := idToken.Claims(&claims); err != nil { - util.LogErr(err, "Failed to decode id token claims") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, werr := writer.Write([]byte("Failed to decode id token claims")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, err - } - - return &claims, nil -} - // validateOIDCAllowedDomains checks that if AllowedDomains is provided, // that the authenticated principal ends with @. func validateOIDCAllowedDomains( - writer http.ResponseWriter, allowedDomains []string, - claims *IDTokenClaims, + claims *types.OIDCClaims, ) error { if len(allowedDomains) > 0 { if at := strings.LastIndex(claims.Email, "@"); at < 0 || - !util.IsStringInSlice(allowedDomains, claims.Email[at+1:]) { - log.Trace().Msg("authenticated principal does not match any allowed domain") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("unauthorized principal (domain mismatch)")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return errOIDCAllowedDomains + !slices.Contains(allowedDomains, claims.Email[at+1:]) { + return NewHTTPError(http.StatusUnauthorized, "unauthorised domain", errOIDCAllowedDomains) } } @@ -391,274 +423,197 @@ func validateOIDCAllowedDomains( // claims.Groups can be populated by adding a client scope named // 'groups' that contains group membership. func validateOIDCAllowedGroups( - writer http.ResponseWriter, allowedGroups []string, - claims *IDTokenClaims, + claims *types.OIDCClaims, ) error { - if len(allowedGroups) > 0 { - for _, group := range allowedGroups { - if util.IsStringInSlice(claims.Groups, group) { - return nil - } + for _, group := range allowedGroups { + if slices.Contains(claims.Groups, group) { + return nil } - - log.Trace().Msg("authenticated principal not in any allowed groups") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("unauthorized principal (allowed groups)")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return errOIDCAllowedGroups } - return nil + return NewHTTPError(http.StatusUnauthorized, "unauthorised group", errOIDCAllowedGroups) } // validateOIDCAllowedUsers checks that if AllowedUsers is provided, // that the authenticated principal is part of that list. func validateOIDCAllowedUsers( - writer http.ResponseWriter, allowedUsers []string, - claims *IDTokenClaims, + claims *types.OIDCClaims, ) error { - if len(allowedUsers) > 0 && - !util.IsStringInSlice(allowedUsers, claims.Email) { - log.Trace().Msg("authenticated principal does not match any allowed user") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("unauthorized principal (user mismatch)")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return errOIDCAllowedUsers + if !slices.Contains(allowedUsers, claims.Email) { + return NewHTTPError(http.StatusUnauthorized, "unauthorised user", errOIDCAllowedUsers) } return nil } -// validateNode retrieves node information if it exist -// The error is not important, because if it does not -// exist, then this is a new node and we will move -// on to registration. -func (h *Headscale) validateNodeForOIDCCallback( - writer http.ResponseWriter, - state string, - claims *IDTokenClaims, - expiry time.Time, -) (*key.MachinePublic, bool, error) { - // retrieve nodekey from state cache - machineKeyIf, machineKeyFound := h.registrationCache.Get(state) - if !machineKeyFound { - log.Trace(). - Msg("requested node state key expired before authorisation completed") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("state has expired")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, false, errOIDCNodeKeyMissing - } - - var machineKey key.MachinePublic - machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic) - if !machineKeyOK { - log.Trace(). - Interface("got", machineKeyIf). - Msg("requested node state key is not a nodekey") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write([]byte("state is invalid")) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, false, errOIDCInvalidNodeState - } - - // retrieve node information if it exist - // The error is not important, because if it does not - // exist, then this is a new node and we will move - // on to registration. - node, _ := h.db.GetNodeByMachineKey(machineKey) - - if node != nil { - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msg("node already registered, reauthenticating") - - err := h.db.NodeSetExpiry(node, expiry) - if err != nil { - util.LogErr(err, "Failed to refresh node") - http.Error( - writer, - "Failed to refresh node", - http.StatusInternalServerError, - ) - - return nil, true, err - } - log.Debug(). - Str("node", node.Hostname). - Str("expiresAt", fmt.Sprintf("%v", expiry)). - Msg("successfully refreshed node") - - var content bytes.Buffer - if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ - User: claims.Email, - Verb: "Reauthenticated", - }); err != nil { - log.Error(). - Str("func", "OIDCCallback"). - Str("type", "reauthenticate"). - Err(err). - Msg("Could not render OIDC callback template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, werr := writer.Write([]byte("Could not render OIDC callback template")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, true, err - } - - writer.Header().Set("Content-Type", "text/html; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(content.Bytes()) - if err != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, true, nil - } - - return &machineKey, false, nil -} - -func getUserName( - writer http.ResponseWriter, - claims *IDTokenClaims, - stripEmaildomain bool, -) (string, error) { - userName, err := util.NormalizeToFQDNRules( - claims.Email, - stripEmaildomain, - ) - if err != nil { - util.LogErr(err, "couldn't normalize email") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, werr := writer.Write([]byte("couldn't normalize email")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return "", err - } - - return userName, nil -} - -func (h *Headscale) findOrCreateNewUserForOIDCCallback( - writer http.ResponseWriter, - userName string, -) (*types.User, error) { - user, err := h.db.GetUser(userName) - if errors.Is(err, db.ErrUserNotFound) { - user, err = h.db.CreateUser(userName) - if err != nil { - log.Error(). - Err(err). - Caller(). - Msgf("could not create new user '%s'", userName) - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, werr := writer.Write([]byte("could not create user")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, err - } - } else if err != nil { - log.Error(). - Caller(). - Err(err). - Str("user", userName). - Msg("could not find or create user") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, werr := writer.Write([]byte("could not find or create user")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return nil, err - } - - return user, nil -} - -func (h *Headscale) registerNodeForOIDCCallback( - writer http.ResponseWriter, - user *types.User, - machineKey *key.MachinePublic, - expiry time.Time, +// doOIDCAuthorization applies authorization tests to claims. +// +// The following tests are always applied: +// +// - validateOIDCAllowedGroups +// +// The following tests are applied if cfg.EmailVerifiedRequired=false +// or claims.email_verified=true: +// +// - validateOIDCAllowedDomains +// - validateOIDCAllowedUsers +// +// NOTE that, contrary to the function name, validateOIDCAllowedUsers +// only checks the email address -- not the username. +func doOIDCAuthorization( + cfg *types.OIDCConfig, + claims *types.OIDCClaims, ) error { - if _, err := h.db.RegisterNodeFromAuthCallback( - // TODO(kradalby): find a better way to use the cache across modules - h.registrationCache, - *machineKey, - user.Name, + if len(cfg.AllowedGroups) > 0 { + err := validateOIDCAllowedGroups(cfg.AllowedGroups, claims) + if err != nil { + return err + } + } + + trustEmail := !cfg.EmailVerifiedRequired || bool(claims.EmailVerified) + + hasEmailTests := len(cfg.AllowedDomains) > 0 || len(cfg.AllowedUsers) > 0 + if !trustEmail && hasEmailTests { + return NewHTTPError(http.StatusUnauthorized, "unverified email", errOIDCUnverifiedEmail) + } + + if len(cfg.AllowedDomains) > 0 { + err := validateOIDCAllowedDomains(cfg.AllowedDomains, claims) + if err != nil { + return err + } + } + + if len(cfg.AllowedUsers) > 0 { + err := validateOIDCAllowedUsers(cfg.AllowedUsers, claims) + if err != nil { + return err + } + } + + return nil +} + +// getRegistrationIDFromState retrieves the registration ID from the state. +func (a *AuthProviderOIDC) getRegistrationIDFromState(state string) *types.RegistrationID { + regInfo, ok := a.registrationCache.Get(state) + if !ok { + return nil + } + + return ®Info.RegistrationID +} + +func (a *AuthProviderOIDC) createOrUpdateUserFromClaim( + claims *types.OIDCClaims, +) (*types.User, change.Change, error) { + var ( + user *types.User + err error + newUser bool + c change.Change + ) + user, err = a.h.state.GetUserByOIDCIdentifier(claims.Identifier()) + if err != nil && !errors.Is(err, db.ErrUserNotFound) { + return nil, change.Change{}, fmt.Errorf("creating or updating user: %w", err) + } + + // if the user is still not found, create a new empty user. + // TODO(kradalby): This context is not inherited from the request, which is probably not ideal. + // However, we need a context to use the OIDC provider. + if user == nil { + newUser = true + user = &types.User{} + } + + user.FromClaim(claims, a.cfg.EmailVerifiedRequired) + + if newUser { + user, c, err = a.h.state.CreateUser(*user) + if err != nil { + return nil, change.Change{}, fmt.Errorf("creating user: %w", err) + } + } else { + _, c, err = a.h.state.UpdateUser(types.UserID(user.ID), func(u *types.User) error { + *u = *user + return nil + }) + if err != nil { + return nil, change.Change{}, fmt.Errorf("updating user: %w", err) + } + } + + return user, c, nil +} + +func (a *AuthProviderOIDC) handleRegistration( + user *types.User, + registrationID types.RegistrationID, + expiry time.Time, +) (bool, error) { + node, nodeChange, err := a.h.state.HandleNodeFromAuthPath( + registrationID, + types.UserID(user.ID), &expiry, util.RegisterMethodOIDC, - ); err != nil { - util.LogErr(err, "could not register node") - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, werr := writer.Write([]byte("could not register node")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } - - return err + ) + if err != nil { + return false, fmt.Errorf("could not register node: %w", err) } - return nil + // This is a bit of a back and forth, but we have a bit of a chicken and egg + // dependency here. + // Because the way the policy manager works, we need to have the node + // in the database, then add it to the policy manager and then we can + // approve the route. This means we get this dance where the node is + // first added to the database, then we add it to the policy manager via + // SaveNode (which automatically updates the policy manager) and then we can auto approve the routes. + // As that only approves the struct object, we need to save it again and + // ensure we send an update. + // This works, but might be another good candidate for doing some sort of + // eventbus. + routesChange, err := a.h.state.AutoApproveRoutes(node) + if err != nil { + return false, fmt.Errorf("auto approving routes: %w", err) + } + + // Send both changes. Empty changes are ignored by Change(). + a.h.Change(nodeChange, routesChange) + + return !nodeChange.IsEmpty(), nil } func renderOIDCCallbackTemplate( - writer http.ResponseWriter, - claims *IDTokenClaims, + user *types.User, + verb string, ) (*bytes.Buffer, error) { - var content bytes.Buffer - if err := oidcCallbackTemplate.Execute(&content, oidcCallbackTemplateConfig{ - User: claims.Email, - Verb: "Authenticated", - }); err != nil { - log.Error(). - Str("func", "OIDCCallback"). - Str("type", "authenticate"). - Err(err). - Msg("Could not render OIDC callback template") + html := templates.OIDCCallback(user.Display(), verb).Render() + return bytes.NewBufferString(html), nil +} - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, werr := writer.Write([]byte("Could not render OIDC callback template")) - if werr != nil { - util.LogErr(err, "Failed to write response") - } +// getCookieName generates a unique cookie name based on a cookie value. +func getCookieName(baseName, value string) string { + return fmt.Sprintf("%s_%s", baseName, value[:6]) +} - return nil, err +func setCSRFCookie(w http.ResponseWriter, r *http.Request, name string) (string, error) { + val, err := util.GenerateRandomStringURLSafe(64) + if err != nil { + return val, err } - return &content, nil + c := &http.Cookie{ + Path: "/oidc/callback", + Name: getCookieName(name, val), + Value: val, + MaxAge: int(time.Hour.Seconds()), + Secure: r.TLS != nil, + HttpOnly: true, + } + http.SetCookie(w, c) + + return val, nil } diff --git a/hscontrol/oidc_template_test.go b/hscontrol/oidc_template_test.go new file mode 100644 index 00000000..367451b1 --- /dev/null +++ b/hscontrol/oidc_template_test.go @@ -0,0 +1,51 @@ +package hscontrol + +import ( + "testing" + + "github.com/juanfont/headscale/hscontrol/templates" + "github.com/stretchr/testify/assert" +) + +func TestOIDCCallbackTemplate(t *testing.T) { + tests := []struct { + name string + userName string + verb string + }{ + { + name: "logged_in_user", + userName: "test@example.com", + verb: "Logged in", + }, + { + name: "registered_user", + userName: "newuser@example.com", + verb: "Registered", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Render using the elem-go template + html := templates.OIDCCallback(tt.userName, tt.verb).Render() + + // Verify the HTML contains expected elements + assert.Contains(t, html, "") + assert.Contains(t, html, "Headscale Authentication Succeeded") + assert.Contains(t, html, tt.verb) + assert.Contains(t, html, tt.userName) + assert.Contains(t, html, "You can now close this window") + + // Verify Material for MkDocs design system CSS is present + assert.Contains(t, html, "Material for MkDocs") + assert.Contains(t, html, "Roboto") + assert.Contains(t, html, ".md-typeset") + + // Verify SVG elements are present + assert.Contains(t, html, " want=%v | got=%v", tC.name, tC.wantErr, err) + } + }) + } +} diff --git a/hscontrol/platform_config.go b/hscontrol/platform_config.go index 0404f546..23c4d25d 100644 --- a/hscontrol/platform_config.go +++ b/hscontrol/platform_config.go @@ -9,94 +9,17 @@ import ( "github.com/gofrs/uuid/v5" "github.com/gorilla/mux" - "github.com/rs/zerolog/log" + "github.com/juanfont/headscale/hscontrol/templates" ) -//go:embed templates/apple.html -var appleTemplate string - -//go:embed templates/windows.html -var windowsTemplate string - // WindowsConfigMessage shows a simple message in the browser for how to configure the Windows Tailscale client. func (h *Headscale) WindowsConfigMessage( writer http.ResponseWriter, req *http.Request, ) { - winTemplate := template.Must(template.New("windows").Parse(windowsTemplate)) - config := map[string]interface{}{ - "URL": h.cfg.ServerURL, - } - - var payload bytes.Buffer - if err := winTemplate.Execute(&payload, config); err != nil { - log.Error(). - Str("handler", "WindowsRegConfig"). - Err(err). - Msg("Could not render Windows index template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render Windows index template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(payload.Bytes()) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } -} - -// WindowsRegConfig generates and serves a .reg file configured with the Headscale server address. -func (h *Headscale) WindowsRegConfig( - writer http.ResponseWriter, - req *http.Request, -) { - config := WindowsRegistryConfig{ - URL: h.cfg.ServerURL, - } - - var content bytes.Buffer - if err := windowsRegTemplate.Execute(&content, config); err != nil { - log.Error(). - Str("handler", "WindowsRegConfig"). - Err(err). - Msg("Could not render Apple macOS template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render Windows registry template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - - writer.Header().Set("Content-Type", "text/x-ms-regedit; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err := writer.Write(content.Bytes()) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } + writer.Write([]byte(templates.Windows(h.cfg.ServerURL).Render())) } // AppleConfigMessage shows a simple message in the browser to point the user to the iOS/MacOS profile and instructions for how to install it. @@ -104,41 +27,9 @@ func (h *Headscale) AppleConfigMessage( writer http.ResponseWriter, req *http.Request, ) { - appleTemplate := template.Must(template.New("apple").Parse(appleTemplate)) - - config := map[string]interface{}{ - "URL": h.cfg.ServerURL, - } - - var payload bytes.Buffer - if err := appleTemplate.Execute(&payload, config); err != nil { - log.Error(). - Str("handler", "AppleMobileConfig"). - Err(err). - Msg("Could not render Apple index template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render Apple index template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - - return - } - writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err := writer.Write(payload.Bytes()) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } + writer.Write([]byte(templates.Apple(h.cfg.ServerURL).Render())) } func (h *Headscale) ApplePlatformConfig( @@ -148,51 +39,19 @@ func (h *Headscale) ApplePlatformConfig( vars := mux.Vars(req) platform, ok := vars["platform"] if !ok { - log.Error(). - Str("handler", "ApplePlatformConfig"). - Msg("No platform specified") - http.Error(writer, "No platform specified", http.StatusBadRequest) - + httpError(writer, NewHTTPError(http.StatusBadRequest, "no platform specified", nil)) return } id, err := uuid.NewV4() if err != nil { - log.Error(). - Str("handler", "ApplePlatformConfig"). - Err(err). - Msg("Failed not create UUID") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Failed to create UUID")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, err) return } contentID, err := uuid.NewV4() if err != nil { - log.Error(). - Str("handler", "ApplePlatformConfig"). - Err(err). - Msg("Failed not create UUID") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Failed to create content UUID")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, err) return } @@ -202,68 +61,25 @@ func (h *Headscale) ApplePlatformConfig( } var payload bytes.Buffer - handleMacError := func(ierr error) { - log.Error(). - Str("handler", "ApplePlatformConfig"). - Err(ierr). - Msg("Could not render Apple macOS template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render Apple macOS template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - } switch platform { case "macos-standalone": if err := macosStandaloneTemplate.Execute(&payload, platformConfig); err != nil { - handleMacError(err) - + httpError(writer, err) return } case "macos-app-store": if err := macosAppStoreTemplate.Execute(&payload, platformConfig); err != nil { - handleMacError(err) - + httpError(writer, err) return } case "ios": if err := iosTemplate.Execute(&payload, platformConfig); err != nil { - log.Error(). - Str("handler", "ApplePlatformConfig"). - Err(err). - Msg("Could not render Apple iOS template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render Apple iOS template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, err) return } default: - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusBadRequest) - _, err := writer.Write( - []byte("Invalid platform. Only ios, macos-app-store and macos-standalone are supported"), - ) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, NewHTTPError(http.StatusBadRequest, "platform must be ios, macos-app-store or macos-standalone", nil)) return } @@ -275,38 +91,14 @@ func (h *Headscale) ApplePlatformConfig( var content bytes.Buffer if err := commonTemplate.Execute(&content, config); err != nil { - log.Error(). - Str("handler", "ApplePlatformConfig"). - Err(err). - Msg("Could not render Apple platform template") - - writer.Header().Set("Content-Type", "text/plain; charset=utf-8") - writer.WriteHeader(http.StatusInternalServerError) - _, err := writer.Write([]byte("Could not render Apple platform template")) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } - + httpError(writer, err) return } writer.Header(). Set("Content-Type", "application/x-apple-aspen-config; charset=utf-8") writer.WriteHeader(http.StatusOK) - _, err = writer.Write(content.Bytes()) - if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") - } -} - -type WindowsRegistryConfig struct { - URL string + writer.Write(content.Bytes()) } type AppleMobileConfig struct { @@ -320,14 +112,6 @@ type AppleMobilePlatformConfig struct { URL string } -var windowsRegTemplate = textTemplate.Must( - textTemplate.New("windowsconfig").Parse(`Windows Registry Editor Version 5.00 - -[HKEY_LOCAL_MACHINE\SOFTWARE\Tailscale IPN] -"UnattendedMode"="always" -"LoginURL"="{{.URL}}" -`)) - var commonTemplate = textTemplate.Must( textTemplate.New("mobileconfig").Parse(` diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go deleted file mode 100644 index 11f280ac..00000000 --- a/hscontrol/policy/acls.go +++ /dev/null @@ -1,950 +0,0 @@ -package policy - -import ( - "encoding/json" - "errors" - "fmt" - "io" - "net/netip" - "os" - "path/filepath" - "strconv" - "strings" - "time" - - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" - "github.com/tailscale/hujson" - "go4.org/netipx" - "gopkg.in/yaml.v3" - "tailscale.com/tailcfg" -) - -var ( - ErrEmptyPolicy = errors.New("empty policy") - ErrInvalidAction = errors.New("invalid action") - ErrInvalidGroup = errors.New("invalid group") - ErrInvalidTag = errors.New("invalid tag") - ErrInvalidPortFormat = errors.New("invalid port format") - ErrWildcardIsNeeded = errors.New("wildcard as port is required for the protocol") -) - -const ( - portRangeBegin = 0 - portRangeEnd = 65535 - expectedTokenItems = 2 -) - -// For some reason golang.org/x/net/internal/iana is an internal package. -const ( - protocolICMP = 1 // Internet Control Message - protocolIGMP = 2 // Internet Group Management - protocolIPv4 = 4 // IPv4 encapsulation - protocolTCP = 6 // Transmission Control - protocolEGP = 8 // Exterior Gateway Protocol - protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) - protocolUDP = 17 // User Datagram - protocolGRE = 47 // Generic Routing Encapsulation - protocolESP = 50 // Encap Security Payload - protocolAH = 51 // Authentication Header - protocolIPv6ICMP = 58 // ICMP for IPv6 - protocolSCTP = 132 // Stream Control Transmission Protocol - ProtocolFC = 133 // Fibre Channel -) - -// LoadACLPolicyFromPath loads the ACL policy from the specify path, and generates the ACL rules. -func LoadACLPolicyFromPath(path string) (*ACLPolicy, error) { - log.Debug(). - Str("func", "LoadACLPolicy"). - Str("path", path). - Msg("Loading ACL policy from path") - - policyFile, err := os.Open(path) - if err != nil { - return nil, err - } - defer policyFile.Close() - - policyBytes, err := io.ReadAll(policyFile) - if err != nil { - return nil, err - } - - log.Debug(). - Str("path", path). - Bytes("file", policyBytes). - Msg("Loading ACLs") - - switch filepath.Ext(path) { - case ".yml", ".yaml": - return LoadACLPolicyFromBytes(policyBytes, "yaml") - } - - return LoadACLPolicyFromBytes(policyBytes, "hujson") -} - -func LoadACLPolicyFromBytes(acl []byte, format string) (*ACLPolicy, error) { - var policy ACLPolicy - switch format { - case "yaml": - err := yaml.Unmarshal(acl, &policy) - if err != nil { - return nil, err - } - - default: - ast, err := hujson.Parse(acl) - if err != nil { - return nil, err - } - - ast.Standardize() - acl = ast.Pack() - err = json.Unmarshal(acl, &policy) - if err != nil { - return nil, err - } - } - - if policy.IsZero() { - return nil, ErrEmptyPolicy - } - - return &policy, nil -} - -func GenerateFilterAndSSHRules( - policy *ACLPolicy, - node *types.Node, - peers types.Nodes, -) ([]tailcfg.FilterRule, *tailcfg.SSHPolicy, error) { - // If there is no policy defined, we default to allow all - if policy == nil { - return tailcfg.FilterAllowAll, &tailcfg.SSHPolicy{}, nil - } - - rules, err := policy.generateFilterRules(node, peers) - if err != nil { - return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err - } - - log.Trace().Interface("ACL", rules).Str("node", node.GivenName).Msg("ACL rules") - - var sshPolicy *tailcfg.SSHPolicy - sshRules, err := policy.generateSSHRules(node, peers) - if err != nil { - return []tailcfg.FilterRule{}, &tailcfg.SSHPolicy{}, err - } - - log.Trace(). - Interface("SSH", sshRules). - Str("node", node.GivenName). - Msg("SSH rules") - - if sshPolicy == nil { - sshPolicy = &tailcfg.SSHPolicy{} - } - sshPolicy.Rules = sshRules - - return rules, sshPolicy, nil -} - -// generateFilterRules takes a set of nodes and an ACLPolicy and generates a -// set of Tailscale compatible FilterRules used to allow traffic on clients. -func (pol *ACLPolicy) generateFilterRules( - node *types.Node, - peers types.Nodes, -) ([]tailcfg.FilterRule, error) { - rules := []tailcfg.FilterRule{} - nodes := append(peers, node) - - for index, acl := range pol.ACLs { - if acl.Action != "accept" { - return nil, ErrInvalidAction - } - - srcIPs := []string{} - for srcIndex, src := range acl.Sources { - srcs, err := pol.expandSource(src, nodes) - if err != nil { - log.Error(). - Interface("src", src). - Int("ACL index", index). - Int("Src index", srcIndex). - Msgf("Error parsing ACL") - - return nil, err - } - srcIPs = append(srcIPs, srcs...) - } - - protocols, isWildcard, err := parseProtocol(acl.Protocol) - if err != nil { - log.Error(). - Msgf("Error parsing ACL %d. protocol unknown %s", index, acl.Protocol) - - return nil, err - } - - destPorts := []tailcfg.NetPortRange{} - for _, dest := range acl.Destinations { - alias, port, err := parseDestination(dest) - if err != nil { - return nil, err - } - - expanded, err := pol.ExpandAlias( - nodes, - alias, - ) - if err != nil { - return nil, err - } - - ports, err := expandPorts(port, isWildcard) - if err != nil { - return nil, err - } - - dests := []tailcfg.NetPortRange{} - for _, dest := range expanded.Prefixes() { - for _, port := range *ports { - pr := tailcfg.NetPortRange{ - IP: dest.String(), - Ports: port, - } - dests = append(dests, pr) - } - } - destPorts = append(destPorts, dests...) - } - - rules = append(rules, tailcfg.FilterRule{ - SrcIPs: srcIPs, - DstPorts: destPorts, - IPProto: protocols, - }) - } - - return rules, nil -} - -// ReduceFilterRules takes a node and a set of rules and removes all rules and destinations -// that are not relevant to that particular node. -func ReduceFilterRules(node *types.Node, rules []tailcfg.FilterRule) []tailcfg.FilterRule { - ret := []tailcfg.FilterRule{} - - for _, rule := range rules { - // record if the rule is actually relevant for the given node. - dests := []tailcfg.NetPortRange{} - - for _, dest := range rule.DstPorts { - expanded, err := util.ParseIPSet(dest.IP, nil) - // Fail closed, if we cant parse it, then we should not allow - // access. - if err != nil { - continue - } - - if node.IPAddresses.InIPSet(expanded) { - dests = append(dests, dest) - } - } - - if len(dests) > 0 { - ret = append(ret, tailcfg.FilterRule{ - SrcIPs: rule.SrcIPs, - DstPorts: dests, - IPProto: rule.IPProto, - }) - } - } - - return ret -} - -func (pol *ACLPolicy) generateSSHRules( - node *types.Node, - peers types.Nodes, -) ([]*tailcfg.SSHRule, error) { - rules := []*tailcfg.SSHRule{} - - acceptAction := tailcfg.SSHAction{ - Message: "", - Reject: false, - Accept: true, - SessionDuration: 0, - AllowAgentForwarding: false, - HoldAndDelegate: "", - AllowLocalPortForwarding: true, - } - - rejectAction := tailcfg.SSHAction{ - Message: "", - Reject: true, - Accept: false, - SessionDuration: 0, - AllowAgentForwarding: false, - HoldAndDelegate: "", - AllowLocalPortForwarding: false, - } - - for index, sshACL := range pol.SSHs { - var dest netipx.IPSetBuilder - for _, src := range sshACL.Destinations { - expanded, err := pol.ExpandAlias(append(peers, node), src) - if err != nil { - return nil, err - } - dest.AddSet(expanded) - } - - destSet, err := dest.IPSet() - if err != nil { - return nil, err - } - - if !node.IPAddresses.InIPSet(destSet) { - continue - } - - action := rejectAction - switch sshACL.Action { - case "accept": - action = acceptAction - case "check": - checkAction, err := sshCheckAction(sshACL.CheckPeriod) - if err != nil { - log.Error(). - Msgf("Error parsing SSH %d, check action with unparsable duration '%s'", index, sshACL.CheckPeriod) - } else { - action = *checkAction - } - default: - log.Error(). - Msgf("Error parsing SSH %d, unknown action '%s', skipping", index, sshACL.Action) - - continue - } - - principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources)) - for innerIndex, rawSrc := range sshACL.Sources { - if isWildcard(rawSrc) { - principals = append(principals, &tailcfg.SSHPrincipal{ - Any: true, - }) - } else if isGroup(rawSrc) { - users, err := pol.expandUsersFromGroup(rawSrc) - if err != nil { - log.Error(). - Msgf("Error parsing SSH %d, Source %d", index, innerIndex) - - return nil, err - } - - for _, user := range users { - principals = append(principals, &tailcfg.SSHPrincipal{ - UserLogin: user, - }) - } - } else { - expandedSrcs, err := pol.ExpandAlias( - peers, - rawSrc, - ) - if err != nil { - log.Error(). - Msgf("Error parsing SSH %d, Source %d", index, innerIndex) - - return nil, err - } - for _, expandedSrc := range expandedSrcs.Prefixes() { - principals = append(principals, &tailcfg.SSHPrincipal{ - NodeIP: expandedSrc.Addr().String(), - }) - } - } - } - - userMap := make(map[string]string, len(sshACL.Users)) - for _, user := range sshACL.Users { - userMap[user] = "=" - } - rules = append(rules, &tailcfg.SSHRule{ - Principals: principals, - SSHUsers: userMap, - Action: &action, - }) - } - - return rules, nil -} - -func sshCheckAction(duration string) (*tailcfg.SSHAction, error) { - sessionLength, err := time.ParseDuration(duration) - if err != nil { - return nil, err - } - - return &tailcfg.SSHAction{ - Message: "", - Reject: false, - Accept: true, - SessionDuration: sessionLength, - AllowAgentForwarding: false, - HoldAndDelegate: "", - AllowLocalPortForwarding: true, - }, nil -} - -func parseDestination(dest string) (string, string, error) { - var tokens []string - - // Check if there is a IPv4/6:Port combination, IPv6 has more than - // three ":". - tokens = strings.Split(dest, ":") - if len(tokens) < expectedTokenItems || len(tokens) > 3 { - port := tokens[len(tokens)-1] - - maybeIPv6Str := strings.TrimSuffix(dest, ":"+port) - log.Trace().Str("maybeIPv6Str", maybeIPv6Str).Msg("") - - filteredMaybeIPv6Str := maybeIPv6Str - if strings.Contains(maybeIPv6Str, "/") { - networkParts := strings.Split(maybeIPv6Str, "/") - filteredMaybeIPv6Str = networkParts[0] - } - - if maybeIPv6, err := netip.ParseAddr(filteredMaybeIPv6Str); err != nil && !maybeIPv6.Is6() { - log.Trace().Err(err).Msg("trying to parse as IPv6") - - return "", "", fmt.Errorf( - "failed to parse destination, tokens %v: %w", - tokens, - ErrInvalidPortFormat, - ) - } else { - tokens = []string{maybeIPv6Str, port} - } - } - - var alias string - // We can have here stuff like: - // git-server:* - // 192.168.1.0/24:22 - // fd7a:115c:a1e0::2:22 - // fd7a:115c:a1e0::2/128:22 - // tag:montreal-webserver:80,443 - // tag:api-server:443 - // example-host-1:* - if len(tokens) == expectedTokenItems { - alias = tokens[0] - } else { - alias = fmt.Sprintf("%s:%s", tokens[0], tokens[1]) - } - - return alias, tokens[len(tokens)-1], nil -} - -// parseProtocol reads the proto field of the ACL and generates a list of -// protocols that will be allowed, following the IANA IP protocol number -// https://www.iana.org/assignments/protocol-numbers/protocol-numbers.xhtml -// -// If the ACL proto field is empty, it allows ICMPv4, ICMPv6, TCP, and UDP, -// as per Tailscale behaviour (see tailcfg.FilterRule). -// -// Also returns a boolean indicating if the protocol -// requires all the destinations to use wildcard as port number (only TCP, -// UDP and SCTP support specifying ports). -func parseProtocol(protocol string) ([]int, bool, error) { - switch protocol { - case "": - return nil, false, nil - case "igmp": - return []int{protocolIGMP}, true, nil - case "ipv4", "ip-in-ip": - return []int{protocolIPv4}, true, nil - case "tcp": - return []int{protocolTCP}, false, nil - case "egp": - return []int{protocolEGP}, true, nil - case "igp": - return []int{protocolIGP}, true, nil - case "udp": - return []int{protocolUDP}, false, nil - case "gre": - return []int{protocolGRE}, true, nil - case "esp": - return []int{protocolESP}, true, nil - case "ah": - return []int{protocolAH}, true, nil - case "sctp": - return []int{protocolSCTP}, false, nil - case "icmp": - return []int{protocolICMP, protocolIPv6ICMP}, true, nil - - default: - protocolNumber, err := strconv.Atoi(protocol) - if err != nil { - return nil, false, err - } - needsWildcard := protocolNumber != protocolTCP && - protocolNumber != protocolUDP && - protocolNumber != protocolSCTP - - return []int{protocolNumber}, needsWildcard, nil - } -} - -// expandSource returns a set of Source IPs that would be associated -// with the given src alias. -func (pol *ACLPolicy) expandSource( - src string, - nodes types.Nodes, -) ([]string, error) { - ipSet, err := pol.ExpandAlias(nodes, src) - if err != nil { - return []string{}, err - } - - prefixes := []string{} - - for _, prefix := range ipSet.Prefixes() { - prefixes = append(prefixes, prefix.String()) - } - - return prefixes, nil -} - -// expandalias has an input of either -// - a user -// - a group -// - a tag -// - a host -// - an ip -// - a cidr -// and transform these in IPAddresses. -func (pol *ACLPolicy) ExpandAlias( - nodes types.Nodes, - alias string, -) (*netipx.IPSet, error) { - if isWildcard(alias) { - return util.ParseIPSet("*", nil) - } - - build := netipx.IPSetBuilder{} - - log.Debug(). - Str("alias", alias). - Msg("Expanding") - - // if alias is a group - if isGroup(alias) { - return pol.expandIPsFromGroup(alias, nodes) - } - - // if alias is a tag - if isTag(alias) { - return pol.expandIPsFromTag(alias, nodes) - } - - // if alias is a user - if ips, err := pol.expandIPsFromUser(alias, nodes); ips != nil { - return ips, err - } - - // if alias is an host - // Note, this is recursive. - if h, ok := pol.Hosts[alias]; ok { - log.Trace().Str("host", h.String()).Msg("ExpandAlias got hosts entry") - - return pol.ExpandAlias(nodes, h.String()) - } - - // if alias is an IP - if ip, err := netip.ParseAddr(alias); err == nil { - return pol.expandIPsFromSingleIP(ip, nodes) - } - - // if alias is an IP Prefix (CIDR) - if prefix, err := netip.ParsePrefix(alias); err == nil { - return pol.expandIPsFromIPPrefix(prefix, nodes) - } - - log.Warn().Msgf("No IPs found with the alias %v", alias) - - return build.IPSet() -} - -// excludeCorrectlyTaggedNodes will remove from the list of input nodes the ones -// that are correctly tagged since they should not be listed as being in the user -// we assume in this function that we only have nodes from 1 user. -func excludeCorrectlyTaggedNodes( - aclPolicy *ACLPolicy, - nodes types.Nodes, - user string, -) types.Nodes { - out := types.Nodes{} - tags := []string{} - for tag := range aclPolicy.TagOwners { - owners, _ := expandOwnersFromTag(aclPolicy, user) - ns := append(owners, user) - if util.StringOrPrefixListContains(ns, user) { - tags = append(tags, tag) - } - } - // for each node if tag is in tags list, don't append it. - for _, node := range nodes { - found := false - - if node.Hostinfo == nil { - continue - } - - for _, t := range node.Hostinfo.RequestTags { - if util.StringOrPrefixListContains(tags, t) { - found = true - - break - } - } - if len(node.ForcedTags) > 0 { - found = true - } - if !found { - out = append(out, node) - } - } - - return out -} - -func expandPorts(portsStr string, isWild bool) (*[]tailcfg.PortRange, error) { - if isWildcard(portsStr) { - return &[]tailcfg.PortRange{ - {First: portRangeBegin, Last: portRangeEnd}, - }, nil - } - - if isWild { - return nil, ErrWildcardIsNeeded - } - - ports := []tailcfg.PortRange{} - for _, portStr := range strings.Split(portsStr, ",") { - log.Trace().Msgf("parsing portstring: %s", portStr) - rang := strings.Split(portStr, "-") - switch len(rang) { - case 1: - port, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) - if err != nil { - return nil, err - } - ports = append(ports, tailcfg.PortRange{ - First: uint16(port), - Last: uint16(port), - }) - - case expectedTokenItems: - start, err := strconv.ParseUint(rang[0], util.Base10, util.BitSize16) - if err != nil { - return nil, err - } - last, err := strconv.ParseUint(rang[1], util.Base10, util.BitSize16) - if err != nil { - return nil, err - } - ports = append(ports, tailcfg.PortRange{ - First: uint16(start), - Last: uint16(last), - }) - - default: - return nil, ErrInvalidPortFormat - } - } - - return &ports, nil -} - -// expandOwnersFromTag will return a list of user. An owner can be either a user or a group -// a group cannot be composed of groups. -func expandOwnersFromTag( - pol *ACLPolicy, - tag string, -) ([]string, error) { - var owners []string - ows, ok := pol.TagOwners[tag] - if !ok { - return []string{}, fmt.Errorf( - "%w. %v isn't owned by a TagOwner. Please add one first. https://tailscale.com/kb/1018/acls/#tag-owners", - ErrInvalidTag, - tag, - ) - } - for _, owner := range ows { - if isGroup(owner) { - gs, err := pol.expandUsersFromGroup(owner) - if err != nil { - return []string{}, err - } - owners = append(owners, gs...) - } else { - owners = append(owners, owner) - } - } - - return owners, nil -} - -// expandUsersFromGroup will return the list of user inside the group -// after some validation. -func (pol *ACLPolicy) expandUsersFromGroup( - group string, -) ([]string, error) { - users := []string{} - log.Trace().Caller().Interface("pol", pol).Msg("test") - aclGroups, ok := pol.Groups[group] - if !ok { - return []string{}, fmt.Errorf( - "group %v isn't registered. %w", - group, - ErrInvalidGroup, - ) - } - for _, group := range aclGroups { - if isGroup(group) { - return []string{}, fmt.Errorf( - "%w. A group cannot be composed of groups. https://tailscale.com/kb/1018/acls/#groups", - ErrInvalidGroup, - ) - } - grp, err := util.NormalizeToFQDNRulesConfigFromViper(group) - if err != nil { - return []string{}, fmt.Errorf( - "failed to normalize group %q, err: %w", - group, - ErrInvalidGroup, - ) - } - users = append(users, grp) - } - - return users, nil -} - -func (pol *ACLPolicy) expandIPsFromGroup( - group string, - nodes types.Nodes, -) (*netipx.IPSet, error) { - build := netipx.IPSetBuilder{} - - users, err := pol.expandUsersFromGroup(group) - if err != nil { - return &netipx.IPSet{}, err - } - for _, user := range users { - filteredNodes := filterNodesByUser(nodes, user) - for _, node := range filteredNodes { - node.IPAddresses.AppendToIPSet(&build) - } - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromTag( - alias string, - nodes types.Nodes, -) (*netipx.IPSet, error) { - build := netipx.IPSetBuilder{} - - // check for forced tags - for _, node := range nodes { - if util.StringOrPrefixListContains(node.ForcedTags, alias) { - node.IPAddresses.AppendToIPSet(&build) - } - } - - // find tag owners - owners, err := expandOwnersFromTag(pol, alias) - if err != nil { - if errors.Is(err, ErrInvalidTag) { - ipSet, _ := build.IPSet() - if len(ipSet.Prefixes()) == 0 { - return ipSet, fmt.Errorf( - "%w. %v isn't owned by a TagOwner and no forced tags are defined", - ErrInvalidTag, - alias, - ) - } - - return build.IPSet() - } else { - return nil, err - } - } - - // filter out nodes per tag owner - for _, user := range owners { - nodes := filterNodesByUser(nodes, user) - for _, node := range nodes { - if node.Hostinfo == nil { - continue - } - - if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) { - node.IPAddresses.AppendToIPSet(&build) - } - } - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromUser( - user string, - nodes types.Nodes, -) (*netipx.IPSet, error) { - build := netipx.IPSetBuilder{} - - filteredNodes := filterNodesByUser(nodes, user) - filteredNodes = excludeCorrectlyTaggedNodes(pol, filteredNodes, user) - - // shortcurcuit if we have no nodes to get ips from. - if len(filteredNodes) == 0 { - return nil, nil //nolint - } - - for _, node := range filteredNodes { - node.IPAddresses.AppendToIPSet(&build) - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromSingleIP( - ip netip.Addr, - nodes types.Nodes, -) (*netipx.IPSet, error) { - log.Trace().Str("ip", ip.String()).Msg("ExpandAlias got ip") - - matches := nodes.FilterByIP(ip) - - build := netipx.IPSetBuilder{} - build.Add(ip) - - for _, node := range matches { - node.IPAddresses.AppendToIPSet(&build) - } - - return build.IPSet() -} - -func (pol *ACLPolicy) expandIPsFromIPPrefix( - prefix netip.Prefix, - nodes types.Nodes, -) (*netipx.IPSet, error) { - log.Trace().Str("prefix", prefix.String()).Msg("expandAlias got prefix") - build := netipx.IPSetBuilder{} - build.AddPrefix(prefix) - - // This is suboptimal and quite expensive, but if we only add the prefix, we will miss all the relevant IPv6 - // addresses for the hosts that belong to tailscale. This doesnt really affect stuff like subnet routers. - for _, node := range nodes { - for _, ip := range node.IPAddresses { - // log.Trace(). - // Msgf("checking if node ip (%s) is part of prefix (%s): %v, is single ip prefix (%v), addr: %s", ip.String(), prefix.String(), prefix.Contains(ip), prefix.IsSingleIP(), prefix.Addr().String()) - if prefix.Contains(ip) { - node.IPAddresses.AppendToIPSet(&build) - } - } - } - - return build.IPSet() -} - -func isWildcard(str string) bool { - return str == "*" -} - -func isGroup(str string) bool { - return strings.HasPrefix(str, "group:") -} - -func isTag(str string) bool { - return strings.HasPrefix(str, "tag:") -} - -// TagsOfNode will return the tags of the current node. -// Invalid tags are tags added by a user on a node, and that user doesn't have authority to add this tag. -// Valid tags are tags added by a user that is allowed in the ACL policy to add this tag. -func (pol *ACLPolicy) TagsOfNode( - node *types.Node, -) ([]string, []string) { - validTags := make([]string, 0) - invalidTags := make([]string, 0) - - validTagMap := make(map[string]bool) - invalidTagMap := make(map[string]bool) - for _, tag := range node.Hostinfo.RequestTags { - owners, err := expandOwnersFromTag(pol, tag) - if errors.Is(err, ErrInvalidTag) { - invalidTagMap[tag] = true - - continue - } - var found bool - for _, owner := range owners { - if node.User.Name == owner { - found = true - } - } - if found { - validTagMap[tag] = true - } else { - invalidTagMap[tag] = true - } - } - for tag := range invalidTagMap { - invalidTags = append(invalidTags, tag) - } - for tag := range validTagMap { - validTags = append(validTags, tag) - } - - return validTags, invalidTags -} - -func filterNodesByUser(nodes types.Nodes, user string) types.Nodes { - out := types.Nodes{} - for _, node := range nodes { - if node.User.Name == user { - out = append(out, node) - } - } - - return out -} - -// FilterNodesByACL returns the list of peers authorized to be accessed from a given node. -func FilterNodesByACL( - node *types.Node, - nodes types.Nodes, - filter []tailcfg.FilterRule, -) types.Nodes { - result := types.Nodes{} - - for index, peer := range nodes { - if peer.ID == node.ID { - continue - } - - if node.CanAccess(filter, nodes[index]) || peer.CanAccess(filter, node) { - result = append(result, peer) - } - } - - return result -} diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go deleted file mode 100644 index c048778d..00000000 --- a/hscontrol/policy/acls_test.go +++ /dev/null @@ -1,3211 +0,0 @@ -package policy - -import ( - "errors" - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/rs/zerolog/log" - "github.com/spf13/viper" - "github.com/stretchr/testify/assert" - "go4.org/netipx" - "gopkg.in/check.v1" - "tailscale.com/tailcfg" -) - -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&Suite{}) - -type Suite struct{} - -func (s *Suite) TestWrongPath(c *check.C) { - _, err := LoadACLPolicyFromPath("asdfg") - c.Assert(err, check.NotNil) -} - -func TestParsing(t *testing.T) { - tests := []struct { - name string - format string - acl string - want []tailcfg.FilterRule - wantErr bool - }{ - { - name: "invalid-hujson", - format: "hujson", - acl: ` -{ - `, - want: []tailcfg.FilterRule{}, - wantErr: true, - }, - { - name: "valid-hujson-invalid-content", - format: "hujson", - acl: ` -{ - "valid_json": true, - "but_a_policy_though": false -} - `, - want: []tailcfg.FilterRule{}, - wantErr: true, - }, - { - name: "invalid-cidr", - format: "hujson", - acl: ` -{"example-host-1": "100.100.100.100/42"} - `, - want: []tailcfg.FilterRule{}, - wantErr: true, - }, - { - name: "basic-rule", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - "192.168.1.0/24" - ], - "dst": [ - "*:22,3389", - "host-1:*", - ], - }, - ], -} - `, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, - {IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, - {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "parse-protocol", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "tcp", - "dst": [ - "host-1:*", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "udp", - "dst": [ - "host-1:53", - ], - }, - { - "Action": "accept", - "src": [ - "*", - ], - "proto": "icmp", - "dst": [ - "host-1:*", - ], - }, - ], -}`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - IPProto: []int{protocolTCP}, - }, - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, - }, - IPProto: []int{protocolUDP}, - }, - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - IPProto: []int{protocolICMP, protocolIPv6ICMP}, - }, - }, - wantErr: false, - }, - { - name: "port-wildcard", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "Action": "accept", - "src": [ - "*", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "port-range", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "subnet-1", - ], - "dst": [ - "host-1:5400-5500", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.100.101.0/24"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.100.100.100/32", - Ports: tailcfg.PortRange{First: 5400, Last: 5500}, - }, - }, - }, - }, - wantErr: false, - }, - { - name: "port-group", - format: "hujson", - acl: ` -{ - "groups": { - "group:example": [ - "testuser", - ], - }, - - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "group:example", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"200.200.200.200/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "port-user", - format: "hujson", - acl: ` -{ - "hosts": { - "host-1": "100.100.100.100", - "subnet-1": "100.100.101.100/24", - }, - - "acls": [ - { - "action": "accept", - "src": [ - "testuser", - ], - "dst": [ - "host-1:*", - ], - }, - ], -} -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"200.200.200.200/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "port-wildcard-yaml", - format: "yaml", - acl: ` ---- -hosts: - host-1: 100.100.100.100/32 - subnet-1: 100.100.101.100/24 -acls: - - action: accept - src: - - "*" - dst: - - host-1:* -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - { - name: "ipv6-yaml", - format: "yaml", - acl: ` ---- -hosts: - host-1: 100.100.100.100/32 - subnet-1: 100.100.101.100/24 -acls: - - action: accept - src: - - "*" - dst: - - host-1:* -`, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - wantErr: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - pol, err := LoadACLPolicyFromBytes([]byte(tt.acl), tt.format) - - if tt.wantErr && err == nil { - t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) - - return - } else if !tt.wantErr && err != nil { - t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if err != nil { - return - } - - rules, err := pol.generateFilterRules(&types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.100.100.100"), - }, - }, types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("200.200.200.200"), - }, - User: types.User{ - Name: "testuser", - }, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }) - - if (err != nil) != tt.wantErr { - t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if diff := cmp.Diff(tt.want, rules); diff != "" { - t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func (s *Suite) TestRuleInvalidGeneration(c *check.C) { - acl := []byte(` -{ - // Declare static groups of users beyond those in the identity service. - "groups": { - "group:example": [ - "user1@example.com", - "user2@example.com", - ], - }, - // Declare hostname aliases to use in place of IP addresses or subnets. - "hosts": { - "example-host-1": "100.100.100.100", - "example-host-2": "100.100.101.100/24", - }, - // Define who is allowed to use which tags. - "tagOwners": { - // Everyone in the montreal-admins or global-admins group are - // allowed to tag servers as montreal-webserver. - "tag:montreal-webserver": [ - "group:montreal-admins", - "group:global-admins", - ], - // Only a few admins are allowed to create API servers. - "tag:api-server": [ - "group:global-admins", - "example-host-1", - ], - }, - // Access control lists. - "acls": [ - // Engineering users, plus the president, can access port 22 (ssh) - // and port 3389 (remote desktop protocol) on all servers, and all - // ports on git-server or ci-server. - { - "action": "accept", - "src": [ - "group:engineering", - "president@example.com" - ], - "dst": [ - "*:22,3389", - "git-server:*", - "ci-server:*" - ], - }, - // Allow engineer users to access any port on a device tagged with - // tag:production. - { - "action": "accept", - "src": [ - "group:engineers" - ], - "dst": [ - "tag:production:*" - ], - }, - // Allow servers in the my-subnet host and 192.168.1.0/24 to access hosts - // on both networks. - { - "action": "accept", - "src": [ - "my-subnet", - "192.168.1.0/24" - ], - "dst": [ - "my-subnet:*", - "192.168.1.0/24:*" - ], - }, - // Allow every user of your network to access anything on the network. - // Comment out this section if you want to define specific ACL - // restrictions above. - { - "action": "accept", - "src": [ - "*" - ], - "dst": [ - "*:*" - ], - }, - // All users in Montreal are allowed to access the Montreal web - // servers. - { - "action": "accept", - "src": [ - "group:montreal-users" - ], - "dst": [ - "tag:montreal-webserver:80,443" - ], - }, - // Montreal web servers are allowed to make outgoing connections to - // the API servers, but only on https port 443. - // In contrast, this doesn't grant API servers the right to initiate - // any connections. - { - "action": "accept", - "src": [ - "tag:montreal-webserver" - ], - "dst": [ - "tag:api-server:443" - ], - }, - ], - // Declare tests to check functionality of ACL rules - "tests": [ - { - "src": "user1@example.com", - "accept": [ - "example-host-1:22", - "example-host-2:80" - ], - "deny": [ - "exapmle-host-2:100" - ], - }, - { - "src": "user2@example.com", - "accept": [ - "100.60.3.4:22" - ], - }, - ], -} - `) - pol, err := LoadACLPolicyFromBytes(acl, "hujson") - c.Assert(pol.ACLs, check.HasLen, 6) - c.Assert(err, check.IsNil) - - rules, err := pol.generateFilterRules(&types.Node{}, types.Nodes{}) - c.Assert(err, check.NotNil) - c.Assert(rules, check.IsNil) -} - -// TODO(kradalby): Make tests values safe, independent and descriptive. -func (s *Suite) TestInvalidAction(c *check.C) { - pol := &ACLPolicy{ - ACLs: []ACL{ - { - Action: "invalidAction", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - } - _, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) - c.Assert(errors.Is(err, ErrInvalidAction), check.Equals, true) -} - -func (s *Suite) TestInvalidGroupInGroup(c *check.C) { - // this ACL is wrong because the group in Sources sections doesn't exist - pol := &ACLPolicy{ - Groups: Groups{ - "group:test": []string{"foo"}, - "group:error": []string{"foo", "group:test"}, - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"group:error"}, - Destinations: []string{"*:*"}, - }, - }, - } - _, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) - c.Assert(errors.Is(err, ErrInvalidGroup), check.Equals, true) -} - -func (s *Suite) TestInvalidTagOwners(c *check.C) { - // this ACL is wrong because no tagOwners own the requested tag for the server - pol := &ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"tag:foo"}, - Destinations: []string{"*:*"}, - }, - }, - } - - _, _, err := GenerateFilterAndSSHRules(pol, &types.Node{}, types.Nodes{}) - c.Assert(errors.Is(err, ErrInvalidTag), check.Equals, true) -} - -func Test_expandGroup(t *testing.T) { - type field struct { - pol ACLPolicy - } - type args struct { - group string - stripEmail bool - } - tests := []struct { - name string - field field - args args - want []string - wantErr bool - }{ - { - name: "simple test", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1", "user2", "user3"}, - "group:foo": []string{"user2", "user3"}, - }, - }, - }, - args: args{ - group: "group:test", - }, - want: []string{"user1", "user2", "user3"}, - wantErr: false, - }, - { - name: "InexistantGroup", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1", "user2", "user3"}, - "group:foo": []string{"user2", "user3"}, - }, - }, - }, - args: args{ - group: "group:undefined", - }, - want: []string{}, - wantErr: true, - }, - { - name: "Expand emails in group strip domains", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:admin": []string{ - "joe.bar@gmail.com", - "john.doe@yahoo.fr", - }, - }, - }, - }, - args: args{ - group: "group:admin", - stripEmail: true, - }, - want: []string{"joe.bar", "john.doe"}, - wantErr: false, - }, - { - name: "Expand emails in group", - field: field{ - pol: ACLPolicy{ - Groups: Groups{ - "group:admin": []string{ - "joe.bar@gmail.com", - "john.doe@yahoo.fr", - }, - }, - }, - }, - args: args{ - group: "group:admin", - }, - want: []string{"joe.bar.gmail.com", "john.doe.yahoo.fr"}, - wantErr: false, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - viper.Set("oidc.strip_email_domain", test.args.stripEmail) - - got, err := test.field.pol.expandUsersFromGroup( - test.args.group, - ) - - if (err != nil) != test.wantErr { - t.Errorf("expandGroup() error = %v, wantErr %v", err, test.wantErr) - - return - } - - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandGroup() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func Test_expandTagOwners(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - tag string - } - tests := []struct { - name string - args args - want []string - wantErr bool - }{ - { - name: "simple tag expansion", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:test": []string{"user1"}}, - }, - tag: "tag:test", - }, - want: []string{"user1"}, - wantErr: false, - }, - { - name: "expand with tag and group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{"group:foo": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"group:foo"}}, - }, - tag: "tag:test", - }, - want: []string{"user1", "user2"}, - wantErr: false, - }, - { - name: "expand with user and group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{"group:foo": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"group:foo", "user3"}}, - }, - tag: "tag:test", - }, - want: []string{"user1", "user2", "user3"}, - wantErr: false, - }, - { - name: "invalid tag", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:foo": []string{"group:foo", "user1"}}, - }, - tag: "tag:test", - }, - want: []string{}, - wantErr: true, - }, - { - name: "invalid group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{"group:bar": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"group:foo", "user2"}}, - }, - tag: "tag:test", - }, - want: []string{}, - wantErr: true, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := expandOwnersFromTag( - test.args.aclPolicy, - test.args.tag, - ) - if (err != nil) != test.wantErr { - t.Errorf("expandTagOwners() error = %v, wantErr %v", err, test.wantErr) - - return - } - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandTagOwners() = (-want +got):\n%s", diff) - } - }) - } -} - -func Test_expandPorts(t *testing.T) { - type args struct { - portsStr string - needsWildcard bool - } - tests := []struct { - name string - args args - want *[]tailcfg.PortRange - wantErr bool - }{ - { - name: "wildcard", - args: args{portsStr: "*", needsWildcard: true}, - want: &[]tailcfg.PortRange{ - {First: portRangeBegin, Last: portRangeEnd}, - }, - wantErr: false, - }, - { - name: "needs wildcard but does not require it", - args: args{portsStr: "*", needsWildcard: false}, - want: &[]tailcfg.PortRange{ - {First: portRangeBegin, Last: portRangeEnd}, - }, - wantErr: false, - }, - { - name: "needs wildcard but gets port", - args: args{portsStr: "80,443", needsWildcard: true}, - want: nil, - wantErr: true, - }, - { - name: "two Destinations", - args: args{portsStr: "80,443", needsWildcard: false}, - want: &[]tailcfg.PortRange{ - {First: 80, Last: 80}, - {First: 443, Last: 443}, - }, - wantErr: false, - }, - { - name: "a range and a port", - args: args{portsStr: "80-1024,443", needsWildcard: false}, - want: &[]tailcfg.PortRange{ - {First: 80, Last: 1024}, - {First: 443, Last: 443}, - }, - wantErr: false, - }, - { - name: "out of bounds", - args: args{portsStr: "854038", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port", - args: args{portsStr: "85a38", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port in first", - args: args{portsStr: "a-80", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port in last", - args: args{portsStr: "80-85a38", needsWildcard: false}, - want: nil, - wantErr: true, - }, - { - name: "wrong port format", - args: args{portsStr: "80-85a38-3", needsWildcard: false}, - want: nil, - wantErr: true, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := expandPorts(test.args.portsStr, test.args.needsWildcard) - if (err != nil) != test.wantErr { - t.Errorf("expandPorts() error = %v, wantErr %v", err, test.wantErr) - - return - } - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandPorts() = (-want +got):\n%s", diff) - } - }) - } -} - -func Test_listNodesInUser(t *testing.T) { - type args struct { - nodes types.Nodes - user string - } - tests := []struct { - name string - args args - want types.Nodes - }{ - { - name: "1 node in user", - args: args{ - nodes: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{User: types.User{Name: "joe"}}, - }, - }, - { - name: "3 nodes, 2 in user", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - }, - user: "marc", - }, - want: types.Nodes{ - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - }, - }, - { - name: "5 nodes, 0 in user", - args: args{ - nodes: types.Nodes{ - &types.Node{ID: 1, User: types.User{Name: "joe"}}, - &types.Node{ID: 2, User: types.User{Name: "marc"}}, - &types.Node{ID: 3, User: types.User{Name: "marc"}}, - &types.Node{ID: 4, User: types.User{Name: "marc"}}, - &types.Node{ID: 5, User: types.User{Name: "marc"}}, - }, - user: "mickael", - }, - want: types.Nodes{}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got := filterNodesByUser(test.args.nodes, test.args.user) - - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) - } - }) - } -} - -func Test_expandAlias(t *testing.T) { - set := func(ips []string, prefixes []string) *netipx.IPSet { - var builder netipx.IPSetBuilder - - for _, ip := range ips { - builder.Add(netip.MustParseAddr(ip)) - } - - for _, pre := range prefixes { - builder.AddPrefix(netip.MustParsePrefix(pre)) - } - - s, _ := builder.IPSet() - - return s - } - - type field struct { - pol ACLPolicy - } - type args struct { - nodes types.Nodes - aclPolicy ACLPolicy - alias string - } - tests := []struct { - name string - field field - args args - want *netipx.IPSet - wantErr bool - }{ - { - name: "wildcard", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "*", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.78.84.227"), - }, - }, - }, - }, - want: set([]string{}, []string{ - "0.0.0.0/0", - "::/0", - }), - wantErr: false, - }, - { - name: "simple group", - field: field{ - pol: ACLPolicy{ - Groups: Groups{"group:accountant": []string{"joe", "marc"}}, - }, - }, - args: args{ - alias: "group:accountant", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "100.64.0.1", "100.64.0.2", "100.64.0.3", - }, []string{}), - wantErr: false, - }, - { - name: "wrong group", - field: field{ - pol: ACLPolicy{ - Groups: Groups{"group:accountant": []string{"joe", "marc"}}, - }, - }, - args: args{ - alias: "group:hr", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{}, []string{}), - wantErr: true, - }, - { - name: "simple ipaddress", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.3", - nodes: types.Nodes{}, - }, - want: set([]string{ - "10.0.0.3", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ip passed through", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.1", - nodes: types.Nodes{}, - }, - want: set([]string{ - "10.0.0.1", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ipv4 single ipv4", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.1", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("10.0.0.1"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "10.0.0.1", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ipv4 single dual stack", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.1", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("10.0.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "10.0.0.1", "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by ipv6 single dual stack", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("10.0.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{ - "fd7a:115c:a1e0:ab12:4843:2222:6273:2222", "10.0.0.1", - }, []string{}), - wantErr: false, - }, - { - name: "simple host by hostname alias", - field: field{ - pol: ACLPolicy{ - Hosts: Hosts{ - "testy": netip.MustParsePrefix("10.0.0.132/32"), - }, - }, - }, - args: args{ - alias: "testy", - nodes: types.Nodes{}, - }, - want: set([]string{}, []string{"10.0.0.132/32"}), - wantErr: false, - }, - { - name: "private network", - field: field{ - pol: ACLPolicy{ - Hosts: Hosts{ - "homeNetwork": netip.MustParsePrefix("192.168.1.0/24"), - }, - }, - }, - args: args{ - alias: "homeNetwork", - nodes: types.Nodes{}, - }, - want: set([]string{}, []string{"192.168.1.0/24"}), - wantErr: false, - }, - { - name: "simple CIDR", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "10.0.0.0/16", - nodes: types.Nodes{}, - aclPolicy: ACLPolicy{}, - }, - want: set([]string{}, []string{"10.0.0.0/16"}), - wantErr: false, - }, - { - name: "simple tag", - field: field{ - pol: ACLPolicy{ - TagOwners: TagOwners{"tag:hr-webserver": []string{"joe"}}, - }, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - }, - }, - }, - want: set([]string{ - "100.64.0.1", "100.64.0.2", - }, []string{}), - wantErr: false, - }, - { - name: "No tag defined", - field: field{ - pol: ACLPolicy{ - Groups: Groups{"group:accountant": []string{"joe", "marc"}}, - TagOwners: TagOwners{ - "tag:accountant-webserver": []string{"group:accountant"}, - }, - }, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{}, []string{}), - wantErr: true, - }, - { - name: "Forced tag defined", - field: field{ - pol: ACLPolicy{}, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - ForcedTags: []string{"tag:hr-webserver"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - ForcedTags: []string{"tag:hr-webserver"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), - wantErr: false, - }, - { - name: "Forced tag with legitimate tagOwner", - field: field{ - pol: ACLPolicy{ - TagOwners: TagOwners{ - "tag:hr-webserver": []string{"joe"}, - }, - }, - }, - args: args{ - alias: "tag:hr-webserver", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - ForcedTags: []string{"tag:hr-webserver"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: set([]string{"100.64.0.1", "100.64.0.2"}, []string{}), - wantErr: false, - }, - { - name: "list host in user without correctly tagged servers", - field: field{ - pol: ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - }, - args: args{ - alias: "joe", - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "marc"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - want: set([]string{"100.64.0.4"}, []string{}), - wantErr: false, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got, err := test.field.pol.ExpandAlias( - test.args.nodes, - test.args.alias, - ) - if (err != nil) != test.wantErr { - t.Errorf("expandAlias() error = %v, wantErr %v", err, test.wantErr) - - return - } - if diff := cmp.Diff(test.want, got); diff != "" { - t.Errorf("expandAlias() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func Test_excludeCorrectlyTaggedNodes(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - nodes types.Nodes - user string - } - tests := []struct { - name string - args args - want types.Nodes - wantErr bool - }{ - { - name: "exclude nodes with valid tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - { - name: "exclude nodes with valid tags, and owner is in a group", - args: args{ - aclPolicy: &ACLPolicy{ - Groups: Groups{ - "group:accountant": []string{"joe", "bar"}, - }, - TagOwners: TagOwners{ - "tag:accountant-webserver": []string{"group:accountant"}, - }, - }, - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - { - name: "exclude nodes with valid tags and with forced tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "foo", - RequestTags: []string{"tag:accountant-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - ForcedTags: []string{"tag:accountant-webserver"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")}, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - { - name: "all nodes have invalid tags, don't exclude them", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{"tag:accountant-webserver": []string{"joe"}}, - }, - nodes: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web1", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web2", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - user: "joe", - }, - want: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web1", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{ - OS: "centos", - Hostname: "hr-web2", - RequestTags: []string{"tag:hr-webserver"}, - }, - }, - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - }, - User: types.User{Name: "joe"}, - Hostinfo: &tailcfg.Hostinfo{}, - }, - }, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - got := excludeCorrectlyTaggedNodes( - test.args.aclPolicy, - test.args.nodes, - test.args.user, - ) - if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { - t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff) - } - }) - } -} - -func TestACLPolicy_generateFilterRules(t *testing.T) { - type field struct { - pol ACLPolicy - } - type args struct { - node *types.Node - peers types.Nodes - } - tests := []struct { - name string - field field - args args - want []tailcfg.FilterRule - wantErr bool - }{ - { - name: "no-policy", - field: field{}, - args: args{}, - want: []tailcfg.FilterRule{}, - wantErr: false, - }, - { - name: "allow-all", - field: field{ - pol: ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - }, - }, - args: args{ - node: &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - }, - }, - peers: types.Nodes{}, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "0.0.0.0/0", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - IP: "::/0", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - wantErr: false, - }, - { - name: "host1-can-reach-host2-full", - field: field{ - pol: ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"100.64.0.2"}, - Destinations: []string{"100.64.0.1:*"}, - }, - }, - }, - }, - args: args{ - node: &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - }, - User: types.User{Name: "mickael"}, - }, - peers: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - want: []tailcfg.FilterRule{ - { - SrcIPs: []string{ - "100.64.0.2/32", - "fd7a:115c:a1e0:ab12:4843:2222:6273:2222/128", - }, - DstPorts: []tailcfg.NetPortRange{ - { - IP: "100.64.0.1/32", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - { - IP: "fd7a:115c:a1e0:ab12:4843:2222:6273:2221/128", - Ports: tailcfg.PortRange{ - First: 0, - Last: 65535, - }, - }, - }, - }, - }, - wantErr: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.field.pol.generateFilterRules( - tt.args.node, - tt.args.peers, - ) - if (err != nil) != tt.wantErr { - t.Errorf("ACLgenerateFilterRules() error = %v, wantErr %v", err, tt.wantErr) - - return - } - - if diff := cmp.Diff(tt.want, got); diff != "" { - log.Trace().Interface("got", got).Msg("result") - t.Errorf("ACLgenerateFilterRules() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func TestReduceFilterRules(t *testing.T) { - tests := []struct { - name string - node *types.Node - peers types.Nodes - pol ACLPolicy - want []tailcfg.FilterRule - }{ - { - name: "host1-can-reach-host2-no-rules", - pol: ACLPolicy{ - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"100.64.0.1"}, - Destinations: []string{"100.64.0.2:*"}, - }, - }, - }, - node: &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), - }, - User: types.User{Name: "mickael"}, - }, - peers: types.Nodes{ - &types.Node{ - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - netip.MustParseAddr("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), - }, - User: types.User{Name: "mickael"}, - }, - }, - want: []tailcfg.FilterRule{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - rules, _ := tt.pol.generateFilterRules( - tt.node, - tt.peers, - ) - - got := ReduceFilterRules(tt.node, rules) - - if diff := cmp.Diff(tt.want, got); diff != "" { - log.Trace().Interface("got", got).Msg("result") - t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func Test_getTags(t *testing.T) { - type args struct { - aclPolicy *ACLPolicy - node *types.Node - } - tests := []struct { - name string - args args - wantInvalid []string - wantValid []string - }{ - { - name: "valid tag one nodes", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: types.User{ - Name: "joe", - }, - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:valid"}, - }, - }, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: nil, - }, - { - name: "invalid tag and valid tag one nodes", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: types.User{ - Name: "joe", - }, - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:valid", "tag:invalid"}, - }, - }, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "multiple invalid and identical tags, should return only one invalid tag", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: types.User{ - Name: "joe", - }, - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{ - "tag:invalid", - "tag:valid", - "tag:invalid", - }, - }, - }, - }, - wantValid: []string{"tag:valid"}, - wantInvalid: []string{"tag:invalid"}, - }, - { - name: "only invalid tags", - args: args{ - aclPolicy: &ACLPolicy{ - TagOwners: TagOwners{ - "tag:valid": []string{"joe"}, - }, - }, - node: &types.Node{ - User: types.User{ - Name: "joe", - }, - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - }, - wantValid: nil, - wantInvalid: []string{"tag:invalid", "very-invalid"}, - }, - { - name: "empty ACLPolicy should return empty tags and should not panic", - args: args{ - aclPolicy: &ACLPolicy{}, - node: &types.Node{ - User: types.User{ - Name: "joe", - }, - Hostinfo: &tailcfg.Hostinfo{ - RequestTags: []string{"tag:invalid", "very-invalid"}, - }, - }, - }, - wantValid: nil, - wantInvalid: []string{"tag:invalid", "very-invalid"}, - }, - } - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - gotValid, gotInvalid := test.args.aclPolicy.TagsOfNode( - test.args.node, - ) - for _, valid := range gotValid { - if !util.StringOrPrefixListContains(test.wantValid, valid) { - t.Errorf( - "valids: getTags() = %v, want %v", - gotValid, - test.wantValid, - ) - - break - } - } - for _, invalid := range gotInvalid { - if !util.StringOrPrefixListContains(test.wantInvalid, invalid) { - t.Errorf( - "invalids: getTags() = %v, want %v", - gotInvalid, - test.wantInvalid, - ) - - break - } - } - }) - } -} - -func Test_getFilteredByACLPeers(t *testing.T) { - type args struct { - nodes types.Nodes - rules []tailcfg.FilterRule - node *types.Node - } - tests := []struct { - name string - args args - want types.Nodes - }{ - { - name: "all hosts can talk to each other", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "*"}, - }, - }, - }, - node: &types.Node{ // current nodes - ID: 1, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - User: types.User{Name: "joe"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")}, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.3")}, - User: types.User{Name: "mickael"}, - }, - }, - }, - { - name: "One host can talk to another, but not all hosts", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - node: &types.Node{ // current nodes - ID: 1, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - User: types.User{Name: "joe"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")}, - User: types.User{Name: "marc"}, - }, - }, - }, - { - name: "host cannot directly talk to destination, but return path is authorized", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"100.64.0.3"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - node: &types.Node{ // current nodes - ID: 2, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")}, - User: types.User{Name: "marc"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.3")}, - User: types.User{Name: "mickael"}, - }, - }, - }, - { - name: "rules allows all hosts to reach one destination", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - node: &types.Node{ // current nodes - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - }, - }, - { - name: "rules allows all hosts to reach one destination, destination can reach all hosts", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.2"}, - }, - }, - }, - node: &types.Node{ // current nodes - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - }, - { - name: "rule allows all hosts to reach all destinations", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - SrcIPs: []string{"*"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "*"}, - }, - }, - }, - node: &types.Node{ // current nodes - ID: 2, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")}, - User: types.User{Name: "marc"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.3")}, - User: types.User{Name: "mickael"}, - }, - }, - }, - { - name: "without rule all communications are forbidden", - args: args{ - nodes: types.Nodes{ // list of all nodess in the database - &types.Node{ - ID: 1, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - }, - User: types.User{Name: "joe"}, - }, - &types.Node{ - ID: 2, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - }, - User: types.User{Name: "marc"}, - }, - &types.Node{ - ID: 3, - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - }, - User: types.User{Name: "mickael"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - }, - node: &types.Node{ // current nodes - ID: 2, - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")}, - User: types.User{Name: "marc"}, - }, - }, - want: types.Nodes{}, - }, - { - // Investigating 699 - // Found some nodes: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] nodes=ts-head-8w6paa - // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] - // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} - name: "issue-699-broken-star", - args: args{ - nodes: types.Nodes{ // - &types.Node{ - ID: 1, - Hostname: "ts-head-upcrmb", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - netip.MustParseAddr("fd7a:115c:a1e0::3"), - }, - User: types.User{Name: "user1"}, - }, - &types.Node{ - ID: 2, - Hostname: "ts-unstable-rlwpvr", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - netip.MustParseAddr("fd7a:115c:a1e0::4"), - }, - User: types.User{Name: "user1"}, - }, - &types.Node{ - ID: 3, - Hostname: "ts-head-8w6paa", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0::1"), - }, - User: types.User{Name: "user2"}, - }, - &types.Node{ - ID: 4, - Hostname: "ts-unstable-lys2ib", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.2"), - netip.MustParseAddr("fd7a:115c:a1e0::2"), - }, - User: types.User{Name: "user2"}, - }, - }, - rules: []tailcfg.FilterRule{ // list of all ACLRules registered - { - DstPorts: []tailcfg.NetPortRange{ - { - IP: "*", - Ports: tailcfg.PortRange{First: 0, Last: 65535}, - }, - }, - SrcIPs: []string{ - "fd7a:115c:a1e0::3", "100.64.0.3", - "fd7a:115c:a1e0::4", "100.64.0.4", - }, - }, - }, - node: &types.Node{ // current nodes - ID: 3, - Hostname: "ts-head-8w6paa", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.1"), - netip.MustParseAddr("fd7a:115c:a1e0::1"), - }, - User: types.User{Name: "user2"}, - }, - }, - want: types.Nodes{ - &types.Node{ - ID: 1, - Hostname: "ts-head-upcrmb", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.3"), - netip.MustParseAddr("fd7a:115c:a1e0::3"), - }, - User: types.User{Name: "user1"}, - }, - &types.Node{ - ID: 2, - Hostname: "ts-unstable-rlwpvr", - IPAddresses: types.NodeAddresses{ - netip.MustParseAddr("100.64.0.4"), - netip.MustParseAddr("fd7a:115c:a1e0::4"), - }, - User: types.User{Name: "user1"}, - }, - }, - }, - { - name: "failing-edge-case-during-p3-refactor", - args: args{ - nodes: []*types.Node{ - { - ID: 1, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - Hostname: "peer1", - User: types.User{Name: "mini"}, - }, - { - ID: 2, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - Hostname: "peer2", - User: types.User{Name: "peer2"}, - }, - }, - rules: []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, - {IP: "::/0", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - node: &types.Node{ - ID: 0, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - Hostname: "mini", - User: types.User{Name: "mini"}, - }, - }, - want: []*types.Node{ - { - ID: 2, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - Hostname: "peer2", - User: types.User{Name: "peer2"}, - }, - }, - }, - { - name: "p4-host-in-netmap-user2-dest-bug", - args: args{ - nodes: []*types.Node{ - { - ID: 1, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - Hostname: "user1-2", - User: types.User{Name: "user1"}, - }, - { - ID: 0, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - Hostname: "user1-1", - User: types.User{Name: "user1"}, - }, - { - ID: 3, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - Hostname: "user2-2", - User: types.User{Name: "user2"}, - }, - }, - rules: []tailcfg.FilterRule{ - { - SrcIPs: []string{ - "100.64.0.3/32", - "100.64.0.4/32", - "fd7a:115c:a1e0::3/128", - "fd7a:115c:a1e0::4/128", - }, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, - {IP: "100.64.0.4/32", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::3/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::4/128", Ports: tailcfg.PortRangeAny}, - }, - }, - { - SrcIPs: []string{ - "100.64.0.1/32", - "100.64.0.2/32", - "fd7a:115c:a1e0::1/128", - "fd7a:115c:a1e0::2/128", - }, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, - {IP: "100.64.0.4/32", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::3/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::4/128", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - node: &types.Node{ - ID: 2, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - Hostname: "user-2-1", - User: types.User{Name: "user2"}, - }, - }, - want: []*types.Node{ - { - ID: 1, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - Hostname: "user1-2", - User: types.User{Name: "user1"}, - }, - { - ID: 0, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - Hostname: "user1-1", - User: types.User{Name: "user1"}, - }, - { - ID: 3, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - Hostname: "user2-2", - User: types.User{Name: "user2"}, - }, - }, - }, - { - name: "p4-host-in-netmap-user1-dest-bug", - args: args{ - nodes: []*types.Node{ - { - ID: 1, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - Hostname: "user1-2", - User: types.User{Name: "user1"}, - }, - { - ID: 2, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - Hostname: "user-2-1", - User: types.User{Name: "user2"}, - }, - { - ID: 3, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - Hostname: "user2-2", - User: types.User{Name: "user2"}, - }, - }, - rules: []tailcfg.FilterRule{ - { - SrcIPs: []string{ - "100.64.0.1/32", - "100.64.0.2/32", - "fd7a:115c:a1e0::1/128", - "fd7a:115c:a1e0::2/128", - }, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}, - {IP: "100.64.0.2/32", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, - }, - }, - { - SrcIPs: []string{ - "100.64.0.1/32", - "100.64.0.2/32", - "fd7a:115c:a1e0::1/128", - "fd7a:115c:a1e0::2/128", - }, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, - {IP: "100.64.0.4/32", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::3/128", Ports: tailcfg.PortRangeAny}, - {IP: "fd7a:115c:a1e0::4/128", Ports: tailcfg.PortRangeAny}, - }, - }, - }, - node: &types.Node{ - ID: 0, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, - Hostname: "user1-1", - User: types.User{Name: "user1"}, - }, - }, - want: []*types.Node{ - { - ID: 1, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, - Hostname: "user1-2", - User: types.User{Name: "user1"}, - }, - { - ID: 2, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, - Hostname: "user-2-1", - User: types.User{Name: "user2"}, - }, - { - ID: 3, - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.4")}, - Hostname: "user2-2", - User: types.User{Name: "user2"}, - }, - }, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := FilterNodesByACL( - tt.args.node, - tt.args.nodes, - tt.args.rules, - ) - if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { - t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func TestSSHRules(t *testing.T) { - tests := []struct { - name string - node types.Node - peers types.Nodes - pol ACLPolicy - want []*tailcfg.SSHRule - }{ - { - name: "peers-can-connect", - node: types.Node{ - Hostname: "testnodes", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.99.42")}, - UserID: 0, - User: types.User{ - Name: "user1", - }, - }, - peers: types.Nodes{ - &types.Node{ - Hostname: "testnodes2", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: 0, - User: types.User{ - Name: "user1", - }, - }, - }, - pol: ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1"}, - }, - Hosts: Hosts{ - "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - SSHs: []SSH{ - { - Action: "accept", - Sources: []string{"group:test"}, - Destinations: []string{"client"}, - Users: []string{"autogroup:nonroot"}, - }, - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"client"}, - Users: []string{"autogroup:nonroot"}, - }, - { - Action: "accept", - Sources: []string{"group:test"}, - Destinations: []string{"100.64.99.42"}, - Users: []string{"autogroup:nonroot"}, - }, - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"100.64.99.42"}, - Users: []string{"autogroup:nonroot"}, - }, - }, - }, - want: []*tailcfg.SSHRule{ - { - Principals: []*tailcfg.SSHPrincipal{ - { - UserLogin: "user1", - }, - }, - SSHUsers: map[string]string{ - "autogroup:nonroot": "=", - }, - Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true}, - }, - { - SSHUsers: map[string]string{ - "autogroup:nonroot": "=", - }, - Principals: []*tailcfg.SSHPrincipal{ - { - Any: true, - }, - }, - Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true}, - }, - { - Principals: []*tailcfg.SSHPrincipal{ - { - UserLogin: "user1", - }, - }, - SSHUsers: map[string]string{ - "autogroup:nonroot": "=", - }, - Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true}, - }, - { - SSHUsers: map[string]string{ - "autogroup:nonroot": "=", - }, - Principals: []*tailcfg.SSHPrincipal{ - { - Any: true, - }, - }, - Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true}, - }, - }, - }, - { - name: "peers-cannot-connect", - node: types.Node{ - Hostname: "testnodes", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: 0, - User: types.User{ - Name: "user1", - }, - }, - peers: types.Nodes{ - &types.Node{ - Hostname: "testnodes2", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.99.42")}, - UserID: 0, - User: types.User{ - Name: "user1", - }, - }, - }, - pol: ACLPolicy{ - Groups: Groups{ - "group:test": []string{"user1"}, - }, - Hosts: Hosts{ - "client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32), - }, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, - }, - }, - SSHs: []SSH{ - { - Action: "accept", - Sources: []string{"group:test"}, - Destinations: []string{"100.64.99.42"}, - Users: []string{"autogroup:nonroot"}, - }, - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"100.64.99.42"}, - Users: []string{"autogroup:nonroot"}, - }, - }, - }, - want: []*tailcfg.SSHRule{}, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got, err := tt.pol.generateSSHRules(&tt.node, tt.peers) - assert.NoError(t, err) - - if diff := cmp.Diff(tt.want, got); diff != "" { - t.Errorf("TestSSHRules() unexpected result (-want +got):\n%s", diff) - } - }) - } -} - -func TestParseDestination(t *testing.T) { - tests := []struct { - dest string - wantAlias string - wantPort string - }{ - { - dest: "git-server:*", - wantAlias: "git-server", - wantPort: "*", - }, - { - dest: "192.168.1.0/24:22", - wantAlias: "192.168.1.0/24", - wantPort: "22", - }, - { - dest: "192.168.1.1:22", - wantAlias: "192.168.1.1", - wantPort: "22", - }, - { - dest: "fd7a:115c:a1e0::2:22", - wantAlias: "fd7a:115c:a1e0::2", - wantPort: "22", - }, - { - dest: "fd7a:115c:a1e0::2/128:22", - wantAlias: "fd7a:115c:a1e0::2/128", - wantPort: "22", - }, - { - dest: "tag:montreal-webserver:80,443", - wantAlias: "tag:montreal-webserver", - wantPort: "80,443", - }, - { - dest: "tag:api-server:443", - wantAlias: "tag:api-server", - wantPort: "443", - }, - { - dest: "example-host-1:*", - wantAlias: "example-host-1", - wantPort: "*", - }, - } - - for _, tt := range tests { - t.Run(tt.dest, func(t *testing.T) { - alias, port, _ := parseDestination(tt.dest) - - if alias != tt.wantAlias { - t.Errorf("unexpected alias: want(%s) != got(%s)", tt.wantAlias, alias) - } - - if port != tt.wantPort { - t.Errorf("unexpected port: want(%s) != got(%s)", tt.wantPort, port) - } - }) - } -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Sources section. -func TestValidExpandTagOwnersInSources(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testnodes", - RequestTags: []string{"tag:test"}, - } - - node := &types.Node{ - ID: 0, - Hostname: "testnodes", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: 0, - User: types.User{ - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - pol := &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"tag:test"}, - Destinations: []string{"*:*"}, - }, - }, - } - - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) - assert.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}}, - {IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff) - } -} - -// need a test with: -// tag on a host that isn't owned by a tag owners. So the user -// of the host should be valid. -func TestInvalidTagValidUser(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testnodes", - RequestTags: []string{"tag:foo"}, - } - - node := &types.Node{ - ID: 1, - Hostname: "testnodes", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: 1, - User: types.User{ - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - pol := &ACLPolicy{ - TagOwners: TagOwners{"tag:test": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"*:*"}, - }, - }, - } - - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) - assert.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.1/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}}, - {IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff) - } -} - -// this test should validate that we can expand a group in a TagOWner section and -// match properly the IP's of the related hosts. The owner is valid and the tag is also valid. -// the tag is matched in the Destinations section. -func TestValidExpandTagOwnersInDestinations(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "testnodes", - RequestTags: []string{"tag:test"}, - } - - node := &types.Node{ - ID: 1, - Hostname: "testnodes", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: 1, - User: types.User{ - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - pol := &ACLPolicy{ - Groups: Groups{"group:test": []string{"user1", "user2"}}, - TagOwners: TagOwners{"tag:test": []string{"user3", "group:test"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"tag:test:*"}, - }, - }, - } - - // rules, _, err := GenerateFilterRules(pol, &node, peers, false) - // c.Assert(err, check.IsNil) - // - // c.Assert(rules, check.HasLen, 1) - // c.Assert(rules[0].DstPorts, check.HasLen, 1) - // c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32") - - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{}) - assert.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"0.0.0.0/0", "::/0"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf( - "TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", - diff, - ) - } -} - -// tag on a host is owned by a tag owner, the tag is valid. -// an ACL rule is matching the tag to a user. It should not be valid since the -// host should be tied to the tag now. -func TestValidTagInvalidUser(t *testing.T) { - hostInfo := tailcfg.Hostinfo{ - OS: "centos", - Hostname: "webserver", - RequestTags: []string{"tag:webapp"}, - } - - node := &types.Node{ - ID: 1, - Hostname: "webserver", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")}, - UserID: 1, - User: types.User{ - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo, - } - - hostInfo2 := tailcfg.Hostinfo{ - OS: "debian", - Hostname: "Hostname", - } - - nodes2 := &types.Node{ - ID: 2, - Hostname: "user", - IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.2")}, - UserID: 1, - User: types.User{ - Name: "user1", - }, - RegisterMethod: util.RegisterMethodAuthKey, - Hostinfo: &hostInfo2, - } - - pol := &ACLPolicy{ - TagOwners: TagOwners{"tag:webapp": []string{"user1"}}, - ACLs: []ACL{ - { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"tag:webapp:80,443"}, - }, - }, - } - - got, _, err := GenerateFilterAndSSHRules(pol, node, types.Nodes{nodes2}) - assert.NoError(t, err) - - want := []tailcfg.FilterRule{ - { - SrcIPs: []string{"100.64.0.2/32"}, - DstPorts: []tailcfg.NetPortRange{ - {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}}, - {IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}}, - }, - }, - } - - if diff := cmp.Diff(want, got); diff != "" { - t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff) - } -} diff --git a/hscontrol/policy/acls_types.go b/hscontrol/policy/acls_types.go deleted file mode 100644 index e9c44909..00000000 --- a/hscontrol/policy/acls_types.go +++ /dev/null @@ -1,145 +0,0 @@ -package policy - -import ( - "encoding/json" - "net/netip" - "strings" - - "github.com/tailscale/hujson" - "gopkg.in/yaml.v3" -) - -// ACLPolicy represents a Tailscale ACL Policy. -type ACLPolicy struct { - Groups Groups `json:"groups" yaml:"groups"` - Hosts Hosts `json:"hosts" yaml:"hosts"` - TagOwners TagOwners `json:"tagOwners" yaml:"tagOwners"` - ACLs []ACL `json:"acls" yaml:"acls"` - Tests []ACLTest `json:"tests" yaml:"tests"` - AutoApprovers AutoApprovers `json:"autoApprovers" yaml:"autoApprovers"` - SSHs []SSH `json:"ssh" yaml:"ssh"` -} - -// ACL is a basic rule for the ACL Policy. -type ACL struct { - Action string `json:"action" yaml:"action"` - Protocol string `json:"proto" yaml:"proto"` - Sources []string `json:"src" yaml:"src"` - Destinations []string `json:"dst" yaml:"dst"` -} - -// Groups references a series of alias in the ACL rules. -type Groups map[string][]string - -// Hosts are alias for IP addresses or subnets. -type Hosts map[string]netip.Prefix - -// TagOwners specify what users (users?) are allow to use certain tags. -type TagOwners map[string][]string - -// ACLTest is not implemented, but should be use to check if a certain rule is allowed. -type ACLTest struct { - Source string `json:"src" yaml:"src"` - Accept []string `json:"accept" yaml:"accept"` - Deny []string `json:"deny,omitempty" yaml:"deny,omitempty"` -} - -// AutoApprovers specify which users (users?), groups or tags have their advertised routes -// or exit node status automatically enabled. -type AutoApprovers struct { - Routes map[string][]string `json:"routes" yaml:"routes"` - ExitNode []string `json:"exitNode" yaml:"exitNode"` -} - -// SSH controls who can ssh into which machines. -type SSH struct { - Action string `json:"action" yaml:"action"` - Sources []string `json:"src" yaml:"src"` - Destinations []string `json:"dst" yaml:"dst"` - Users []string `json:"users" yaml:"users"` - CheckPeriod string `json:"checkPeriod,omitempty" yaml:"checkPeriod,omitempty"` -} - -// UnmarshalJSON allows to parse the Hosts directly into netip objects. -func (hosts *Hosts) UnmarshalJSON(data []byte) error { - newHosts := Hosts{} - hostIPPrefixMap := make(map[string]string) - ast, err := hujson.Parse(data) - if err != nil { - return err - } - ast.Standardize() - data = ast.Pack() - err = json.Unmarshal(data, &hostIPPrefixMap) - if err != nil { - return err - } - for host, prefixStr := range hostIPPrefixMap { - if !strings.Contains(prefixStr, "/") { - prefixStr += "/32" - } - prefix, err := netip.ParsePrefix(prefixStr) - if err != nil { - return err - } - newHosts[host] = prefix - } - *hosts = newHosts - - return nil -} - -// UnmarshalYAML allows to parse the Hosts directly into netip objects. -func (hosts *Hosts) UnmarshalYAML(data []byte) error { - newHosts := Hosts{} - hostIPPrefixMap := make(map[string]string) - - err := yaml.Unmarshal(data, &hostIPPrefixMap) - if err != nil { - return err - } - for host, prefixStr := range hostIPPrefixMap { - prefix, err := netip.ParsePrefix(prefixStr) - if err != nil { - return err - } - newHosts[host] = prefix - } - *hosts = newHosts - - return nil -} - -// IsZero is perhaps a bit naive here. -func (pol ACLPolicy) IsZero() bool { - if len(pol.Groups) == 0 && len(pol.Hosts) == 0 && len(pol.ACLs) == 0 { - return true - } - - return false -} - -// Returns the list of autoApproving users, groups or tags for a given IPPrefix. -func (autoApprovers *AutoApprovers) GetRouteApprovers( - prefix netip.Prefix, -) ([]string, error) { - if prefix.Bits() == 0 { - return autoApprovers.ExitNode, nil // 0.0.0.0/0, ::/0 or equivalent - } - - approverAliases := []string{} - - for autoApprovedPrefix, autoApproverAliases := range autoApprovers.Routes { - autoApprovedPrefix, err := netip.ParsePrefix(autoApprovedPrefix) - if err != nil { - return nil, err - } - - if prefix.Bits() >= autoApprovedPrefix.Bits() && - autoApprovedPrefix.Contains(prefix.Masked().Addr()) { - approverAliases = append(approverAliases, autoApproverAliases...) - } - } - - return approverAliases, nil -} diff --git a/hscontrol/policy/matcher/matcher.go b/hscontrol/policy/matcher/matcher.go index 1905dad2..afc3cf68 100644 --- a/hscontrol/policy/matcher/matcher.go +++ b/hscontrol/policy/matcher/matcher.go @@ -2,15 +2,43 @@ package matcher import ( "net/netip" + "slices" + "strings" "github.com/juanfont/headscale/hscontrol/util" "go4.org/netipx" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" ) type Match struct { - Srcs *netipx.IPSet - Dests *netipx.IPSet + srcs *netipx.IPSet + dests *netipx.IPSet +} + +func (m Match) DebugString() string { + var sb strings.Builder + + sb.WriteString("Match:\n") + sb.WriteString(" Sources:\n") + for _, prefix := range m.srcs.Prefixes() { + sb.WriteString(" " + prefix.String() + "\n") + } + sb.WriteString(" Destinations:\n") + for _, prefix := range m.dests.Prefixes() { + sb.WriteString(" " + prefix.String() + "\n") + } + + return sb.String() +} + +func MatchesFromFilterRules(rules []tailcfg.FilterRule) []Match { + matches := make([]Match, 0, len(rules)) + for _, rule := range rules { + matches = append(matches, MatchFromFilterRule(rule)) + } + + return matches } func MatchFromFilterRule(rule tailcfg.FilterRule) Match { @@ -42,29 +70,34 @@ func MatchFromStrings(sources, destinations []string) Match { destsSet, _ := dests.IPSet() match := Match{ - Srcs: srcsSet, - Dests: destsSet, + srcs: srcsSet, + dests: destsSet, } return match } -func (m *Match) SrcsContainsIPs(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Srcs.Contains(ip) { - return true - } - } - - return false +func (m *Match) SrcsContainsIPs(ips ...netip.Addr) bool { + return slices.ContainsFunc(ips, m.srcs.Contains) } -func (m *Match) DestsContainsIP(ips []netip.Addr) bool { - for _, ip := range ips { - if m.Dests.Contains(ip) { - return true - } - } - - return false +func (m *Match) DestsContainsIP(ips ...netip.Addr) bool { + return slices.ContainsFunc(ips, m.dests.Contains) +} + +func (m *Match) SrcsOverlapsPrefixes(prefixes ...netip.Prefix) bool { + return slices.ContainsFunc(prefixes, m.srcs.OverlapsPrefix) +} + +func (m *Match) DestsOverlapsPrefixes(prefixes ...netip.Prefix) bool { + return slices.ContainsFunc(prefixes, m.dests.OverlapsPrefix) +} + +// DestsIsTheInternet reports if the destination is equal to "the internet" +// which is a IPSet that represents "autogroup:internet" and is special +// cased for exit nodes. +func (m Match) DestsIsTheInternet() bool { + return m.dests.Equal(util.TheInternet()) || + m.dests.ContainsPrefix(tsaddr.AllIPv4()) || + m.dests.ContainsPrefix(tsaddr.AllIPv6()) } diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go new file mode 100644 index 00000000..f4db88a4 --- /dev/null +++ b/hscontrol/policy/pm.go @@ -0,0 +1,76 @@ +package policy + +import ( + "net/netip" + + "github.com/juanfont/headscale/hscontrol/policy/matcher" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" + "tailscale.com/types/views" +) + +type PolicyManager interface { + // Filter returns the current filter rules for the entire tailnet and the associated matchers. + Filter() ([]tailcfg.FilterRule, []matcher.Match) + // FilterForNode returns filter rules for a specific node, handling autogroup:self + FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) + // MatchersForNode returns matchers for peer relationship determination (unreduced) + MatchersForNode(node types.NodeView) ([]matcher.Match, error) + // BuildPeerMap constructs peer relationship maps for the given nodes + BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView + SSHPolicy(types.NodeView) (*tailcfg.SSHPolicy, error) + SetPolicy([]byte) (bool, error) + SetUsers(users []types.User) (bool, error) + SetNodes(nodes views.Slice[types.NodeView]) (bool, error) + // NodeCanHaveTag reports whether the given node can have the given tag. + NodeCanHaveTag(types.NodeView, string) bool + + // TagExists reports whether the given tag is defined in the policy. + TagExists(tag string) bool + + // NodeCanApproveRoute reports whether the given node can approve the given route. + NodeCanApproveRoute(types.NodeView, netip.Prefix) bool + + Version() int + DebugString() string +} + +// NewPolicyManager returns a new policy manager. +func NewPolicyManager(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) (PolicyManager, error) { + var polMan PolicyManager + var err error + polMan, err = policyv2.NewPolicyManager(pol, users, nodes) + if err != nil { + return nil, err + } + + return polMan, err +} + +// PolicyManagersForTest returns all available PostureManagers to be used +// in tests to validate them in tests that try to determine that they +// behave the same. +func PolicyManagersForTest(pol []byte, users []types.User, nodes views.Slice[types.NodeView]) ([]PolicyManager, error) { + var polMans []PolicyManager + + for _, pmf := range PolicyManagerFuncsForTest(pol) { + pm, err := pmf(users, nodes) + if err != nil { + return nil, err + } + polMans = append(polMans, pm) + } + + return polMans, nil +} + +func PolicyManagerFuncsForTest(pol []byte) []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) { + var polmanFuncs []func([]types.User, views.Slice[types.NodeView]) (PolicyManager, error) + + polmanFuncs = append(polmanFuncs, func(u []types.User, n views.Slice[types.NodeView]) (PolicyManager, error) { + return policyv2.NewPolicyManager(pol, u, n) + }) + + return polmanFuncs +} diff --git a/hscontrol/policy/policy.go b/hscontrol/policy/policy.go new file mode 100644 index 00000000..677cb854 --- /dev/null +++ b/hscontrol/policy/policy.go @@ -0,0 +1,151 @@ +package policy + +import ( + "net/netip" + "slices" + + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "github.com/samber/lo" + "tailscale.com/net/tsaddr" + "tailscale.com/types/views" +) + +// ReduceNodes returns the list of peers authorized to be accessed from a given node. +func ReduceNodes( + node types.NodeView, + nodes views.Slice[types.NodeView], + matchers []matcher.Match, +) views.Slice[types.NodeView] { + var result []types.NodeView + + for _, peer := range nodes.All() { + if peer.ID() == node.ID() { + continue + } + + if node.CanAccess(matchers, peer) || peer.CanAccess(matchers, node) { + result = append(result, peer) + } + } + + return views.SliceOf(result) +} + +// ReduceRoutes returns a reduced list of routes for a given node that it can access. +func ReduceRoutes( + node types.NodeView, + routes []netip.Prefix, + matchers []matcher.Match, +) []netip.Prefix { + var result []netip.Prefix + + for _, route := range routes { + if node.CanAccessRoute(matchers, route) { + result = append(result, route) + } + } + + return result +} + +// BuildPeerMap builds a map of all peers that can be accessed by each node. +func BuildPeerMap( + nodes views.Slice[types.NodeView], + matchers []matcher.Match, +) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, nodes.Len()) + + // Build the map of all peers according to the matchers. + // Compared to ReduceNodes, which builds the list per node, we end up with doing + // the full work for every node (On^2), while this will reduce the list as we see + // relationships while building the map, making it O(n^2/2) in the end, but with less work per node. + for i := range nodes.Len() { + for j := i + 1; j < nodes.Len(); j++ { + if nodes.At(i).ID() == nodes.At(j).ID() { + continue + } + + if nodes.At(i).CanAccess(matchers, nodes.At(j)) || nodes.At(j).CanAccess(matchers, nodes.At(i)) { + ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j)) + ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i)) + } + } + } + + return ret +} + +// ApproveRoutesWithPolicy checks if the node can approve the announced routes +// and returns the new list of approved routes. +// The approved routes will include: +// 1. ALL previously approved routes (regardless of whether they're still advertised) +// 2. New routes from announcedRoutes that can be auto-approved by policy +// This ensures that: +// - Previously approved routes are ALWAYS preserved (auto-approval never removes routes) +// - New routes can be auto-approved according to policy +// - Routes can only be removed by explicit admin action (not by auto-approval). +func ApproveRoutesWithPolicy(pm PolicyManager, nv types.NodeView, currentApproved, announcedRoutes []netip.Prefix) ([]netip.Prefix, bool) { + if pm == nil { + return currentApproved, false + } + + // Start with ALL currently approved routes - we never remove approved routes + newApproved := make([]netip.Prefix, len(currentApproved)) + copy(newApproved, currentApproved) + + // Then, check for new routes that can be auto-approved + for _, route := range announcedRoutes { + // Skip if already approved + if slices.Contains(newApproved, route) { + continue + } + + // Check if this new route can be auto-approved by policy + canApprove := pm.NodeCanApproveRoute(nv, route) + if canApprove { + newApproved = append(newApproved, route) + } + } + + // Sort and deduplicate + tsaddr.SortPrefixes(newApproved) + newApproved = slices.Compact(newApproved) + newApproved = lo.Filter(newApproved, func(route netip.Prefix, index int) bool { + return route.IsValid() + }) + + // Sort the current approved for comparison + sortedCurrent := make([]netip.Prefix, len(currentApproved)) + copy(sortedCurrent, currentApproved) + tsaddr.SortPrefixes(sortedCurrent) + + // Only update if the routes actually changed + if !slices.Equal(sortedCurrent, newApproved) { + // Log what changed + var added, kept []netip.Prefix + for _, route := range newApproved { + if !slices.Contains(sortedCurrent, route) { + added = append(added, route) + } else { + kept = append(kept, route) + } + } + + if len(added) > 0 { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.added", util.PrefixesToString(added)). + Strs("routes.kept", util.PrefixesToString(kept)). + Int("routes.total", len(newApproved)). + Msg("Routes auto-approved by policy") + } + + return newApproved, true + } + + return newApproved, false +} diff --git a/hscontrol/policy/policy_autoapprove_test.go b/hscontrol/policy/policy_autoapprove_test.go new file mode 100644 index 00000000..61c69067 --- /dev/null +++ b/hscontrol/policy/policy_autoapprove_test.go @@ -0,0 +1,339 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/types/key" + "tailscale.com/types/ptr" + "tailscale.com/types/views" +) + +func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) { + user1 := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser@", + } + user2 := types.User{ + Model: gorm.Model{ID: 2}, + Name: "otheruser@", + } + users := []types.User{user1, user2} + + node1 := &types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "test-node", + UserID: ptr.To(user1.ID), + User: ptr.To(user1), + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + Tags: []string{"tag:test"}, + } + + node2 := &types.Node{ + ID: 2, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "other-node", + UserID: ptr.To(user2.ID), + User: ptr.To(user2), + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")), + } + + // Create a policy that auto-approves specific routes + policyJSON := `{ + "groups": { + "group:test": ["testuser@"] + }, + "tagOwners": { + "tag:test": ["testuser@"] + }, + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:*"] + } + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["testuser@", "tag:test"], + "10.1.0.0/24": ["testuser@"], + "10.2.0.0/24": ["testuser@"], + "192.168.0.0/24": ["tag:test"] + } + } + }` + + pm, err := policyv2.NewPolicyManager([]byte(policyJSON), users, views.SliceOf([]types.NodeView{node1.View(), node2.View()})) + assert.NoError(t, err) + + tests := []struct { + name string + node *types.Node + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + description string + }{ + { + name: "previously_approved_route_no_longer_advertised_should_remain", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only this one is still advertised + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Should still be here! + }, + wantChanged: false, + description: "Previously approved routes should never be removed even when no longer advertised", + }, + { + name: "add_new_auto_approved_route_keeps_old_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.5.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New route that should be auto-approved + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), // New auto-approved route (subset of 10.0.0.0/8) + netip.MustParsePrefix("10.5.0.0/24"), // Old approved route kept + }, + wantChanged: true, + description: "New auto-approved routes should be added while keeping old approved routes", + }, + { + name: "no_announced_routes_keeps_all_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + description: "All approved routes should remain when no routes are announced", + }, + { + name: "no_changes_when_announced_equals_approved", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + description: "No changes should occur when announced routes match approved routes", + }, + { + name: "auto_approve_multiple_new_routes", + node: node1, + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/24"), // This was manually approved + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // Should be auto-approved (subset of 10.0.0.0/8) + netip.MustParsePrefix("192.168.0.0/24"), // Should be auto-approved for tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.2.0.0/24"), // New auto-approved + netip.MustParsePrefix("172.16.0.0/24"), // Original kept + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + }, + wantChanged: true, + description: "Multiple new routes should be auto-approved while keeping existing approved routes", + }, + { + name: "node_without_permission_no_auto_approval", + node: node2, // Different node without the tag + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // This requires tag:test + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Only the original approved route + }, + wantChanged: false, + description: "Routes should not be auto-approved for nodes without proper permissions", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, tt.node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch: %s", tt.description) + + // Sort for comparison since ApproveRoutesWithPolicy sorts the results + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch: %s", tt.description) + + // Verify that all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should never happen", prevRoute) + } + }) + } +} + +func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) { + // Create a basic policy for edge case testing + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.1.0.0/24": ["test@"], + }, + }, +}` + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_policy_manager", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "nil_announced_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: nil, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: false, + }, + { + name: "duplicate_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.1.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_slices", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{}, + wantApproved: []netip.Prefix{}, + wantChanged: false, + }, + } + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: ptr.To(user.ID), + User: ptr.To(user), + RegisterMethod: util.RegisterMethodAuthKey, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + // Create policy manager or use nil if specified + var pm PolicyManager + var err error + if tt.name != "nil_policy_manager" { + pm, err = pmf(users, nodes.ViewSlice()) + assert.NoError(t, err) + } else { + pm = nil + } + + gotApproved, gotChanged := ApproveRoutesWithPolicy(pm, node.View(), tt.currentApproved, tt.announcedRoutes) + + assert.Equal(t, tt.wantChanged, gotChanged, "changed flag mismatch") + + // Handle nil vs empty slice comparison + if tt.wantApproved == nil { + assert.Nil(t, gotApproved, "expected nil approved routes") + } else { + tsaddr.SortPrefixes(tt.wantApproved) + assert.Equal(t, tt.wantApproved, gotApproved, "approved routes mismatch") + } + }) + } + } +} diff --git a/hscontrol/policy/policy_route_approval_test.go b/hscontrol/policy/policy_route_approval_test.go new file mode 100644 index 00000000..70aa6a21 --- /dev/null +++ b/hscontrol/policy/policy_route_approval_test.go @@ -0,0 +1,361 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) { + // Test policy that allows specific routes to be auto-approved + aclPolicy := ` +{ + "groups": { + "group:admins": ["test@"], + }, + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/24": ["test@"], + "192.168.0.0/24": ["group:admins"], + "172.16.0.0/16": ["tag:approved"], + }, + }, + "tagOwners": { + "tag:approved": ["test@"], + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + nodeHostname string + nodeUser string + nodeTags []string + wantApproved []netip.Prefix + wantChanged bool + wantRemovedRoutes []netip.Prefix // Routes that should NOT be in the result + }{ + { + name: "previously_approved_route_no_longer_advertised_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // Only this one still advertised + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Should remain! + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + wantRemovedRoutes: []netip.Prefix{}, // Nothing should be removed + }, + { + name: "add_new_auto_approved_route_keeps_existing", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Still advertised + netip.MustParsePrefix("192.168.0.0/24"), // New route + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), // Auto-approved via group + }, + wantChanged: true, + }, + { + name: "no_announced_routes_keeps_all_approved", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + announcedRoutes: []netip.Prefix{}, // No routes announced anymore + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.0.0/24"), + }, + wantChanged: false, + }, + { + name: "manually_approved_route_not_in_policy_remains", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("203.0.113.0/24"), // Not in auto-approvers + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Can be auto-approved + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Manual approval preserved + }, + wantChanged: true, + }, + { + name: "tagged_node_gets_tag_approved_routes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // Tag-approved route + }, + nodeUser: "test", + nodeTags: []string{"tag:approved"}, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("172.16.0.0/16"), // New tag-approved + netip.MustParsePrefix("10.0.0.0/24"), // Previous approval preserved + }, + wantChanged: true, + }, + { + name: "complex_scenario_multiple_changes", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Will not be advertised + netip.MustParsePrefix("203.0.113.0/24"), // Manual, not advertised + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), // New, auto-approvable + netip.MustParsePrefix("172.16.0.0/16"), // New, not approvable (no tag) + netip.MustParsePrefix("198.51.100.0/24"), // New, not in policy + }, + nodeUser: "test", + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), // Kept despite not advertised + netip.MustParsePrefix("192.168.0.0/24"), // New auto-approved + netip.MustParsePrefix("203.0.113.0/24"), // Kept despite not advertised + }, + wantChanged: true, + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: tt.nodeUser, + } + users := []types.User{user} + + // Create test node + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: tt.nodeHostname, + UserID: ptr.To(user.ID), + User: ptr.To(user), + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + Tags: tt.nodeTags, + } + nodes := types.Nodes{&node} + + // Create policy manager + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, pm) + + // Test ApproveRoutesWithPolicy + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + // Check change flag + assert.Equal(t, tt.wantChanged, gotChanged, "change flag mismatch") + + // Check approved routes match expected + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Logf("Want: %v", tt.wantApproved) + t.Logf("Got: %v", gotApproved) + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + + // Verify all previously approved routes are still present + for _, prevRoute := range tt.currentApproved { + assert.Contains(t, gotApproved, prevRoute, + "previously approved route %s was removed - this should NEVER happen", prevRoute) + } + + // Verify no routes were incorrectly removed + for _, removedRoute := range tt.wantRemovedRoutes { + assert.NotContains(t, gotApproved, removedRoute, + "route %s should have been removed but wasn't", removedRoute) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) { + aclPolicy := ` +{ + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]}, + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["test@"], + }, + }, +}` + + tests := []struct { + name string + currentApproved []netip.Prefix + announcedRoutes []netip.Prefix + wantApproved []netip.Prefix + wantChanged bool + }{ + { + name: "nil_current_approved", + currentApproved: nil, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "empty_current_approved", + currentApproved: []netip.Prefix{}, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, + }, + { + name: "duplicate_routes_handled", + currentApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/24"), // Duplicate + }, + announcedRoutes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantApproved: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + wantChanged: true, // Duplicates are removed, so it's a change + }, + } + + pmfs := PolicyManagerFuncsForTest([]byte(aclPolicy)) + + for _, tt := range tests { + for i, pmf := range pmfs { + t.Run(fmt.Sprintf("%s-policy-index%d", tt.name, i), func(t *testing.T) { + // Create test user + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + users := []types.User{user} + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: ptr.To(user.ID), + User: ptr.To(user), + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tt.announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: tt.currentApproved, + } + nodes := types.Nodes{&node} + + pm, err := pmf(users, nodes.ViewSlice()) + require.NoError(t, err) + + gotApproved, gotChanged := ApproveRoutesWithPolicy( + pm, + node.View(), + tt.currentApproved, + tt.announcedRoutes, + ) + + assert.Equal(t, tt.wantChanged, gotChanged) + + if diff := cmp.Diff(tt.wantApproved, gotApproved, util.Comparers...); diff != "" { + t.Errorf("unexpected approved routes (-want +got):\n%s", diff) + } + }) + } + } +} + +func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) { + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "test", + } + + currentApproved := []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + } + announcedRoutes := []netip.Prefix{ + netip.MustParsePrefix("192.168.0.0/24"), + } + + node := types.Node{ + ID: 1, + MachineKey: key.NewMachine().Public(), + NodeKey: key.NewNode().Public(), + Hostname: "testnode", + UserID: ptr.To(user.ID), + User: ptr.To(user), + RegisterMethod: util.RegisterMethodAuthKey, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: announcedRoutes, + }, + IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")), + ApprovedRoutes: currentApproved, + } + + // With nil policy manager, should return current approved unchanged + gotApproved, gotChanged := ApproveRoutesWithPolicy(nil, node.View(), currentApproved, announcedRoutes) + + assert.False(t, gotChanged) + assert.Equal(t, currentApproved, gotApproved) +} diff --git a/hscontrol/policy/policy_test.go b/hscontrol/policy/policy_test.go new file mode 100644 index 00000000..da212605 --- /dev/null +++ b/hscontrol/policy/policy_test.go @@ -0,0 +1,2114 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +var ap = func(ipStr string) *netip.Addr { + ip := netip.MustParseAddr(ipStr) + return &ip +} + +var p = func(prefStr string) netip.Prefix { + ip := netip.MustParsePrefix(prefStr) + return ip +} + +func TestReduceNodes(t *testing.T) { + type args struct { + nodes types.Nodes + rules []tailcfg.FilterRule + node *types.Node + } + tests := []struct { + name string + args args + want types.Nodes + }{ + { + name: "all hosts can talk to each other", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + node: &types.Node{ // current nodes + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + }, + { + name: "One host can talk to another, but not all hosts", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.1", "100.64.0.2", "100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + node: &types.Node{ // current nodes + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + }, + }, + { + name: "host cannot directly talk to destination, but return path is authorized", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"100.64.0.3"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + node: &types.Node{ // current nodes + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + node: &types.Node{ // current nodes + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + }, + }, + { + name: "rules allows all hosts to reach one destination, destination can reach all hosts", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.2"}, + }, + }, + }, + node: &types.Node{ // current nodes + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + }, + { + name: "rule allows all hosts to reach all destinations", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + node: &types.Node{ // current nodes + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + }, + { + name: "without rule all communications are forbidden", + args: args{ + nodes: types.Nodes{ // list of all nodes in the database + &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "joe"}, + }, + &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + &types.Node{ + ID: 3, + IPv4: ap("100.64.0.3"), + User: &types.User{Name: "mickael"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + }, + node: &types.Node{ // current nodes + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "marc"}, + }, + }, + want: nil, + }, + { + // Investigating 699 + // Found some nodes: [ts-head-8w6paa ts-unstable-lys2ib ts-head-upcrmb ts-unstable-rlwpvr] nodes=ts-head-8w6paa + // ACL rules generated ACL=[{"DstPorts":[{"Bits":null,"IP":"*","Ports":{"First":0,"Last":65535}}],"SrcIPs":["fd7a:115c:a1e0::3","100.64.0.3","fd7a:115c:a1e0::4","100.64.0.4"]}] + // ACL Cache Map={"100.64.0.3":{"*":{}},"100.64.0.4":{"*":{}},"fd7a:115c:a1e0::3":{"*":{}},"fd7a:115c:a1e0::4":{"*":{}}} + name: "issue-699-broken-star", + args: args{ + nodes: types.Nodes{ // + &types.Node{ + ID: 1, + Hostname: "ts-head-upcrmb", + IPv4: ap("100.64.0.3"), + IPv6: ap("fd7a:115c:a1e0::3"), + User: &types.User{Name: "user1"}, + }, + &types.Node{ + ID: 2, + Hostname: "ts-unstable-rlwpvr", + IPv4: ap("100.64.0.4"), + IPv6: ap("fd7a:115c:a1e0::4"), + User: &types.User{Name: "user1"}, + }, + &types.Node{ + ID: 3, + Hostname: "ts-head-8w6paa", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &types.User{Name: "user2"}, + }, + &types.Node{ + ID: 4, + Hostname: "ts-unstable-lys2ib", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: &types.User{Name: "user2"}, + }, + }, + rules: []tailcfg.FilterRule{ // list of all ACLRules registered + { + DstPorts: []tailcfg.NetPortRange{ + { + IP: "*", + Ports: tailcfg.PortRange{First: 0, Last: 65535}, + }, + }, + SrcIPs: []string{ + "fd7a:115c:a1e0::3", "100.64.0.3", + "fd7a:115c:a1e0::4", "100.64.0.4", + }, + }, + }, + node: &types.Node{ // current nodes + ID: 3, + Hostname: "ts-head-8w6paa", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &types.User{Name: "user2"}, + }, + }, + want: types.Nodes{ + &types.Node{ + ID: 1, + Hostname: "ts-head-upcrmb", + IPv4: ap("100.64.0.3"), + IPv6: ap("fd7a:115c:a1e0::3"), + User: &types.User{Name: "user1"}, + }, + &types.Node{ + ID: 2, + Hostname: "ts-unstable-rlwpvr", + IPv4: ap("100.64.0.4"), + IPv6: ap("fd7a:115c:a1e0::4"), + User: &types.User{Name: "user1"}, + }, + }, + }, + { + name: "failing-edge-case-during-p3-refactor", + args: args{ + nodes: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.2"), + Hostname: "peer1", + User: &types.User{Name: "mini"}, + }, + { + ID: 2, + IPv4: ap("100.64.0.3"), + Hostname: "peer2", + User: &types.User{Name: "peer2"}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, + {IP: "::/0", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + node: &types.Node{ + ID: 0, + IPv4: ap("100.64.0.1"), + Hostname: "mini", + User: &types.User{Name: "mini"}, + }, + }, + want: []*types.Node{ + { + ID: 2, + IPv4: ap("100.64.0.3"), + Hostname: "peer2", + User: &types.User{Name: "peer2"}, + }, + }, + }, + { + name: "p4-host-in-netmap-user2-dest-bug", + args: args{ + nodes: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.2"), + Hostname: "user1-2", + User: &types.User{Name: "user1"}, + }, + { + ID: 0, + IPv4: ap("100.64.0.1"), + Hostname: "user1-1", + User: &types.User{Name: "user1"}, + }, + { + ID: 3, + IPv4: ap("100.64.0.4"), + Hostname: "user2-2", + User: &types.User{Name: "user2"}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.3/32", + "100.64.0.4/32", + "fd7a:115c:a1e0::3/128", + "fd7a:115c:a1e0::4/128", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, + {IP: "100.64.0.4/32", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::3/128", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::4/128", Ports: tailcfg.PortRangeAny}, + }, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, + {IP: "100.64.0.4/32", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::3/128", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::4/128", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.3"), + Hostname: "user-2-1", + User: &types.User{Name: "user2"}, + }, + }, + want: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.2"), + Hostname: "user1-2", + User: &types.User{Name: "user1"}, + }, + { + ID: 0, + IPv4: ap("100.64.0.1"), + Hostname: "user1-1", + User: &types.User{Name: "user1"}, + }, + { + ID: 3, + IPv4: ap("100.64.0.4"), + Hostname: "user2-2", + User: &types.User{Name: "user2"}, + }, + }, + }, + { + name: "p4-host-in-netmap-user1-dest-bug", + args: args{ + nodes: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.2"), + Hostname: "user1-2", + User: &types.User{Name: "user1"}, + }, + { + ID: 2, + IPv4: ap("100.64.0.3"), + Hostname: "user-2-1", + User: &types.User{Name: "user2"}, + }, + { + ID: 3, + IPv4: ap("100.64.0.4"), + Hostname: "user2-2", + User: &types.User{Name: "user2"}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}, + {IP: "100.64.0.2/32", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::1/128", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::2/128", Ports: tailcfg.PortRangeAny}, + }, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.3/32", Ports: tailcfg.PortRangeAny}, + {IP: "100.64.0.4/32", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::3/128", Ports: tailcfg.PortRangeAny}, + {IP: "fd7a:115c:a1e0::4/128", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + node: &types.Node{ + ID: 0, + IPv4: ap("100.64.0.1"), + Hostname: "user1-1", + User: &types.User{Name: "user1"}, + }, + }, + want: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.2"), + Hostname: "user1-2", + User: &types.User{Name: "user1"}, + }, + { + ID: 2, + IPv4: ap("100.64.0.3"), + Hostname: "user-2-1", + User: &types.User{Name: "user2"}, + }, + { + ID: 3, + IPv4: ap("100.64.0.4"), + Hostname: "user2-2", + User: &types.User{Name: "user2"}, + }, + }, + }, + { + name: "subnet-router-with-only-route", + args: args{ + nodes: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.1"), + Hostname: "user1", + User: &types.User{Name: "user1"}, + }, + { + ID: 2, + IPv4: ap("100.64.0.2"), + Hostname: "router", + User: &types.User{Name: "router"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + }, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.33.0.0/16", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + Hostname: "user1", + User: &types.User{Name: "user1"}, + }, + }, + want: []*types.Node{ + { + ID: 2, + IPv4: ap("100.64.0.2"), + Hostname: "router", + User: &types.User{Name: "router"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + }, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")}, + }, + }, + }, + { + name: "subnet-router-with-only-route-smaller-mask-2181", + args: args{ + nodes: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.1"), + Hostname: "router", + User: &types.User{Name: "router"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + { + ID: 2, + IPv4: ap("100.64.0.2"), + Hostname: "node", + User: &types.User{Name: "node"}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.2/32", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.99.0.2/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + Hostname: "router", + User: &types.User{Name: "router"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + }, + want: []*types.Node{ + { + ID: 2, + IPv4: ap("100.64.0.2"), + Hostname: "node", + User: &types.User{Name: "node"}, + }, + }, + }, + { + name: "node-to-subnet-router-with-only-route-smaller-mask-2181", + args: args{ + nodes: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.1"), + Hostname: "router", + User: &types.User{Name: "router"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + { + ID: 2, + IPv4: ap("100.64.0.2"), + Hostname: "node", + User: &types.User{Name: "node"}, + }, + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.2/32", + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.99.0.2/32", Ports: tailcfg.PortRangeAny}, + }, + }, + }, + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + Hostname: "node", + User: &types.User{Name: "node"}, + }, + }, + want: []*types.Node{ + { + ID: 1, + IPv4: ap("100.64.0.1"), + Hostname: "router", + User: &types.User{Name: "router"}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")}, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matchers := matcher.MatchesFromFilterRules(tt.args.rules) + gotViews := ReduceNodes( + tt.args.node.View(), + tt.args.nodes.ViewSlice(), + matchers, + ) + // Convert views back to nodes for comparison in tests + var got types.Nodes + for _, v := range gotViews.All() { + got = append(got, v.AsStruct()) + } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("ReduceNodes() unexpected result (-want +got):\n%s", diff) + t.Log("Matchers: ") + for _, m := range matchers { + t.Log("\t+", m.DebugString()) + } + } + }) + } +} + +func TestReduceNodesFromPolicy(t *testing.T) { + n := func(id types.NodeID, ip, hostname, username string, routess ...string) *types.Node { + var routes []netip.Prefix + for _, route := range routess { + routes = append(routes, netip.MustParsePrefix(route)) + } + + return &types.Node{ + ID: id, + IPv4: ap(ip), + Hostname: hostname, + User: &types.User{Name: username}, + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: routes, + }, + ApprovedRoutes: routes, + } + } + + tests := []struct { + name string + nodes types.Nodes + policy string + node *types.Node + want types.Nodes + wantMatchers int + }{ + { + name: "2788-exit-node-too-visible", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + }, + wantMatchers: 1, + }, + { + name: "2788-exit-node-autogroup:internet", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + }, + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + { + name: "2788-exit-node-0000-route", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + }, + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "0.0.0.0/0:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + { + name: "2788-exit-node-::0-route", + nodes: types.Nodes{ + n(1, "100.64.0.1", "mobile", "mobile"), + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + policy: ` +{ + "hosts": { + "mobile": "100.64.0.1/32", + "server": "100.64.0.2/32", + "exit": "100.64.0.3/32" + }, + + "acls": [ + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "server:80" + ] + }, + { + "action": "accept", + "src": [ + "mobile" + ], + "dst": [ + "::0/0:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "mobile", "mobile"), + want: types.Nodes{ + n(2, "100.64.0.2", "server", "server"), + n(3, "100.64.0.3", "exit", "server", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + { + name: "2784-split-exit-node-access", + nodes: types.Nodes{ + n(1, "100.64.0.1", "user", "user"), + n(2, "100.64.0.2", "exit1", "exit", "0.0.0.0/0", "::/0"), + n(3, "100.64.0.3", "exit2", "exit", "0.0.0.0/0", "::/0"), + n(4, "100.64.0.4", "otheruser", "otheruser"), + }, + policy: ` +{ + "hosts": { + "user": "100.64.0.1/32", + "exit1": "100.64.0.2/32", + "exit2": "100.64.0.3/32", + "otheruser": "100.64.0.4/32", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "user" + ], + "dst": [ + "exit1:*" + ] + }, + { + "action": "accept", + "src": [ + "otheruser" + ], + "dst": [ + "exit2:*" + ] + } + ] +}`, + node: n(1, "100.64.0.1", "user", "user"), + want: types.Nodes{ + n(2, "100.64.0.2", "exit1", "exit", "0.0.0.0/0", "::/0"), + }, + wantMatchers: 2, + }, + } + + for _, tt := range tests { + for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { + var pm PolicyManager + var err error + pm, err = pmf(nil, tt.nodes.ViewSlice()) + require.NoError(t, err) + + matchers, err := pm.MatchersForNode(tt.node.View()) + require.NoError(t, err) + assert.Len(t, matchers, tt.wantMatchers) + + gotViews := ReduceNodes( + tt.node.View(), + tt.nodes.ViewSlice(), + matchers, + ) + // Convert views back to nodes for comparison in tests + var got types.Nodes + for _, v := range gotViews.All() { + got = append(got, v.AsStruct()) + } + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("TestReduceNodesFromPolicy() unexpected result (-want +got):\n%s", diff) + t.Log("Matchers: ") + for _, m := range matchers { + t.Log("\t+", m.DebugString()) + } + } + }) + } + } +} + +func TestSSHPolicyRules(t *testing.T) { + users := []types.User{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + {Name: "user3", Model: gorm.Model{ID: 3}}, + } + + // Create standard node setups used across tests + nodeUser1 := types.Node{ + Hostname: "user1-device", + IPv4: ap("100.64.0.1"), + UserID: ptr.To(uint(1)), + User: ptr.To(users[0]), + } + nodeUser2 := types.Node{ + Hostname: "user2-device", + IPv4: ap("100.64.0.2"), + UserID: ptr.To(uint(2)), + User: ptr.To(users[1]), + } + + taggedClient := types.Node{ + Hostname: "tagged-client", + IPv4: ap("100.64.0.4"), + UserID: ptr.To(uint(2)), + User: ptr.To(users[1]), + Tags: []string{"tag:client"}, + } + + // Create a tagged server node for valid SSH patterns + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: ap("100.64.0.5"), + UserID: ptr.To(uint(1)), + User: ptr.To(users[0]), + Tags: []string{"tag:server"}, + } + + tests := []struct { + name string + targetNode types.Node + peers types.Nodes + policy string + wantSSH *tailcfg.SSHPolicy + expectErr bool + errorMessage string + }{ + { + name: "group-to-tag", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + } + ] + }`, + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.2"}, + }, + SSHUsers: map[string]string{ + "*": "=", + "root": "", + }, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + { + name: "check-period-specified", + targetNode: taggedClient, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:client": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "check", + "checkPeriod": "24h", + "src": ["group:admins"], + "dst": ["tag:client"], + "users": ["autogroup:nonroot"] + } + ] + }`, + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.2"}, + }, + SSHUsers: map[string]string{ + "*": "=", + "root": "", + }, + Action: &tailcfg.SSHAction{ + Accept: true, + SessionDuration: 24 * time.Hour, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + { + name: "no-matching-rules", + targetNode: nodeUser2, + peers: types.Nodes{&nodeUser1, &nodeTaggedServer}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user1@"] + }, + "ssh": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + } + ] + }`, + wantSSH: &tailcfg.SSHPolicy{Rules: nil}, + }, + { + name: "invalid-action", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "invalid", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + } + ] + }`, + expectErr: true, + errorMessage: `invalid SSH action "invalid", must be one of: accept, check`, + }, + { + name: "invalid-check-period", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "check", + "checkPeriod": "invalid", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + } + ] + }`, + expectErr: true, + errorMessage: "not a valid duration string", + }, + { + name: "unsupported-autogroup", + targetNode: taggedClient, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:client": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:client"], + "users": ["autogroup:invalid"] + } + ] + }`, + expectErr: true, + errorMessage: "autogroup \"autogroup:invalid\" is not supported", + }, + { + name: "autogroup-nonroot-should-use-wildcard-with-root-excluded", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + } + ] + }`, + // autogroup:nonroot should map to wildcard "*" with root excluded + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.2"}, + }, + SSHUsers: map[string]string{ + "*": "=", + "root": "", + }, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + { + name: "autogroup-nonroot-plus-root-should-use-wildcard-with-root-mapped", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot", "root"] + } + ] + }`, + // autogroup:nonroot + root should map to wildcard "*" with root mapped to itself + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.2"}, + }, + SSHUsers: map[string]string{ + "*": "=", + "root": "root", + }, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + { + name: "specific-users-should-map-to-themselves-not-equals", + targetNode: nodeTaggedServer, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "tagOwners": { + "tag:server": ["user1@"] + }, + "groups": { + "group:admins": ["user2@"] + }, + "ssh": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["ubuntu", "root"] + } + ] + }`, + // specific usernames should map to themselves, not "=" + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.2"}, + }, + SSHUsers: map[string]string{ + "root": "root", + "ubuntu": "ubuntu", + }, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + { + name: "2863-allow-predefined-missing-users", + targetNode: taggedClient, + peers: types.Nodes{&nodeUser2}, + policy: `{ + "groups": { + "group:example-infra": [ + "user2@", + "not-created-yet@", + ], + }, + "tagOwners": { + "tag:client": [ + "user2@" + ], + }, + "ssh": [ + // Allow infra to ssh to tag:example-infra server as debian + { + "action": "accept", + "src": [ + "group:example-infra" + ], + "dst": [ + "tag:client", + ], + "users": [ + "debian", + ], + }, + ], +}`, + wantSSH: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{ + { + Principals: []*tailcfg.SSHPrincipal{ + {NodeIP: "100.64.0.2"}, + }, + SSHUsers: map[string]string{ + "debian": "debian", + }, + Action: &tailcfg.SSHAction{ + Accept: true, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + }, + }, + }}, + }, + } + + for _, tt := range tests { + for idx, pmf := range PolicyManagerFuncsForTest([]byte(tt.policy)) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { + var pm PolicyManager + var err error + pm, err = pmf(users, append(tt.peers, &tt.targetNode).ViewSlice()) + + if tt.expectErr { + require.Error(t, err) + require.Contains(t, err.Error(), tt.errorMessage) + return + } + + require.NoError(t, err) + + got, err := pm.SSHPolicy(tt.targetNode.View()) + require.NoError(t, err) + + if diff := cmp.Diff(tt.wantSSH, got); diff != "" { + t.Errorf("SSHPolicy() unexpected result (-want +got):\n%s", diff) + } + }) + } + } +} + +func TestReduceRoutes(t *testing.T) { + type args struct { + node *types.Node + routes []netip.Prefix + rules []tailcfg.FilterRule + } + tests := []struct { + name string + args args + want []netip.Prefix + }{ + { + name: "node-can-access-all-routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + }, + { + name: "node-can-access-specific-route", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + }, + }, + { + name: "node-can-access-multiple-specific-routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, + {IP: "192.168.1.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "node-can-access-overlapping-routes", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/16"), // Overlaps with the first one + netip.MustParsePrefix("192.168.1.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/16"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("10.0.0.0/16"), + }, + }, + { + name: "node-with-no-matching-rules", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + User: &types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("192.168.1.0/24"), + netip.MustParsePrefix("172.16.0.0/16"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, // Different source IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, + }, + }, + }, + }, + want: nil, + }, + { + name: "node-with-both-ipv4-and-ipv6", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: &types.User{Name: "user1"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/64"), + netip.MustParsePrefix("192.168.1.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"fd7a:115c:a1e0::1"}, // IPv6 source + DstPorts: []tailcfg.NetPortRange{ + {IP: "2001:db8::/64"}, // IPv6 destination + }, + }, + { + SrcIPs: []string{"100.64.0.1"}, // IPv4 source + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.0.0.0/24"}, // IPv4 destination + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.0.0.0/24"), + netip.MustParsePrefix("2001:db8::/64"), + }, + }, + { + name: "router-with-multiple-routes-and-node-with-specific-access", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // Node IP + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, // Any source + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1"}, // Router node + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // Node IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, // Only one subnet allowed + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + }, + }, + { + name: "node-with-access-to-one-subnet-and-partial-overlap", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.10.0/16"), // Overlaps with the first one + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, // Only specific subnet + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.10.0/16"), // With current implementation, this is included because it overlaps with the allowed subnet + }, + }, + { + name: "node-with-access-to-wildcard-subnet", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.0.0/16"}, // Broader subnet that includes all three + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + }, + { + name: "multiple-nodes-with-different-subnet-permissions", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1"}, // Different node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.11.0/24"}, + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // Our node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, + }, + }, + { + SrcIPs: []string{"100.64.0.3"}, // Different node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.12.0/24"}, + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + }, + }, + { + name: "exactly-matching-users-acl-example", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // node with IP 100.64.0.2 + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + // This represents the rule: action: accept, src: ["*"], dst: ["router:0"] + SrcIPs: []string{"*"}, // Any source + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1"}, // Router IP + }, + }, + { + // This represents the rule: action: accept, src: ["node"], dst: ["10.10.10.0/24:*"] + SrcIPs: []string{"100.64.0.2"}, // Node IP + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24", Ports: tailcfg.PortRangeAny}, // All ports on this subnet + }, + }, + }, + }, + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + }, + }, + { + name: "acl-all-source-nodes-can-access-router-only-node-can-access-10.10.10.0-24", + args: args{ + // When testing from router node's perspective + node: &types.Node{ + ID: 1, + IPv4: ap("100.64.0.1"), // router with IP 100.64.0.1 + User: &types.User{Name: "router"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"*"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.64.0.1"}, // Router can be accessed by all + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // Only node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, // Can access this subnet + }, + }, + // Add a rule for router to access its own routes + { + SrcIPs: []string{"100.64.0.1"}, // Router node + DstPorts: []tailcfg.NetPortRange{ + {IP: "*"}, // Can access everything + }, + }, + }, + }, + // Router needs explicit rules to access routes + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + }, + { + name: "acl-specific-port-ranges-for-subnets", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // node + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.2"}, // node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24", Ports: tailcfg.PortRange{First: 22, Last: 22}}, // Only SSH + }, + }, + { + SrcIPs: []string{"100.64.0.2"}, // node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.11.0/24", Ports: tailcfg.PortRange{First: 80, Last: 80}}, // Only HTTP + }, + }, + }, + }, + // Should get both subnets with specific port ranges + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + }, + }, + { + name: "acl-order-of-rules-and-rule-specificity", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.64.0.2"), // node + User: &types.User{Name: "node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + rules: []tailcfg.FilterRule{ + // First rule allows all traffic + { + SrcIPs: []string{"*"}, // Any source + DstPorts: []tailcfg.NetPortRange{ + {IP: "*", Ports: tailcfg.PortRangeAny}, // Any destination and any port + }, + }, + // Second rule is more specific but should be overridden by the first rule + { + SrcIPs: []string{"100.64.0.2"}, // node + DstPorts: []tailcfg.NetPortRange{ + {IP: "10.10.10.0/24"}, + }, + }, + }, + }, + // Due to the first rule allowing all traffic, node should have access to all routes + want: []netip.Prefix{ + netip.MustParsePrefix("10.10.10.0/24"), + netip.MustParsePrefix("10.10.11.0/24"), + netip.MustParsePrefix("10.10.12.0/24"), + }, + }, + { + name: "return-path-subnet-router-to-regular-node-issue-2608", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.123.45.89"), // Node B - regular node + User: &types.User{Name: "node-b"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), // Subnet connected to Node A + }, + rules: []tailcfg.FilterRule{ + { + // Policy allows 192.168.1.0/24 and group:routers to access *:* + SrcIPs: []string{ + "192.168.1.0/24", // Subnet behind router + "100.123.45.67", // Node A (router, part of group:routers) + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*", Ports: tailcfg.PortRangeAny}, // Access to everything + }, + }, + }, + }, + // Node B should receive the 192.168.1.0/24 route for return traffic + // even though Node B cannot initiate connections to that network + want: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "return-path-router-perspective-2608", + args: args{ + node: &types.Node{ + ID: 1, + IPv4: ap("100.123.45.67"), // Node A - router node + User: &types.User{Name: "router"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), // Subnet connected to this router + }, + rules: []tailcfg.FilterRule{ + { + // Policy allows 192.168.1.0/24 and group:routers to access *:* + SrcIPs: []string{ + "192.168.1.0/24", // Subnet behind router + "100.123.45.67", // Node A (router, part of group:routers) + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*", Ports: tailcfg.PortRangeAny}, // Access to everything + }, + }, + }, + }, + // Router should have access to its own routes + want: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "subnet-behind-router-bidirectional-connectivity-issue-2608", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.123.45.89"), // Node B - regular node that should be reachable + User: &types.User{Name: "node-b"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), // Subnet behind router + netip.MustParsePrefix("10.0.0.0/24"), // Another subnet + }, + rules: []tailcfg.FilterRule{ + { + // Only 192.168.1.0/24 and routers can access everything + SrcIPs: []string{ + "192.168.1.0/24", // Subnet that can connect to Node B + "100.123.45.67", // Router node + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*", Ports: tailcfg.PortRangeAny}, + }, + }, + { + // Node B cannot access anything (no rules with Node B as source) + SrcIPs: []string{"100.123.45.89"}, + DstPorts: []tailcfg.NetPortRange{ + // No destinations - Node B cannot initiate connections + }, + }, + }, + }, + // Node B should still get the 192.168.1.0/24 route for return traffic + // but should NOT get 10.0.0.0/24 since nothing allows that subnet to connect to Node B + want: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), + }, + }, + { + name: "no-route-leakage-when-no-connection-allowed-2608", + args: args{ + node: &types.Node{ + ID: 3, + IPv4: ap("100.123.45.99"), // Node C - isolated node + User: &types.User{Name: "isolated-node"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/24"), // Subnet behind router + netip.MustParsePrefix("10.0.0.0/24"), // Another private subnet + netip.MustParsePrefix("172.16.0.0/24"), // Yet another subnet + }, + rules: []tailcfg.FilterRule{ + { + // Only specific subnets and routers can access specific destinations + SrcIPs: []string{ + "192.168.1.0/24", // This subnet can access everything + "100.123.45.67", // Router node can access everything + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.123.45.89", Ports: tailcfg.PortRangeAny}, // Only to Node B + }, + }, + { + // 10.0.0.0/24 can only access router + SrcIPs: []string{"10.0.0.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.123.45.67", Ports: tailcfg.PortRangeAny}, // Only to router + }, + }, + { + // 172.16.0.0/24 has no access rules at all + }, + }, + }, + // Node C should get NO routes because: + // - 192.168.1.0/24 can only connect to Node B (not Node C) + // - 10.0.0.0/24 can only connect to router (not Node C) + // - 172.16.0.0/24 has no rules allowing it to connect anywhere + // - Node C is not in any rules as a destination + want: nil, + }, + { + name: "original-issue-2608-with-slash14-network", + args: args{ + node: &types.Node{ + ID: 2, + IPv4: ap("100.123.45.89"), // Node B - regular node + User: &types.User{Name: "node-b"}, + }, + routes: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/14"), // Network 192.168.1.0/14 as mentioned in original issue + }, + rules: []tailcfg.FilterRule{ + { + // Policy allows 192.168.1.0/24 (part of /14) and group:routers to access *:* + SrcIPs: []string{ + "192.168.1.0/24", // Subnet behind router (part of the larger /14 network) + "100.123.45.67", // Node A (router, part of group:routers) + }, + DstPorts: []tailcfg.NetPortRange{ + {IP: "*", Ports: tailcfg.PortRangeAny}, // Access to everything + }, + }, + }, + }, + // Node B should receive the 192.168.1.0/14 route for return traffic + // even though only 192.168.1.0/24 (part of /14) can connect to Node B + // This is the exact scenario from the original issue + want: []netip.Prefix{ + netip.MustParsePrefix("192.168.1.0/14"), + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + matchers := matcher.MatchesFromFilterRules(tt.args.rules) + got := ReduceRoutes( + tt.args.node.View(), + tt.args.routes, + matchers, + ) + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("ReduceRoutes() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policy/policyutil/reduce.go b/hscontrol/policy/policyutil/reduce.go new file mode 100644 index 00000000..e4549c10 --- /dev/null +++ b/hscontrol/policy/policyutil/reduce.go @@ -0,0 +1,71 @@ +package policyutil + +import ( + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "tailscale.com/tailcfg" +) + +// ReduceFilterRules takes a node and a set of global filter rules and removes all rules +// and destinations that are not relevant to that particular node. +// +// IMPORTANT: This function is designed for global filters only. Per-node filters +// (from autogroup:self policies) are already node-specific and should not be passed +// to this function. Use PolicyManager.FilterForNode() instead, which handles both cases. +func ReduceFilterRules(node types.NodeView, rules []tailcfg.FilterRule) []tailcfg.FilterRule { + ret := []tailcfg.FilterRule{} + + for _, rule := range rules { + // record if the rule is actually relevant for the given node. + var dests []tailcfg.NetPortRange + DEST_LOOP: + for _, dest := range rule.DstPorts { + expanded, err := util.ParseIPSet(dest.IP, nil) + // Fail closed, if we can't parse it, then we should not allow + // access. + if err != nil { + continue DEST_LOOP + } + + if node.InIPSet(expanded) { + dests = append(dests, dest) + continue DEST_LOOP + } + + // If the node exposes routes, ensure they are note removed + // when the filters are reduced. + if node.Hostinfo().Valid() { + routableIPs := node.Hostinfo().RoutableIPs() + if routableIPs.Len() > 0 { + for _, routableIP := range routableIPs.All() { + if expanded.OverlapsPrefix(routableIP) { + dests = append(dests, dest) + continue DEST_LOOP + } + } + } + } + + // Also check approved subnet routes - nodes should have access + // to subnets they're approved to route traffic for. + subnetRoutes := node.SubnetRoutes() + + for _, subnetRoute := range subnetRoutes { + if expanded.OverlapsPrefix(subnetRoute) { + dests = append(dests, dest) + continue DEST_LOOP + } + } + } + + if len(dests) > 0 { + ret = append(ret, tailcfg.FilterRule{ + SrcIPs: rule.SrcIPs, + DstPorts: dests, + IPProto: rule.IPProto, + }) + } + } + + return ret +} diff --git a/hscontrol/policy/policyutil/reduce_test.go b/hscontrol/policy/policyutil/reduce_test.go new file mode 100644 index 00000000..35f5b472 --- /dev/null +++ b/hscontrol/policy/policyutil/reduce_test.go @@ -0,0 +1,842 @@ +package policyutil_test + +import ( + "encoding/json" + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/policy/policyutil" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" + "tailscale.com/util/must" +) + +var ap = func(ipStr string) *netip.Addr { + ip := netip.MustParseAddr(ipStr) + return &ip +} + +var p = func(prefStr string) netip.Prefix { + ip := netip.MustParsePrefix(prefStr) + return ip +} + +// hsExitNodeDestForTest is the list of destination IP ranges that are allowed when +// we use headscale "autogroup:internet". +var hsExitNodeDestForTest = []tailcfg.NetPortRange{ + {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "64.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "96.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "100.0.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "100.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "101.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "102.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "104.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "112.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "168.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "169.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "169.128.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "169.192.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "169.224.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "169.240.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "169.248.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "169.252.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "169.255.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "170.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "224.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "2000::/3", Ports: tailcfg.PortRangeAny}, +} + +func TestTheInternet(t *testing.T) { + internetSet := util.TheInternet() + + internetPrefs := internetSet.Prefixes() + + for i := range internetPrefs { + if internetPrefs[i].String() != hsExitNodeDestForTest[i].IP { + t.Errorf( + "prefix from internet set %q != hsExit list %q", + internetPrefs[i].String(), + hsExitNodeDestForTest[i].IP, + ) + } + } + + if len(internetPrefs) != len(hsExitNodeDestForTest) { + t.Fatalf( + "expected same length of prefixes, internet: %d, hsExit: %d", + len(internetPrefs), + len(hsExitNodeDestForTest), + ) + } +} + +func TestReduceFilterRules(t *testing.T) { + users := types.Users{ + types.User{Model: gorm.Model{ID: 1}, Name: "mickael"}, + types.User{Model: gorm.Model{ID: 2}, Name: "user1"}, + types.User{Model: gorm.Model{ID: 3}, Name: "user2"}, + types.User{Model: gorm.Model{ID: 4}, Name: "user100"}, + types.User{Model: gorm.Model{ID: 5}, Name: "user3"}, + } + + tests := []struct { + name string + node *types.Node + peers types.Nodes + pol string + want []tailcfg.FilterRule + }{ + { + name: "host1-can-reach-host2-no-rules", + pol: ` +{ + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "100.64.0.1" + ], + "dst": [ + "100.64.0.2:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"), + User: ptr.To(users[0]), + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"), + User: ptr.To(users[0]), + }, + }, + want: []tailcfg.FilterRule{}, + }, + { + name: "1604-subnet-routers-are-preserved", + pol: ` +{ + "groups": { + "group:admins": [ + "user1@" + ] + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "group:admins:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:admins" + ], + "dst": [ + "10.33.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{ + netip.MustParsePrefix("10.33.0.0/16"), + }, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.1/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::1/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{ + "100.64.0.1/32", + "100.64.0.2/32", + "fd7a:115c:a1e0::1/128", + "fd7a:115c:a1e0::2/128", + }, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "10.33.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-client", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[2]), + }, + // "internal" exit node + &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: ptr.To(users[3]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + }, + want: []tailcfg.FilterRule{}, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-exit", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: ptr.To(users[3]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[2]), + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: hsExitNodeDestForTest, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-the-example-from-issue", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "0.0.0.0/5:*", + "8.0.0.0/7:*", + "11.0.0.0/8:*", + "12.0.0.0/6:*", + "16.0.0.0/4:*", + "32.0.0.0/3:*", + "64.0.0.0/2:*", + "128.0.0.0/3:*", + "160.0.0.0/5:*", + "168.0.0.0/6:*", + "172.0.0.0/12:*", + "172.32.0.0/11:*", + "172.64.0.0/10:*", + "172.128.0.0/9:*", + "173.0.0.0/8:*", + "174.0.0.0/7:*", + "176.0.0.0/4:*", + "192.0.0.0/9:*", + "192.128.0.0/11:*", + "192.160.0.0/13:*", + "192.169.0.0/16:*", + "192.170.0.0/15:*", + "192.172.0.0/14:*", + "192.176.0.0/12:*", + "192.192.0.0/10:*", + "193.0.0.0/8:*", + "194.0.0.0/7:*", + "196.0.0.0/6:*", + "200.0.0.0/5:*", + "208.0.0.0/4:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: ptr.To(users[3]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: tsaddr.ExitRoutes(), + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[2]), + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "8.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "11.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "12.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "16.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "32.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "64.0.0.0/2", Ports: tailcfg.PortRangeAny}, + {IP: "128.0.0.0/3", Ports: tailcfg.PortRangeAny}, + {IP: "160.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "168.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "172.0.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "172.32.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "172.64.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "172.128.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "173.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "174.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "176.0.0.0/4", Ports: tailcfg.PortRangeAny}, + {IP: "192.0.0.0/9", Ports: tailcfg.PortRangeAny}, + {IP: "192.128.0.0/11", Ports: tailcfg.PortRangeAny}, + {IP: "192.160.0.0/13", Ports: tailcfg.PortRangeAny}, + {IP: "192.169.0.0/16", Ports: tailcfg.PortRangeAny}, + {IP: "192.170.0.0/15", Ports: tailcfg.PortRangeAny}, + {IP: "192.172.0.0/14", Ports: tailcfg.PortRangeAny}, + {IP: "192.176.0.0/12", Ports: tailcfg.PortRangeAny}, + {IP: "192.192.0.0/10", Ports: tailcfg.PortRangeAny}, + {IP: "193.0.0.0/8", Ports: tailcfg.PortRangeAny}, + {IP: "194.0.0.0/7", Ports: tailcfg.PortRangeAny}, + {IP: "196.0.0.0/6", Ports: tailcfg.PortRangeAny}, + {IP: "200.0.0.0/5", Ports: tailcfg.PortRangeAny}, + {IP: "208.0.0.0/4", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-app-connector-like", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/8:*", + "16.0.0.0/8:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: ptr.To(users[3]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")}, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[2]), + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "8.0.0.0/8", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "16.0.0.0/8", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1786-reducing-breaks-exit-nodes-app-connector-like2", + pol: ` +{ + "groups": { + "group:team": [ + "user3@", + "user2@", + "user1@" + ] + }, + "hosts": { + "internal": "100.64.0.100/32" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "internal:*" + ] + }, + { + "action": "accept", + "proto": "", + "src": [ + "group:team" + ], + "dst": [ + "8.0.0.0/16:*", + "16.0.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: ptr.To(users[3]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")}, + }, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[2]), + }, + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + { + SrcIPs: []string{"100.64.0.1/32", "100.64.0.2/32", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::2/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "8.0.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "16.0.0.0/16", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "1817-reduce-breaks-32-mask", + pol: ` +{ + "tagOwners": { + "tag:access-servers": ["user100@"], + }, + "groups": { + "group:access": [ + "user1@" + ] + }, + "hosts": { + "dns1": "172.16.0.21/32", + "vlan1": "172.16.0.0/24" + }, + "acls": [ + { + "action": "accept", + "proto": "", + "src": [ + "group:access" + ], + "dst": [ + "tag:access-servers:*", + "dns1:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.100"), + IPv6: ap("fd7a:115c:a1e0::100"), + User: ptr.To(users[3]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")}, + }, + Tags: []string{"tag:access-servers"}, + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + }, + }, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.64.0.1/32", "fd7a:115c:a1e0::1/128"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.64.0.100/32", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "fd7a:115c:a1e0::100/128", + Ports: tailcfg.PortRangeAny, + }, + { + IP: "172.16.0.21/32", + Ports: tailcfg.PortRangeAny, + }, + }, + IPProto: []int{6, 17}, + }, + }, + }, + { + name: "2365-only-route-policy", + pol: ` +{ + "hosts": { + "router": "100.64.0.1/32", + "node": "100.64.0.2/32" + }, + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "router:8000" + ] + }, + { + "action": "accept", + "src": [ + "node" + ], + "dst": [ + "172.26.0.0/16:*" + ] + } + ], +} +`, + node: &types.Node{ + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[3]), + }, + peers: types.Nodes{ + &types.Node{ + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[1]), + Hostinfo: &tailcfg.Hostinfo{ + RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, + }, + ApprovedRoutes: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")}, + }, + }, + want: []tailcfg.FilterRule{}, + }, + } + + for _, tt := range tests { + for idx, pmf := range policy.PolicyManagerFuncsForTest([]byte(tt.pol)) { + t.Run(fmt.Sprintf("%s-index%d", tt.name, idx), func(t *testing.T) { + var pm policy.PolicyManager + var err error + pm, err = pmf(users, append(tt.peers, tt.node).ViewSlice()) + require.NoError(t, err) + got, _ := pm.Filter() + t.Logf("full filter:\n%s", must.Get(json.MarshalIndent(got, "", " "))) + got = policyutil.ReduceFilterRules(tt.node.View(), got) + + if diff := cmp.Diff(tt.want, got); diff != "" { + log.Trace().Interface("got", got).Msg("result") + t.Errorf("TestReduceFilterRules() unexpected result (-want +got):\n%s", diff) + } + }) + } + } +} diff --git a/hscontrol/policy/route_approval_test.go b/hscontrol/policy/route_approval_test.go new file mode 100644 index 00000000..39b15cee --- /dev/null +++ b/hscontrol/policy/route_approval_test.go @@ -0,0 +1,852 @@ +package policy + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/types/ptr" +) + +func TestNodeCanApproveRoute(t *testing.T) { + users := []types.User{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + {Name: "user3", Model: gorm.Model{ID: 3}}, + } + + // Create standard node setups used across tests + normalNode := types.Node{ + ID: 1, + Hostname: "user1-device", + IPv4: ap("100.64.0.1"), + UserID: ptr.To(uint(1)), + User: ptr.To(users[0]), + } + + exitNode := types.Node{ + ID: 2, + Hostname: "user2-device", + IPv4: ap("100.64.0.2"), + UserID: ptr.To(uint(2)), + User: ptr.To(users[1]), + } + + taggedNode := types.Node{ + ID: 3, + Hostname: "tagged-server", + IPv4: ap("100.64.0.3"), + UserID: ptr.To(uint(3)), + User: ptr.To(users[2]), + Tags: []string{"tag:router"}, + } + + multiTagNode := types.Node{ + ID: 4, + Hostname: "multi-tag-node", + IPv4: ap("100.64.0.4"), + UserID: ptr.To(uint(2)), + User: ptr.To(users[1]), + Tags: []string{"tag:router", "tag:server"}, + } + + tests := []struct { + name string + node types.Node + route netip.Prefix + policy string + canApprove bool + }{ + { + name: "allow-all-routes-for-admin-user", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.0.0/16": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "deny-route-that-doesnt-match-autoApprovers", + node: normalNode, + route: p("10.0.0.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.0.0/16": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "user-not-in-group", + node: exitNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.0.0/16": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "tagged-node-can-approve", + node: taggedNode, + route: p("10.0.0.0/8"), + policy: `{ + "tagOwners": { + "tag:router": ["user3@"] + }, + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["tag:router"] + } + } + }`, + canApprove: true, + }, + { + name: "multiple-routes-in-policy", + node: normalNode, + route: p("172.16.10.0/24"), + policy: `{ + "tagOwners": { + "tag:router": ["user3@"] + }, + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.0.0/16": ["group:admin"], + "172.16.0.0/12": ["group:admin"], + "10.0.0.0/8": ["tag:router"] + } + } + }`, + canApprove: true, + }, + { + name: "match-specific-route-within-range", + node: normalNode, + route: p("192.168.5.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.0.0/16": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "ip-address-within-range", + node: normalNode, + route: p("192.168.1.5/32"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.1.0/24": ["group:admin"], + "192.168.1.128/25": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "all-IPv4-routes-(0.0.0.0/0)-approval", + node: normalNode, + route: p("0.0.0.0/0"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "0.0.0.0/0": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "all-IPv4-routes-exitnode-approval", + node: normalNode, + route: p("0.0.0.0/0"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "exitNode": ["group:admin"] + } + }`, + canApprove: true, + }, + { + name: "all-IPv6-routes-exitnode-approval", + node: normalNode, + route: p("::/0"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "exitNode": ["group:admin"] + } + }`, + canApprove: true, + }, + { + name: "specific-IPv4-route-with-exitnode-only-approval", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "exitNode": ["group:admin"] + } + }`, + canApprove: false, + }, + { + name: "specific-IPv6-route-with-exitnode-only-approval", + node: normalNode, + route: p("fd00::/8"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "exitNode": ["group:admin"] + } + }`, + canApprove: false, + }, + { + name: "specific-IPv4-route-with-all-routes-policy", + node: normalNode, + route: p("10.0.0.0/8"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "0.0.0.0/0": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "all-IPv6-routes-(::0/0)-approval", + node: normalNode, + route: p("::/0"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "::/0": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "specific-IPv6-route-with-all-routes-policy", + node: normalNode, + route: p("fd00::/8"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "::/0": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "IPv6-route-with-IPv4-all-routes-policy", + node: normalNode, + route: p("fd00::/8"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "0.0.0.0/0": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "IPv4-route-with-IPv6-all-routes-policy", + node: normalNode, + route: p("10.0.0.0/8"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "::/0": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "both-IPv4-and-IPv6-all-routes-policy", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "0.0.0.0/0": ["group:admin"], + "::/0": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "ip-address-with-all-routes-policy", + node: normalNode, + route: p("192.168.101.5/32"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "0.0.0.0/0": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "specific-IPv6-host-route-with-all-routes-policy", + node: normalNode, + route: p("2001:db8::1/128"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "::/0": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "multiple-groups-allowed-to-approve-same-route", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"], + "group:netadmin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.1.0/24": ["group:admin", "group:netadmin"] + } + } + }`, + canApprove: true, + }, + { + name: "overlapping-routes-with-different-groups", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"], + "group:restricted": ["user2@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "192.168.0.0/16": ["group:restricted"], + "192.168.1.0/24": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "unique-local-IPv6-address-with-all-routes-policy", + node: normalNode, + route: p("fc00::/7"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "::/0": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "exact-prefix-match-in-policy", + node: normalNode, + route: p("203.0.113.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "203.0.113.0/24": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "narrower-range-than-policy", + node: normalNode, + route: p("203.0.113.0/26"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "203.0.113.0/24": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "wider-range-than-policy-should-fail", + node: normalNode, + route: p("203.0.113.0/23"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "203.0.113.0/24": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "adjacent-route-to-policy-route-should-fail", + node: normalNode, + route: p("203.0.114.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "203.0.113.0/24": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "combined-routes-and-exitnode-approvers-specific-route", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "exitNode": ["group:admin"], + "routes": { + "192.168.1.0/24": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "partly-overlapping-route-with-policy-should-fail", + node: normalNode, + route: p("203.0.113.128/23"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "203.0.113.0/24": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "multiple-routes-with-aggregatable-ranges", + node: normalNode, + route: p("10.0.0.0/8"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.0.0.0/9": ["group:admin"], + "10.128.0.0/9": ["group:admin"] + } + } + }`, + canApprove: false, + }, + { + name: "non-standard-IPv6-notation", + node: normalNode, + route: p("2001:db8::1/128"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "2001:db8::/32": ["group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "node-with-multiple-tags-all-required", + node: multiTagNode, + route: p("10.10.0.0/16"), + policy: `{ + "tagOwners": { + "tag:router": ["user2@"], + "tag:server": ["user2@"] + }, + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.10.0.0/16": ["tag:router", "tag:server"] + } + } + }`, + canApprove: true, + }, + { + name: "node-with-multiple-tags-one-matching-is-sufficient", + node: multiTagNode, + route: p("10.10.0.0/16"), + policy: `{ + "tagOwners": { + "tag:router": ["user2@"], + "tag:server": ["user2@"] + }, + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.10.0.0/16": ["tag:router", "group:admin"] + } + } + }`, + canApprove: true, + }, + { + name: "node-with-multiple-tags-missing-required-tag", + node: multiTagNode, + route: p("10.10.0.0/16"), + policy: `{ + "tagOwners": { + "tag:othertag": ["user1@"] + }, + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.10.0.0/16": ["tag:othertag"] + } + } + }`, + canApprove: false, + }, + { + name: "node-with-tag-and-group-membership", + node: normalNode, + route: p("10.20.0.0/16"), + policy: `{ + "tagOwners": { + "tag:router": ["user3@"] + }, + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.20.0.0/16": ["group:admin", "tag:router"] + } + } + }`, + canApprove: true, + }, + { + // Tags-as-identity: Tagged nodes are identified by their tags, not by the + // user who created them. Group membership of the creator is irrelevant. + // A tagged node can only be auto-approved via tag-based autoApprovers, + // not group-based ones (even if the creator is in the group). + name: "tagged-node-with-group-autoapprover-not-approved", + node: taggedNode, // Has tag:router, owned by user3 + route: p("10.30.0.0/16"), + policy: `{ + "tagOwners": { + "tag:router": ["user3@"] + }, + "groups": { + "group:ops": ["user3@"] + }, + "acls": [ + {"action": "accept", "src": ["*"], "dst": ["*:*"]} + ], + "autoApprovers": { + "routes": { + "10.30.0.0/16": ["group:ops"] + } + } + }`, + canApprove: false, // Tagged nodes don't inherit group membership for auto-approval + }, + { + name: "small-subnet-with-exitnode-only-approval", + node: normalNode, + route: p("192.168.1.1/32"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + {"action": "accept", "src": ["group:admin"], "dst": ["*:*"]} + ], + "autoApprovers": { + "exitNode": ["group:admin"] + } + }`, + canApprove: false, + }, + { + name: "empty-policy", + node: normalNode, + route: p("192.168.1.0/24"), + policy: `{"acls":[{"action":"accept","src":["*"],"dst":["*:*"]}]}`, + canApprove: false, + }, + { + name: "policy-without-autoApprovers-section", + node: normalNode, + route: p("10.33.0.0/16"), + policy: `{ + "groups": { + "group:admin": ["user1@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admin"], + "dst": ["group:admin:*"] + }, + { + "action": "accept", + "src": ["group:admin"], + "dst": ["10.33.0.0/16:*"] + } + ] + }`, + canApprove: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Initialize all policy manager implementations + policyManagers, err := PolicyManagersForTest([]byte(tt.policy), users, types.Nodes{&tt.node}.ViewSlice()) + if tt.name == "empty policy" { + // We expect this one to have a valid but empty policy + require.NoError(t, err) + if err != nil { + return + } + } else { + require.NoError(t, err) + } + + for i, pm := range policyManagers { + t.Run(fmt.Sprintf("policy-index%d", i), func(t *testing.T) { + result := pm.NodeCanApproveRoute(tt.node.View(), tt.route) + + if diff := cmp.Diff(tt.canApprove, result); diff != "" { + t.Errorf("NodeCanApproveRoute() mismatch (-want +got):\n%s", diff) + } + assert.Equal(t, tt.canApprove, result, "Unexpected route approval result") + }) + } + }) + } +} diff --git a/hscontrol/policy/v2/filter.go b/hscontrol/policy/v2/filter.go new file mode 100644 index 00000000..78c6ebc5 --- /dev/null +++ b/hscontrol/policy/v2/filter.go @@ -0,0 +1,463 @@ +package v2 + +import ( + "errors" + "fmt" + "slices" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/types/views" +) + +var ErrInvalidAction = errors.New("invalid action") + +// compileFilterRules takes a set of nodes and an ACLPolicy and generates a +// set of Tailscale compatible FilterRules used to allow traffic on clients. +func (pol *Policy) compileFilterRules( + users types.Users, + nodes views.Slice[types.NodeView], +) ([]tailcfg.FilterRule, error) { + if pol == nil || pol.ACLs == nil { + return tailcfg.FilterAllowAll, nil + } + + var rules []tailcfg.FilterRule + + for _, acl := range pol.ACLs { + if acl.Action != ActionAccept { + return nil, ErrInvalidAction + } + + srcIPs, err := acl.Sources.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Caller().Err(err).Msgf("resolving source ips") + } + + if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { + continue + } + + protocols, _ := acl.Protocol.parseProtocol() + + var destPorts []tailcfg.NetPortRange + for _, dest := range acl.Destinations { + ips, err := dest.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Caller().Err(err).Msgf("resolving destination ips") + } + + if ips == nil { + log.Debug().Caller().Msgf("destination resolved to nil ips: %v", dest) + continue + } + + prefixes := ips.Prefixes() + + for _, pref := range prefixes { + for _, port := range dest.Ports { + pr := tailcfg.NetPortRange{ + IP: pref.String(), + Ports: port, + } + destPorts = append(destPorts, pr) + } + } + } + + if len(destPorts) == 0 { + continue + } + + rules = append(rules, tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcIPs), + DstPorts: destPorts, + IPProto: protocols, + }) + } + + return rules, nil +} + +// compileFilterRulesForNode compiles filter rules for a specific node. +func (pol *Policy) compileFilterRulesForNode( + users types.Users, + node types.NodeView, + nodes views.Slice[types.NodeView], +) ([]tailcfg.FilterRule, error) { + if pol == nil { + return tailcfg.FilterAllowAll, nil + } + + var rules []tailcfg.FilterRule + + for _, acl := range pol.ACLs { + if acl.Action != ActionAccept { + return nil, ErrInvalidAction + } + + aclRules, err := pol.compileACLWithAutogroupSelf(acl, users, node, nodes) + if err != nil { + log.Trace().Err(err).Msgf("compiling ACL") + continue + } + + for _, rule := range aclRules { + if rule != nil { + rules = append(rules, *rule) + } + } + } + + return rules, nil +} + +// compileACLWithAutogroupSelf compiles a single ACL rule, handling +// autogroup:self per-node while supporting all other alias types normally. +// It returns a slice of filter rules because when an ACL has both autogroup:self +// and other destinations, they need to be split into separate rules with different +// source filtering logic. +func (pol *Policy) compileACLWithAutogroupSelf( + acl ACL, + users types.Users, + node types.NodeView, + nodes views.Slice[types.NodeView], +) ([]*tailcfg.FilterRule, error) { + var autogroupSelfDests []AliasWithPorts + var otherDests []AliasWithPorts + + for _, dest := range acl.Destinations { + if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + autogroupSelfDests = append(autogroupSelfDests, dest) + } else { + otherDests = append(otherDests, dest) + } + } + + protocols, _ := acl.Protocol.parseProtocol() + var rules []*tailcfg.FilterRule + + var resolvedSrcIPs []*netipx.IPSet + + for _, src := range acl.Sources { + if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + return nil, fmt.Errorf("autogroup:self cannot be used in sources") + } + + ips, err := src.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving source ips") + continue + } + + if ips != nil { + resolvedSrcIPs = append(resolvedSrcIPs, ips) + } + } + + if len(resolvedSrcIPs) == 0 { + return rules, nil + } + + // Handle autogroup:self destinations (if any) + if len(autogroupSelfDests) > 0 { + // Pre-filter to same-user untagged devices once - reuse for both sources and destinations + sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { + if n.User().ID() == node.User().ID() && !n.IsTagged() { + sameUserNodes = append(sameUserNodes, n) + } + } + + if len(sameUserNodes) > 0 { + // Filter sources to only same-user untagged devices + var srcIPs netipx.IPSetBuilder + for _, ips := range resolvedSrcIPs { + for _, n := range sameUserNodes { + // Check if any of this node's IPs are in the source set + if slices.ContainsFunc(n.IPs(), ips.Contains) { + n.AppendToIPSet(&srcIPs) + } + } + } + + srcSet, err := srcIPs.IPSet() + if err != nil { + return nil, err + } + + if srcSet != nil && len(srcSet.Prefixes()) > 0 { + var destPorts []tailcfg.NetPortRange + for _, dest := range autogroupSelfDests { + for _, n := range sameUserNodes { + for _, port := range dest.Ports { + for _, ip := range n.IPs() { + destPorts = append(destPorts, tailcfg.NetPortRange{ + IP: ip.String(), + Ports: port, + }) + } + } + } + } + + if len(destPorts) > 0 { + rules = append(rules, &tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcSet), + DstPorts: destPorts, + IPProto: protocols, + }) + } + } + } + } + + if len(otherDests) > 0 { + var srcIPs netipx.IPSetBuilder + + for _, ips := range resolvedSrcIPs { + srcIPs.AddSet(ips) + } + + srcSet, err := srcIPs.IPSet() + if err != nil { + return nil, err + } + + if srcSet != nil && len(srcSet.Prefixes()) > 0 { + var destPorts []tailcfg.NetPortRange + + for _, dest := range otherDests { + ips, err := dest.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Err(err).Msgf("resolving destination ips") + continue + } + + if ips == nil { + log.Debug().Msgf("destination resolved to nil ips: %v", dest) + continue + } + + prefixes := ips.Prefixes() + + for _, pref := range prefixes { + for _, port := range dest.Ports { + pr := tailcfg.NetPortRange{ + IP: pref.String(), + Ports: port, + } + destPorts = append(destPorts, pr) + } + } + } + + if len(destPorts) > 0 { + rules = append(rules, &tailcfg.FilterRule{ + SrcIPs: ipSetToPrefixStringList(srcSet), + DstPorts: destPorts, + IPProto: protocols, + }) + } + } + } + + return rules, nil +} + +func sshAction(accept bool, duration time.Duration) tailcfg.SSHAction { + return tailcfg.SSHAction{ + Reject: !accept, + Accept: accept, + SessionDuration: duration, + AllowAgentForwarding: true, + AllowLocalPortForwarding: true, + AllowRemotePortForwarding: true, + } +} + +func (pol *Policy) compileSSHPolicy( + users types.Users, + node types.NodeView, + nodes views.Slice[types.NodeView], +) (*tailcfg.SSHPolicy, error) { + if pol == nil || pol.SSHs == nil || len(pol.SSHs) == 0 { + return nil, nil + } + + log.Trace().Caller().Msgf("compiling SSH policy for node %q", node.Hostname()) + + var rules []*tailcfg.SSHRule + + for index, rule := range pol.SSHs { + // Separate destinations into autogroup:self and others + // This is needed because autogroup:self requires filtering sources to same-user only, + // while other destinations should use all resolved sources + var autogroupSelfDests []Alias + var otherDests []Alias + + for _, dst := range rule.Destinations { + if ag, ok := dst.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + autogroupSelfDests = append(autogroupSelfDests, dst) + } else { + otherDests = append(otherDests, dst) + } + } + + // Note: Tagged nodes can't match autogroup:self destinations, but can still match other destinations + + // Resolve sources once - we'll use them differently for each destination type + srcIPs, err := rule.Sources.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Caller().Err(err).Msgf("SSH policy compilation failed resolving source ips for rule %+v", rule) + } + + if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { + continue + } + + var action tailcfg.SSHAction + switch rule.Action { + case SSHActionAccept: + action = sshAction(true, 0) + case SSHActionCheck: + action = sshAction(true, time.Duration(rule.CheckPeriod)) + default: + return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", rule.Action, index, err) + } + + userMap := make(map[string]string, len(rule.Users)) + if rule.Users.ContainsNonRoot() { + userMap["*"] = "=" + // by default, we do not allow root unless explicitly stated + userMap["root"] = "" + } + if rule.Users.ContainsRoot() { + userMap["root"] = "root" + } + for _, u := range rule.Users.NormalUsers() { + userMap[u.String()] = u.String() + } + + // Handle autogroup:self destinations (if any) + // Note: Tagged nodes can't match autogroup:self, so skip this block for tagged nodes + if len(autogroupSelfDests) > 0 && !node.IsTagged() { + // Build destination set for autogroup:self (same-user untagged devices only) + var dest netipx.IPSetBuilder + for _, n := range nodes.All() { + if n.User().ID() == node.User().ID() && !n.IsTagged() { + n.AppendToIPSet(&dest) + } + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + // Only create rule if this node is in the destination set + if node.InIPSet(destSet) { + // Filter sources to only same-user untagged devices + // Pre-filter to same-user untagged devices for efficiency + sameUserNodes := make([]types.NodeView, 0) + for _, n := range nodes.All() { + if n.User().ID() == node.User().ID() && !n.IsTagged() { + sameUserNodes = append(sameUserNodes, n) + } + } + + var filteredSrcIPs netipx.IPSetBuilder + for _, n := range sameUserNodes { + // Check if any of this node's IPs are in the source set + if slices.ContainsFunc(n.IPs(), srcIPs.Contains) { + n.AppendToIPSet(&filteredSrcIPs) // Found this node, move to next + } + } + + filteredSrcSet, err := filteredSrcIPs.IPSet() + if err != nil { + return nil, err + } + + if filteredSrcSet != nil && len(filteredSrcSet.Prefixes()) > 0 { + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(filteredSrcSet) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + } + } + } + + // Handle other destinations (if any) + if len(otherDests) > 0 { + // Build destination set for other destinations + var dest netipx.IPSetBuilder + for _, dst := range otherDests { + ips, err := dst.Resolve(pol, users, nodes) + if err != nil { + log.Trace().Caller().Err(err).Msgf("resolving destination ips") + continue + } + if ips != nil { + dest.AddSet(ips) + } + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + // Only create rule if this node is in the destination set + if node.InIPSet(destSet) { + // For non-autogroup:self destinations, use all resolved sources (no filtering) + var principals []*tailcfg.SSHPrincipal + for addr := range util.IPSetAddrIter(srcIPs) { + principals = append(principals, &tailcfg.SSHPrincipal{ + NodeIP: addr.String(), + }) + } + + if len(principals) > 0 { + rules = append(rules, &tailcfg.SSHRule{ + Principals: principals, + SSHUsers: userMap, + Action: &action, + }) + } + } + } + } + + return &tailcfg.SSHPolicy{ + Rules: rules, + }, nil +} + +func ipSetToPrefixStringList(ips *netipx.IPSet) []string { + var out []string + + if ips == nil { + return out + } + + for _, pref := range ips.Prefixes() { + out = append(out, pref.String()) + } + + return out +} diff --git a/hscontrol/policy/v2/filter_test.go b/hscontrol/policy/v2/filter_test.go new file mode 100644 index 00000000..0df1e147 --- /dev/null +++ b/hscontrol/policy/v2/filter_test.go @@ -0,0 +1,1689 @@ +package v2 + +import ( + "encoding/json" + "net/netip" + "slices" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/prometheus/common/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +// aliasWithPorts creates an AliasWithPorts structure from an alias and ports. +func aliasWithPorts(alias Alias, ports ...tailcfg.PortRange) AliasWithPorts { + return AliasWithPorts{ + Alias: alias, + Ports: ports, + } +} + +func TestParsing(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser"}, + } + tests := []struct { + name string + format string + acl string + want []tailcfg.FilterRule + wantErr bool + }{ + { + name: "invalid-hujson", + format: "hujson", + acl: ` +{ + `, + want: []tailcfg.FilterRule{}, + wantErr: true, + }, + // The new parser will ignore all that is irrelevant + // { + // name: "valid-hujson-invalid-content", + // format: "hujson", + // acl: ` + // { + // "valid_json": true, + // "but_a_policy_though": false + // } + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + // { + // name: "invalid-cidr", + // format: "hujson", + // acl: ` + // {"example-host-1": "100.100.100.100/42"} + // `, + // want: []tailcfg.FilterRule{}, + // wantErr: true, + // }, + { + name: "basic-rule", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + "192.168.1.0/24" + ], + "dst": [ + "*:22,3389", + "host-1:*", + ], + }, + ], +} + `, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24", "192.168.1.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "0.0.0.0/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 22, Last: 22}}, + {IP: "::/0", Ports: tailcfg.PortRange{First: 3389, Last: 3389}}, + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP, protocolUDP}, + }, + }, + wantErr: false, + }, + { + name: "parse-protocol", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "tcp", + "dst": [ + "host-1:*", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "udp", + "dst": [ + "host-1:53", + ], + }, + { + "Action": "accept", + "src": [ + "*", + ], + "proto": "icmp", + "dst": [ + "host-1:*", + ], + }, + ], +}`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRange{First: 53, Last: 53}}, + }, + IPProto: []int{protocolUDP}, + }, + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolICMP, protocolIPv6ICMP}, + }, + }, + wantErr: false, + }, + { + name: "port-wildcard", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "Action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP, protocolUDP}, + }, + }, + wantErr: false, + }, + { + name: "port-range", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "subnet-1", + ], + "dst": [ + "host-1:5400-5500", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"100.100.101.0/24"}, + DstPorts: []tailcfg.NetPortRange{ + { + IP: "100.100.100.100/32", + Ports: tailcfg.PortRange{First: 5400, Last: 5500}, + }, + }, + IPProto: []int{protocolTCP, protocolUDP}, + }, + }, + wantErr: false, + }, + { + name: "port-group", + format: "hujson", + acl: ` +{ + "groups": { + "group:example": [ + "testuser@", + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "group:example", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP, protocolUDP}, + }, + }, + wantErr: false, + }, + { + name: "port-user", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "testuser@", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"200.200.200.200/32"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP, protocolUDP}, + }, + }, + wantErr: false, + }, + { + name: "ipv6", + format: "hujson", + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100/32", + "subnet-1": "100.100.101.100/24", + }, + + "acls": [ + { + "action": "accept", + "src": [ + "*", + ], + "dst": [ + "host-1:*", + ], + }, + ], +} +`, + want: []tailcfg.FilterRule{ + { + SrcIPs: []string{"0.0.0.0/0", "::/0"}, + DstPorts: []tailcfg.NetPortRange{ + {IP: "100.100.100.100/32", Ports: tailcfg.PortRangeAny}, + }, + IPProto: []int{protocolTCP, protocolUDP}, + }, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol, err := unmarshalPolicy([]byte(tt.acl)) + if tt.wantErr && err == nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } else if !tt.wantErr && err != nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if err != nil { + return + } + + rules, err := pol.compileFilterRules( + users, + types.Nodes{ + &types.Node{ + IPv4: ap("100.100.100.100"), + }, + &types.Node{ + IPv4: ap("200.200.200.200"), + User: &users[0], + Hostinfo: &tailcfg.Hostinfo{}, + }, + }.ViewSlice()) + + if (err != nil) != tt.wantErr { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, rules); diff != "" { + t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestCompileSSHPolicy_UserMapping(t *testing.T) { + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + } + + // Create test nodes - use tagged nodes as SSH destinations + // and untagged nodes as SSH sources (since group->username destinations + // are not allowed per Tailscale security model, but groups can SSH to tags) + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: createAddr("100.64.0.1"), + UserID: ptr.To(users[0].ID), + User: ptr.To(users[0]), + Tags: []string{"tag:server"}, + } + nodeTaggedDB := types.Node{ + Hostname: "tagged-db", + IPv4: createAddr("100.64.0.2"), + UserID: ptr.To(users[1].ID), + User: ptr.To(users[1]), + Tags: []string{"tag:database"}, + } + // Add untagged node for user2 - this will be the SSH source + // (group:admins contains user2, so user2's untagged node provides the source IPs) + nodeUser2Untagged := types.Node{ + Hostname: "user2-device", + IPv4: createAddr("100.64.0.3"), + UserID: ptr.To(users[1].ID), + User: ptr.To(users[1]), + } + + nodes := types.Nodes{&nodeTaggedServer, &nodeTaggedDB, &nodeUser2Untagged} + + tests := []struct { + name string + targetNode types.Node + policy *Policy + wantSSHUsers map[string]string + wantEmpty bool + }{ + { + name: "specific user mapping", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{"ssh-it-user"}, + }, + }, + }, + wantSSHUsers: map[string]string{ + "ssh-it-user": "ssh-it-user", + }, + }, + { + name: "multiple specific users", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{"ubuntu", "admin", "deploy"}, + }, + }, + }, + wantSSHUsers: map[string]string{ + "ubuntu": "ubuntu", + "admin": "admin", + "deploy": "deploy", + }, + }, + { + name: "autogroup:nonroot only", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + wantSSHUsers: map[string]string{ + "*": "=", + "root": "", + }, + }, + { + name: "root only", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{"root"}, + }, + }, + }, + wantSSHUsers: map[string]string{ + "root": "root", + }, + }, + { + name: "autogroup:nonroot plus root", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot), "root"}, + }, + }, + }, + wantSSHUsers: map[string]string{ + "*": "=", + "root": "root", + }, + }, + { + name: "mixed specific users and autogroups", + targetNode: nodeTaggedServer, + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot), "root", "ubuntu", "admin"}, + }, + }, + }, + wantSSHUsers: map[string]string{ + "*": "=", + "root": "root", + "ubuntu": "ubuntu", + "admin": "admin", + }, + }, + { + name: "no matching destination", + targetNode: nodeTaggedDB, // Target tag:database, but policy only allows tag:server + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + Tag("tag:database"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, // Only tag:server, not tag:database + Users: []SSHUser{"ssh-it-user"}, + }, + }, + }, + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Validate the policy + err := tt.policy.validate() + require.NoError(t, err) + + // Compile SSH policy + sshPolicy, err := tt.policy.compileSSHPolicy(users, tt.targetNode.View(), nodes.ViewSlice()) + require.NoError(t, err) + + if tt.wantEmpty { + if sshPolicy == nil { + return // Expected empty result + } + assert.Empty(t, sshPolicy.Rules, "SSH policy should be empty when no rules match") + return + } + + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1, "Should have exactly one SSH rule") + + rule := sshPolicy.Rules[0] + assert.Equal(t, tt.wantSSHUsers, rule.SSHUsers, "SSH users mapping should match expected") + + // Verify principals are set correctly (should contain user2's untagged device IP since that's the source) + require.Len(t, rule.Principals, 1) + assert.Equal(t, "100.64.0.3", rule.Principals[0].NodeIP) + + // Verify action is set correctly + assert.True(t, rule.Action.Accept) + assert.True(t, rule.Action.AllowAgentForwarding) + assert.True(t, rule.Action.AllowLocalPortForwarding) + assert.True(t, rule.Action.AllowRemotePortForwarding) + }) + } +} + +func TestCompileSSHPolicy_CheckAction(t *testing.T) { + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + } + + // Use tagged nodes for SSH user mapping tests + nodeTaggedServer := types.Node{ + Hostname: "tagged-server", + IPv4: createAddr("100.64.0.1"), + UserID: ptr.To(users[0].ID), + User: ptr.To(users[0]), + Tags: []string{"tag:server"}, + } + nodeUser2 := types.Node{ + Hostname: "user2-device", + IPv4: createAddr("100.64.0.2"), + UserID: ptr.To(users[1].ID), + User: ptr.To(users[1]), + } + + nodes := types.Nodes{&nodeTaggedServer, &nodeUser2} + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "check", + CheckPeriod: model.Duration(24 * time.Hour), + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{"ssh-it-user"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + sshPolicy, err := policy.compileSSHPolicy(users, nodeTaggedServer.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + + // Verify SSH users are correctly mapped + expectedUsers := map[string]string{ + "ssh-it-user": "ssh-it-user", + } + assert.Equal(t, expectedUsers, rule.SSHUsers) + + // Verify check action with session duration + assert.True(t, rule.Action.Accept) + assert.Equal(t, 24*time.Hour, rule.Action.SessionDuration) +} + +// TestSSHIntegrationReproduction reproduces the exact scenario from the integration test +// TestSSHOneUserToAll that was failing with empty sshUsers +func TestSSHIntegrationReproduction(t *testing.T) { + // Create users matching the integration test + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + {Name: "user2", Model: gorm.Model{ID: 2}}, + } + + // Create simple nodes for testing + node1 := &types.Node{ + Hostname: "user1-node", + IPv4: createAddr("100.64.0.1"), + UserID: ptr.To(users[0].ID), + User: ptr.To(users[0]), + } + + node2 := &types.Node{ + Hostname: "user2-node", + IPv4: createAddr("100.64.0.2"), + UserID: ptr.To(users[1].ID), + User: ptr.To(users[1]), + } + + nodes := types.Nodes{node1, node2} + + // Create a simple policy that reproduces the issue + // Updated to use autogroup:self instead of username destination (per Tailscale security model) + policy := &Policy{ + Groups: Groups{ + Group("group:integration-test"): []Username{Username("user1@"), Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:integration-test")}, + Destinations: SSHDstAliases{agp("autogroup:self")}, // Users can SSH to their own devices + Users: []SSHUser{SSHUser("ssh-it-user")}, // This is the key - specific user + }, + }, + } + + // Validate policy + err := policy.validate() + require.NoError(t, err) + + // Test SSH policy compilation for node2 (owned by user2, who is in the group) + sshPolicy, err := policy.compileSSHPolicy(users, node2.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + + // This was the failing assertion in integration test - sshUsers was empty + assert.NotEmpty(t, rule.SSHUsers, "SSH users should not be empty") + assert.Contains(t, rule.SSHUsers, "ssh-it-user", "ssh-it-user should be present in SSH users") + assert.Equal(t, "ssh-it-user", rule.SSHUsers["ssh-it-user"], "ssh-it-user should map to itself") + + // Verify that ssh-it-user is correctly mapped + expectedUsers := map[string]string{ + "ssh-it-user": "ssh-it-user", + } + assert.Equal(t, expectedUsers, rule.SSHUsers, "ssh-it-user should be mapped to itself") +} + +// TestSSHJSONSerialization verifies that the SSH policy can be properly serialized +// to JSON and that the sshUsers field is not empty +func TestSSHJSONSerialization(t *testing.T) { + users := types.Users{ + {Name: "user1", Model: gorm.Model{ID: 1}}, + } + + uid := uint(1) + node := &types.Node{ + Hostname: "test-node", + IPv4: createAddr("100.64.0.1"), + UserID: &uid, + User: &users[0], + } + + nodes := types.Nodes{node} + + policy := &Policy{ + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{up("user1@")}, + Destinations: SSHDstAliases{up("user1@")}, + Users: []SSHUser{"ssh-it-user", "ubuntu", "admin"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + sshPolicy, err := policy.compileSSHPolicy(users, node.View(), nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + + // Serialize to JSON to verify structure + jsonData, err := json.MarshalIndent(sshPolicy, "", " ") + require.NoError(t, err) + + // Parse back to verify structure + var parsed tailcfg.SSHPolicy + err = json.Unmarshal(jsonData, &parsed) + require.NoError(t, err) + + // Verify the parsed structure has the expected SSH users + require.Len(t, parsed.Rules, 1) + rule := parsed.Rules[0] + + expectedUsers := map[string]string{ + "ssh-it-user": "ssh-it-user", + "ubuntu": "ubuntu", + "admin": "admin", + } + assert.Equal(t, expectedUsers, rule.SSHUsers, "SSH users should survive JSON round-trip") + + // Verify JSON contains the SSH users (not empty) + assert.Contains(t, string(jsonData), `"ssh-it-user"`) + assert.Contains(t, string(jsonData), `"ubuntu"`) + assert.Contains(t, string(jsonData), `"admin"`) + assert.NotContains(t, string(jsonData), `"sshUsers": {}`, "SSH users should not be empty") + assert.NotContains(t, string(jsonData), `"sshUsers": null`, "SSH users should not be null") +} + +func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + { + User: ptr.To(users[0]), + IPv4: ap("100.64.0.1"), + }, + { + User: ptr.To(users[0]), + IPv4: ap("100.64.0.2"), + }, + { + User: ptr.To(users[1]), + IPv4: ap("100.64.0.3"), + }, + { + User: ptr.To(users[1]), + IPv4: ap("100.64.0.4"), + }, + // Tagged device for user1 + { + User: &users[0], + IPv4: ap("100.64.0.5"), + Tags: []string{"tag:test"}, + }, + // Tagged device for user2 + { + User: &users[1], + IPv4: ap("100.64.0.6"), + Tags: []string{"tag:test"}, + }, + } + + // Test: Tailscale intended usage pattern (autogroup:member + autogroup:self) + policy2 := &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: []Alias{agp("autogroup:member")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(agp("autogroup:self"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err := policy2.validate() + if err != nil { + t.Fatalf("policy validation failed: %v", err) + } + + // Test compilation for user1's first node + node1 := nodes[0].View() + + rules, err := policy2.compileFilterRulesForNode(users, node1, nodes.ViewSlice()) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(rules) != 1 { + t.Fatalf("expected 1 rule, got %d", len(rules)) + } + + // Check that the rule includes: + // - Sources: only user1's untagged devices (filtered by autogroup:self semantics) + // - Destinations: only user1's untagged devices (autogroup:self) + rule := rules[0] + + // Sources should ONLY include user1's untagged devices (100.64.0.1, 100.64.0.2) + expectedSourceIPs := []string{"100.64.0.1", "100.64.0.2"} + + for _, expectedIP := range expectedSourceIPs { + found := false + + addr := netip.MustParseAddr(expectedIP) + for _, prefix := range rule.SrcIPs { + pref := netip.MustParsePrefix(prefix) + if pref.Contains(addr) { + found = true + break + } + } + + if !found { + t.Errorf("expected source IP %s to be covered by generated prefixes %v", expectedIP, rule.SrcIPs) + } + } + + // Verify that other users' devices and tagged devices are not included in sources + excludedSourceIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"} + for _, excludedIP := range excludedSourceIPs { + addr := netip.MustParseAddr(excludedIP) + for _, prefix := range rule.SrcIPs { + pref := netip.MustParsePrefix(prefix) + if pref.Contains(addr) { + t.Errorf("SECURITY VIOLATION: source IP %s should not be included but found in prefix %s", excludedIP, prefix) + } + } + } + + expectedDestIPs := []string{"100.64.0.1", "100.64.0.2"} + + actualDestIPs := make([]string, 0, len(rule.DstPorts)) + for _, dst := range rule.DstPorts { + actualDestIPs = append(actualDestIPs, dst.IP) + } + + for _, expectedIP := range expectedDestIPs { + found := slices.Contains(actualDestIPs, expectedIP) + + if !found { + t.Errorf("expected destination IP %s to be included, got: %v", expectedIP, actualDestIPs) + } + } + + // Verify that other users' devices and tagged devices are not in destinations + excludedDestIPs := []string{"100.64.0.3", "100.64.0.4", "100.64.0.5", "100.64.0.6"} + for _, excludedIP := range excludedDestIPs { + for _, actualIP := range actualDestIPs { + if actualIP == excludedIP { + t.Errorf("SECURITY: destination IP %s should not be included but found in destinations", excludedIP) + } + } + } +} + +// TestTagUserMutualExclusivity tests that user-owned nodes and tagged nodes +// are treated as separate identity classes and cannot inadvertently access each other. +func TestTagUserMutualExclusivity(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + // User-owned nodes + { + User: ptr.To(users[0]), + IPv4: ap("100.64.0.1"), + }, + { + User: ptr.To(users[1]), + IPv4: ap("100.64.0.2"), + }, + // Tagged nodes + { + User: &users[0], // "created by" tracking + IPv4: ap("100.64.0.10"), + Tags: []string{"tag:server"}, + }, + { + User: &users[1], // "created by" tracking + IPv4: ap("100.64.0.11"), + Tags: []string{"tag:database"}, + }, + } + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + }, + ACLs: []ACL{ + // Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes + { + Action: "accept", + Sources: []Alias{up("user1@")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(tp("tag:server"), tailcfg.PortRangeAny), + }, + }, + // Rule 2: tag:server should be able to reach tag:database + { + Action: "accept", + Sources: []Alias{tp("tag:server")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(tp("tag:database"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err := policy.validate() + if err != nil { + t.Fatalf("policy validation failed: %v", err) + } + + // Test user1's user-owned node (100.64.0.1) + userNode := nodes[0].View() + + userRules, err := policy.compileFilterRulesForNode(users, userNode, nodes.ViewSlice()) + if err != nil { + t.Fatalf("unexpected error for user node: %v", err) + } + + // User1's user-owned node should NOT reach tag:server (100.64.0.10) + // because user1@ as a source only matches user1's user-owned devices, NOT tagged devices + for _, rule := range userRules { + for _, dst := range rule.DstPorts { + if dst.IP == "100.64.0.10" { + t.Errorf("SECURITY: user-owned node should NOT reach tagged node (got dest %s in rule)", dst.IP) + } + } + } + + // Test tag:server node (100.64.0.10) + // compileFilterRulesForNode returns rules for what the node can ACCESS (as source) + taggedNode := nodes[2].View() + + taggedRules, err := policy.compileFilterRulesForNode(users, taggedNode, nodes.ViewSlice()) + if err != nil { + t.Fatalf("unexpected error for tagged node: %v", err) + } + + // Tag:server (as source) should be able to reach tag:database (100.64.0.11) + // Check destinations in the rules for this node + foundDatabaseDest := false + + for _, rule := range taggedRules { + // Check if this rule applies to tag:server as source + if !slices.Contains(rule.SrcIPs, "100.64.0.10/32") { + continue + } + + // Check if tag:database is in destinations + for _, dst := range rule.DstPorts { + if dst.IP == "100.64.0.11/32" { + foundDatabaseDest = true + break + } + } + + if foundDatabaseDest { + break + } + } + + if !foundDatabaseDest { + t.Errorf("tag:server should reach tag:database but didn't find 100.64.0.11 in destinations") + } +} + +// TestAutogroupTagged tests that autogroup:tagged correctly selects all devices +// with tag-based identity (IsTagged() == true or has requested tags in tagOwners). +func TestAutogroupTagged(t *testing.T) { + t.Parallel() + + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + // User-owned nodes (not tagged) + { + User: ptr.To(users[0]), + IPv4: ap("100.64.0.1"), + }, + { + User: ptr.To(users[1]), + IPv4: ap("100.64.0.2"), + }, + // Tagged nodes + { + User: &users[0], // "created by" tracking + IPv4: ap("100.64.0.10"), + Tags: []string{"tag:server"}, + }, + { + User: &users[1], // "created by" tracking + IPv4: ap("100.64.0.11"), + Tags: []string{"tag:database"}, + }, + { + User: &users[0], + IPv4: ap("100.64.0.12"), + Tags: []string{"tag:web", "tag:prod"}, + }, + } + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + ACLs: []ACL{ + // Rule: autogroup:tagged can reach user-owned nodes + { + Action: "accept", + Sources: []Alias{agp("autogroup:tagged")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(up("user1@"), tailcfg.PortRangeAny), + aliasWithPorts(up("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // Verify autogroup:tagged includes all tagged nodes + taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, taggedIPs) + + // Should contain all tagged nodes + assert.True(t, taggedIPs.Contains(*ap("100.64.0.10")), "should include tag:server") + assert.True(t, taggedIPs.Contains(*ap("100.64.0.11")), "should include tag:database") + assert.True(t, taggedIPs.Contains(*ap("100.64.0.12")), "should include tag:web,tag:prod") + + // Should NOT contain user-owned nodes + assert.False(t, taggedIPs.Contains(*ap("100.64.0.1")), "should not include user1 node") + assert.False(t, taggedIPs.Contains(*ap("100.64.0.2")), "should not include user2 node") + + // Test ACL filtering: all tagged nodes should be able to reach user nodes + tests := []struct { + name string + sourceNode types.NodeView + shouldReach []string // IP strings for comparison + }{ + { + name: "tag:server can reach user-owned nodes", + sourceNode: nodes[2].View(), + shouldReach: []string{"100.64.0.1", "100.64.0.2"}, + }, + { + name: "tag:database can reach user-owned nodes", + sourceNode: nodes[3].View(), + shouldReach: []string{"100.64.0.1", "100.64.0.2"}, + }, + { + name: "tag:web,tag:prod can reach user-owned nodes", + sourceNode: nodes[4].View(), + shouldReach: []string{"100.64.0.1", "100.64.0.2"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + rules, err := policy.compileFilterRulesForNode(users, tt.sourceNode, nodes.ViewSlice()) + require.NoError(t, err) + + // Verify all expected destinations are reachable + for _, expectedDest := range tt.shouldReach { + found := false + + for _, rule := range rules { + for _, dstPort := range rule.DstPorts { + // DstPort.IP is CIDR notation like "100.64.0.1/32" + if strings.HasPrefix(dstPort.IP, expectedDest+"/") || dstPort.IP == expectedDest { + found = true + break + } + } + + if found { + break + } + } + + assert.True(t, found, "Expected to find destination %s in rules", expectedDest) + } + }) + } +} + +func TestAutogroupSelfInSourceIsRejected(t *testing.T) { + // Test that autogroup:self cannot be used in sources (per Tailscale spec) + policy := &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: []Alias{agp("autogroup:self")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(agp("autogroup:member"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err := policy.validate() + if err == nil { + t.Error("expected validation error when using autogroup:self in sources") + } + + if !strings.Contains(err.Error(), "autogroup:self") { + t.Errorf("expected error message to mention autogroup:self, got: %v", err) + } +} + +// TestAutogroupSelfWithSpecificUserSource verifies that when autogroup:self is in +// the destination and a specific user is in the source, only that user's devices +// are allowed (and only if they match the target user). +func TestAutogroupSelfWithSpecificUserSource(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + } + + policy := &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: []Alias{up("user1@")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(agp("autogroup:self"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // For user1's node: sources should be user1's devices + node1 := nodes[0].View() + rules, err := policy.compileFilterRulesForNode(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.Len(t, rules, 1) + + expectedSourceIPs := []string{"100.64.0.1", "100.64.0.2"} + for _, expectedIP := range expectedSourceIPs { + found := false + addr := netip.MustParseAddr(expectedIP) + + for _, prefix := range rules[0].SrcIPs { + pref := netip.MustParsePrefix(prefix) + if pref.Contains(addr) { + found = true + break + } + } + + assert.True(t, found, "expected source IP %s to be present", expectedIP) + } + + actualDestIPs := make([]string, 0, len(rules[0].DstPorts)) + for _, dst := range rules[0].DstPorts { + actualDestIPs = append(actualDestIPs, dst.IP) + } + + assert.ElementsMatch(t, expectedSourceIPs, actualDestIPs) + + node2 := nodes[2].View() + rules2, err := policy.compileFilterRulesForNode(users, node2, nodes.ViewSlice()) + require.NoError(t, err) + assert.Empty(t, rules2, "user2's node should have no rules (user1@ devices can't match user2's self)") +} + +// TestAutogroupSelfWithGroupSource verifies that when a group is used as source +// and autogroup:self as destination, only group members who are the same user +// as the target are allowed. +func TestAutogroupSelfWithGroupSource(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + {User: ptr.To(users[2]), IPv4: ap("100.64.0.5")}, + } + + policy := &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("user1@"), Username("user2@")}, + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: []Alias{gp("group:admins")}, + Destinations: []AliasWithPorts{ + aliasWithPorts(agp("autogroup:self"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // (group:admins has user1+user2, but autogroup:self filters to same user) + node1 := nodes[0].View() + rules, err := policy.compileFilterRulesForNode(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.Len(t, rules, 1) + + expectedSrcIPs := []string{"100.64.0.1", "100.64.0.2"} + for _, expectedIP := range expectedSrcIPs { + found := false + addr := netip.MustParseAddr(expectedIP) + + for _, prefix := range rules[0].SrcIPs { + pref := netip.MustParsePrefix(prefix) + if pref.Contains(addr) { + found = true + break + } + } + + assert.True(t, found, "expected source IP %s for user1", expectedIP) + } + + node3 := nodes[4].View() + rules3, err := policy.compileFilterRulesForNode(users, node3, nodes.ViewSlice()) + require.NoError(t, err) + assert.Empty(t, rules3, "user3 should have no rules") +} + +// Helper function to create IP addresses for testing +func createAddr(ip string) *netip.Addr { + addr, _ := netip.ParseAddr(ip) + return &addr +} + +// TestSSHWithAutogroupSelfInDestination verifies that SSH policies work correctly +// with autogroup:self in destinations +func TestSSHWithAutogroupSelfInDestination(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + // User1's nodes + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"}, + // User2's nodes + {User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"}, + // Tagged node for user1 (should be excluded) + {User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}}, + } + + policy := &Policy{ + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{agp("autogroup:self")}, + Users: []SSHUser{"autogroup:nonroot"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // Test for user1's first node + node1 := nodes[0].View() + sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + + // Principals should only include user1's untagged devices + require.Len(t, rule.Principals, 2, "should have 2 principals (user1's 2 untagged nodes)") + + principalIPs := make([]string, len(rule.Principals)) + for i, p := range rule.Principals { + principalIPs[i] = p.NodeIP + } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) + + // Test for user2's first node + node3 := nodes[2].View() + sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy2) + require.Len(t, sshPolicy2.Rules, 1) + + rule2 := sshPolicy2.Rules[0] + + // Principals should only include user2's untagged devices + require.Len(t, rule2.Principals, 2, "should have 2 principals (user2's 2 untagged nodes)") + + principalIPs2 := make([]string, len(rule2.Principals)) + for i, p := range rule2.Principals { + principalIPs2[i] = p.NodeIP + } + assert.ElementsMatch(t, []string{"100.64.0.3", "100.64.0.4"}, principalIPs2) + + // Test for tagged node (should have no SSH rules) + node5 := nodes[4].View() + sshPolicy3, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + require.NoError(t, err) + if sshPolicy3 != nil { + assert.Empty(t, sshPolicy3.Rules, "tagged nodes should not get SSH rules with autogroup:self") + } +} + +// TestSSHWithAutogroupSelfAndSpecificUser verifies that when a specific user +// is in the source and autogroup:self in destination, only that user's devices +// can SSH (and only if they match the target user) +func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + } + + policy := &Policy{ + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{up("user1@")}, + Destinations: SSHDstAliases{agp("autogroup:self")}, + Users: []SSHUser{"ubuntu"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // For user1's node: should allow SSH from user1's devices + node1 := nodes[0].View() + sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + require.Len(t, rule.Principals, 2, "user1 should have 2 principals") + + principalIPs := make([]string, len(rule.Principals)) + for i, p := range rule.Principals { + principalIPs[i] = p.NodeIP + } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) + + // For user2's node: should have no rules (user1's devices can't match user2's self) + node3 := nodes[2].View() + sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + require.NoError(t, err) + if sshPolicy2 != nil { + assert.Empty(t, sshPolicy2.Rules, "user2 should have no SSH rules since source is user1") + } +} + +// TestSSHWithAutogroupSelfAndGroup verifies SSH with group sources and autogroup:self destinations +func TestSSHWithAutogroupSelfAndGroup(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1")}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.3")}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.4")}, + {User: ptr.To(users[2]), IPv4: ap("100.64.0.5")}, + } + + policy := &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("user1@"), Username("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{agp("autogroup:self")}, + Users: []SSHUser{"root"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // For user1's node: should allow SSH from user1's devices only (not user2's) + node1 := nodes[0].View() + sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + require.Len(t, rule.Principals, 2, "user1 should have 2 principals (only user1's nodes)") + + principalIPs := make([]string, len(rule.Principals)) + for i, p := range rule.Principals { + principalIPs[i] = p.NodeIP + } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs) + + // For user3's node: should have no rules (not in group:admins) + node5 := nodes[4].View() + sshPolicy2, err := policy.compileSSHPolicy(users, node5, nodes.ViewSlice()) + require.NoError(t, err) + if sshPolicy2 != nil { + assert.Empty(t, sshPolicy2.Rules, "user3 should have no SSH rules (not in group)") + } +} + +// TestSSHWithAutogroupSelfExcludesTaggedDevices verifies that tagged devices +// are excluded from both sources and destinations when autogroup:self is used +func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + } + + nodes := types.Nodes{ + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}}, + } + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("user1@")}, + Tag("tag:web"): Owners{up("user1@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{agp("autogroup:self")}, + Users: []SSHUser{"admin"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // For untagged node: should only get principals from other untagged nodes + node1 := nodes[0].View() + sshPolicy, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy) + require.Len(t, sshPolicy.Rules, 1) + + rule := sshPolicy.Rules[0] + require.Len(t, rule.Principals, 2, "should only have 2 principals (untagged nodes)") + + principalIPs := make([]string, len(rule.Principals)) + for i, p := range rule.Principals { + principalIPs[i] = p.NodeIP + } + assert.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, principalIPs, + "should only include untagged devices") + + // For tagged node: should get no SSH rules + node3 := nodes[2].View() + sshPolicy2, err := policy.compileSSHPolicy(users, node3, nodes.ViewSlice()) + require.NoError(t, err) + if sshPolicy2 != nil { + assert.Empty(t, sshPolicy2.Rules, "tagged node should get no SSH rules with autogroup:self") + } +} + +// TestSSHWithAutogroupSelfAndMixedDestinations tests that SSH rules can have both +// autogroup:self and other destinations (like tag:router) in the same rule, and that +// autogroup:self filtering only applies to autogroup:self destinations, not others. +func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + } + + nodes := types.Nodes{ + {User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"}, + {User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"}, + {User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}}, + } + + policy := &Policy{ + TagOwners: TagOwners{ + Tag("tag:router"): Owners{up("user2@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{agp("autogroup:self"), tp("tag:router")}, + Users: []SSHUser{"admin"}, + }, + }, + } + + err := policy.validate() + require.NoError(t, err) + + // Test 1: Compile for user1's device (should only match autogroup:self destination) + node1 := nodes[0].View() + sshPolicy1, err := policy.compileSSHPolicy(users, node1, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicy1) + require.Len(t, sshPolicy1.Rules, 1, "user1's device should have 1 SSH rule (autogroup:self)") + + // Verify autogroup:self rule has filtered sources (only same-user devices) + selfRule := sshPolicy1.Rules[0] + require.Len(t, selfRule.Principals, 2, "autogroup:self rule should only have user1's devices") + selfPrincipals := make([]string, len(selfRule.Principals)) + for i, p := range selfRule.Principals { + selfPrincipals[i] = p.NodeIP + } + require.ElementsMatch(t, []string{"100.64.0.1", "100.64.0.2"}, selfPrincipals, + "autogroup:self rule should only include same-user untagged devices") + + // Test 2: Compile for router (should only match tag:router destination) + routerNode := nodes[3].View() // user2-router + sshPolicyRouter, err := policy.compileSSHPolicy(users, routerNode, nodes.ViewSlice()) + require.NoError(t, err) + require.NotNil(t, sshPolicyRouter) + require.Len(t, sshPolicyRouter.Rules, 1, "router should have 1 SSH rule (tag:router)") + + routerRule := sshPolicyRouter.Rules[0] + routerPrincipals := make([]string, len(routerRule.Principals)) + for i, p := range routerRule.Principals { + routerPrincipals[i] = p.NodeIP + } + require.Contains(t, routerPrincipals, "100.64.0.1", "router rule should include user1's device (unfiltered sources)") + require.Contains(t, routerPrincipals, "100.64.0.2", "router rule should include user1's other device (unfiltered sources)") + require.Contains(t, routerPrincipals, "100.64.0.3", "router rule should include user2's device (unfiltered sources)") +} diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go new file mode 100644 index 00000000..54196e6b --- /dev/null +++ b/hscontrol/policy/v2/policy.go @@ -0,0 +1,1076 @@ +package v2 + +import ( + "cmp" + "encoding/json" + "errors" + "fmt" + "net/netip" + "slices" + "strings" + "sync" + + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/policy/policyutil" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + "go4.org/netipx" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/views" + "tailscale.com/util/deephash" +) + +// ErrInvalidTagOwner is returned when a tag owner is not an Alias type. +var ErrInvalidTagOwner = errors.New("tag owner is not an Alias") + +type PolicyManager struct { + mu sync.Mutex + pol *Policy + users []types.User + nodes views.Slice[types.NodeView] + + filterHash deephash.Sum + filter []tailcfg.FilterRule + matchers []matcher.Match + + tagOwnerMapHash deephash.Sum + tagOwnerMap map[Tag]*netipx.IPSet + + exitSetHash deephash.Sum + exitSet *netipx.IPSet + autoApproveMapHash deephash.Sum + autoApproveMap map[netip.Prefix]*netipx.IPSet + + // Lazy map of SSH policies + sshPolicyMap map[types.NodeID]*tailcfg.SSHPolicy + + // Lazy map of per-node compiled filter rules (unreduced, for autogroup:self) + compiledFilterRulesMap map[types.NodeID][]tailcfg.FilterRule + // Lazy map of per-node filter rules (reduced, for packet filters) + filterRulesMap map[types.NodeID][]tailcfg.FilterRule + usesAutogroupSelf bool +} + +// filterAndPolicy combines the compiled filter rules with policy content for hashing. +// This ensures filterHash changes when policy changes, even for autogroup:self where +// the compiled filter is always empty. +type filterAndPolicy struct { + Filter []tailcfg.FilterRule + Policy *Policy +} + +// NewPolicyManager creates a new PolicyManager from a policy file and a list of users and nodes. +// It returns an error if the policy file is invalid. +// The policy manager will update the filter rules based on the users and nodes. +func NewPolicyManager(b []byte, users []types.User, nodes views.Slice[types.NodeView]) (*PolicyManager, error) { + policy, err := unmarshalPolicy(b) + if err != nil { + return nil, fmt.Errorf("parsing policy: %w", err) + } + + pm := PolicyManager{ + pol: policy, + users: users, + nodes: nodes, + sshPolicyMap: make(map[types.NodeID]*tailcfg.SSHPolicy, nodes.Len()), + compiledFilterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()), + filterRulesMap: make(map[types.NodeID][]tailcfg.FilterRule, nodes.Len()), + usesAutogroupSelf: policy.usesAutogroupSelf(), + } + + _, err = pm.updateLocked() + if err != nil { + return nil, err + } + + return &pm, nil +} + +// updateLocked updates the filter rules based on the current policy and nodes. +// It must be called with the lock held. +func (pm *PolicyManager) updateLocked() (bool, error) { + // Check if policy uses autogroup:self + pm.usesAutogroupSelf = pm.pol.usesAutogroupSelf() + + var filter []tailcfg.FilterRule + + var err error + + // Standard compilation for all policies + filter, err = pm.pol.compileFilterRules(pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("compiling filter rules: %w", err) + } + + // Hash both the compiled filter AND the policy content together. + // This ensures filterHash changes when policy changes, even for autogroup:self + // where the compiled filter is always empty. This eliminates the need for + // a separate policyHash field. + filterHash := deephash.Hash(&filterAndPolicy{ + Filter: filter, + Policy: pm.pol, + }) + filterChanged := filterHash != pm.filterHash + if filterChanged { + log.Debug(). + Str("filter.hash.old", pm.filterHash.String()[:8]). + Str("filter.hash.new", filterHash.String()[:8]). + Int("filter.rules", len(pm.filter)). + Int("filter.rules.new", len(filter)). + Msg("Policy filter hash changed") + } + pm.filter = filter + pm.filterHash = filterHash + if filterChanged { + pm.matchers = matcher.MatchesFromFilterRules(pm.filter) + } + + // Order matters, tags might be used in autoapprovers, so we need to ensure + // that the map for tag owners is resolved before resolving autoapprovers. + // TODO(kradalby): Order might not matter after #2417 + tagMap, err := resolveTagOwners(pm.pol, pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("resolving tag owners map: %w", err) + } + + tagOwnerMapHash := deephash.Hash(&tagMap) + tagOwnerChanged := tagOwnerMapHash != pm.tagOwnerMapHash + if tagOwnerChanged { + log.Debug(). + Str("tagOwner.hash.old", pm.tagOwnerMapHash.String()[:8]). + Str("tagOwner.hash.new", tagOwnerMapHash.String()[:8]). + Int("tagOwners.old", len(pm.tagOwnerMap)). + Int("tagOwners.new", len(tagMap)). + Msg("Tag owner hash changed") + } + pm.tagOwnerMap = tagMap + pm.tagOwnerMapHash = tagOwnerMapHash + + autoMap, exitSet, err := resolveAutoApprovers(pm.pol, pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("resolving auto approvers map: %w", err) + } + + autoApproveMapHash := deephash.Hash(&autoMap) + autoApproveChanged := autoApproveMapHash != pm.autoApproveMapHash + if autoApproveChanged { + log.Debug(). + Str("autoApprove.hash.old", pm.autoApproveMapHash.String()[:8]). + Str("autoApprove.hash.new", autoApproveMapHash.String()[:8]). + Int("autoApprovers.old", len(pm.autoApproveMap)). + Int("autoApprovers.new", len(autoMap)). + Msg("Auto-approvers hash changed") + } + pm.autoApproveMap = autoMap + pm.autoApproveMapHash = autoApproveMapHash + + exitSetHash := deephash.Hash(&exitSet) + exitSetChanged := exitSetHash != pm.exitSetHash + if exitSetChanged { + log.Debug(). + Str("exitSet.hash.old", pm.exitSetHash.String()[:8]). + Str("exitSet.hash.new", exitSetHash.String()[:8]). + Msg("Exit node set hash changed") + } + pm.exitSet = exitSet + pm.exitSetHash = exitSetHash + + // Determine if we need to send updates to nodes + // filterChanged now includes policy content changes (via combined hash), + // so it will detect changes even for autogroup:self where compiled filter is empty + needsUpdate := filterChanged || tagOwnerChanged || autoApproveChanged || exitSetChanged + + // Only clear caches if we're actually going to send updates + // This prevents clearing caches when nothing changed, which would leave nodes + // with stale filters until they reconnect. This is critical for autogroup:self + // where even reloading the same policy would clear caches but not send updates. + if needsUpdate { + // Clear the SSH policy map to ensure it's recalculated with the new policy. + // TODO(kradalby): This could potentially be optimized by only clearing the + // policies for nodes that have changed. Particularly if the only difference is + // that nodes has been added or removed. + clear(pm.sshPolicyMap) + clear(pm.compiledFilterRulesMap) + clear(pm.filterRulesMap) + } + + // If nothing changed, no need to update nodes + if !needsUpdate { + log.Trace(). + Msg("Policy evaluation detected no changes - all hashes match") + return false, nil + } + + log.Debug(). + Bool("filter.changed", filterChanged). + Bool("tagOwners.changed", tagOwnerChanged). + Bool("autoApprovers.changed", autoApproveChanged). + Bool("exitNodes.changed", exitSetChanged). + Msg("Policy changes require node updates") + + return true, nil +} + +func (pm *PolicyManager) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + if sshPol, ok := pm.sshPolicyMap[node.ID()]; ok { + return sshPol, nil + } + + sshPol, err := pm.pol.compileSSHPolicy(pm.users, node, pm.nodes) + if err != nil { + return nil, fmt.Errorf("compiling SSH policy: %w", err) + } + pm.sshPolicyMap[node.ID()] = sshPol + + return sshPol, nil +} + +func (pm *PolicyManager) SetPolicy(polB []byte) (bool, error) { + if len(polB) == 0 { + return false, nil + } + + pol, err := unmarshalPolicy(polB) + if err != nil { + return false, fmt.Errorf("parsing policy: %w", err) + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Log policy metadata for debugging + log.Debug(). + Int("policy.bytes", len(polB)). + Int("acls.count", len(pol.ACLs)). + Int("groups.count", len(pol.Groups)). + Int("hosts.count", len(pol.Hosts)). + Int("tagOwners.count", len(pol.TagOwners)). + Int("autoApprovers.routes.count", len(pol.AutoApprovers.Routes)). + Msg("Policy parsed successfully") + + pm.pol = pol + + return pm.updateLocked() +} + +// Filter returns the current filter rules for the entire tailnet and the associated matchers. +func (pm *PolicyManager) Filter() ([]tailcfg.FilterRule, []matcher.Match) { + if pm == nil { + return nil, nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.filter, pm.matchers +} + +// BuildPeerMap constructs peer relationship maps for the given nodes. +// For global filters, it uses the global filter matchers for all nodes. +// For autogroup:self policies (empty global filter), it builds per-node +// peer maps using each node's specific filter rules. +func (pm *PolicyManager) BuildPeerMap(nodes views.Slice[types.NodeView]) map[types.NodeID][]types.NodeView { + if pm == nil { + return nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // If we have a global filter, use it for all nodes (normal case) + if !pm.usesAutogroupSelf { + ret := make(map[types.NodeID][]types.NodeView, nodes.Len()) + + // Build the map of all peers according to the matchers. + // Compared to ReduceNodes, which builds the list per node, we end up with doing + // the full work for every node O(n^2), while this will reduce the list as we see + // relationships while building the map, making it O(n^2/2) in the end, but with less work per node. + for i := range nodes.Len() { + for j := i + 1; j < nodes.Len(); j++ { + if nodes.At(i).ID() == nodes.At(j).ID() { + continue + } + + if nodes.At(i).CanAccess(pm.matchers, nodes.At(j)) || nodes.At(j).CanAccess(pm.matchers, nodes.At(i)) { + ret[nodes.At(i).ID()] = append(ret[nodes.At(i).ID()], nodes.At(j)) + ret[nodes.At(j).ID()] = append(ret[nodes.At(j).ID()], nodes.At(i)) + } + } + } + + return ret + } + + // For autogroup:self (empty global filter), build per-node peer relationships + ret := make(map[types.NodeID][]types.NodeView, nodes.Len()) + + // Pre-compute per-node matchers using unreduced compiled rules + // We need unreduced rules to determine peer relationships correctly. + // Reduced rules only show destinations where the node is the target, + // but peer relationships require the full bidirectional access rules. + nodeMatchers := make(map[types.NodeID][]matcher.Match, nodes.Len()) + for _, node := range nodes.All() { + filter, err := pm.compileFilterRulesForNodeLocked(node) + if err != nil || len(filter) == 0 { + continue + } + nodeMatchers[node.ID()] = matcher.MatchesFromFilterRules(filter) + } + + // Check each node pair for peer relationships. + // Start j at i+1 to avoid checking the same pair twice and creating duplicates. + // We use symmetric visibility: if EITHER node can access the other, BOTH see + // each other. This matches the global filter path behavior and ensures that + // one-way access rules (e.g., admin -> tagged server) still allow both nodes + // to see each other as peers, which is required for network connectivity. + for i := range nodes.Len() { + nodeI := nodes.At(i) + matchersI, hasFilterI := nodeMatchers[nodeI.ID()] + + for j := i + 1; j < nodes.Len(); j++ { + nodeJ := nodes.At(j) + matchersJ, hasFilterJ := nodeMatchers[nodeJ.ID()] + + // If either node can access the other, both should see each other as peers. + // This symmetric visibility is required for proper network operation: + // - Admin with *:* rule should see tagged servers (even if servers + // can't access admin) + // - Servers should see admin so they can respond to admin's connections + canIAccessJ := hasFilterI && nodeI.CanAccess(matchersI, nodeJ) + canJAccessI := hasFilterJ && nodeJ.CanAccess(matchersJ, nodeI) + + if canIAccessJ || canJAccessI { + ret[nodeI.ID()] = append(ret[nodeI.ID()], nodeJ) + ret[nodeJ.ID()] = append(ret[nodeJ.ID()], nodeI) + } + } + } + + return ret +} + +// compileFilterRulesForNodeLocked returns the unreduced compiled filter rules for a node +// when using autogroup:self. This is used by BuildPeerMap to determine peer relationships. +// For packet filters sent to nodes, use filterForNodeLocked which returns reduced rules. +func (pm *PolicyManager) compileFilterRulesForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) { + if pm == nil { + return nil, nil + } + + // Check if we have cached compiled rules + if rules, ok := pm.compiledFilterRulesMap[node.ID()]; ok { + return rules, nil + } + + // Compile per-node rules with autogroup:self expanded + rules, err := pm.pol.compileFilterRulesForNode(pm.users, node, pm.nodes) + if err != nil { + return nil, fmt.Errorf("compiling filter rules for node: %w", err) + } + + // Cache the unreduced compiled rules + pm.compiledFilterRulesMap[node.ID()] = rules + + return rules, nil +} + +// filterForNodeLocked returns the filter rules for a specific node, already reduced +// to only include rules relevant to that node. +// This is a lock-free version of FilterForNode for internal use when the lock is already held. +// BuildPeerMap already holds the lock, so we need a version that doesn't re-acquire it. +func (pm *PolicyManager) filterForNodeLocked(node types.NodeView) ([]tailcfg.FilterRule, error) { + if pm == nil { + return nil, nil + } + + if !pm.usesAutogroupSelf { + // For global filters, reduce to only rules relevant to this node. + // Cache the reduced filter per node for efficiency. + if rules, ok := pm.filterRulesMap[node.ID()]; ok { + return rules, nil + } + + // Use policyutil.ReduceFilterRules for global filter reduction. + reducedFilter := policyutil.ReduceFilterRules(node, pm.filter) + + pm.filterRulesMap[node.ID()] = reducedFilter + return reducedFilter, nil + } + + // For autogroup:self, compile per-node rules then reduce them. + // Check if we have cached reduced rules for this node. + if rules, ok := pm.filterRulesMap[node.ID()]; ok { + return rules, nil + } + + // Get unreduced compiled rules + compiledRules, err := pm.compileFilterRulesForNodeLocked(node) + if err != nil { + return nil, err + } + + // Reduce the compiled rules to only destinations relevant to this node + reducedFilter := policyutil.ReduceFilterRules(node, compiledRules) + + // Cache the reduced filter + pm.filterRulesMap[node.ID()] = reducedFilter + + return reducedFilter, nil +} + +// FilterForNode returns the filter rules for a specific node, already reduced +// to only include rules relevant to that node. +// If the policy uses autogroup:self, this returns node-specific compiled rules. +// Otherwise, it returns the global filter reduced for this node. +func (pm *PolicyManager) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) { + if pm == nil { + return nil, nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.filterForNodeLocked(node) +} + +// MatchersForNode returns the matchers for peer relationship determination for a specific node. +// These are UNREDUCED matchers - they include all rules where the node could be either source or destination. +// This is different from FilterForNode which returns REDUCED rules for packet filtering. +// +// For global policies: returns the global matchers (same for all nodes) +// For autogroup:self: returns node-specific matchers from unreduced compiled rules +func (pm *PolicyManager) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { + if pm == nil { + return nil, nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // For global policies, return the shared global matchers + if !pm.usesAutogroupSelf { + return pm.matchers, nil + } + + // For autogroup:self, get unreduced compiled rules and create matchers + compiledRules, err := pm.compileFilterRulesForNodeLocked(node) + if err != nil { + return nil, err + } + + // Create matchers from unreduced rules for peer relationship determination + return matcher.MatchesFromFilterRules(compiledRules), nil +} + +// SetUsers updates the users in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetUsers(users []types.User) (bool, error) { + if pm == nil { + return false, nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + pm.users = users + + // Clear SSH policy map when users change to force SSH policy recomputation + // This ensures that if SSH policy compilation previously failed due to missing users, + // it will be retried with the new user list + clear(pm.sshPolicyMap) + + changed, err := pm.updateLocked() + if err != nil { + return false, err + } + + // If SSH policies exist, force a policy change when users are updated + // This ensures nodes get updated SSH policies even if other policy hashes didn't change + if pm.pol != nil && pm.pol.SSHs != nil && len(pm.pol.SSHs) > 0 { + return true, nil + } + + return changed, nil +} + +// SetNodes updates the nodes in the policy manager and updates the filter rules. +func (pm *PolicyManager) SetNodes(nodes views.Slice[types.NodeView]) (bool, error) { + if pm == nil { + return false, nil + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + policyChanged := pm.nodesHavePolicyAffectingChanges(nodes) + + // Invalidate cache entries for nodes that changed. + // For autogroup:self: invalidate all nodes belonging to affected users (peer changes). + // For global policies: invalidate only nodes whose properties changed (IPs, routes). + pm.invalidateNodeCache(nodes) + + pm.nodes = nodes + + // When policy-affecting node properties change, we must recompile filters because: + // 1. User/group aliases (like "user1@") resolve to node IPs + // 2. Tag aliases (like "tag:server") match nodes based on their tags + // 3. Filter compilation needs nodes to generate rules + // + // For autogroup:self: return true when nodes change even if the global filter + // hash didn't change. The global filter is empty for autogroup:self (each node + // has its own filter), so the hash never changes. But peer relationships DO + // change when nodes are added/removed, so we must signal this to trigger updates. + // For global policies: the filter must be recompiled to include the new nodes. + if policyChanged { + // Recompile filter with the new node list + needsUpdate, err := pm.updateLocked() + if err != nil { + return false, err + } + + if !needsUpdate { + // This ensures fresh filter rules are generated for all nodes + clear(pm.sshPolicyMap) + clear(pm.compiledFilterRulesMap) + clear(pm.filterRulesMap) + } + // Always return true when nodes changed, even if filter hash didn't change + // (can happen with autogroup:self or when nodes are added but don't affect rules) + return true, nil + } + + return false, nil +} + +func (pm *PolicyManager) nodesHavePolicyAffectingChanges(newNodes views.Slice[types.NodeView]) bool { + if pm.nodes.Len() != newNodes.Len() { + return true + } + + oldNodes := make(map[types.NodeID]types.NodeView, pm.nodes.Len()) + for _, node := range pm.nodes.All() { + oldNodes[node.ID()] = node + } + + for _, newNode := range newNodes.All() { + oldNode, exists := oldNodes[newNode.ID()] + if !exists { + return true + } + + if newNode.HasPolicyChange(oldNode) { + return true + } + } + + return false +} + +// NodeCanHaveTag checks if a node can have the specified tag during client-initiated +// registration or reauth flows (e.g., tailscale up --advertise-tags). +// +// This function is NOT used by the admin API's SetNodeTags - admins can set any +// existing tag on any node by calling State.SetNodeTags directly, which bypasses +// this authorization check. +func (pm *PolicyManager) NodeCanHaveTag(node types.NodeView, tag string) bool { + if pm == nil || pm.pol == nil { + return false + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Check if tag exists in policy + owners, exists := pm.pol.TagOwners[Tag(tag)] + if !exists { + return false + } + + // Check if node's owner can assign this tag via the pre-resolved tagOwnerMap. + // The tagOwnerMap contains IP sets built from resolving TagOwners entries + // (usernames/groups) to their nodes' IPs, so checking if the node's IP + // is in the set answers "does this node's owner own this tag?" + if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok { + if slices.ContainsFunc(node.IPs(), ips.Contains) { + return true + } + } + + // For new nodes being registered, their IP may not yet be in the tagOwnerMap. + // Fall back to checking the node's user directly against the TagOwners. + // This handles the case where a user registers a new node with --advertise-tags. + if node.User().Valid() { + for _, owner := range owners { + if pm.userMatchesOwner(node.User(), owner) { + return true + } + } + } + + return false +} + +// userMatchesOwner checks if a user matches a tag owner entry. +// This is used as a fallback when the node's IP is not in the tagOwnerMap. +func (pm *PolicyManager) userMatchesOwner(user types.UserView, owner Owner) bool { + switch o := owner.(type) { + case *Username: + if o == nil { + return false + } + // Resolve the username to find the user it refers to + resolvedUser, err := o.resolveUser(pm.users) + if err != nil { + return false + } + + return user.ID() == resolvedUser.ID + + case *Group: + if o == nil || pm.pol == nil { + return false + } + // Resolve the group to get usernames + usernames, ok := pm.pol.Groups[*o] + if !ok { + return false + } + // Check if the user matches any username in the group + for _, uname := range usernames { + resolvedUser, err := uname.resolveUser(pm.users) + if err != nil { + continue + } + + if user.ID() == resolvedUser.ID { + return true + } + } + + return false + + default: + return false + } +} + +// TagExists reports whether the given tag is defined in the policy. +func (pm *PolicyManager) TagExists(tag string) bool { + if pm == nil || pm.pol == nil { + return false + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + _, exists := pm.pol.TagOwners[Tag(tag)] + + return exists +} + +func (pm *PolicyManager) NodeCanApproveRoute(node types.NodeView, route netip.Prefix) bool { + if pm == nil { + return false + } + + // If the route to-be-approved is an exit route, then we need to check + // if the node is in allowed to approve it. This is treated differently + // than the auto-approvers, as the auto-approvers are not allowed to + // approve the whole /0 range. + // However, an auto approver might be /0, meaning that they can approve + // all routes available, just not exit nodes. + if tsaddr.IsExitRoute(route) { + if pm.exitSet == nil { + return false + } + if slices.ContainsFunc(node.IPs(), pm.exitSet.Contains) { + return true + } + + return false + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // The fast path is that a node requests to approve a prefix + // where there is an exact entry, e.g. 10.0.0.0/8, then + // check and return quickly + if approvers, ok := pm.autoApproveMap[route]; ok { + canApprove := slices.ContainsFunc(node.IPs(), approvers.Contains) + if canApprove { + return true + } + } + + // The slow path is that the node tries to approve + // 10.0.10.0/24, which is a part of 10.0.0.0/8, then we + // cannot just lookup in the prefix map and have to check + // if there is a "parent" prefix available. + for prefix, approveAddrs := range pm.autoApproveMap { + // Check if prefix is larger (so containing) and then overlaps + // the route to see if the node can approve a subset of an autoapprover + if prefix.Bits() <= route.Bits() && prefix.Overlaps(route) { + canApprove := slices.ContainsFunc(node.IPs(), approveAddrs.Contains) + if canApprove { + return true + } + } + } + + return false +} + +func (pm *PolicyManager) Version() int { + return 2 +} + +func (pm *PolicyManager) DebugString() string { + if pm == nil { + return "PolicyManager is not setup" + } + + var sb strings.Builder + + fmt.Fprintf(&sb, "PolicyManager (v%d):\n\n", pm.Version()) + + sb.WriteString("\n\n") + + if pm.pol != nil { + pol, err := json.MarshalIndent(pm.pol, "", " ") + if err == nil { + sb.WriteString("Policy:\n") + sb.Write(pol) + sb.WriteString("\n\n") + } + } + + fmt.Fprintf(&sb, "AutoApprover (%d):\n", len(pm.autoApproveMap)) + for prefix, approveAddrs := range pm.autoApproveMap { + fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range approveAddrs.Ranges() { + fmt.Fprintf(&sb, "\t\t%s\n", iprange) + } + } + + sb.WriteString("\n\n") + + fmt.Fprintf(&sb, "TagOwner (%d):\n", len(pm.tagOwnerMap)) + for prefix, tagOwners := range pm.tagOwnerMap { + fmt.Fprintf(&sb, "\t%s:\n", prefix) + for _, iprange := range tagOwners.Ranges() { + fmt.Fprintf(&sb, "\t\t%s\n", iprange) + } + } + + sb.WriteString("\n\n") + if pm.filter != nil { + filter, err := json.MarshalIndent(pm.filter, "", " ") + if err == nil { + sb.WriteString("Compiled filter:\n") + sb.Write(filter) + sb.WriteString("\n\n") + } + } + + sb.WriteString("\n\n") + sb.WriteString("Matchers:\n") + sb.WriteString("an internal structure used to filter nodes and routes\n") + for _, match := range pm.matchers { + sb.WriteString(match.DebugString()) + sb.WriteString("\n") + } + + sb.WriteString("\n\n") + sb.WriteString("Nodes:\n") + for _, node := range pm.nodes.All() { + sb.WriteString(node.String()) + sb.WriteString("\n") + } + + return sb.String() +} + +// invalidateAutogroupSelfCache intelligently clears only the cache entries that need to be +// invalidated when using autogroup:self policies. This is much more efficient than clearing +// the entire cache. +func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.Slice[types.NodeView]) { + // Build maps for efficient lookup + oldNodeMap := make(map[types.NodeID]types.NodeView) + for _, node := range oldNodes.All() { + oldNodeMap[node.ID()] = node + } + + newNodeMap := make(map[types.NodeID]types.NodeView) + for _, node := range newNodes.All() { + newNodeMap[node.ID()] = node + } + + // Track which users are affected by changes + affectedUsers := make(map[uint]struct{}) + + // Check for removed nodes + for nodeID, oldNode := range oldNodeMap { + if _, exists := newNodeMap[nodeID]; !exists { + affectedUsers[oldNode.User().ID()] = struct{}{} + } + } + + // Check for added nodes + for nodeID, newNode := range newNodeMap { + if _, exists := oldNodeMap[nodeID]; !exists { + affectedUsers[newNode.User().ID()] = struct{}{} + } + } + + // Check for modified nodes (user changes, tag changes, IP changes) + for nodeID, newNode := range newNodeMap { + if oldNode, exists := oldNodeMap[nodeID]; exists { + // Check if user changed + if oldNode.User().ID() != newNode.User().ID() { + affectedUsers[oldNode.User().ID()] = struct{}{} + affectedUsers[newNode.User().ID()] = struct{}{} + } + + // Check if tag status changed + if oldNode.IsTagged() != newNode.IsTagged() { + affectedUsers[newNode.User().ID()] = struct{}{} + } + + // Check if IPs changed (simple check - could be more sophisticated) + oldIPs := oldNode.IPs() + newIPs := newNode.IPs() + if len(oldIPs) != len(newIPs) { + affectedUsers[newNode.User().ID()] = struct{}{} + } else { + // Check if any IPs are different + for i, oldIP := range oldIPs { + if i >= len(newIPs) || oldIP != newIPs[i] { + affectedUsers[newNode.User().ID()] = struct{}{} + break + } + } + } + } + } + + // Clear cache entries for affected users only + // For autogroup:self, we need to clear all nodes belonging to affected users + // because autogroup:self rules depend on the entire user's device set + for nodeID := range pm.filterRulesMap { + // Find the user for this cached node + var nodeUserID uint + found := false + + // Check in new nodes first + for _, node := range newNodes.All() { + if node.ID() == nodeID { + nodeUserID = node.User().ID() + found = true + break + } + } + + // If not found in new nodes, check old nodes + if !found { + for _, node := range oldNodes.All() { + if node.ID() == nodeID { + nodeUserID = node.User().ID() + found = true + break + } + } + } + + // If we found the user and they're affected, clear this cache entry + if found { + if _, affected := affectedUsers[nodeUserID]; affected { + delete(pm.compiledFilterRulesMap, nodeID) + delete(pm.filterRulesMap, nodeID) + } + } else { + // Node not found in either old or new list, clear it + delete(pm.compiledFilterRulesMap, nodeID) + delete(pm.filterRulesMap, nodeID) + } + } + + if len(affectedUsers) > 0 { + log.Debug(). + Int("affected_users", len(affectedUsers)). + Int("remaining_cache_entries", len(pm.filterRulesMap)). + Msg("Selectively cleared autogroup:self cache for affected users") + } +} + +// invalidateNodeCache invalidates cache entries based on what changed. +func (pm *PolicyManager) invalidateNodeCache(newNodes views.Slice[types.NodeView]) { + if pm.usesAutogroupSelf { + // For autogroup:self, a node's filter depends on its peers (same user). + // When any node in a user changes, all nodes for that user need invalidation. + pm.invalidateAutogroupSelfCache(pm.nodes, newNodes) + } else { + // For global policies, a node's filter depends only on its own properties. + // Only invalidate nodes whose properties actually changed. + pm.invalidateGlobalPolicyCache(newNodes) + } +} + +// invalidateGlobalPolicyCache invalidates only nodes whose properties affecting +// ReduceFilterRules changed. For global policies, each node's filter is independent. +func (pm *PolicyManager) invalidateGlobalPolicyCache(newNodes views.Slice[types.NodeView]) { + oldNodeMap := make(map[types.NodeID]types.NodeView) + for _, node := range pm.nodes.All() { + oldNodeMap[node.ID()] = node + } + + newNodeMap := make(map[types.NodeID]types.NodeView) + for _, node := range newNodes.All() { + newNodeMap[node.ID()] = node + } + + // Invalidate nodes whose properties changed + for nodeID, newNode := range newNodeMap { + oldNode, existed := oldNodeMap[nodeID] + if !existed { + // New node - no cache entry yet, will be lazily calculated + continue + } + + if newNode.HasNetworkChanges(oldNode) { + delete(pm.filterRulesMap, nodeID) + } + } + + // Remove deleted nodes from cache + for nodeID := range pm.filterRulesMap { + if _, exists := newNodeMap[nodeID]; !exists { + delete(pm.filterRulesMap, nodeID) + } + } +} + +// flattenTags flattens the TagOwners by resolving nested tags and detecting cycles. +// It will return a Owners list where all the Tag types have been resolved to their underlying Owners. +func flattenTags(tagOwners TagOwners, tag Tag, visiting map[Tag]bool, chain []Tag) (Owners, error) { + if visiting[tag] { + cycleStart := 0 + + for i, t := range chain { + if t == tag { + cycleStart = i + break + } + } + + cycleTags := make([]string, len(chain[cycleStart:])) + for i, t := range chain[cycleStart:] { + cycleTags[i] = string(t) + } + + slices.Sort(cycleTags) + + return nil, fmt.Errorf("%w: %s", ErrCircularReference, strings.Join(cycleTags, " -> ")) + } + + visiting[tag] = true + + chain = append(chain, tag) + defer delete(visiting, tag) + + var result Owners + + for _, owner := range tagOwners[tag] { + switch o := owner.(type) { + case *Tag: + if _, ok := tagOwners[*o]; !ok { + return nil, fmt.Errorf("tag %q %w %q", tag, ErrUndefinedTagReference, *o) + } + + nested, err := flattenTags(tagOwners, *o, visiting, chain) + if err != nil { + return nil, err + } + + result = append(result, nested...) + default: + result = append(result, owner) + } + } + + return result, nil +} + +// flattenTagOwners flattens all TagOwners by resolving nested tags and detecting cycles. +// It will return a new TagOwners map where all the Tag types have been resolved to their underlying Owners. +func flattenTagOwners(tagOwners TagOwners) (TagOwners, error) { + ret := make(TagOwners) + + for tag := range tagOwners { + flattened, err := flattenTags(tagOwners, tag, make(map[Tag]bool), nil) + if err != nil { + return nil, err + } + + slices.SortFunc(flattened, func(a, b Owner) int { + return cmp.Compare(a.String(), b.String()) + }) + ret[tag] = slices.CompactFunc(flattened, func(a, b Owner) bool { + return a.String() == b.String() + }) + } + + return ret, nil +} + +// resolveTagOwners resolves the TagOwners to a map of Tag to netipx.IPSet. +// The resulting map can be used to quickly look up the IPSet for a given Tag. +// It is intended for internal use in a PolicyManager. +func resolveTagOwners(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[Tag]*netipx.IPSet, error) { + if p == nil { + return make(map[Tag]*netipx.IPSet), nil + } + + if len(p.TagOwners) == 0 { + return make(map[Tag]*netipx.IPSet), nil + } + + ret := make(map[Tag]*netipx.IPSet) + + tagOwners, err := flattenTagOwners(p.TagOwners) + if err != nil { + return nil, err + } + + for tag, owners := range tagOwners { + var ips netipx.IPSetBuilder + + for _, owner := range owners { + switch o := owner.(type) { + case *Tag: + // After flattening, Tag types should not appear in the owners list. + // If they do, skip them as they represent already-resolved references. + + case Alias: + // If it does not resolve, that means the tag is not associated with any IP addresses. + resolved, _ := o.Resolve(p, users, nodes) + ips.AddSet(resolved) + + default: + // Should never happen - after flattening, all owners should be Alias types + return nil, fmt.Errorf("%w: %v", ErrInvalidTagOwner, owner) + } + } + + ipSet, err := ips.IPSet() + if err != nil { + return nil, err + } + + ret[tag] = ipSet + } + + return ret, nil +} diff --git a/hscontrol/policy/v2/policy_test.go b/hscontrol/policy/v2/policy_test.go new file mode 100644 index 00000000..26b0d141 --- /dev/null +++ b/hscontrol/policy/v2/policy_test.go @@ -0,0 +1,890 @@ +package v2 + +import ( + "net/netip" + "slices" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node { + return &types.Node{ + ID: 0, + Hostname: name, + IPv4: ap(ipv4), + IPv6: ap(ipv6), + User: ptr.To(user), + UserID: ptr.To(user.ID), + Hostinfo: hostinfo, + } +} + +func TestPolicyManager(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "testuser", Email: "testuser@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "otheruser", Email: "otheruser@headscale.net"}, + } + + tests := []struct { + name string + pol string + nodes types.Nodes + wantFilter []tailcfg.FilterRule + wantMatchers []matcher.Match + }{ + { + name: "empty-policy", + pol: "{}", + nodes: types.Nodes{}, + wantFilter: tailcfg.FilterAllowAll, + wantMatchers: matcher.MatchesFromFilterRules(tailcfg.FilterAllowAll), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm, err := NewPolicyManager([]byte(tt.pol), users, tt.nodes.ViewSlice()) + require.NoError(t, err) + + filter, matchers := pm.Filter() + if diff := cmp.Diff(tt.wantFilter, filter); diff != "" { + t.Errorf("Filter() filter mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff( + tt.wantMatchers, + matchers, + cmp.AllowUnexported(matcher.Match{}), + ); diff != "" { + t.Errorf("Filter() matchers mismatch (-want +got):\n%s", diff) + } + + // TODO(kradalby): Test SSH Policy + }) + } +} + +func TestInvalidateAutogroupSelfCache(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}, + {Model: gorm.Model{ID: 3}, Name: "user3", Email: "user3@headscale.net"}, + } + + policy := `{ + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + initialNodes := types.Nodes{ + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + } + + for i, n := range initialNodes { + n.ID = types.NodeID(i + 1) + } + + pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice()) + require.NoError(t, err) + + // Add to cache by calling FilterForNode for each node + for _, n := range initialNodes { + _, err := pm.FilterForNode(n.View()) + require.NoError(t, err) + } + + require.Equal(t, len(initialNodes), len(pm.filterRulesMap)) + + tests := []struct { + name string + newNodes types.Nodes + expectedCleared int + description string + }{ + { + name: "no_changes", + newNodes: types.Nodes{ + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + }, + expectedCleared: 0, + description: "No changes should clear no cache entries", + }, + { + name: "node_added", + newNodes: types.Nodes{ + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), + node("user1-node3", "100.64.0.5", "fd7a:115c:a1e0::5", users[0], nil), // New node + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + }, + expectedCleared: 2, // user1's existing nodes should be cleared + description: "Adding a node should clear cache for that user's existing nodes", + }, + { + name: "node_removed", + newNodes: types.Nodes{ + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + // user1-node2 removed + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + }, + expectedCleared: 2, // user1's remaining node + removed node should be cleared + description: "Removing a node should clear cache for that user's remaining nodes", + }, + { + name: "user_changed", + newNodes: types.Nodes{ + node("user1-node1", "100.64.0.1", "fd7a:115c:a1e0::1", users[0], nil), + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[2], nil), // Changed to user3 + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + }, + expectedCleared: 3, // user1's node + user2's node + user3's nodes should be cleared + description: "Changing a node's user should clear cache for both old and new users", + }, + { + name: "ip_changed", + newNodes: types.Nodes{ + node("user1-node1", "100.64.0.10", "fd7a:115c:a1e0::10", users[0], nil), // IP changed + node("user1-node2", "100.64.0.2", "fd7a:115c:a1e0::2", users[0], nil), + node("user2-node1", "100.64.0.3", "fd7a:115c:a1e0::3", users[1], nil), + node("user3-node1", "100.64.0.4", "fd7a:115c:a1e0::4", users[2], nil), + }, + expectedCleared: 2, // user1's nodes should be cleared + description: "Changing a node's IP should clear cache for that user's nodes", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for i, n := range tt.newNodes { + found := false + for _, origNode := range initialNodes { + if n.Hostname == origNode.Hostname { + n.ID = origNode.ID + found = true + break + } + } + if !found { + n.ID = types.NodeID(len(initialNodes) + i + 1) + } + } + + pm.filterRulesMap = make(map[types.NodeID][]tailcfg.FilterRule) + for _, n := range initialNodes { + _, err := pm.FilterForNode(n.View()) + require.NoError(t, err) + } + + initialCacheSize := len(pm.filterRulesMap) + require.Equal(t, len(initialNodes), initialCacheSize) + + pm.invalidateAutogroupSelfCache(initialNodes.ViewSlice(), tt.newNodes.ViewSlice()) + + // Verify the expected number of cache entries were cleared + finalCacheSize := len(pm.filterRulesMap) + clearedEntries := initialCacheSize - finalCacheSize + require.Equal(t, tt.expectedCleared, clearedEntries, tt.description) + }) + } +} + +// TestInvalidateGlobalPolicyCache tests the cache invalidation logic for global policies. +func TestInvalidateGlobalPolicyCache(t *testing.T) { + mustIPPtr := func(s string) *netip.Addr { + ip := netip.MustParseAddr(s) + return &ip + } + + tests := []struct { + name string + oldNodes types.Nodes + newNodes types.Nodes + initialCache map[types.NodeID][]tailcfg.FilterRule + expectedCacheAfter map[types.NodeID]bool // true = should exist, false = should not exist + }{ + { + name: "node property changed - invalidates only that node", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Invalidated + 2: true, // Preserved + }, + }, + { + name: "multiple nodes changed", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + &types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.3")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.99")}, // Changed + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // Unchanged + &types.Node{ID: 3, IPv4: mustIPPtr("100.64.0.88")}, // Changed + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + 3: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Invalidated + 2: true, // Preserved + 3: false, // Invalidated + }, + }, + { + name: "node deleted - removes from cache", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Deleted + 2: true, // Preserved + }, + }, + { + name: "node added - no cache invalidation needed", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, // New + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: true, // Preserved + 2: false, // Not in cache (new node) + }, + }, + { + name: "no changes - preserves all cache", + oldNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ID: 1, IPv4: mustIPPtr("100.64.0.1")}, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + }, + { + name: "routes changed - invalidates that node only", + oldNodes: types.Nodes{ + &types.Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + }, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + newNodes: types.Nodes{ + &types.Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, // Changed + }, + &types.Node{ID: 2, IPv4: mustIPPtr("100.64.0.2")}, + }, + initialCache: map[types.NodeID][]tailcfg.FilterRule{ + 1: {}, + 2: {}, + }, + expectedCacheAfter: map[types.NodeID]bool{ + 1: false, // Invalidated + 2: true, // Preserved + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pm := &PolicyManager{ + nodes: tt.oldNodes.ViewSlice(), + filterRulesMap: tt.initialCache, + usesAutogroupSelf: false, + } + + pm.invalidateGlobalPolicyCache(tt.newNodes.ViewSlice()) + + // Verify cache state + for nodeID, shouldExist := range tt.expectedCacheAfter { + _, exists := pm.filterRulesMap[nodeID] + require.Equal(t, shouldExist, exists, "node %d cache existence mismatch", nodeID) + } + }) + } +} + +// TestAutogroupSelfReducedVsUnreducedRules verifies that: +// 1. BuildPeerMap uses unreduced compiled rules for determining peer relationships +// 2. FilterForNode returns reduced compiled rules for packet filters +func TestAutogroupSelfReducedVsUnreducedRules(t *testing.T) { + user1 := types.User{Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"} + user2 := types.User{Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"} + users := types.Users{user1, user2} + + // Create two nodes + node1 := node("node1", "100.64.0.1", "fd7a:115c:a1e0::1", user1, nil) + node1.ID = 1 + node2 := node("node2", "100.64.0.2", "fd7a:115c:a1e0::2", user2, nil) + node2.ID = 2 + nodes := types.Nodes{node1, node2} + + // Policy with autogroup:self - all members can reach their own devices + policyStr := `{ + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policyStr), users, nodes.ViewSlice()) + require.NoError(t, err) + require.True(t, pm.usesAutogroupSelf, "policy should use autogroup:self") + + // Test FilterForNode returns reduced rules + // For node1: should have rules where node1 is in destinations (its own IP) + filterNode1, err := pm.FilterForNode(nodes[0].View()) + require.NoError(t, err) + + // For node2: should have rules where node2 is in destinations (its own IP) + filterNode2, err := pm.FilterForNode(nodes[1].View()) + require.NoError(t, err) + + // FilterForNode should return reduced rules - verify they only contain the node's own IPs as destinations + // For node1, destinations should only be node1's IPs + node1IPs := []string{"100.64.0.1/32", "100.64.0.1", "fd7a:115c:a1e0::1/128", "fd7a:115c:a1e0::1"} + for _, rule := range filterNode1 { + for _, dst := range rule.DstPorts { + require.Contains(t, node1IPs, dst.IP, + "node1 filter should only contain node1's IPs as destinations") + } + } + + // For node2, destinations should only be node2's IPs + node2IPs := []string{"100.64.0.2/32", "100.64.0.2", "fd7a:115c:a1e0::2/128", "fd7a:115c:a1e0::2"} + for _, rule := range filterNode2 { + for _, dst := range rule.DstPorts { + require.Contains(t, node2IPs, dst.IP, + "node2 filter should only contain node2's IPs as destinations") + } + } + + // Test BuildPeerMap uses unreduced rules + peerMap := pm.BuildPeerMap(nodes.ViewSlice()) + + // According to the policy, user1 can reach autogroup:self (which expands to node1's own IPs for node1) + // So node1 should be able to reach itself, but since we're looking at peer relationships, + // node1 should NOT have itself in the peer map (nodes don't peer with themselves) + // node2 should also not have any peers since user2 has no rules allowing it to reach anyone + + // Verify peer relationships based on unreduced rules + // With unreduced rules, BuildPeerMap can properly determine that: + // - node1 can access autogroup:self (its own IPs) + // - node2 cannot access node1 + require.Empty(t, peerMap[node1.ID], "node1 should have no peers (can only reach itself)") + require.Empty(t, peerMap[node2.ID], "node2 should have no peers") +} + +// When separate ACL rules exist (one with autogroup:self, one with tag:router), +// the autogroup:self rule should not prevent the tag:router rule from working. +// This ensures that autogroup:self doesn't interfere with other ACL rules. +func TestAutogroupSelfWithOtherRules(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "test-1", Email: "test-1@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "test-2", Email: "test-2@example.com"}, + } + + // test-1 has a regular device + test1Node := &types.Node{ + ID: 1, + Hostname: "test-1-device", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + Hostinfo: &tailcfg.Hostinfo{}, + } + + // test-2 has a router device with tag:node-router + test2RouterNode := &types.Node{ + ID: 2, + Hostname: "test-2-router", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:node-router"}, + Hostinfo: &tailcfg.Hostinfo{}, + } + + nodes := types.Nodes{test1Node, test2RouterNode} + + // This matches the exact policy from issue #2838: + // - First rule: autogroup:member -> autogroup:self (allows users to see their own devices) + // - Second rule: group:home -> tag:node-router (should allow group members to see router) + policy := `{ + "groups": { + "group:home": ["test-1@example.com", "test-2@example.com"] + }, + "tagOwners": { + "tag:node-router": ["group:home"] + }, + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + }, + { + "action": "accept", + "src": ["group:home"], + "dst": ["tag:node-router:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + peerMap := pm.BuildPeerMap(nodes.ViewSlice()) + + // test-1 (in group:home) should see: + // 1. Their own node (from autogroup:self rule) + // 2. The router node (from group:home -> tag:node-router rule) + test1Peers := peerMap[test1Node.ID] + + // Verify test-1 can see the router (group:home -> tag:node-router rule) + require.True(t, slices.ContainsFunc(test1Peers, func(n types.NodeView) bool { + return n.ID() == test2RouterNode.ID + }), "test-1 should see test-2's router via group:home -> tag:node-router rule, even when autogroup:self rule exists (issue #2838)") + + // Verify that test-1 has filter rules (including autogroup:self and tag:node-router access) + rules, err := pm.FilterForNode(test1Node.View()) + require.NoError(t, err) + require.NotEmpty(t, rules, "test-1 should have filter rules from both ACL rules") +} + +// TestAutogroupSelfPolicyUpdateTriggersMapResponse verifies that when a policy with +// autogroup:self is updated, SetPolicy returns true to trigger MapResponse updates, +// even if the global filter hash didn't change (which is always empty for autogroup:self). +// This fixes the issue where policy updates would clear caches but not trigger updates, +// leaving nodes with stale filter rules until reconnect. +func TestAutogroupSelfPolicyUpdateTriggersMapResponse(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "test-1", Email: "test-1@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "test-2", Email: "test-2@example.com"}, + } + + test1Node := &types.Node{ + ID: 1, + Hostname: "test-1-device", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + Hostinfo: &tailcfg.Hostinfo{}, + } + + test2Node := &types.Node{ + ID: 2, + Hostname: "test-2-device", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Hostinfo: &tailcfg.Hostinfo{}, + } + + nodes := types.Nodes{test1Node, test2Node} + + // Initial policy with autogroup:self + initialPolicy := `{ + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(initialPolicy), users, nodes.ViewSlice()) + require.NoError(t, err) + require.True(t, pm.usesAutogroupSelf, "policy should use autogroup:self") + + // Get initial filter rules for test-1 (should be cached) + rules1, err := pm.FilterForNode(test1Node.View()) + require.NoError(t, err) + require.NotEmpty(t, rules1, "test-1 should have filter rules") + + // Update policy with a different ACL that still results in empty global filter + // (only autogroup:self rules, which compile to empty global filter) + // We add a comment/description change by adding groups (which don't affect filter compilation) + updatedPolicy := `{ + "groups": { + "group:test": ["test-1@example.com"] + }, + "acls": [ + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + // SetPolicy should return true even though global filter hash didn't change + policyChanged, err := pm.SetPolicy([]byte(updatedPolicy)) + require.NoError(t, err) + require.True(t, policyChanged, "SetPolicy should return true when policy content changes, even if global filter hash unchanged (autogroup:self)") + + // Verify that caches were cleared and new rules are generated + // The cache should be empty, so FilterForNode will recompile + rules2, err := pm.FilterForNode(test1Node.View()) + require.NoError(t, err) + require.NotEmpty(t, rules2, "test-1 should have filter rules after policy update") + + // Verify that the policy hash tracking works - a second identical update should return false + policyChanged2, err := pm.SetPolicy([]byte(updatedPolicy)) + require.NoError(t, err) + require.False(t, policyChanged2, "SetPolicy should return false when policy content hasn't changed") +} + +// TestTagPropagationToPeerMap tests that when a node's tags change, +// the peer map is correctly updated. This is a regression test for +// https://github.com/juanfont/headscale/issues/2389 +func TestTagPropagationToPeerMap(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@headscale.net"}, + {Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@headscale.net"}, + } + + // Policy: user2 can access tag:web nodes + policy := `{ + "tagOwners": { + "tag:web": ["user1@headscale.net"], + "tag:internal": ["user1@headscale.net"] + }, + "acls": [ + { + "action": "accept", + "src": ["user2@headscale.net"], + "dst": ["user2@headscale.net:*"] + }, + { + "action": "accept", + "src": ["user2@headscale.net"], + "dst": ["tag:web:*"] + }, + { + "action": "accept", + "src": ["tag:web"], + "dst": ["user2@headscale.net:*"] + } + ] + }` + + // user1's node starts with tag:web and tag:internal + user1Node := &types.Node{ + ID: 1, + Hostname: "user1-node", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + Tags: []string{"tag:web", "tag:internal"}, + } + + // user2's node (no tags) + user2Node := &types.Node{ + ID: 2, + Hostname: "user2-node", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + } + + initialNodes := types.Nodes{user1Node, user2Node} + + pm, err := NewPolicyManager([]byte(policy), users, initialNodes.ViewSlice()) + require.NoError(t, err) + + // Initial state: user2 should see user1 as a peer (user1 has tag:web) + initialPeerMap := pm.BuildPeerMap(initialNodes.ViewSlice()) + + // Check user2's peers - should include user1 + user2Peers := initialPeerMap[user2Node.ID] + require.Len(t, user2Peers, 1, "user2 should have 1 peer initially (user1 with tag:web)") + require.Equal(t, user1Node.ID, user2Peers[0].ID(), "user2's peer should be user1") + + // Check user1's peers - should include user2 (bidirectional ACL) + user1Peers := initialPeerMap[user1Node.ID] + require.Len(t, user1Peers, 1, "user1 should have 1 peer initially (user2)") + require.Equal(t, user2Node.ID, user1Peers[0].ID(), "user1's peer should be user2") + + // Now change user1's tags: remove tag:web, keep only tag:internal + user1NodeUpdated := &types.Node{ + ID: 1, + Hostname: "user1-node", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + Tags: []string{"tag:internal"}, // tag:web removed! + } + + updatedNodes := types.Nodes{user1NodeUpdated, user2Node} + + // SetNodes should detect the tag change + changed, err := pm.SetNodes(updatedNodes.ViewSlice()) + require.NoError(t, err) + require.True(t, changed, "SetNodes should return true when tags change") + + // After tag change: user2 should NOT see user1 as a peer anymore + // (no ACL allows user2 to access tag:internal) + updatedPeerMap := pm.BuildPeerMap(updatedNodes.ViewSlice()) + + // Check user2's peers - should be empty now + user2PeersAfter := updatedPeerMap[user2Node.ID] + require.Empty(t, user2PeersAfter, "user2 should have no peers after tag:web is removed from user1") + + // Check user1's peers - should also be empty + user1PeersAfter := updatedPeerMap[user1Node.ID] + require.Empty(t, user1PeersAfter, "user1 should have no peers after tag:web is removed") + + // Also verify MatchersForNode returns non-empty matchers and ReduceNodes filters correctly + // This simulates what buildTailPeers does in the mapper + matchersForUser2, err := pm.MatchersForNode(user2Node.View()) + require.NoError(t, err) + require.NotEmpty(t, matchersForUser2, "MatchersForNode should return non-empty matchers (at least self-access rule)") + + // Test ReduceNodes logic with the updated nodes and matchers + // This is what buildTailPeers does - it takes peers from ListPeers (which might include user1) + // and filters them using ReduceNodes with the updated matchers + // Inline the ReduceNodes logic to avoid import cycle + user2View := user2Node.View() + user1UpdatedView := user1NodeUpdated.View() + + // Check if user2 can access user1 OR user1 can access user2 + canAccess := user2View.CanAccess(matchersForUser2, user1UpdatedView) || + user1UpdatedView.CanAccess(matchersForUser2, user2View) + + require.False(t, canAccess, "user2 should NOT be able to access user1 after tag:web is removed (ReduceNodes should filter out)") +} + +// TestAutogroupSelfWithAdminOverride reproduces issue #2990: +// When autogroup:self is combined with an admin rule (group:admin -> *:*), +// tagged nodes become invisible to admins because BuildPeerMap uses asymmetric +// peer visibility in the autogroup:self path. +// +// The fix requires symmetric visibility: if admin can access tagged node, +// BOTH admin and tagged node should see each other as peers. +func TestAutogroupSelfWithAdminOverride(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "admin", Email: "admin@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "user1", Email: "user1@example.com"}, + } + + // Admin has a regular device + adminNode := &types.Node{ + ID: 1, + Hostname: "admin-device", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + Hostinfo: &tailcfg.Hostinfo{}, + } + + // user1 has a tagged server + user1TaggedNode := &types.Node{ + ID: 2, + Hostname: "user1-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:server"}, + Hostinfo: &tailcfg.Hostinfo{}, + } + + nodes := types.Nodes{adminNode, user1TaggedNode} + + // Policy from issue #2990: + // - group:admin has full access to everything (*:*) + // - autogroup:member -> autogroup:self (allows users to see their own devices) + // + // Bug: The tagged server becomes invisible to admin because: + // 1. Admin can access tagged server (via *:* rule) + // 2. Tagged server CANNOT access admin (no rule for that) + // 3. With asymmetric logic, tagged server is not added to admin's peer list + policy := `{ + "groups": { + "group:admin": ["admin@example.com"] + }, + "tagOwners": { + "tag:server": ["user1@example.com"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admin"], + "dst": ["*:*"] + }, + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + peerMap := pm.BuildPeerMap(nodes.ViewSlice()) + + // Admin should see the tagged server as a peer (via group:admin -> *:* rule) + adminPeers := peerMap[adminNode.ID] + require.True(t, slices.ContainsFunc(adminPeers, func(n types.NodeView) bool { + return n.ID() == user1TaggedNode.ID + }), "admin should see tagged server as peer via *:* rule (issue #2990)") + + // Tagged server should also see admin as a peer (symmetric visibility) + // Even though tagged server cannot ACCESS admin, it should still SEE admin + // because admin CAN access it. This is required for proper network operation. + taggedPeers := peerMap[user1TaggedNode.ID] + require.True(t, slices.ContainsFunc(taggedPeers, func(n types.NodeView) bool { + return n.ID() == adminNode.ID + }), "tagged server should see admin as peer (symmetric visibility)") +} + +// TestAutogroupSelfSymmetricVisibility verifies that peer visibility is symmetric: +// if node A can access node B, then both A and B should see each other as peers. +// This is the same behavior as the global filter path. +func TestAutogroupSelfSymmetricVisibility(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1", Email: "user1@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "user2", Email: "user2@example.com"}, + } + + // user1 has device A + deviceA := &types.Node{ + ID: 1, + Hostname: "device-a", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + Hostinfo: &tailcfg.Hostinfo{}, + } + + // user2 has device B (tagged) + deviceB := &types.Node{ + ID: 2, + Hostname: "device-b", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + Tags: []string{"tag:web"}, + Hostinfo: &tailcfg.Hostinfo{}, + } + + nodes := types.Nodes{deviceA, deviceB} + + // One-way rule: user1 can access tag:web, but tag:web cannot access user1 + policy := `{ + "tagOwners": { + "tag:web": ["user2@example.com"] + }, + "acls": [ + { + "action": "accept", + "src": ["user1@example.com"], + "dst": ["tag:web:*"] + }, + { + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:self:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + peerMap := pm.BuildPeerMap(nodes.ViewSlice()) + + // Device A (user1) should see device B (tag:web) as peer + aPeers := peerMap[deviceA.ID] + require.True(t, slices.ContainsFunc(aPeers, func(n types.NodeView) bool { + return n.ID() == deviceB.ID + }), "device A should see device B as peer (user1 -> tag:web rule)") + + // Device B (tag:web) should ALSO see device A as peer (symmetric visibility) + // Even though B cannot ACCESS A, B should still SEE A as a peer + bPeers := peerMap[deviceB.ID] + require.True(t, slices.ContainsFunc(bPeers, func(n types.NodeView) bool { + return n.ID() == deviceA.ID + }), "device B should see device A as peer (symmetric visibility)") +} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go new file mode 100644 index 00000000..ce968225 --- /dev/null +++ b/hscontrol/policy/v2/types.go @@ -0,0 +1,2171 @@ +package v2 + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strconv" + "strings" + + "github.com/go-json-experiment/json" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/prometheus/common/model" + "github.com/tailscale/hujson" + "go4.org/netipx" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" + "tailscale.com/types/views" + "tailscale.com/util/multierr" + "tailscale.com/util/slicesx" +) + +// Global JSON options for consistent parsing across all struct unmarshaling +var policyJSONOpts = []json.Options{ + json.DefaultOptionsV2(), + json.MatchCaseInsensitiveNames(true), + json.RejectUnknownMembers(true), +} + +const Wildcard = Asterix(0) + +var ErrAutogroupSelfRequiresPerNodeResolution = errors.New("autogroup:self requires per-node resolution and cannot be resolved in this context") + +var ErrCircularReference = errors.New("circular reference detected") + +var ErrUndefinedTagReference = errors.New("references undefined tag") + +// SSH validation errors. +var ( + ErrSSHTagSourceToUserDest = errors.New("tags in SSH source cannot access user-owned devices") + ErrSSHUserDestRequiresSameUser = errors.New("user destination requires source to contain only that same user") + ErrSSHAutogroupSelfRequiresUserSource = errors.New("autogroup:self destination requires source to contain only users or groups, not tags or autogroup:tagged") + ErrSSHTagSourceToAutogroupMember = errors.New("tags in SSH source cannot access autogroup:member (user-owned devices)") + ErrSSHWildcardDestination = errors.New("wildcard (*) is not supported as SSH destination") +) + +type Asterix int + +func (a Asterix) Validate() error { + return nil +} + +func (a Asterix) String() string { + return "*" +} + +// MarshalJSON marshals the Asterix to JSON. +func (a Asterix) MarshalJSON() ([]byte, error) { + return []byte(`"*"`), nil +} + +// MarshalJSON marshals the AliasWithPorts to JSON. +func (a AliasWithPorts) MarshalJSON() ([]byte, error) { + if a.Alias == nil { + return []byte(`""`), nil + } + + var alias string + switch v := a.Alias.(type) { + case *Username: + alias = string(*v) + case *Group: + alias = string(*v) + case *Tag: + alias = string(*v) + case *Host: + alias = string(*v) + case *Prefix: + alias = v.String() + case *AutoGroup: + alias = string(*v) + case Asterix: + alias = "*" + default: + return nil, fmt.Errorf("unknown alias type: %T", v) + } + + // If no ports are specified + if len(a.Ports) == 0 { + return json.Marshal(alias) + } + + // Check if it's the wildcard port range + if len(a.Ports) == 1 && a.Ports[0].First == 0 && a.Ports[0].Last == 65535 { + return json.Marshal(alias + ":*") + } + + // Otherwise, format as "alias:ports" + var ports []string + for _, port := range a.Ports { + if port.First == port.Last { + ports = append(ports, strconv.FormatUint(uint64(port.First), 10)) + } else { + ports = append(ports, fmt.Sprintf("%d-%d", port.First, port.Last)) + } + } + + return json.Marshal(fmt.Sprintf("%s:%s", alias, strings.Join(ports, ","))) +} + +func (a Asterix) UnmarshalJSON(b []byte) error { + return nil +} + +func (a Asterix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + // TODO(kradalby): + // Should this actually only be the CGNAT spaces? I do not think so, because + // we also want to include subnet routers right? + ips.AddPrefix(tsaddr.AllIPv4()) + ips.AddPrefix(tsaddr.AllIPv6()) + + return ips.IPSet() +} + +// Username is a string that represents a username, it must contain an @. +type Username string + +func (u Username) Validate() error { + if isUser(string(u)) { + return nil + } + return fmt.Errorf("Username has to contain @, got: %q", u) +} + +func (u *Username) String() string { + return string(*u) +} + +// MarshalJSON marshals the Username to JSON. +func (u Username) MarshalJSON() ([]byte, error) { + return json.Marshal(string(u)) +} + +// MarshalJSON marshals the Prefix to JSON. +func (p Prefix) MarshalJSON() ([]byte, error) { + return json.Marshal(p.String()) +} + +func (u *Username) UnmarshalJSON(b []byte) error { + *u = Username(strings.Trim(string(b), `"`)) + if err := u.Validate(); err != nil { + return err + } + + return nil +} + +func (u Username) CanBeTagOwner() bool { + return true +} + +func (u Username) CanBeAutoApprover() bool { + return true +} + +// resolveUser attempts to find a user in the provided [types.Users] slice that matches the Username. +// It prioritizes matching the ProviderIdentifier, and if not found, it falls back to matching the Email or Name. +// If no matching user is found, it returns an error indicating no user matching. +// If multiple matching users are found, it returns an error indicating multiple users matching. +// It returns the matched types.User and a nil error if exactly one match is found. +func (u Username) resolveUser(users types.Users) (types.User, error) { + var potentialUsers types.Users + + // At parsetime, we require all usernames to contain an "@" character, if the + // username token does not naturally do so (like email), the user have to + // add it to the end of the username. We strip it here as we do not expect the + // usernames to be stored with the "@". + uTrimmed := strings.TrimSuffix(u.String(), "@") + + for _, user := range users { + if user.ProviderIdentifier.Valid && user.ProviderIdentifier.String == uTrimmed { + // Prioritize ProviderIdentifier match and exit early + return user, nil + } + + if user.Email == uTrimmed || user.Name == uTrimmed { + potentialUsers = append(potentialUsers, user) + } + } + + if len(potentialUsers) == 0 { + return types.User{}, fmt.Errorf("user with token %q not found", u.String()) + } + + if len(potentialUsers) > 1 { + return types.User{}, fmt.Errorf("multiple users with token %q found: %s", u.String(), potentialUsers.String()) + } + + return potentialUsers[0], nil +} + +func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + user, err := u.resolveUser(users) + if err != nil { + errs = append(errs, err) + } + + for _, node := range nodes.All() { + // Skip tagged nodes - they are identified by tags, not users + if node.IsTagged() { + continue + } + + // Skip nodes without a user (defensive check for tests) + if !node.User().Valid() { + continue + } + + if node.User().ID() == user.ID { + node.AppendToIPSet(&ips) + } + } + + return buildIPSetMultiErr(&ips, errs) +} + +// Group is a special string which is always prefixed with `group:`. +type Group string + +func (g Group) Validate() error { + if isGroup(string(g)) { + return nil + } + return fmt.Errorf(`Group has to start with "group:", got: %q`, g) +} + +func (g *Group) UnmarshalJSON(b []byte) error { + *g = Group(strings.Trim(string(b), `"`)) + if err := g.Validate(); err != nil { + return err + } + + return nil +} + +func (g Group) CanBeTagOwner() bool { + return true +} + +func (g Group) CanBeAutoApprover() bool { + return true +} + +// String returns the string representation of the Group. +func (g Group) String() string { + return string(g) +} + +func (h Host) String() string { + return string(h) +} + +// MarshalJSON marshals the Host to JSON. +func (h Host) MarshalJSON() ([]byte, error) { + return json.Marshal(string(h)) +} + +// MarshalJSON marshals the Group to JSON. +func (g Group) MarshalJSON() ([]byte, error) { + return json.Marshal(string(g)) +} + +func (g Group) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + for _, user := range p.Groups[g] { + uips, err := user.Resolve(nil, users, nodes) + if err != nil { + errs = append(errs, err) + } + + ips.AddSet(uips) + } + + return buildIPSetMultiErr(&ips, errs) +} + +// Tag is a special string which is always prefixed with `tag:`. +type Tag string + +func (t Tag) Validate() error { + if isTag(string(t)) { + return nil + } + return fmt.Errorf(`tag has to start with "tag:", got: %q`, t) +} + +func (t *Tag) UnmarshalJSON(b []byte) error { + *t = Tag(strings.Trim(string(b), `"`)) + if err := t.Validate(); err != nil { + return err + } + + return nil +} + +func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + + for _, node := range nodes.All() { + // Check if node has this tag + if node.HasTag(string(t)) { + node.AppendToIPSet(&ips) + } + } + + return ips.IPSet() +} + +func (t Tag) CanBeAutoApprover() bool { + return true +} + +func (t Tag) CanBeTagOwner() bool { + return true +} + +func (t Tag) String() string { + return string(t) +} + +// MarshalJSON marshals the Tag to JSON. +func (t Tag) MarshalJSON() ([]byte, error) { + return json.Marshal(string(t)) +} + +// Host is a string that represents a hostname. +type Host string + +func (h Host) Validate() error { + if isHost(string(h)) { + return nil + } + return fmt.Errorf("Hostname %q is invalid", h) +} + +func (h *Host) UnmarshalJSON(b []byte) error { + *h = Host(strings.Trim(string(b), `"`)) + if err := h.Validate(); err != nil { + return err + } + + return nil +} + +func (h Host) Resolve(p *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + pref, ok := p.Hosts[h] + if !ok { + return nil, fmt.Errorf("unable to resolve host: %q", h) + } + err := pref.Validate() + if err != nil { + errs = append(errs, err) + } + + ips.AddPrefix(netip.Prefix(pref)) + + // If the IP is a single host, look for a node to ensure we add all the IPs of + // the node to the IPSet. + appendIfNodeHasIP(nodes, &ips, netip.Prefix(pref)) + + // TODO(kradalby): I am a bit unsure what is the correct way to do this, + // should a host with a non single IP be able to resolve the full host (inc all IPs). + ipsTemp, err := ips.IPSet() + if err != nil { + errs = append(errs, err) + } + for _, node := range nodes.All() { + if node.InIPSet(ipsTemp) { + node.AppendToIPSet(&ips) + } + } + + return buildIPSetMultiErr(&ips, errs) +} + +type Prefix netip.Prefix + +func (p Prefix) Validate() error { + if netip.Prefix(p).IsValid() { + return nil + } + return fmt.Errorf("Prefix %q is invalid", p) +} + +func (p Prefix) String() string { + return netip.Prefix(p).String() +} + +func (p *Prefix) parseString(addr string) error { + if !strings.Contains(addr, "/") { + addr, err := netip.ParseAddr(addr) + if err != nil { + return err + } + addrPref, err := addr.Prefix(addr.BitLen()) + if err != nil { + return err + } + + *p = Prefix(addrPref) + + return nil + } + + pref, err := netip.ParsePrefix(addr) + if err != nil { + return err + } + *p = Prefix(pref) + + return nil +} + +func (p *Prefix) UnmarshalJSON(b []byte) error { + err := p.parseString(strings.Trim(string(b), `"`)) + if err != nil { + return err + } + if err := p.Validate(); err != nil { + return err + } + + return nil +} + +// Resolve resolves the Prefix to an IPSet. The IPSet will contain all the IP +// addresses that the Prefix represents within Headscale. It is the product +// of the Prefix and the Policy, Users, and Nodes. +// +// See [Policy], [types.Users], and [types.Nodes] for more details. +func (p Prefix) Resolve(_ *Policy, _ types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + ips.AddPrefix(netip.Prefix(p)) + // If the IP is a single host, look for a node to ensure we add all the IPs of + // the node to the IPSet. + appendIfNodeHasIP(nodes, &ips, netip.Prefix(p)) + + return buildIPSetMultiErr(&ips, errs) +} + +// appendIfNodeHasIP appends the IPs of the nodes to the IPSet if the node has the +// IP address in the prefix. +func appendIfNodeHasIP(nodes views.Slice[types.NodeView], ips *netipx.IPSetBuilder, pref netip.Prefix) { + if !pref.IsSingleIP() && !tsaddr.IsTailscaleIP(pref.Addr()) { + return + } + + for _, node := range nodes.All() { + if node.HasIP(pref.Addr()) { + node.AppendToIPSet(ips) + } + } +} + +// AutoGroup is a special string which is always prefixed with `autogroup:`. +type AutoGroup string + +const ( + AutoGroupInternet AutoGroup = "autogroup:internet" + AutoGroupMember AutoGroup = "autogroup:member" + AutoGroupNonRoot AutoGroup = "autogroup:nonroot" + AutoGroupTagged AutoGroup = "autogroup:tagged" + AutoGroupSelf AutoGroup = "autogroup:self" +) + +var autogroups = []AutoGroup{ + AutoGroupInternet, + AutoGroupMember, + AutoGroupNonRoot, + AutoGroupTagged, + AutoGroupSelf, +} + +func (ag AutoGroup) Validate() error { + if slices.Contains(autogroups, ag) { + return nil + } + + return fmt.Errorf("AutoGroup is invalid, got: %q, must be one of %v", ag, autogroups) +} + +func (ag *AutoGroup) UnmarshalJSON(b []byte) error { + *ag = AutoGroup(strings.Trim(string(b), `"`)) + if err := ag.Validate(); err != nil { + return err + } + + return nil +} + +func (ag AutoGroup) String() string { + return string(ag) +} + +// MarshalJSON marshals the AutoGroup to JSON. +func (ag AutoGroup) MarshalJSON() ([]byte, error) { + return json.Marshal(string(ag)) +} + +func (ag AutoGroup) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var build netipx.IPSetBuilder + + switch ag { + case AutoGroupInternet: + return util.TheInternet(), nil + + case AutoGroupMember: + for _, node := range nodes.All() { + // Skip if node is tagged + if node.IsTagged() { + continue + } + + // Node is a member if it is not tagged + node.AppendToIPSet(&build) + } + + return build.IPSet() + + case AutoGroupTagged: + for _, node := range nodes.All() { + // Include if node is tagged + if !node.IsTagged() { + continue + } + + node.AppendToIPSet(&build) + } + + return build.IPSet() + + case AutoGroupSelf: + // autogroup:self represents all devices owned by the same user. + // This cannot be resolved in the general context and should be handled + // specially during policy compilation per-node for security. + return nil, ErrAutogroupSelfRequiresPerNodeResolution + + default: + return nil, fmt.Errorf("unknown autogroup %q", ag) + } +} + +func (ag *AutoGroup) Is(c AutoGroup) bool { + if ag == nil { + return false + } + + return *ag == c +} + +type Alias interface { + Validate() error + UnmarshalJSON([]byte) error + + // Resolve resolves the Alias to an IPSet. The IPSet will contain all the IP + // addresses that the Alias represents within Headscale. It is the product + // of the Alias and the Policy, Users and Nodes. + // This is an interface definition and the implementation is independent of + // the Alias type. + Resolve(*Policy, types.Users, views.Slice[types.NodeView]) (*netipx.IPSet, error) +} + +type AliasWithPorts struct { + Alias + Ports []tailcfg.PortRange +} + +func (ve *AliasWithPorts) UnmarshalJSON(b []byte) error { + var v any + if err := json.Unmarshal(b, &v); err != nil { + return err + } + + switch vs := v.(type) { + case string: + var portsPart string + var err error + + if strings.Contains(vs, ":") { + vs, portsPart, err = splitDestinationAndPort(vs) + if err != nil { + return err + } + + ports, err := parsePortRange(portsPart) + if err != nil { + return err + } + ve.Ports = ports + } else { + return errors.New(`hostport must contain a colon (":")`) + } + + ve.Alias, err = parseAlias(vs) + if err != nil { + return err + } + if err := ve.Validate(); err != nil { + return err + } + + default: + return fmt.Errorf("type %T not supported", vs) + } + + return nil +} + +func isWildcard(str string) bool { + return str == "*" +} + +func isUser(str string) bool { + return strings.Contains(str, "@") +} + +func isGroup(str string) bool { + return strings.HasPrefix(str, "group:") +} + +func isTag(str string) bool { + return strings.HasPrefix(str, "tag:") +} + +func isAutoGroup(str string) bool { + return strings.HasPrefix(str, "autogroup:") +} + +func isHost(str string) bool { + return !isUser(str) && !strings.Contains(str, ":") +} + +func parseAlias(vs string) (Alias, error) { + var pref Prefix + err := pref.parseString(vs) + if err == nil { + return &pref, nil + } + + switch { + case isWildcard(vs): + return Wildcard, nil + case isUser(vs): + return ptr.To(Username(vs)), nil + case isGroup(vs): + return ptr.To(Group(vs)), nil + case isTag(vs): + return ptr.To(Tag(vs)), nil + case isAutoGroup(vs): + return ptr.To(AutoGroup(vs)), nil + } + + if isHost(vs) { + return ptr.To(Host(vs)), nil + } + + return nil, fmt.Errorf(`Invalid alias %q. An alias must be one of the following types: +- wildcard (*) +- user (containing an "@") +- group (starting with "group:") +- tag (starting with "tag:") +- autogroup (starting with "autogroup:") +- host + +Please check the format and try again.`, vs) +} + +// AliasEnc is used to deserialize a Alias. +type AliasEnc struct{ Alias } + +func (ve *AliasEnc) UnmarshalJSON(b []byte) error { + ptr, err := unmarshalPointer( + b, + parseAlias, + ) + if err != nil { + return err + } + ve.Alias = ptr + + return nil +} + +type Aliases []Alias + +func (a *Aliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + (*a)[i] = alias.Alias + } + + return nil +} + +// MarshalJSON marshals the Aliases to JSON. +func (a Aliases) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte("[]"), nil + } + + aliases := make([]string, len(a)) + for i, alias := range a { + switch v := alias.(type) { + case *Username: + aliases[i] = string(*v) + case *Group: + aliases[i] = string(*v) + case *Tag: + aliases[i] = string(*v) + case *Host: + aliases[i] = string(*v) + case *Prefix: + aliases[i] = v.String() + case *AutoGroup: + aliases[i] = string(*v) + case Asterix: + aliases[i] = "*" + default: + return nil, fmt.Errorf("unknown alias type: %T", v) + } + } + + return json.Marshal(aliases) +} + +func (a Aliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + for _, alias := range a { + aips, err := alias.Resolve(p, users, nodes) + if err != nil { + errs = append(errs, err) + } + + ips.AddSet(aips) + } + + return buildIPSetMultiErr(&ips, errs) +} + +func buildIPSetMultiErr(ipBuilder *netipx.IPSetBuilder, errs []error) (*netipx.IPSet, error) { + ips, err := ipBuilder.IPSet() + return ips, multierr.New(append(errs, err)...) +} + +// Helper function to unmarshal a JSON string into either an AutoApprover or Owner pointer. +func unmarshalPointer[T any]( + b []byte, + parseFunc func(string) (T, error), +) (T, error) { + var s string + err := json.Unmarshal(b, &s) + if err != nil { + var t T + return t, err + } + + return parseFunc(s) +} + +type AutoApprover interface { + CanBeAutoApprover() bool + UnmarshalJSON([]byte) error + String() string +} + +type AutoApprovers []AutoApprover + +func (aa *AutoApprovers) UnmarshalJSON(b []byte) error { + var autoApprovers []AutoApproverEnc + err := json.Unmarshal(b, &autoApprovers, policyJSONOpts...) + if err != nil { + return err + } + + *aa = make([]AutoApprover, len(autoApprovers)) + for i, autoApprover := range autoApprovers { + (*aa)[i] = autoApprover.AutoApprover + } + + return nil +} + +// MarshalJSON marshals the AutoApprovers to JSON. +func (aa AutoApprovers) MarshalJSON() ([]byte, error) { + if aa == nil { + return []byte("[]"), nil + } + + approvers := make([]string, len(aa)) + for i, approver := range aa { + switch v := approver.(type) { + case *Username: + approvers[i] = string(*v) + case *Tag: + approvers[i] = string(*v) + case *Group: + approvers[i] = string(*v) + default: + return nil, fmt.Errorf("unknown auto approver type: %T", v) + } + } + + return json.Marshal(approvers) +} + +func parseAutoApprover(s string) (AutoApprover, error) { + switch { + case isUser(s): + return ptr.To(Username(s)), nil + case isGroup(s): + return ptr.To(Group(s)), nil + case isTag(s): + return ptr.To(Tag(s)), nil + } + + return nil, fmt.Errorf(`Invalid AutoApprover %q. An alias must be one of the following types: +- user (containing an "@") +- group (starting with "group:") +- tag (starting with "tag:") + +Please check the format and try again.`, s) +} + +// AutoApproverEnc is used to deserialize a AutoApprover. +type AutoApproverEnc struct{ AutoApprover } + +func (ve *AutoApproverEnc) UnmarshalJSON(b []byte) error { + ptr, err := unmarshalPointer( + b, + parseAutoApprover, + ) + if err != nil { + return err + } + ve.AutoApprover = ptr + + return nil +} + +type Owner interface { + CanBeTagOwner() bool + UnmarshalJSON([]byte) error + String() string +} + +// OwnerEnc is used to deserialize a Owner. +type OwnerEnc struct{ Owner } + +func (ve *OwnerEnc) UnmarshalJSON(b []byte) error { + ptr, err := unmarshalPointer( + b, + parseOwner, + ) + if err != nil { + return err + } + ve.Owner = ptr + + return nil +} + +type Owners []Owner + +func (o *Owners) UnmarshalJSON(b []byte) error { + var owners []OwnerEnc + err := json.Unmarshal(b, &owners, policyJSONOpts...) + if err != nil { + return err + } + + *o = make([]Owner, len(owners)) + for i, owner := range owners { + (*o)[i] = owner.Owner + } + + return nil +} + +// MarshalJSON marshals the Owners to JSON. +func (o Owners) MarshalJSON() ([]byte, error) { + if o == nil { + return []byte("[]"), nil + } + + owners := make([]string, len(o)) + for i, owner := range o { + switch v := owner.(type) { + case *Username: + owners[i] = string(*v) + case *Group: + owners[i] = string(*v) + case *Tag: + owners[i] = string(*v) + default: + return nil, fmt.Errorf("unknown owner type: %T", v) + } + } + + return json.Marshal(owners) +} + +func parseOwner(s string) (Owner, error) { + switch { + case isUser(s): + return ptr.To(Username(s)), nil + case isGroup(s): + return ptr.To(Group(s)), nil + case isTag(s): + return ptr.To(Tag(s)), nil + } + + return nil, fmt.Errorf(`Invalid Owner %q. An alias must be one of the following types: +- user (containing an "@") +- group (starting with "group:") +- tag (starting with "tag:") + +Please check the format and try again.`, s) +} + +type Usernames []Username + +// Groups are a map of Group to a list of Username. +type Groups map[Group]Usernames + +func (g Groups) Contains(group *Group) error { + if group == nil { + return nil + } + + for defined := range map[Group]Usernames(g) { + if defined == *group { + return nil + } + } + + return fmt.Errorf(`Group %q is not defined in the Policy, please define or remove the reference to it`, group) +} + +// UnmarshalJSON overrides the default JSON unmarshalling for Groups to ensure +// that each group name is validated using the isGroup function. This ensures +// that all group names conform to the expected format, which is always prefixed +// with "group:". If any group name is invalid, an error is returned. +func (g *Groups) UnmarshalJSON(b []byte) error { + // First unmarshal as a generic map to validate group names first + var rawMap map[string]any + if err := json.Unmarshal(b, &rawMap); err != nil { + return err + } + + // Validate group names first before checking data types + for key := range rawMap { + group := Group(key) + if err := group.Validate(); err != nil { + return err + } + } + + // Then validate each field can be converted to []string + rawGroups := make(map[string][]string) + for key, value := range rawMap { + switch v := value.(type) { + case []any: + // Convert []interface{} to []string + var stringSlice []string + for _, item := range v { + if str, ok := item.(string); ok { + stringSlice = append(stringSlice, str) + } else { + return fmt.Errorf(`Group "%s" contains invalid member type, expected string but got %T`, key, item) + } + } + rawGroups[key] = stringSlice + case string: + return fmt.Errorf(`Group "%s" value must be an array of users, got string: "%s"`, key, v) + default: + return fmt.Errorf(`Group "%s" value must be an array of users, got %T`, key, v) + } + } + + *g = make(Groups) + for key, value := range rawGroups { + group := Group(key) + // Group name already validated above + var usernames Usernames + + for _, u := range value { + username := Username(u) + if err := username.Validate(); err != nil { + if isGroup(u) { + return fmt.Errorf("Nested groups are not allowed, found %q inside %q", u, group) + } + + return err + } + usernames = append(usernames, username) + } + + (*g)[group] = usernames + } + + return nil +} + +// Hosts are alias for IP addresses or subnets. +type Hosts map[Host]Prefix + +func (h *Hosts) UnmarshalJSON(b []byte) error { + var rawHosts map[string]string + if err := json.Unmarshal(b, &rawHosts, policyJSONOpts...); err != nil { + return err + } + + *h = make(Hosts) + for key, value := range rawHosts { + host := Host(key) + if err := host.Validate(); err != nil { + return err + } + + var prefix Prefix + if err := prefix.parseString(value); err != nil { + return fmt.Errorf(`Hostname "%s" contains an invalid IP address: "%s"`, key, value) + } + + (*h)[host] = prefix + } + + return nil +} + +// MarshalJSON marshals the Hosts to JSON. +func (h Hosts) MarshalJSON() ([]byte, error) { + if h == nil { + return []byte("{}"), nil + } + + rawHosts := make(map[string]string) + for host, prefix := range h { + rawHosts[string(host)] = prefix.String() + } + + return json.Marshal(rawHosts) +} + +func (h Hosts) exist(name Host) bool { + _, ok := h[name] + return ok +} + +// MarshalJSON marshals the TagOwners to JSON. +func (to TagOwners) MarshalJSON() ([]byte, error) { + if to == nil { + return []byte("{}"), nil + } + + rawTagOwners := make(map[string][]string) + for tag, owners := range to { + tagStr := string(tag) + ownerStrs := make([]string, len(owners)) + + for i, owner := range owners { + switch v := owner.(type) { + case *Username: + ownerStrs[i] = string(*v) + case *Group: + ownerStrs[i] = string(*v) + case *Tag: + ownerStrs[i] = string(*v) + default: + return nil, fmt.Errorf("unknown owner type: %T", v) + } + } + + rawTagOwners[tagStr] = ownerStrs + } + + return json.Marshal(rawTagOwners) +} + +// TagOwners are a map of Tag to a list of the UserEntities that own the tag. +type TagOwners map[Tag]Owners + +func (to TagOwners) Contains(tagOwner *Tag) error { + if tagOwner == nil { + return nil + } + + for defined := range map[Tag]Owners(to) { + if defined == *tagOwner { + return nil + } + } + + return fmt.Errorf(`Tag %q is not defined in the Policy, please define or remove the reference to it`, tagOwner) +} + +type AutoApproverPolicy struct { + Routes map[netip.Prefix]AutoApprovers `json:"routes,omitempty"` + ExitNode AutoApprovers `json:"exitNode,omitempty"` +} + +// MarshalJSON marshals the AutoApproverPolicy to JSON. +func (ap AutoApproverPolicy) MarshalJSON() ([]byte, error) { + // Marshal empty policies as empty object + if ap.Routes == nil && ap.ExitNode == nil { + return []byte("{}"), nil + } + + type Alias AutoApproverPolicy + + // Create a new object to avoid marshalling nil slices as null instead of empty arrays + obj := Alias(ap) + + // Initialize empty maps/slices to ensure they're marshalled as empty objects/arrays instead of null + if obj.Routes == nil { + obj.Routes = make(map[netip.Prefix]AutoApprovers) + } + + if obj.ExitNode == nil { + obj.ExitNode = AutoApprovers{} + } + + return json.Marshal(&obj) +} + +// resolveAutoApprovers resolves the AutoApprovers to a map of netip.Prefix to netipx.IPSet. +// The resulting map can be used to quickly look up if a node can self-approve a route. +// It is intended for internal use in a PolicyManager. +func resolveAutoApprovers(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (map[netip.Prefix]*netipx.IPSet, *netipx.IPSet, error) { + if p == nil { + return nil, nil, nil + } + var err error + + routes := make(map[netip.Prefix]*netipx.IPSetBuilder) + + for prefix, autoApprovers := range p.AutoApprovers.Routes { + if _, ok := routes[prefix]; !ok { + routes[prefix] = new(netipx.IPSetBuilder) + } + for _, autoApprover := range autoApprovers { + aa, ok := autoApprover.(Alias) + if !ok { + // Should never happen + return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + } + // If it does not resolve, that means the autoApprover is not associated with any IP addresses. + ips, _ := aa.Resolve(p, users, nodes) + routes[prefix].AddSet(ips) + } + } + + var exitNodeSetBuilder netipx.IPSetBuilder + if len(p.AutoApprovers.ExitNode) > 0 { + for _, autoApprover := range p.AutoApprovers.ExitNode { + aa, ok := autoApprover.(Alias) + if !ok { + // Should never happen + return nil, nil, fmt.Errorf("autoApprover %v is not an Alias", autoApprover) + } + // If it does not resolve, that means the autoApprover is not associated with any IP addresses. + ips, _ := aa.Resolve(p, users, nodes) + exitNodeSetBuilder.AddSet(ips) + } + } + + ret := make(map[netip.Prefix]*netipx.IPSet) + for prefix, builder := range routes { + ipSet, err := builder.IPSet() + if err != nil { + return nil, nil, err + } + ret[prefix] = ipSet + } + + var exitNodeSet *netipx.IPSet + if len(p.AutoApprovers.ExitNode) > 0 { + exitNodeSet, err = exitNodeSetBuilder.IPSet() + if err != nil { + return nil, nil, err + } + } + + return ret, exitNodeSet, nil +} + +// Action represents the action to take for an ACL rule. +type Action string + +const ( + ActionAccept Action = "accept" +) + +// SSHAction represents the action to take for an SSH rule. +type SSHAction string + +const ( + SSHActionAccept SSHAction = "accept" + SSHActionCheck SSHAction = "check" +) + +// String returns the string representation of the Action. +func (a Action) String() string { + return string(a) +} + +// UnmarshalJSON implements JSON unmarshaling for Action. +func (a *Action) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + switch str { + case "accept": + *a = ActionAccept + default: + return fmt.Errorf("invalid action %q, must be %q", str, ActionAccept) + } + return nil +} + +// MarshalJSON implements JSON marshaling for Action. +func (a Action) MarshalJSON() ([]byte, error) { + return json.Marshal(string(a)) +} + +// String returns the string representation of the SSHAction. +func (a SSHAction) String() string { + return string(a) +} + +// UnmarshalJSON implements JSON unmarshaling for SSHAction. +func (a *SSHAction) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + switch str { + case "accept": + *a = SSHActionAccept + case "check": + *a = SSHActionCheck + default: + return fmt.Errorf("invalid SSH action %q, must be one of: accept, check", str) + } + return nil +} + +// MarshalJSON implements JSON marshaling for SSHAction. +func (a SSHAction) MarshalJSON() ([]byte, error) { + return json.Marshal(string(a)) +} + +// Protocol represents a network protocol with its IANA number and descriptions. +type Protocol string + +const ( + ProtocolICMP Protocol = "icmp" + ProtocolIGMP Protocol = "igmp" + ProtocolIPv4 Protocol = "ipv4" + ProtocolIPInIP Protocol = "ip-in-ip" + ProtocolTCP Protocol = "tcp" + ProtocolEGP Protocol = "egp" + ProtocolIGP Protocol = "igp" + ProtocolUDP Protocol = "udp" + ProtocolGRE Protocol = "gre" + ProtocolESP Protocol = "esp" + ProtocolAH Protocol = "ah" + ProtocolIPv6ICMP Protocol = "ipv6-icmp" + ProtocolSCTP Protocol = "sctp" + ProtocolFC Protocol = "fc" + ProtocolWildcard Protocol = "*" +) + +// String returns the string representation of the Protocol. +func (p Protocol) String() string { + return string(p) +} + +// Description returns the human-readable description of the Protocol. +func (p Protocol) Description() string { + switch p { + case ProtocolICMP: + return "Internet Control Message Protocol" + case ProtocolIGMP: + return "Internet Group Management Protocol" + case ProtocolIPv4: + return "IPv4 encapsulation" + case ProtocolTCP: + return "Transmission Control Protocol" + case ProtocolEGP: + return "Exterior Gateway Protocol" + case ProtocolIGP: + return "Interior Gateway Protocol" + case ProtocolUDP: + return "User Datagram Protocol" + case ProtocolGRE: + return "Generic Routing Encapsulation" + case ProtocolESP: + return "Encapsulating Security Payload" + case ProtocolAH: + return "Authentication Header" + case ProtocolIPv6ICMP: + return "Internet Control Message Protocol for IPv6" + case ProtocolSCTP: + return "Stream Control Transmission Protocol" + case ProtocolFC: + return "Fibre Channel" + case ProtocolWildcard: + return "Wildcard (not supported - use specific protocol)" + default: + return "Unknown Protocol" + } +} + +// parseProtocol converts a Protocol to its IANA protocol numbers and wildcard requirement. +// Since validation happens during UnmarshalJSON, this method should not fail for valid Protocol values. +func (p Protocol) parseProtocol() ([]int, bool) { + switch p { + case "": + // Empty protocol applies to TCP and UDP traffic only + return []int{protocolTCP, protocolUDP}, false + case ProtocolWildcard: + // Wildcard protocol - defensive handling (should not reach here due to validation) + return nil, false + case ProtocolIGMP: + return []int{protocolIGMP}, true + case ProtocolIPv4, ProtocolIPInIP: + return []int{protocolIPv4}, true + case ProtocolTCP: + return []int{protocolTCP}, false + case ProtocolEGP: + return []int{protocolEGP}, true + case ProtocolIGP: + return []int{protocolIGP}, true + case ProtocolUDP: + return []int{protocolUDP}, false + case ProtocolGRE: + return []int{protocolGRE}, true + case ProtocolESP: + return []int{protocolESP}, true + case ProtocolAH: + return []int{protocolAH}, true + case ProtocolSCTP: + return []int{protocolSCTP}, false + case ProtocolICMP: + return []int{protocolICMP, protocolIPv6ICMP}, true + default: + // Try to parse as a numeric protocol number + // This should not fail since validation happened during unmarshaling + protocolNumber, _ := strconv.Atoi(string(p)) + + // Determine if wildcard is needed based on protocol number + needsWildcard := protocolNumber != protocolTCP && + protocolNumber != protocolUDP && + protocolNumber != protocolSCTP + + return []int{protocolNumber}, needsWildcard + } +} + +// UnmarshalJSON implements JSON unmarshaling for Protocol. +func (p *Protocol) UnmarshalJSON(b []byte) error { + str := strings.Trim(string(b), `"`) + + // Normalize to lowercase for case-insensitive matching + *p = Protocol(strings.ToLower(str)) + + // Validate the protocol + if err := p.validate(); err != nil { + return err + } + + return nil +} + +// validate checks if the Protocol is valid. +func (p Protocol) validate() error { + switch p { + case "", ProtocolICMP, ProtocolIGMP, ProtocolIPv4, ProtocolIPInIP, + ProtocolTCP, ProtocolEGP, ProtocolIGP, ProtocolUDP, ProtocolGRE, + ProtocolESP, ProtocolAH, ProtocolSCTP: + return nil + case ProtocolWildcard: + // Wildcard "*" is not allowed - Tailscale rejects it + return fmt.Errorf("proto name \"*\" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)") + default: + // Try to parse as a numeric protocol number + str := string(p) + + // Check for leading zeros (not allowed by Tailscale) + if str == "0" || (len(str) > 1 && str[0] == '0') { + return fmt.Errorf("leading 0 not permitted in protocol number \"%s\"", str) + } + + protocolNumber, err := strconv.Atoi(str) + if err != nil { + return fmt.Errorf("invalid protocol %q: must be a known protocol name or valid protocol number 0-255", p) + } + + if protocolNumber < 0 || protocolNumber > 255 { + return fmt.Errorf("protocol number %d out of range (0-255)", protocolNumber) + } + + return nil + } +} + +// MarshalJSON implements JSON marshaling for Protocol. +func (p Protocol) MarshalJSON() ([]byte, error) { + return json.Marshal(string(p)) +} + +// Protocol constants matching the IANA numbers +const ( + protocolICMP = 1 // Internet Control Message + protocolIGMP = 2 // Internet Group Management + protocolIPv4 = 4 // IPv4 encapsulation + protocolTCP = 6 // Transmission Control + protocolEGP = 8 // Exterior Gateway Protocol + protocolIGP = 9 // any private interior gateway (used by Cisco for their IGRP) + protocolUDP = 17 // User Datagram + protocolGRE = 47 // Generic Routing Encapsulation + protocolESP = 50 // Encap Security Payload + protocolAH = 51 // Authentication Header + protocolIPv6ICMP = 58 // ICMP for IPv6 + protocolSCTP = 132 // Stream Control Transmission Protocol + protocolFC = 133 // Fibre Channel +) + +type ACL struct { + Action Action `json:"action"` + Protocol Protocol `json:"proto"` + Sources Aliases `json:"src"` + Destinations []AliasWithPorts `json:"dst"` +} + +// UnmarshalJSON implements custom unmarshalling for ACL that ignores fields starting with '#'. +// headscale-admin uses # in some field names to add metadata, so we will ignore +// those to ensure it doesnt break. +// https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38 +func (a *ACL) UnmarshalJSON(b []byte) error { + // First unmarshal into a map to filter out comment fields + var raw map[string]any + if err := json.Unmarshal(b, &raw, policyJSONOpts...); err != nil { + return err + } + + // Remove any fields that start with '#' + filtered := make(map[string]any) + for key, value := range raw { + if !strings.HasPrefix(key, "#") { + filtered[key] = value + } + } + + // Marshal the filtered map back to JSON + filteredBytes, err := json.Marshal(filtered) + if err != nil { + return err + } + + // Create a type alias to avoid infinite recursion + type aclAlias ACL + var temp aclAlias + + // Unmarshal into the temporary struct using the v2 JSON options + if err := json.Unmarshal(filteredBytes, &temp, policyJSONOpts...); err != nil { + return err + } + + // Copy the result back to the original struct + *a = ACL(temp) + return nil +} + +// Policy represents a Tailscale Network Policy. +// TODO(kradalby): +// Add validation method checking: +// All users exists +// All groups and users are valid tag TagOwners +// Everything referred to in ACLs exists in other +// entities. +type Policy struct { + // validated is set if the policy has been validated. + // It is not safe to use before it is validated, and + // callers using it should panic if not + validated bool `json:"-"` + + Groups Groups `json:"groups,omitempty"` + Hosts Hosts `json:"hosts,omitempty"` + TagOwners TagOwners `json:"tagOwners,omitempty"` + ACLs []ACL `json:"acls,omitempty"` + AutoApprovers AutoApproverPolicy `json:"autoApprovers"` + SSHs []SSH `json:"ssh,omitempty"` +} + +// MarshalJSON is deliberately not implemented for Policy. +// We use the default JSON marshalling behavior provided by the Go runtime. + +var ( + // TODO(kradalby): Add these checks for tagOwners and autoApprovers. + autogroupForSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} + autogroupForDst = []AutoGroup{AutoGroupInternet, AutoGroupMember, AutoGroupTagged, AutoGroupSelf} + autogroupForSSHSrc = []AutoGroup{AutoGroupMember, AutoGroupTagged} + autogroupForSSHDst = []AutoGroup{AutoGroupMember, AutoGroupTagged, AutoGroupSelf} + autogroupForSSHUser = []AutoGroup{AutoGroupNonRoot} + autogroupNotSupported = []AutoGroup{} +) + +func validateAutogroupSupported(ag *AutoGroup) error { + if ag == nil { + return nil + } + + if slices.Contains(autogroupNotSupported, *ag) { + return fmt.Errorf("autogroup %q is not supported in headscale", *ag) + } + + return nil +} + +func validateAutogroupForSrc(src *AutoGroup) error { + if src == nil { + return nil + } + + if src.Is(AutoGroupInternet) { + return errors.New(`"autogroup:internet" used in source, it can only be used in ACL destinations`) + } + + if src.Is(AutoGroupSelf) { + return errors.New(`"autogroup:self" used in source, it can only be used in ACL destinations`) + } + + if !slices.Contains(autogroupForSrc, *src) { + return fmt.Errorf("autogroup %q is not supported for ACL sources, can be %v", *src, autogroupForSrc) + } + + return nil +} + +func validateAutogroupForDst(dst *AutoGroup) error { + if dst == nil { + return nil + } + + if !slices.Contains(autogroupForDst, *dst) { + return fmt.Errorf("autogroup %q is not supported for ACL destinations, can be %v", *dst, autogroupForDst) + } + + return nil +} + +func validateAutogroupForSSHSrc(src *AutoGroup) error { + if src == nil { + return nil + } + + if src.Is(AutoGroupInternet) { + return errors.New(`"autogroup:internet" used in SSH source, it can only be used in ACL destinations`) + } + + if !slices.Contains(autogroupForSSHSrc, *src) { + return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *src, autogroupForSSHSrc) + } + + return nil +} + +func validateAutogroupForSSHDst(dst *AutoGroup) error { + if dst == nil { + return nil + } + + if dst.Is(AutoGroupInternet) { + return errors.New(`"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`) + } + + if !slices.Contains(autogroupForSSHDst, *dst) { + return fmt.Errorf("autogroup %q is not supported for SSH sources, can be %v", *dst, autogroupForSSHDst) + } + + return nil +} + +func validateAutogroupForSSHUser(user *AutoGroup) error { + if user == nil { + return nil + } + + if !slices.Contains(autogroupForSSHUser, *user) { + return fmt.Errorf("autogroup %q is not supported for SSH user, can be %v", *user, autogroupForSSHUser) + } + + return nil +} + +// validateSSHSrcDstCombination validates that SSH source/destination combinations +// follow Tailscale's security model: +// - Destination can be: tags, autogroup:self (if source is users/groups), or same-user +// - Tags/autogroup:tagged CANNOT SSH to user destinations +// - Username destinations require the source to be that same single user only. +func validateSSHSrcDstCombination(sources SSHSrcAliases, destinations SSHDstAliases) error { + // Categorize source types + srcHasTaggedEntities := false + srcHasGroups := false + srcUsernames := make(map[string]bool) + + for _, src := range sources { + switch v := src.(type) { + case *Tag: + srcHasTaggedEntities = true + case *AutoGroup: + if v.Is(AutoGroupTagged) { + srcHasTaggedEntities = true + } else if v.Is(AutoGroupMember) { + srcHasGroups = true // autogroup:member is like a group of users + } + case *Group: + srcHasGroups = true + case *Username: + srcUsernames[string(*v)] = true + } + } + + // Check destinations against source constraints + for _, dst := range destinations { + switch v := dst.(type) { + case *Username: + // Rule: Tags/autogroup:tagged CANNOT SSH to user destinations + if srcHasTaggedEntities { + return fmt.Errorf("%w (%s); use autogroup:tagged or specific tags as destinations instead", + ErrSSHTagSourceToUserDest, *v) + } + // Rule: Username destination requires source to be that same single user only + if srcHasGroups || len(srcUsernames) != 1 || !srcUsernames[string(*v)] { + return fmt.Errorf("%w %q; use autogroup:self instead for same-user SSH access", + ErrSSHUserDestRequiresSameUser, *v) + } + case *AutoGroup: + // Rule: autogroup:self requires source to NOT contain tags + if v.Is(AutoGroupSelf) && srcHasTaggedEntities { + return ErrSSHAutogroupSelfRequiresUserSource + } + // Rule: autogroup:member (user-owned devices) cannot be accessed by tagged entities + if v.Is(AutoGroupMember) && srcHasTaggedEntities { + return ErrSSHTagSourceToAutogroupMember + } + } + } + + return nil +} + +// validate reports if there are any errors in a policy after +// the unmarshaling process. +// It runs through all rules and checks if there are any inconsistencies +// in the policy that needs to be addressed before it can be used. +func (p *Policy) validate() error { + if p == nil { + panic("passed nil policy") + } + + // All errors are collected and presented to the user, + // when adding more validation, please add to the list of errors. + var errs []error + + for _, acl := range p.ACLs { + for _, src := range acl.Sources { + switch src := src.(type) { + case *Host: + h := src + if !p.Hosts.exist(*h) { + errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + } + case *AutoGroup: + ag := src + + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForSrc(ag); err != nil { + errs = append(errs, err) + continue + } + case *Group: + g := src + if err := p.Groups.Contains(g); err != nil { + errs = append(errs, err) + } + case *Tag: + tagOwner := src + if err := p.TagOwners.Contains(tagOwner); err != nil { + errs = append(errs, err) + } + } + } + + for _, dst := range acl.Destinations { + switch dst.Alias.(type) { + case *Host: + h := dst.Alias.(*Host) + if !p.Hosts.exist(*h) { + errs = append(errs, fmt.Errorf(`Host %q is not defined in the Policy, please define or remove the reference to it`, *h)) + } + case *AutoGroup: + ag := dst.Alias.(*AutoGroup) + + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForDst(ag); err != nil { + errs = append(errs, err) + continue + } + case *Group: + g := dst.Alias.(*Group) + if err := p.Groups.Contains(g); err != nil { + errs = append(errs, err) + } + case *Tag: + tagOwner := dst.Alias.(*Tag) + if err := p.TagOwners.Contains(tagOwner); err != nil { + errs = append(errs, err) + } + } + } + + // Validate protocol-port compatibility + if err := validateProtocolPortCompatibility(acl.Protocol, acl.Destinations); err != nil { + errs = append(errs, err) + } + } + + for _, ssh := range p.SSHs { + for _, user := range ssh.Users { + if strings.HasPrefix(string(user), "autogroup:") { + maybeAuto := AutoGroup(user) + if err := validateAutogroupForSSHUser(&maybeAuto); err != nil { + errs = append(errs, err) + continue + } + } + } + + for _, src := range ssh.Sources { + switch src := src.(type) { + case *AutoGroup: + ag := src + + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForSSHSrc(ag); err != nil { + errs = append(errs, err) + continue + } + case *Group: + g := src + if err := p.Groups.Contains(g); err != nil { + errs = append(errs, err) + } + case *Tag: + tagOwner := src + if err := p.TagOwners.Contains(tagOwner); err != nil { + errs = append(errs, err) + } + } + } + for _, dst := range ssh.Destinations { + switch dst := dst.(type) { + case *AutoGroup: + ag := dst + if err := validateAutogroupSupported(ag); err != nil { + errs = append(errs, err) + continue + } + + if err := validateAutogroupForSSHDst(ag); err != nil { + errs = append(errs, err) + continue + } + case *Tag: + tagOwner := dst + if err := p.TagOwners.Contains(tagOwner); err != nil { + errs = append(errs, err) + } + } + } + + // Validate SSH source/destination combinations follow Tailscale's security model + err := validateSSHSrcDstCombination(ssh.Sources, ssh.Destinations) + if err != nil { + errs = append(errs, err) + } + } + + for _, tagOwners := range p.TagOwners { + for _, tagOwner := range tagOwners { + switch tagOwner := tagOwner.(type) { + case *Group: + g := tagOwner + if err := p.Groups.Contains(g); err != nil { + errs = append(errs, err) + } + case *Tag: + t := tagOwner + + err := p.TagOwners.Contains(t) + if err != nil { + errs = append(errs, err) + } + } + } + } + + // Validate tag ownership chains for circular references and undefined tags. + _, err := flattenTagOwners(p.TagOwners) + if err != nil { + errs = append(errs, err) + } + + for _, approvers := range p.AutoApprovers.Routes { + for _, approver := range approvers { + switch approver := approver.(type) { + case *Group: + g := approver + if err := p.Groups.Contains(g); err != nil { + errs = append(errs, err) + } + case *Tag: + tagOwner := approver + if err := p.TagOwners.Contains(tagOwner); err != nil { + errs = append(errs, err) + } + } + } + } + + for _, approver := range p.AutoApprovers.ExitNode { + switch approver := approver.(type) { + case *Group: + g := approver + if err := p.Groups.Contains(g); err != nil { + errs = append(errs, err) + } + case *Tag: + tagOwner := approver + if err := p.TagOwners.Contains(tagOwner); err != nil { + errs = append(errs, err) + } + } + } + + if len(errs) > 0 { + return multierr.New(errs...) + } + + p.validated = true + + return nil +} + +// SSH controls who can ssh into which machines. +type SSH struct { + Action SSHAction `json:"action"` + Sources SSHSrcAliases `json:"src"` + Destinations SSHDstAliases `json:"dst"` + Users SSHUsers `json:"users"` + CheckPeriod model.Duration `json:"checkPeriod,omitempty"` +} + +// SSHSrcAliases is a list of aliases that can be used as sources in an SSH rule. +// It can be a list of usernames, groups, tags or autogroups. +type SSHSrcAliases []Alias + +// MarshalJSON marshals the Groups to JSON. +func (g Groups) MarshalJSON() ([]byte, error) { + if g == nil { + return []byte("{}"), nil + } + + raw := make(map[string][]string) + for group, usernames := range g { + users := make([]string, len(usernames)) + for i, username := range usernames { + users[i] = string(username) + } + raw[string(group)] = users + } + + return json.Marshal(raw) +} + +func (a *SSHSrcAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Group, *Tag, *AutoGroup: + (*a)[i] = alias.Alias + default: + return fmt.Errorf( + "alias %T is not supported for SSH source", + alias.Alias, + ) + } + } + + return nil +} + +func (a *SSHDstAliases) UnmarshalJSON(b []byte) error { + var aliases []AliasEnc + err := json.Unmarshal(b, &aliases, policyJSONOpts...) + if err != nil { + return err + } + + *a = make([]Alias, len(aliases)) + for i, alias := range aliases { + switch alias.Alias.(type) { + case *Username, *Tag, *AutoGroup, *Host: + (*a)[i] = alias.Alias + case Asterix: + return fmt.Errorf("%w; use 'autogroup:member' for user-owned devices, "+ + "'autogroup:tagged' for tagged devices, or specific tags/users", + ErrSSHWildcardDestination) + default: + return fmt.Errorf( + "alias %T is not supported for SSH destination", + alias.Alias, + ) + } + } + + return nil +} + +// MarshalJSON marshals the SSHDstAliases to JSON. +func (a SSHDstAliases) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte("[]"), nil + } + + aliases := make([]string, len(a)) + for i, alias := range a { + switch v := alias.(type) { + case *Username: + aliases[i] = string(*v) + case *Tag: + aliases[i] = string(*v) + case *AutoGroup: + aliases[i] = string(*v) + case *Host: + aliases[i] = string(*v) + case Asterix: + // Marshal wildcard as "*" so it gets rejected during unmarshal + // with a proper error message explaining alternatives + aliases[i] = "*" + default: + return nil, fmt.Errorf("unknown SSH destination alias type: %T", v) + } + } + + return json.Marshal(aliases) +} + +// MarshalJSON marshals the SSHSrcAliases to JSON. +func (a SSHSrcAliases) MarshalJSON() ([]byte, error) { + if a == nil { + return []byte("[]"), nil + } + + aliases := make([]string, len(a)) + for i, alias := range a { + switch v := alias.(type) { + case *Username: + aliases[i] = string(*v) + case *Group: + aliases[i] = string(*v) + case *Tag: + aliases[i] = string(*v) + case *AutoGroup: + aliases[i] = string(*v) + case Asterix: + aliases[i] = "*" + default: + return nil, fmt.Errorf("unknown SSH source alias type: %T", v) + } + } + + return json.Marshal(aliases) +} + +func (a SSHSrcAliases) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeView]) (*netipx.IPSet, error) { + var ips netipx.IPSetBuilder + var errs []error + + for _, alias := range a { + aips, err := alias.Resolve(p, users, nodes) + if err != nil { + errs = append(errs, err) + } + + ips.AddSet(aips) + } + + return buildIPSetMultiErr(&ips, errs) +} + +// SSHDstAliases is a list of aliases that can be used as destinations in an SSH rule. +// It can be a list of usernames, tags or autogroups. +type SSHDstAliases []Alias + +type SSHUsers []SSHUser + +func (u SSHUsers) ContainsRoot() bool { + return slices.Contains(u, "root") +} + +func (u SSHUsers) ContainsNonRoot() bool { + return slices.Contains(u, SSHUser(AutoGroupNonRoot)) +} + +func (u SSHUsers) NormalUsers() []SSHUser { + return slicesx.Filter(nil, u, func(user SSHUser) bool { + return user != "root" && user != SSHUser(AutoGroupNonRoot) + }) +} + +type SSHUser string + +func (u SSHUser) String() string { + return string(u) +} + +// MarshalJSON marshals the SSHUser to JSON. +func (u SSHUser) MarshalJSON() ([]byte, error) { + return json.Marshal(string(u)) +} + +// unmarshalPolicy takes a byte slice and unmarshals it into a Policy struct. +// In addition to unmarshalling, it will also validate the policy. +// This is the only entrypoint of reading a policy from a file or other source. +func unmarshalPolicy(b []byte) (*Policy, error) { + if len(b) == 0 { + return nil, nil + } + + var policy Policy + ast, err := hujson.Parse(b) + if err != nil { + return nil, fmt.Errorf("parsing HuJSON: %w", err) + } + + ast.Standardize() + if err = json.Unmarshal(ast.Pack(), &policy, policyJSONOpts...); err != nil { + var serr *json.SemanticError + if errors.As(err, &serr) && serr.Err == json.ErrUnknownName { + ptr := serr.JSONPointer + name := ptr.LastToken() + return nil, fmt.Errorf("unknown field %q", name) + } + return nil, fmt.Errorf("parsing policy from bytes: %w", err) + } + + if err := policy.validate(); err != nil { + return nil, err + } + + return &policy, nil +} + +// validateProtocolPortCompatibility checks that only TCP, UDP, and SCTP protocols +// can have specific ports. All other protocols should only use wildcard ports. +func validateProtocolPortCompatibility(protocol Protocol, destinations []AliasWithPorts) error { + // Only TCP, UDP, and SCTP support specific ports + supportsSpecificPorts := protocol == ProtocolTCP || protocol == ProtocolUDP || protocol == ProtocolSCTP || protocol == "" + + if supportsSpecificPorts { + return nil // No validation needed for these protocols + } + + // For all other protocols, check that all destinations use wildcard ports + for _, dst := range destinations { + for _, portRange := range dst.Ports { + // Check if it's not a wildcard port (0-65535) + if !(portRange.First == 0 && portRange.Last == 65535) { + return fmt.Errorf("protocol %q does not support specific ports; only \"*\" is allowed", protocol) + } + } + } + + return nil +} + +// usesAutogroupSelf checks if the policy uses autogroup:self in any ACL or SSH rules. +func (p *Policy) usesAutogroupSelf() bool { + if p == nil { + return false + } + + // Check ACL rules + for _, acl := range p.ACLs { + for _, src := range acl.Sources { + if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + return true + } + } + for _, dest := range acl.Destinations { + if ag, ok := dest.Alias.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + return true + } + } + } + + // Check SSH rules + for _, ssh := range p.SSHs { + for _, src := range ssh.Sources { + if ag, ok := src.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + return true + } + } + for _, dest := range ssh.Destinations { + if ag, ok := dest.(*AutoGroup); ok && ag.Is(AutoGroupSelf) { + return true + } + } + } + + return false +} diff --git a/hscontrol/policy/v2/types_test.go b/hscontrol/policy/v2/types_test.go new file mode 100644 index 00000000..542c9b2c --- /dev/null +++ b/hscontrol/policy/v2/types_test.go @@ -0,0 +1,3641 @@ +package v2 + +import ( + "encoding/json" + "net/netip" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/prometheus/common/model" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go4.org/netipx" + xmaps "golang.org/x/exp/maps" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +// TestUnmarshalPolicy tests the unmarshalling of JSON into Policy objects and the marshalling +// back to JSON (round-trip). It ensures that: +// 1. JSON can be correctly unmarshalled into a Policy object +// 2. A Policy object can be correctly marshalled back to JSON +// 3. The unmarshalled Policy matches the expected Policy +// 4. The marshalled and then unmarshalled Policy is semantically equivalent to the original +// (accounting for nil vs empty map/slice differences) +// +// This test also verifies that all the required struct fields are properly marshalled and +// unmarshalled, maintaining semantic equivalence through a complete JSON round-trip. + +// TestMarshalJSON tests explicit marshalling of Policy objects to JSON. +// This test ensures our custom MarshalJSON methods properly encode +// the various data structures used in the Policy. +func TestMarshalJSON(t *testing.T) { + // Create a complex test policy + policy := &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("user@example.com")}, + }, + Hosts: Hosts{ + "host-1": Prefix(mp("100.100.100.100/32")), + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{up("user@example.com")}, + }, + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Username("user@example.com")), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Username("other@example.com")), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + }, + } + + // Marshal the policy to JSON + marshalled, err := json.MarshalIndent(policy, "", " ") + require.NoError(t, err) + + // Make sure all expected fields are present in the JSON + jsonString := string(marshalled) + assert.Contains(t, jsonString, "group:example") + assert.Contains(t, jsonString, "user@example.com") + assert.Contains(t, jsonString, "host-1") + assert.Contains(t, jsonString, "100.100.100.100/32") + assert.Contains(t, jsonString, "tag:test") + assert.Contains(t, jsonString, "accept") + assert.Contains(t, jsonString, "tcp") + assert.Contains(t, jsonString, "80") + + // Unmarshal back to verify round trip + var roundTripped Policy + err = json.Unmarshal(marshalled, &roundTripped) + require.NoError(t, err) + + // Compare the original and round-tripped policies + cmps := append(util.Comparers, + cmp.Comparer(func(x, y Prefix) bool { + return x == y + }), + cmpopts.IgnoreUnexported(Policy{}), + cmpopts.EquateEmpty(), + ) + + if diff := cmp.Diff(policy, &roundTripped, cmps...); diff != "" { + t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) + } +} + +func TestUnmarshalPolicy(t *testing.T) { + tests := []struct { + name string + input string + want *Policy + wantErr string + }{ + { + name: "empty", + input: "{}", + want: &Policy{}, + }, + { + name: "groups", + input: ` +{ + "groups": { + "group:example": [ + "derp@headscale.net", + ], + }, +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("derp@headscale.net")}, + }, + }, + }, + { + name: "basic-types", + input: ` +{ + "groups": { + "group:example": [ + "testuser@headscale.net", + ], + "group:other": [ + "otheruser@headscale.net", + ], + "group:noat": [ + "noat@", + ], + }, + + "tagOwners": { + "tag:user": ["testuser@headscale.net"], + "tag:group": ["group:other"], + "tag:userandgroup": ["testuser@headscale.net", "group:other"], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + "outside": "192.168.0.0/16", + }, + + "acls": [ + // All + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"], + }, + // Users + { + "action": "accept", + "proto": "tcp", + "src": ["testuser@headscale.net"], + "dst": ["otheruser@headscale.net:80"], + }, + // Groups + { + "action": "accept", + "proto": "tcp", + "src": ["group:example"], + "dst": ["group:other:80"], + }, + // Tailscale IP + { + "action": "accept", + "proto": "tcp", + "src": ["100.101.102.103"], + "dst": ["100.101.102.104:80"], + }, + // Subnet + { + "action": "accept", + "proto": "udp", + "src": ["10.0.0.0/8"], + "dst": ["172.16.0.0/16:80"], + }, + // Hosts + { + "action": "accept", + "proto": "tcp", + "src": ["subnet-1"], + "dst": ["host-1:80-88"], + }, + // Tags + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["tag:user:80,443"], + }, + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:internet:80"], + }, + ], +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:example"): []Username{Username("testuser@headscale.net")}, + Group("group:other"): []Username{Username("otheruser@headscale.net")}, + Group("group:noat"): []Username{Username("noat@")}, + }, + TagOwners: TagOwners{ + Tag("tag:user"): Owners{up("testuser@headscale.net")}, + Tag("tag:group"): Owners{gp("group:other")}, + Tag("tag:userandgroup"): Owners{up("testuser@headscale.net"), gp("group:other")}, + }, + Hosts: Hosts{ + "host-1": Prefix(mp("100.100.100.100/32")), + "subnet-1": Prefix(mp("100.100.101.100/24")), + "outside": Prefix(mp("192.168.0.0/16")), + }, + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + // TODO(kradalby): Should this be host? + // It is: + // Includes any destination (no restrictions). + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + ptr.To(Username("testuser@headscale.net")), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Username("otheruser@headscale.net")), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + gp("group:example"), + }, + Destinations: []AliasWithPorts{ + { + Alias: gp("group:other"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + pp("100.101.102.103/32"), + }, + Destinations: []AliasWithPorts{ + { + Alias: pp("100.101.102.104/32"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "udp", + Sources: Aliases{ + pp("10.0.0.0/8"), + }, + Destinations: []AliasWithPorts{ + { + Alias: pp("172.16.0.0/16"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + hp("subnet-1"), + }, + Destinations: []AliasWithPorts{ + { + Alias: hp("host-1"), + Ports: []tailcfg.PortRange{{First: 80, Last: 88}}, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + tp("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: tp("tag:user"), + Ports: []tailcfg.PortRange{ + {First: 80, Last: 80}, + {First: 443, Last: 443}, + }, + }, + }, + }, + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + tp("tag:group"), + }, + Destinations: []AliasWithPorts{ + { + Alias: agp("autogroup:internet"), + Ports: []tailcfg.PortRange{ + {First: 80, Last: 80}, + }, + }, + }, + }, + }, + }, + }, + { + name: "2652-asterix-error-better-explain", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "*" + ], + "users": ["root"] + } + ] +} + `, + wantErr: "alias v2.Asterix is not supported for SSH source", + }, + { + name: "invalid-username", + input: ` +{ + "groups": { + "group:example": [ + "valid@", + "invalid", + ], + }, +} +`, + wantErr: `Username has to contain @, got: "invalid"`, + }, + { + name: "invalid-group", + input: ` +{ + "groups": { + "grou:example": [ + "valid@", + ], + }, +} +`, + wantErr: `Group has to start with "group:", got: "grou:example"`, + }, + { + name: "group-in-group", + input: ` +{ + "groups": { + "group:inner": [], + "group:example": [ + "group:inner", + ], + }, +} +`, + // wantErr: `Username has to contain @, got: "group:inner"`, + wantErr: `Nested groups are not allowed, found "group:inner" inside "group:example"`, + }, + { + name: "invalid-addr", + input: ` +{ + "hosts": { + "derp": "10.0", + }, +} +`, + wantErr: `Hostname "derp" contains an invalid IP address: "10.0"`, + }, + { + name: "invalid-prefix", + input: ` +{ + "hosts": { + "derp": "10.0/42", + }, +} +`, + wantErr: `Hostname "derp" contains an invalid IP address: "10.0/42"`, + }, + // TODO(kradalby): Figure out why this doesn't work. + // { + // name: "invalid-hostname", + // input: ` + // { + // "hosts": { + // "derp:merp": "10.0.0.0/31", + // }, + // } + // `, + // wantErr: `Hostname "derp:merp" is invalid`, + // }, + { + name: "invalid-auto-group", + input: ` +{ + "acls": [ + // Autogroup + { + "action": "accept", + "proto": "tcp", + "src": ["tag:group"], + "dst": ["autogroup:invalid:80"], + }, + ], +} +`, + wantErr: `AutoGroup is invalid, got: "autogroup:invalid", must be one of [autogroup:internet autogroup:member autogroup:nonroot autogroup:tagged autogroup:self]`, + }, + { + name: "undefined-hostname-errors-2490", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "user1" + ], + "dst": [ + "user1:*" + ] + } + ] +} +`, + wantErr: `Host "user1" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "defined-hostname-does-not-err-2490", + input: ` +{ + "hosts": { + "user1": "100.100.100.100", + }, + "acls": [ + { + "action": "accept", + "src": [ + "user1" + ], + "dst": [ + "user1:*" + ] + } + ] +} +`, + want: &Policy{ + Hosts: Hosts{ + "user1": Prefix(mp("100.100.100.100/32")), + }, + ACLs: []ACL{ + { + Action: "accept", + Sources: Aliases{ + hp("user1"), + }, + Destinations: []AliasWithPorts{ + { + Alias: hp("user1"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "autogroup:internet-in-dst-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "10.0.0.1" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: Aliases{ + pp("10.0.0.1/32"), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(AutoGroup("autogroup:internet")), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "autogroup:internet-in-src-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "autogroup:internet" + ], + "dst": [ + "10.0.0.1:*" + ] + } + ] +} +`, + wantErr: `"autogroup:internet" used in source, it can only be used in ACL destinations`, + }, + { + name: "autogroup:internet-in-ssh-src-not-allowed", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "autogroup:internet" + ], + "dst": [ + "tag:test" + ] + } + ] +} +`, + wantErr: `"autogroup:internet" used in SSH source, it can only be used in ACL destinations`, + }, + { + name: "autogroup:internet-in-ssh-dst-not-allowed", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "tag:test" + ], + "dst": [ + "autogroup:internet" + ] + } + ] +} +`, + wantErr: `"autogroup:internet" used in SSH destination, it can only be used in ACL destinations`, + }, + { + name: "ssh-basic", + input: ` +{ + "groups": { + "group:admins": ["admin@example.com"] + }, + "tagOwners": { + "tag:servers": ["group:admins"] + }, + "ssh": [ + { + "action": "accept", + "src": [ + "group:admins" + ], + "dst": [ + "tag:servers" + ], + "users": ["root", "admin"] + } + ] +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("admin@example.com")}, + }, + TagOwners: TagOwners{ + Tag("tag:servers"): Owners{gp("group:admins")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{ + gp("group:admins"), + }, + Destinations: SSHDstAliases{ + tp("tag:servers"), + }, + Users: []SSHUser{ + SSHUser("root"), + SSHUser("admin"), + }, + }, + }, + }, + }, + { + name: "ssh-with-tag-and-user", + input: ` +{ + "tagOwners": { + "tag:web": ["admin@example.com"], + "tag:server": ["admin@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": [ + "tag:web" + ], + "dst": [ + "tag:server" + ], + "users": ["*"] + } + ] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:web"): Owners{ptr.To(Username("admin@example.com"))}, + Tag("tag:server"): Owners{ptr.To(Username("admin@example.com"))}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{ + tp("tag:web"), + }, + Destinations: SSHDstAliases{ + tp("tag:server"), + }, + Users: []SSHUser{ + SSHUser("*"), + }, + }, + }, + }, + }, + { + name: "ssh-with-check-period", + input: ` +{ + "groups": { + "group:admins": ["admin@example.com"] + }, + "ssh": [ + { + "action": "accept", + "src": [ + "group:admins" + ], + "dst": [ + "autogroup:self" + ], + "users": ["root"], + "checkPeriod": "24h" + } + ] +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("admin@example.com")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{ + gp("group:admins"), + }, + Destinations: SSHDstAliases{ + agp("autogroup:self"), + }, + Users: []SSHUser{ + SSHUser("root"), + }, + CheckPeriod: model.Duration(24 * time.Hour), + }, + }, + }, + }, + { + name: "group-must-be-defined-acl-src", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "group:notdefined" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +} +`, + wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "group-must-be-defined-acl-dst", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "group:notdefined:*" + ] + } + ] +} +`, + wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "group-must-be-defined-acl-ssh-src", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "group:notdefined" + ], + "dst": [ + "user@" + ] + } + ] +} +`, + wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "group-must-be-defined-acl-tagOwner", + input: ` +{ + "tagOwners": { + "tag:test": ["group:notdefined"], + }, +} +`, + wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "group-must-be-defined-acl-autoapprover-route", + input: ` +{ + "autoApprovers": { + "routes": { + "10.0.0.0/16": ["group:notdefined"] + } + }, +} +`, + wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "group-must-be-defined-acl-autoapprover-exitnode", + input: ` +{ + "autoApprovers": { + "exitNode": ["group:notdefined"] + }, +} +`, + wantErr: `Group "group:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "tag-must-be-defined-acl-src", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "tag:notdefined" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +} +`, + wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "tag-must-be-defined-acl-dst", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "tag:notdefined:*" + ] + } + ] +} +`, + wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "tag-must-be-defined-acl-ssh-src", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "tag:notdefined" + ], + "dst": [ + "user@" + ] + } + ] +} +`, + wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "tag-must-be-defined-acl-ssh-dst", + input: ` +{ + "groups": { + "group:defined": ["user@"], + }, + "ssh": [ + { + "action": "accept", + "src": [ + "group:defined" + ], + "dst": [ + "tag:notdefined", + ], + } + ] +} +`, + wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "tag-must-be-defined-acl-autoapprover-route", + input: ` +{ + "autoApprovers": { + "routes": { + "10.0.0.0/16": ["tag:notdefined"] + } + }, +} +`, + wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "tag-must-be-defined-acl-autoapprover-exitnode", + input: ` +{ + "autoApprovers": { + "exitNode": ["tag:notdefined"] + }, +} +`, + wantErr: `Tag "tag:notdefined" is not defined in the Policy, please define or remove the reference to it`, + }, + { + name: "missing-dst-port-is-err", + input: ` + { + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "100.64.0.1" + ] + } + ] +} +`, + wantErr: `hostport must contain a colon (":")`, + }, + { + name: "dst-port-zero-is-err", + input: ` + { + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "100.64.0.1:0" + ] + } + ] +} +`, + wantErr: `first port must be >0, or use '*' for wildcard`, + }, + { + name: "disallow-unsupported-fields", + input: ` +{ + // rules doesnt exists, we have "acls" + "rules": [ + ] +} +`, + wantErr: `unknown field "rules"`, + }, + { + name: "disallow-unsupported-fields-nested", + input: ` +{ + "acls": [ + { "action": "accept", "BAD": ["FOO:BAR:FOO:BAR"], "NOT": ["BAD:BAD:BAD:BAD"] } + ] +} +`, + wantErr: `unknown field`, + }, + { + name: "invalid-group-name", + input: ` +{ + "groups": { + "group:test": ["user@example.com"], + "INVALID_GROUP_FIELD": ["user@example.com"] + } +} +`, + wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + }, + { + name: "invalid-group-datatype", + input: ` +{ + "groups": { + "group:test": ["user@example.com"], + "group:invalid": "should fail" + } +} +`, + wantErr: `Group "group:invalid" value must be an array of users, got string: "should fail"`, + }, + { + name: "invalid-group-name-and-datatype-fails-on-name-first", + input: ` +{ + "groups": { + "group:test": ["user@example.com"], + "INVALID_GROUP_FIELD": "should fail" + } +} +`, + wantErr: `Group has to start with "group:", got: "INVALID_GROUP_FIELD"`, + }, + { + name: "disallow-unsupported-fields-hosts-level", + input: ` +{ + "hosts": { + "host1": "10.0.0.1", + "INVALID_HOST_FIELD": "should fail" + } +} +`, + wantErr: `Hostname "INVALID_HOST_FIELD" contains an invalid IP address: "should fail"`, + }, + { + name: "disallow-unsupported-fields-tagowners-level", + input: ` +{ + "tagOwners": { + "tag:test": ["user@example.com"], + "INVALID_TAG_FIELD": "should fail" + } +} +`, + wantErr: `tag has to start with "tag:", got: "INVALID_TAG_FIELD"`, + }, + { + name: "disallow-unsupported-fields-acls-level", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"], + "INVALID_ACL_FIELD": "should fail" + } + ] +} +`, + wantErr: `unknown field "INVALID_ACL_FIELD"`, + }, + { + name: "disallow-unsupported-fields-ssh-level", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": ["user@example.com"], + "dst": ["user@example.com"], + "users": ["root"], + "INVALID_SSH_FIELD": "should fail" + } + ] +} +`, + wantErr: `unknown field "INVALID_SSH_FIELD"`, + }, + { + name: "disallow-unsupported-fields-policy-level", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ], + "INVALID_POLICY_FIELD": "should fail at policy level" +} +`, + wantErr: `unknown field "INVALID_POLICY_FIELD"`, + }, + { + name: "disallow-unsupported-fields-autoapprovers-level", + input: ` +{ + "autoApprovers": { + "routes": { + "10.0.0.0/8": ["user@example.com"] + }, + "exitNode": ["user@example.com"], + "INVALID_AUTO_APPROVER_FIELD": "should fail" + } +} +`, + wantErr: `unknown field "INVALID_AUTO_APPROVER_FIELD"`, + }, + // headscale-admin uses # in some field names to add metadata, so we will ignore + // those to ensure it doesnt break. + // https://github.com/GoodiesHQ/headscale-admin/blob/214a44a9c15c92d2b42383f131b51df10c84017c/src/lib/common/acl.svelte.ts#L38 + { + name: "hash-fields-are-allowed-but-ignored", + input: ` +{ + "acls": [ + { + "#ha-test": "SOME VALUE", + "action": "accept", + "src": [ + "10.0.0.1" + ], + "dst": [ + "autogroup:internet:*" + ] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Sources: Aliases{ + pp("10.0.0.1/32"), + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(AutoGroup("autogroup:internet")), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "ssh-asterix-invalid-acl-input", + input: ` +{ + "ssh": [ + { + "action": "accept", + "src": [ + "user@example.com" + ], + "dst": [ + "user@example.com" + ], + "users": ["root"], + "proto": "tcp" + } + ] +} +`, + wantErr: `unknown field "proto"`, + }, + { + name: "protocol-wildcard-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "*", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + wantErr: `proto name "*" not known; use protocol number 0-255 or protocol name (icmp, tcp, udp, etc.)`, + }, + { + name: "protocol-case-insensitive-uppercase", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "ICMP", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "icmp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-case-insensitive-mixed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "IcmP", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "icmp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-leading-zero-not-permitted", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "0", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + wantErr: `leading 0 not permitted in protocol number "0"`, + }, + { + name: "protocol-empty-applies-to-tcp-udp-only", + input: ` +{ + "acls": [ + { + "action": "accept", + "src": ["*"], + "dst": ["*:80"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-icmp-with-specific-port-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "icmp", + "src": ["*"], + "dst": ["*:80"] + } + ] +} +`, + wantErr: `protocol "icmp" does not support specific ports; only "*" is allowed`, + }, + { + name: "protocol-icmp-with-wildcard-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "icmp", + "src": ["*"], + "dst": ["*:*"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "icmp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-gre-with-specific-port-not-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "gre", + "src": ["*"], + "dst": ["*:443"] + } + ] +} +`, + wantErr: `protocol "gre" does not support specific ports; only "*" is allowed`, + }, + { + name: "protocol-tcp-with-specific-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:80"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-udp-with-specific-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "udp", + "src": ["*"], + "dst": ["*:53"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "udp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 53, Last: 53}}, + }, + }, + }, + }, + }, + }, + { + name: "protocol-sctp-with-specific-port-allowed", + input: ` +{ + "acls": [ + { + "action": "accept", + "proto": "sctp", + "src": ["*"], + "dst": ["*:9000"] + } + ] +} +`, + want: &Policy{ + ACLs: []ACL{ + { + Action: "accept", + Protocol: "sctp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: Wildcard, + Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}}, + }, + }, + }, + }, + }, + }, + { + name: "tags-can-own-other-tags", + input: ` +{ + "tagOwners": { + "tag:bigbrother": [], + "tag:smallbrother": ["tag:bigbrother"], + }, + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["tag:smallbrother:9000"] + } + ] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): {}, + Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + }, + ACLs: []ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: Aliases{ + Wildcard, + }, + Destinations: []AliasWithPorts{ + { + Alias: ptr.To(Tag("tag:smallbrother")), + Ports: []tailcfg.PortRange{{First: 9000, Last: 9000}}, + }, + }, + }, + }, + }, + }, + { + name: "tag-owner-references-undefined-tag", + input: ` +{ + "tagOwners": { + "tag:child": ["tag:nonexistent"], + }, +} +`, + wantErr: `tag "tag:child" references undefined tag "tag:nonexistent"`, + }, + // SSH source/destination validation tests (#3009, #3010) + { + name: "ssh-tag-to-user-rejected", + input: ` +{ + "tagOwners": {"tag:server": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["tag:server"], + "dst": ["admin@"], + "users": ["autogroup:nonroot"] + }] +} +`, + wantErr: "tags in SSH source cannot access user-owned devices", + }, + { + name: "ssh-autogroup-tagged-to-user-rejected", + input: ` +{ + "ssh": [{ + "action": "accept", + "src": ["autogroup:tagged"], + "dst": ["admin@"], + "users": ["autogroup:nonroot"] + }] +} +`, + wantErr: "tags in SSH source cannot access user-owned devices", + }, + { + name: "ssh-tag-to-autogroup-self-rejected", + input: ` +{ + "tagOwners": {"tag:server": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["tag:server"], + "dst": ["autogroup:self"], + "users": ["autogroup:nonroot"] + }] +} +`, + wantErr: "autogroup:self destination requires source to contain only users or groups", + }, + { + name: "ssh-group-to-user-rejected", + input: ` +{ + "groups": {"group:admins": ["admin@", "user1@"]}, + "ssh": [{ + "action": "accept", + "src": ["group:admins"], + "dst": ["admin@"], + "users": ["autogroup:nonroot"] + }] +} +`, + wantErr: `user destination requires source to contain only that same user "admin@"`, + }, + { + name: "ssh-same-user-to-user-allowed", + input: ` +{ + "ssh": [{ + "action": "accept", + "src": ["admin@"], + "dst": ["admin@"], + "users": ["autogroup:nonroot"] + }] +} +`, + want: &Policy{ + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{up("admin@")}, + Destinations: SSHDstAliases{up("admin@")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-group-to-autogroup-self-allowed", + input: ` +{ + "groups": {"group:admins": ["admin@", "user1@"]}, + "ssh": [{ + "action": "accept", + "src": ["group:admins"], + "dst": ["autogroup:self"], + "users": ["autogroup:nonroot"] + }] +} +`, + want: &Policy{ + Groups: Groups{ + Group("group:admins"): []Username{Username("admin@"), Username("user1@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{agp("autogroup:self")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-autogroup-tagged-to-autogroup-member-rejected", + input: ` +{ + "ssh": [{ + "action": "accept", + "src": ["autogroup:tagged"], + "dst": ["autogroup:member"], + "users": ["autogroup:nonroot"] + }] +} +`, + wantErr: "tags in SSH source cannot access autogroup:member", + }, + { + name: "ssh-autogroup-tagged-to-autogroup-tagged-allowed", + input: ` +{ + "ssh": [{ + "action": "accept", + "src": ["autogroup:tagged"], + "dst": ["autogroup:tagged"], + "users": ["autogroup:nonroot"] + }] +} +`, + want: &Policy{ + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:tagged")}, + Destinations: SSHDstAliases{agp("autogroup:tagged")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-wildcard-destination-rejected", + input: ` +{ + "groups": {"group:admins": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["group:admins"], + "dst": ["*"], + "users": ["autogroup:nonroot"] + }] +} +`, + wantErr: "wildcard (*) is not supported as SSH destination", + }, + { + name: "ssh-group-to-tag-allowed", + input: ` +{ + "tagOwners": {"tag:server": ["admin@"]}, + "groups": {"group:admins": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("admin@")}, + }, + Groups: Groups{ + Group("group:admins"): []Username{Username("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{gp("group:admins")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-user-to-tag-allowed", + input: ` +{ + "tagOwners": {"tag:server": ["admin@"]}, + "ssh": [{ + "action": "accept", + "src": ["admin@"], + "dst": ["tag:server"], + "users": ["autogroup:nonroot"] + }] +} +`, + want: &Policy{ + TagOwners: TagOwners{ + Tag("tag:server"): Owners{up("admin@")}, + }, + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{up("admin@")}, + Destinations: SSHDstAliases{tp("tag:server")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + { + name: "ssh-autogroup-member-to-autogroup-tagged-allowed", + input: ` +{ + "ssh": [{ + "action": "accept", + "src": ["autogroup:member"], + "dst": ["autogroup:tagged"], + "users": ["autogroup:nonroot"] + }] +} +`, + want: &Policy{ + SSHs: []SSH{ + { + Action: "accept", + Sources: SSHSrcAliases{agp("autogroup:member")}, + Destinations: SSHDstAliases{agp("autogroup:tagged")}, + Users: []SSHUser{SSHUser(AutoGroupNonRoot)}, + }, + }, + }, + }, + } + + cmps := append(util.Comparers, + cmp.Comparer(func(x, y Prefix) bool { + return x == y + }), + cmpopts.IgnoreUnexported(Policy{}), + ) + + // For round-trip testing, we'll normalize the policies before comparing + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Test unmarshalling + policy, err := unmarshalPolicy([]byte(tt.input)) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("unmarshalling: got %v; want no error", err) + } + } else { + if err == nil { + t.Fatalf("unmarshalling: got nil; want error %q", tt.wantErr) + } else if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("unmarshalling: got err %v; want error %q", err, tt.wantErr) + } + + return // Skip the rest of the test if we expected an error + } + + if diff := cmp.Diff(tt.want, policy, cmps...); diff != "" { + t.Fatalf("unexpected policy (-want +got):\n%s", diff) + } + + // Test round-trip marshalling/unmarshalling + if policy != nil { + // Marshal the policy back to JSON + marshalled, err := json.MarshalIndent(policy, "", " ") + if err != nil { + t.Fatalf("marshalling: %v", err) + } + + // Unmarshal it again + roundTripped, err := unmarshalPolicy(marshalled) + if err != nil { + t.Fatalf("round-trip unmarshalling: %v", err) + } + + // Add EquateEmpty to handle nil vs empty maps/slices + roundTripCmps := append(cmps, + cmpopts.EquateEmpty(), + cmpopts.IgnoreUnexported(Policy{}), + ) + + // Compare using the enhanced comparers for round-trip testing + if diff := cmp.Diff(policy, roundTripped, roundTripCmps...); diff != "" { + t.Fatalf("round trip policy (-original +roundtripped):\n%s", diff) + } + } + }) + } +} + +func gp(s string) *Group { return ptr.To(Group(s)) } +func up(s string) *Username { return ptr.To(Username(s)) } +func hp(s string) *Host { return ptr.To(Host(s)) } +func tp(s string) *Tag { return ptr.To(Tag(s)) } +func agp(s string) *AutoGroup { return ptr.To(AutoGroup(s)) } +func mp(pref string) netip.Prefix { return netip.MustParsePrefix(pref) } +func ap(addr string) *netip.Addr { return ptr.To(netip.MustParseAddr(addr)) } +func pp(pref string) *Prefix { return ptr.To(Prefix(mp(pref))) } +func p(pref string) Prefix { return Prefix(mp(pref)) } + +func TestResolvePolicy(t *testing.T) { + users := map[string]types.User{ + "testuser": {Model: gorm.Model{ID: 1}, Name: "testuser"}, + "groupuser": {Model: gorm.Model{ID: 2}, Name: "groupuser"}, + "groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"}, + "groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"}, + "notme": {Model: gorm.Model{ID: 5}, Name: "notme"}, + "testuser2": {Model: gorm.Model{ID: 6}, Name: "testuser2"}, + } + + // Extract users to variables so we can take their addresses + testuser := users["testuser"] + groupuser := users["groupuser"] + groupuser1 := users["groupuser1"] + groupuser2 := users["groupuser2"] + notme := users["notme"] + testuser2 := users["testuser2"] + + tests := []struct { + name string + nodes types.Nodes + pol *Policy + toResolve Alias + want []netip.Prefix + wantErr string + }{ + { + name: "prefix", + toResolve: pp("100.100.101.101/32"), + want: []netip.Prefix{mp("100.100.101.101/32")}, + }, + { + name: "host", + pol: &Policy{ + Hosts: Hosts{ + "testhost": p("100.100.101.102/32"), + }, + }, + toResolve: hp("testhost"), + want: []netip.Prefix{mp("100.100.101.102/32")}, + }, + { + name: "username", + toResolve: ptr.To(Username("testuser@")), + nodes: types.Nodes{ + // Not matching other user + { + User: ptr.To(notme), + IPv4: ap("100.100.101.1"), + }, + // Not matching forced tags + { + User: ptr.To(testuser), + Tags: []string{"tag:anything"}, + IPv4: ap("100.100.101.2"), + }, + // not matching because it's tagged (tags copied from AuthKey) + { + User: ptr.To(testuser), + Tags: []string{"alsotagged"}, + IPv4: ap("100.100.101.3"), + }, + { + User: ptr.To(testuser), + IPv4: ap("100.100.101.103"), + }, + { + User: ptr.To(testuser), + IPv4: ap("100.100.101.104"), + }, + }, + want: []netip.Prefix{mp("100.100.101.103/32"), mp("100.100.101.104/32")}, + }, + { + name: "group", + toResolve: ptr.To(Group("group:testgroup")), + nodes: types.Nodes{ + // Not matching other user + { + User: ptr.To(notme), + IPv4: ap("100.100.101.4"), + }, + // Not matching forced tags + { + User: ptr.To(groupuser), + Tags: []string{"tag:anything"}, + IPv4: ap("100.100.101.5"), + }, + // not matching because it's tagged (tags copied from AuthKey) + { + User: ptr.To(groupuser), + Tags: []string{"tag:alsotagged"}, + IPv4: ap("100.100.101.6"), + }, + { + User: ptr.To(groupuser), + IPv4: ap("100.100.101.203"), + }, + { + User: ptr.To(groupuser), + IPv4: ap("100.100.101.204"), + }, + }, + pol: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"groupuser"}, + "group:othergroup": Usernames{"notmetoo"}, + }, + }, + want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + }, + { + name: "tag", + toResolve: tp("tag:test"), + nodes: types.Nodes{ + // Not matching other user + { + User: ptr.To(notme), + IPv4: ap("100.100.101.9"), + }, + // Not matching forced tags + { + Tags: []string{"tag:anything"}, + IPv4: ap("100.100.101.10"), + }, + // not matching pak tag + { + AuthKey: &types.PreAuthKey{ + Tags: []string{"tag:alsotagged"}, + }, + IPv4: ap("100.100.101.11"), + }, + // Not matching forced tags + { + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + // matching tag (tags copied from AuthKey during registration) + { + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.239"), + }, + }, + // TODO(kradalby): tests handling TagOwners + hostinfo + pol: &Policy{}, + want: []netip.Prefix{mp("100.100.101.234/32"), mp("100.100.101.239/32")}, + }, + { + name: "tag-owned-by-tag-call-child", + toResolve: tp("tag:smallbrother"), + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): {}, + Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + }, + }, + nodes: types.Nodes{ + // Should not match as we resolve the "child" tag. + { + Tags: []string{"tag:bigbrother"}, + IPv4: ap("100.100.101.234"), + }, + // Should match. + { + Tags: []string{"tag:smallbrother"}, + IPv4: ap("100.100.101.239"), + }, + }, + want: []netip.Prefix{mp("100.100.101.239/32")}, + }, + { + name: "tag-owned-by-tag-call-parent", + toResolve: tp("tag:bigbrother"), + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): {}, + Tag("tag:smallbrother"): {ptr.To(Tag("tag:bigbrother"))}, + }, + }, + nodes: types.Nodes{ + // Should match - we are resolving "tag:bigbrother" which this node has. + { + Tags: []string{"tag:bigbrother"}, + IPv4: ap("100.100.101.234"), + }, + // Should not match - this node has "tag:smallbrother", not the tag we're resolving. + { + Tags: []string{"tag:smallbrother"}, + IPv4: ap("100.100.101.239"), + }, + }, + want: []netip.Prefix{mp("100.100.101.234/32")}, + }, + { + name: "empty-policy", + toResolve: pp("100.100.101.101/32"), + pol: &Policy{}, + want: []netip.Prefix{mp("100.100.101.101/32")}, + }, + { + name: "invalid-host", + toResolve: hp("invalidhost"), + pol: &Policy{ + Hosts: Hosts{ + "testhost": p("100.100.101.102/32"), + }, + }, + wantErr: `unable to resolve host: "invalidhost"`, + }, + { + name: "multiple-groups", + toResolve: ptr.To(Group("group:testgroup")), + nodes: types.Nodes{ + { + User: ptr.To(groupuser1), + IPv4: ap("100.100.101.203"), + }, + { + User: ptr.To(groupuser2), + IPv4: ap("100.100.101.204"), + }, + }, + pol: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"groupuser1@", "groupuser2@"}, + }, + }, + want: []netip.Prefix{mp("100.100.101.203/32"), mp("100.100.101.204/32")}, + }, + { + name: "autogroup-internet", + toResolve: agp("autogroup:internet"), + want: util.TheInternet().Prefixes(), + }, + { + name: "invalid-username", + toResolve: ptr.To(Username("invaliduser@")), + nodes: types.Nodes{ + { + User: ptr.To(testuser), + IPv4: ap("100.100.101.103"), + }, + }, + wantErr: `user with token "invaliduser@" not found`, + }, + { + name: "invalid-tag", + toResolve: tp("tag:invalid"), + nodes: types.Nodes{ + { + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.234"), + }, + }, + }, + { + name: "ipv6-address", + toResolve: pp("fd7a:115c:a1e0::1/128"), + want: []netip.Prefix{mp("fd7a:115c:a1e0::1/128")}, + }, + { + name: "wildcard-alias", + toResolve: Wildcard, + want: []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}, + }, + { + name: "autogroup-member-comprehensive", + toResolve: ptr.To(AutoGroup(AutoGroupMember)), + nodes: types.Nodes{ + // Node with no tags (should be included - is a member) + { + User: ptr.To(testuser), + IPv4: ap("100.100.101.1"), + }, + // Node with single tag (should be excluded - tagged nodes are not members) + { + User: ptr.To(testuser), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.2"), + }, + // Node with multiple tags, all defined in policy (should be excluded) + { + User: ptr.To(testuser), + Tags: []string{"tag:test", "tag:other"}, + IPv4: ap("100.100.101.3"), + }, + // Node with tag not defined in policy (should be excluded - still tagged) + { + User: ptr.To(testuser), + Tags: []string{"tag:undefined"}, + IPv4: ap("100.100.101.4"), + }, + // Node with mixed tags - some defined, some not (should be excluded) + { + User: ptr.To(testuser), + Tags: []string{"tag:test", "tag:undefined"}, + IPv4: ap("100.100.101.5"), + }, + // Another untagged node from different user (should be included) + { + User: ptr.To(testuser2), + IPv4: ap("100.100.101.6"), + }, + }, + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:other"): Owners{ptr.To(Username("testuser@"))}, + }, + }, + want: []netip.Prefix{ + mp("100.100.101.1/32"), // No tags - is a member + mp("100.100.101.6/32"), // No tags, different user - is a member + }, + }, + { + name: "autogroup-tagged", + toResolve: ptr.To(AutoGroup(AutoGroupTagged)), + nodes: types.Nodes{ + // Node with no tags (should be excluded - not tagged) + { + User: ptr.To(testuser), + IPv4: ap("100.100.101.1"), + }, + // Node with single tag defined in policy (should be included) + { + User: ptr.To(testuser), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.2"), + }, + // Node with multiple tags, all defined in policy (should be included) + { + User: ptr.To(testuser), + Tags: []string{"tag:test", "tag:other"}, + IPv4: ap("100.100.101.3"), + }, + // Node with tag not defined in policy (should be included - still tagged) + { + User: ptr.To(testuser), + Tags: []string{"tag:undefined"}, + IPv4: ap("100.100.101.4"), + }, + // Node with mixed tags - some defined, some not (should be included) + { + User: ptr.To(testuser), + Tags: []string{"tag:test", "tag:undefined"}, + IPv4: ap("100.100.101.5"), + }, + // Another untagged node from different user (should be excluded) + { + User: ptr.To(testuser2), + IPv4: ap("100.100.101.6"), + }, + // Tagged node from different user (should be included) + { + User: ptr.To(testuser2), + Tags: []string{"tag:server"}, + IPv4: ap("100.100.101.7"), + }, + }, + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:other"): Owners{ptr.To(Username("testuser@"))}, + Tag("tag:server"): Owners{ptr.To(Username("testuser2@"))}, + }, + }, + want: []netip.Prefix{ + mp("100.100.101.2/31"), // .2, .3 consecutive tagged nodes + mp("100.100.101.4/31"), // .4, .5 consecutive tagged nodes + mp("100.100.101.7/32"), // Tagged node from different user + }, + }, + { + name: "autogroup-self", + toResolve: ptr.To(AutoGroupSelf), + nodes: types.Nodes{ + { + User: ptr.To(testuser), + IPv4: ap("100.100.101.1"), + }, + { + User: ptr.To(testuser2), + IPv4: ap("100.100.101.2"), + }, + { + User: ptr.To(testuser), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.3"), + }, + { + User: ptr.To(testuser2), + Tags: []string{"tag:test"}, + IPv4: ap("100.100.101.4"), + }, + }, + pol: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("testuser@"))}, + }, + }, + wantErr: "autogroup:self requires per-node resolution", + }, + { + name: "autogroup-invalid", + toResolve: ptr.To(AutoGroup("autogroup:invalid")), + wantErr: "unknown autogroup", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ips, err := tt.toResolve.Resolve(tt.pol, + xmaps.Values(users), + tt.nodes.ViewSlice()) + if tt.wantErr == "" { + if err != nil { + t.Fatalf("got %v; want no error", err) + } + } else { + if err == nil { + t.Fatalf("got nil; want error %q", tt.wantErr) + } else if !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("got err %v; want error %q", err, tt.wantErr) + } + } + + var prefs []netip.Prefix + if ips != nil { + if p := ips.Prefixes(); len(p) > 0 { + prefs = p + } + } + + if diff := cmp.Diff(tt.want, prefs, util.Comparers...); diff != "" { + t.Fatalf("unexpected prefs (-want +got):\n%s", diff) + } + }) + } +} + +func TestResolveAutoApprovers(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: &users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: &users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: &users[2], + }, + { + IPv4: ap("100.64.0.4"), + Tags: []string{"tag:testtag"}, + }, + { + IPv4: ap("100.64.0.5"), + Tags: []string{"tag:exittest"}, + }, + } + + tests := []struct { + name string + policy *Policy + want map[netip.Prefix]*netipx.IPSet + wantAllIPRoutes *netipx.IPSet + wantErr bool + }{ + { + name: "single-route", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32"), + }, + wantAllIPRoutes: nil, + wantErr: false, + }, + { + name: "multiple-routes", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32"), + mp("10.0.1.0/24"): mustIPSet("100.64.0.2/32"), + }, + wantAllIPRoutes: nil, + wantErr: false, + }, + { + name: "exit-node", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{}, + wantAllIPRoutes: mustIPSet("100.64.0.1/32"), + wantErr: false, + }, + { + name: "group-route", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + }, + wantAllIPRoutes: nil, + wantErr: false, + }, + { + name: "tag-route-and-exit", + policy: &Policy{ + TagOwners: TagOwners{ + "tag:testtag": Owners{ + ptr.To(Username("user1@")), + ptr.To(Username("user2@")), + }, + "tag:exittest": Owners{ + ptr.To(Group("group:exitgroup")), + }, + }, + Groups: Groups{ + "group:exitgroup": Usernames{"user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + ExitNode: AutoApprovers{ptr.To(Tag("tag:exittest"))}, + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.1.0/24"): {ptr.To(Tag("tag:testtag"))}, + }, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.1.0/24"): mustIPSet("100.64.0.4/32"), + }, + wantAllIPRoutes: mustIPSet("100.64.0.5/32"), + wantErr: false, + }, + { + name: "mixed-routes-and-exit-nodes", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + }, + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + want: map[netip.Prefix]*netipx.IPSet{ + mp("10.0.0.0/24"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + mp("10.0.1.0/24"): mustIPSet("100.64.0.3/32"), + }, + wantAllIPRoutes: mustIPSet("100.64.0.1/32"), + wantErr: false, + }, + } + + cmps := append(util.Comparers, cmp.Comparer(ipSetComparer)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, gotAllIPRoutes, err := resolveAutoApprovers(tt.policy, users, nodes.ViewSlice()) + if (err != nil) != tt.wantErr { + t.Errorf("resolveAutoApprovers() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { + t.Errorf("resolveAutoApprovers() mismatch (-want +got):\n%s", diff) + } + if tt.wantAllIPRoutes != nil { + if gotAllIPRoutes == nil { + t.Error("resolveAutoApprovers() expected non-nil allIPRoutes, got nil") + } else if diff := cmp.Diff(tt.wantAllIPRoutes, gotAllIPRoutes, cmps...); diff != "" { + t.Errorf("resolveAutoApprovers() allIPRoutes mismatch (-want +got):\n%s", diff) + } + } else if gotAllIPRoutes != nil { + t.Error("resolveAutoApprovers() expected nil allIPRoutes, got non-nil") + } + }) + } +} + +func TestSSHUsers_NormalUsers(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected []SSHUser + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: []SSHUser{}, + }, + { + name: "only root", + users: SSHUsers{"root"}, + expected: []SSHUser{}, + }, + { + name: "only autogroup:nonroot", + users: SSHUsers{SSHUser(AutoGroupNonRoot)}, + expected: []SSHUser{}, + }, + { + name: "only normal user", + users: SSHUsers{"ssh-it-user"}, + expected: []SSHUser{"ssh-it-user"}, + }, + { + name: "multiple normal users", + users: SSHUsers{"ubuntu", "admin", "user1"}, + expected: []SSHUser{"ubuntu", "admin", "user1"}, + }, + { + name: "mixed users with root", + users: SSHUsers{"ubuntu", "root", "admin"}, + expected: []SSHUser{"ubuntu", "admin"}, + }, + { + name: "mixed users with autogroup:nonroot", + users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"}, + expected: []SSHUser{"ubuntu", "admin"}, + }, + { + name: "mixed users with both root and autogroup:nonroot", + users: SSHUsers{"ubuntu", "root", SSHUser(AutoGroupNonRoot), "admin"}, + expected: []SSHUser{"ubuntu", "admin"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.NormalUsers() + assert.ElementsMatch(t, tt.expected, result, "NormalUsers() should return expected normal users") + }) + } +} + +func TestSSHUsers_ContainsRoot(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected bool + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: false, + }, + { + name: "contains root", + users: SSHUsers{"root"}, + expected: true, + }, + { + name: "does not contain root", + users: SSHUsers{"ubuntu", "admin"}, + expected: false, + }, + { + name: "contains root among others", + users: SSHUsers{"ubuntu", "root", "admin"}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.ContainsRoot() + assert.Equal(t, tt.expected, result, "ContainsRoot() should return expected result") + }) + } +} + +func TestSSHUsers_ContainsNonRoot(t *testing.T) { + tests := []struct { + name string + users SSHUsers + expected bool + }{ + { + name: "empty users", + users: SSHUsers{}, + expected: false, + }, + { + name: "contains autogroup:nonroot", + users: SSHUsers{SSHUser(AutoGroupNonRoot)}, + expected: true, + }, + { + name: "does not contain autogroup:nonroot", + users: SSHUsers{"ubuntu", "admin", "root"}, + expected: false, + }, + { + name: "contains autogroup:nonroot among others", + users: SSHUsers{"ubuntu", SSHUser(AutoGroupNonRoot), "admin"}, + expected: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := tt.users.ContainsNonRoot() + assert.Equal(t, tt.expected, result, "ContainsNonRoot() should return expected result") + }) + } +} + +func mustIPSet(prefixes ...string) *netipx.IPSet { + var builder netipx.IPSetBuilder + for _, p := range prefixes { + builder.AddPrefix(mp(p)) + } + ipSet, _ := builder.IPSet() + + return ipSet +} + +func ipSetComparer(x, y *netipx.IPSet) bool { + if x == nil || y == nil { + return x == y + } + return cmp.Equal(x.Prefixes(), y.Prefixes(), util.Comparers...) +} + +func TestNodeCanApproveRoute(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: &users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: &users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: &users[2], + }, + } + + tests := []struct { + name string + policy *Policy + node *types.Node + route netip.Prefix + want bool + wantErr bool + }{ + { + name: "single-route-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + }, + }, + }, + node: nodes[0], + route: mp("10.0.0.0/24"), + want: true, + }, + { + name: "multiple-routes-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user1@"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user2@"))}, + }, + }, + }, + node: nodes[1], + route: mp("10.0.1.0/24"), + want: true, + }, + { + name: "exit-node-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + route: tsaddr.AllIPv4(), + want: true, + }, + { + name: "group-route-approval", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + }, + }, + }, + node: nodes[1], + route: mp("10.0.0.0/24"), + want: true, + }, + { + name: "mixed-routes-and-exit-nodes-approval", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Group("group:testgroup"))}, + mp("10.0.1.0/24"): {ptr.To(Username("user3@"))}, + }, + ExitNode: AutoApprovers{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + route: tsaddr.AllIPv4(), + want: true, + }, + { + name: "no-approval", + policy: &Policy{ + AutoApprovers: AutoApproverPolicy{ + Routes: map[netip.Prefix]AutoApprovers{ + mp("10.0.0.0/24"): {ptr.To(Username("user2@"))}, + }, + }, + }, + node: nodes[0], + route: mp("10.0.0.0/24"), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(tt.policy) + require.NoError(t, err) + + pm, err := NewPolicyManager(b, users, nodes.ViewSlice()) + require.NoErrorf(t, err, "NewPolicyManager() error = %v", err) + + got := pm.NodeCanApproveRoute(tt.node.View(), tt.route) + if got != tt.want { + t.Errorf("NodeCanApproveRoute() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestResolveTagOwners(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: &users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: &users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: &users[2], + }, + } + + tests := []struct { + name string + policy *Policy + want map[Tag]*netipx.IPSet + wantErr bool + }{ + { + name: "single-tag-owner", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:test"): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, + { + name: "multiple-tag-owners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:test"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + }, + wantErr: false, + }, + { + name: "group-tag-owner", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:test"): mustIPSet("100.64.0.1/32", "100.64.0.2/32"), + }, + wantErr: false, + }, + { + name: "tag-owns-tag", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:bigbrother"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + }, + }, + want: map[Tag]*netipx.IPSet{ + Tag("tag:bigbrother"): mustIPSet("100.64.0.1/32"), + Tag("tag:smallbrother"): mustIPSet("100.64.0.1/32"), + }, + wantErr: false, + }, + } + + cmps := append(util.Comparers, cmp.Comparer(ipSetComparer)) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := resolveTagOwners(tt.policy, users, nodes.ViewSlice()) + if (err != nil) != tt.wantErr { + t.Errorf("resolveTagOwners() error = %v, wantErr %v", err, tt.wantErr) + return + } + if diff := cmp.Diff(tt.want, got, cmps...); diff != "" { + t.Errorf("resolveTagOwners() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestNodeCanHaveTag(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: &users[0], + }, + { + IPv4: ap("100.64.0.2"), + User: &users[1], + }, + { + IPv4: ap("100.64.0.3"), + User: &users[2], + }, + } + + tests := []struct { + name string + policy *Policy + node *types.Node + tag string + want bool + wantErr string + }{ + { + name: "single-tag-owner", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + tag: "tag:test", + want: true, + }, + { + name: "multiple-tag-owners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + }, + }, + node: nodes[1], + tag: "tag:test", + want: true, + }, + { + name: "group-tag-owner", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"user1@", "user2@"}, + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + }, + }, + node: nodes[1], + tag: "tag:test", + want: true, + }, + { + name: "invalid-group", + policy: &Policy{ + Groups: Groups{ + "group:testgroup": Usernames{"invalid"}, + }, + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Group("group:testgroup"))}, + }, + }, + node: nodes[0], + tag: "tag:test", + want: false, + wantErr: "Username has to contain @", + }, + { + name: "node-cannot-have-tag", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + }, + }, + node: nodes[0], + tag: "tag:test", + want: false, + }, + { + name: "node-with-unauthorized-tag-different-user", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[2], // user3's node + tag: "tag:prod", + want: false, + }, + { + name: "node-with-multiple-tags-one-unauthorized", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:web"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:database"): Owners{ptr.To(Username("user2@"))}, + }, + }, + node: nodes[0], // user1's node + tag: "tag:database", + want: false, // user1 cannot have tag:database (owned by user2) + }, + { + name: "empty-tagowners-map", + policy: &Policy{ + TagOwners: TagOwners{}, + }, + node: nodes[0], + tag: "tag:test", + want: false, // No one can have tags if tagOwners is empty + }, + { + name: "tag-not-in-tagowners", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:prod"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: nodes[0], + tag: "tag:dev", // This tag is not defined in tagOwners + want: false, + }, + // Test cases for nodes without IPs (new registration scenario) + // These test the user-based fallback in NodeCanHaveTag + { + name: "node-without-ip-user-owns-tag", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: &types.Node{ + // No IPv4 or IPv6 - simulates new node registration + User: &users[0], + UserID: ptr.To(users[0].ID), + }, + tag: "tag:test", + want: true, // Should succeed via user-based fallback + }, + { + name: "node-without-ip-user-does-not-own-tag", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user2@"))}, + }, + }, + node: &types.Node{ + // No IPv4 or IPv6 - simulates new node registration + User: &users[0], // user1, but tag owned by user2 + UserID: ptr.To(users[0].ID), + }, + tag: "tag:test", + want: false, // user1 does not own tag:test + }, + { + name: "node-without-ip-group-owns-tag", + policy: &Policy{ + Groups: Groups{ + "group:admins": Usernames{"user1@", "user2@"}, + }, + TagOwners: TagOwners{ + Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))}, + }, + }, + node: &types.Node{ + // No IPv4 or IPv6 - simulates new node registration + User: &users[1], // user2 is in group:admins + UserID: ptr.To(users[1].ID), + }, + tag: "tag:admin", + want: true, // Should succeed via group membership + }, + { + name: "node-without-ip-not-in-group", + policy: &Policy{ + Groups: Groups{ + "group:admins": Usernames{"user1@"}, + }, + TagOwners: TagOwners{ + Tag("tag:admin"): Owners{ptr.To(Group("group:admins"))}, + }, + }, + node: &types.Node{ + // No IPv4 or IPv6 - simulates new node registration + User: &users[1], // user2 is NOT in group:admins + UserID: ptr.To(users[1].ID), + }, + tag: "tag:admin", + want: false, // user2 is not in group:admins + }, + { + name: "node-without-ip-no-user", + policy: &Policy{ + TagOwners: TagOwners{ + Tag("tag:test"): Owners{ptr.To(Username("user1@"))}, + }, + }, + node: &types.Node{ + // No IPv4, IPv6, or User - edge case + }, + tag: "tag:test", + want: false, // No user means can't authorize via user-based fallback + }, + { + name: "node-without-ip-mixed-owners-user-match", + policy: &Policy{ + Groups: Groups{ + "group:ops": Usernames{"user3@"}, + }, + TagOwners: TagOwners{ + Tag("tag:server"): Owners{ + ptr.To(Username("user1@")), + ptr.To(Group("group:ops")), + }, + }, + }, + node: &types.Node{ + User: &users[0], // user1 directly owns the tag + UserID: ptr.To(users[0].ID), + }, + tag: "tag:server", + want: true, + }, + { + name: "node-without-ip-mixed-owners-group-match", + policy: &Policy{ + Groups: Groups{ + "group:ops": Usernames{"user3@"}, + }, + TagOwners: TagOwners{ + Tag("tag:server"): Owners{ + ptr.To(Username("user1@")), + ptr.To(Group("group:ops")), + }, + }, + }, + node: &types.Node{ + User: &users[2], // user3 is in group:ops + UserID: ptr.To(users[2].ID), + }, + tag: "tag:server", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + b, err := json.Marshal(tt.policy) + require.NoError(t, err) + + pm, err := NewPolicyManager(b, users, nodes.ViewSlice()) + if tt.wantErr != "" { + require.ErrorContains(t, err, tt.wantErr) + return + } + require.NoError(t, err) + + got := pm.NodeCanHaveTag(tt.node.View(), tt.tag) + if got != tt.want { + t.Errorf("NodeCanHaveTag() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestUserMatchesOwner(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "user1"}, + {Model: gorm.Model{ID: 2}, Name: "user2"}, + {Model: gorm.Model{ID: 3}, Name: "user3"}, + } + + tests := []struct { + name string + policy *Policy + user types.User + owner Owner + want bool + }{ + { + name: "username-match", + policy: &Policy{}, + user: users[0], + owner: ptr.To(Username("user1@")), + want: true, + }, + { + name: "username-no-match", + policy: &Policy{}, + user: users[0], + owner: ptr.To(Username("user2@")), + want: false, + }, + { + name: "group-match", + policy: &Policy{ + Groups: Groups{ + "group:admins": Usernames{"user1@", "user2@"}, + }, + }, + user: users[1], // user2 is in group:admins + owner: ptr.To(Group("group:admins")), + want: true, + }, + { + name: "group-no-match", + policy: &Policy{ + Groups: Groups{ + "group:admins": Usernames{"user1@"}, + }, + }, + user: users[1], // user2 is NOT in group:admins + owner: ptr.To(Group("group:admins")), + want: false, + }, + { + name: "group-not-defined", + policy: &Policy{ + Groups: Groups{}, + }, + user: users[0], + owner: ptr.To(Group("group:undefined")), + want: false, + }, + { + name: "nil-username-owner", + policy: &Policy{}, + user: users[0], + owner: (*Username)(nil), + want: false, + }, + { + name: "nil-group-owner", + policy: &Policy{}, + user: users[0], + owner: (*Group)(nil), + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create a minimal PolicyManager for testing + // We need nodes with IPs to initialize the tagOwnerMap + nodes := types.Nodes{ + { + IPv4: ap("100.64.0.1"), + User: &users[0], + }, + } + + b, err := json.Marshal(tt.policy) + require.NoError(t, err) + + pm, err := NewPolicyManager(b, users, nodes.ViewSlice()) + require.NoError(t, err) + + got := pm.userMatchesOwner(tt.user.View(), tt.owner) + if got != tt.want { + t.Errorf("userMatchesOwner() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestACL_UnmarshalJSON_WithCommentFields(t *testing.T) { + tests := []struct { + name string + input string + expected ACL + wantErr bool + }{ + { + name: "basic ACL with comment fields", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp", + "src": ["user1@example.com"], + "dst": ["tag:server:80"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("user1@example.com")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 80, Last: 80}}, + }, + }, + }, + wantErr: false, + }, + { + name: "multiple comment fields", + input: `{ + "#description": "Allow access to web servers", + "#note": "Created by admin", + "#created_date": "2024-01-15", + "action": "accept", + "proto": "tcp", + "src": ["group:developers"], + "dst": ["10.0.0.0/24:443"] + }`, + expected: ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("group:developers")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("10.0.0.0/24"), + Ports: []tailcfg.PortRange{{First: 443, Last: 443}}, + }, + }, + }, + wantErr: false, + }, + { + name: "comment field with complex object value", + input: `{ + "#metadata": { + "description": "Complex comment object", + "tags": ["web", "production"], + "created_by": "admin" + }, + "action": "accept", + "proto": "udp", + "src": ["*"], + "dst": ["autogroup:internet:53"] + }`, + expected: ACL{ + Action: ActionAccept, + Protocol: "udp", + Sources: []Alias{Wildcard}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("autogroup:internet"), + Ports: []tailcfg.PortRange{{First: 53, Last: 53}}, + }, + }, + }, + wantErr: false, + }, + { + name: "invalid action should fail", + input: `{ + "action": "deny", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + }`, + wantErr: true, + }, + { + name: "no comment fields", + input: `{ + "action": "accept", + "proto": "icmp", + "src": ["tag:client"], + "dst": ["tag:server:*"] + }`, + expected: ACL{ + Action: ActionAccept, + Protocol: "icmp", + Sources: []Alias{mustParseAlias("tag:client")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}, + }, + }, + }, + wantErr: false, + }, + { + name: "only comment fields", + input: `{ + "#comment": "This rule is disabled", + "#reason": "Temporary disable for maintenance" + }`, + expected: ACL{ + Action: Action(""), + Protocol: Protocol(""), + Sources: nil, + Destinations: nil, + }, + wantErr: false, + }, + { + name: "invalid JSON", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp" + "src": ["invalid json"] + }`, + wantErr: true, + }, + { + name: "invalid field after comment filtering", + input: `{ + "#comment": "This is a comment", + "action": "accept", + "proto": "tcp", + "src": ["user1@example.com"], + "dst": ["invalid-destination"] + }`, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var acl ACL + err := json.Unmarshal([]byte(tt.input), &acl) + + if tt.wantErr { + assert.Error(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.expected.Action, acl.Action) + assert.Equal(t, tt.expected.Protocol, acl.Protocol) + assert.Equal(t, len(tt.expected.Sources), len(acl.Sources)) + assert.Equal(t, len(tt.expected.Destinations), len(acl.Destinations)) + + // Compare sources + for i, expectedSrc := range tt.expected.Sources { + if i < len(acl.Sources) { + assert.Equal(t, expectedSrc, acl.Sources[i]) + } + } + + // Compare destinations + for i, expectedDst := range tt.expected.Destinations { + if i < len(acl.Destinations) { + assert.Equal(t, expectedDst.Alias, acl.Destinations[i].Alias) + assert.Equal(t, expectedDst.Ports, acl.Destinations[i].Ports) + } + } + }) + } +} + +func TestACL_UnmarshalJSON_Roundtrip(t *testing.T) { + // Test that marshaling and unmarshaling preserves data (excluding comments) + original := ACL{ + Action: "accept", + Protocol: "tcp", + Sources: []Alias{mustParseAlias("group:admins")}, + Destinations: []AliasWithPorts{ + { + Alias: mustParseAlias("tag:server"), + Ports: []tailcfg.PortRange{{First: 22, Last: 22}, {First: 80, Last: 80}}, + }, + }, + } + + // Marshal to JSON + jsonBytes, err := json.Marshal(original) + require.NoError(t, err) + + // Unmarshal back + var unmarshaled ACL + err = json.Unmarshal(jsonBytes, &unmarshaled) + require.NoError(t, err) + + // Should be equal + assert.Equal(t, original.Action, unmarshaled.Action) + assert.Equal(t, original.Protocol, unmarshaled.Protocol) + assert.Equal(t, len(original.Sources), len(unmarshaled.Sources)) + assert.Equal(t, len(original.Destinations), len(unmarshaled.Destinations)) +} + +func TestACL_UnmarshalJSON_PolicyIntegration(t *testing.T) { + // Test that ACL unmarshaling works within a Policy context + policyJSON := `{ + "groups": { + "group:developers": ["user1@example.com", "user2@example.com"] + }, + "tagOwners": { + "tag:server": ["group:developers"] + }, + "acls": [ + { + "#description": "Allow developers to access servers", + "#priority": "high", + "action": "accept", + "proto": "tcp", + "src": ["group:developers"], + "dst": ["tag:server:22,80,443"] + }, + { + "#note": "Allow all other traffic", + "action": "accept", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + policy, err := unmarshalPolicy([]byte(policyJSON)) + require.NoError(t, err) + require.NotNil(t, policy) + + // Check that ACLs were parsed correctly + require.Len(t, policy.ACLs, 2) + + // First ACL + acl1 := policy.ACLs[0] + assert.Equal(t, ActionAccept, acl1.Action) + assert.Equal(t, Protocol("tcp"), acl1.Protocol) + require.Len(t, acl1.Sources, 1) + require.Len(t, acl1.Destinations, 1) + + // Second ACL + acl2 := policy.ACLs[1] + assert.Equal(t, ActionAccept, acl2.Action) + assert.Equal(t, Protocol("tcp"), acl2.Protocol) + require.Len(t, acl2.Sources, 1) + require.Len(t, acl2.Destinations, 1) +} + +func TestACL_UnmarshalJSON_InvalidAction(t *testing.T) { + // Test that invalid actions are rejected + policyJSON := `{ + "acls": [ + { + "action": "deny", + "proto": "tcp", + "src": ["*"], + "dst": ["*:*"] + } + ] + }` + + _, err := unmarshalPolicy([]byte(policyJSON)) + require.Error(t, err) + assert.Contains(t, err.Error(), `invalid action "deny"`) +} + +// Helper function to parse aliases for testing +func mustParseAlias(s string) Alias { + alias, err := parseAlias(s) + if err != nil { + panic(err) + } + return alias +} + +func TestFlattenTagOwners(t *testing.T) { + tests := []struct { + name string + input TagOwners + want TagOwners + wantErr string + }{ + { + name: "tag-owns-tag", + input: TagOwners{ + Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{ptr.To(Tag("tag:bigbrother"))}, + }, + want: TagOwners{ + Tag("tag:bigbrother"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:smallbrother"): Owners{ptr.To(Group("group:user1"))}, + }, + wantErr: "", + }, + { + name: "circular-reference", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + }, + want: nil, + wantErr: "circular reference detected: tag:a -> tag:b", + }, + { + name: "mixed-owners", + input: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, + Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + }, + want: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Username("user2@"))}, + Tag("tag:y"): Owners{ptr.To(Username("user2@"))}, + }, + wantErr: "", + }, + { + name: "mixed-dupe-owners", + input: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@")), ptr.To(Tag("tag:y"))}, + Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + }, + want: TagOwners{ + Tag("tag:x"): Owners{ptr.To(Username("user1@"))}, + Tag("tag:y"): Owners{ptr.To(Username("user1@"))}, + }, + wantErr: "", + }, + { + name: "no-tag-owners", + input: TagOwners{ + Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + }, + want: TagOwners{ + Tag("tag:solo"): Owners{ptr.To(Username("user1@"))}, + }, + wantErr: "", + }, + { + name: "tag-long-owner-chain", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, + Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, + Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, + Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + }, + want: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:b"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:c"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:d"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:e"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:f"): Owners{ptr.To(Group("group:user1"))}, + Tag("tag:g"): Owners{ptr.To(Group("group:user1"))}, + }, + wantErr: "", + }, + { + name: "tag-long-circular-chain", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:g"))}, + Tag("tag:b"): Owners{ptr.To(Tag("tag:a"))}, + Tag("tag:c"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:d"): Owners{ptr.To(Tag("tag:c"))}, + Tag("tag:e"): Owners{ptr.To(Tag("tag:d"))}, + Tag("tag:f"): Owners{ptr.To(Tag("tag:e"))}, + Tag("tag:g"): Owners{ptr.To(Tag("tag:f"))}, + }, + wantErr: "circular reference detected: tag:a -> tag:b -> tag:c -> tag:d -> tag:e -> tag:f -> tag:g", + }, + { + name: "undefined-tag-reference", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:nonexistent"))}, + }, + wantErr: `tag "tag:a" references undefined tag "tag:nonexistent"`, + }, + { + name: "tag-with-empty-owners-is-valid", + input: TagOwners{ + Tag("tag:a"): Owners{ptr.To(Tag("tag:b"))}, + Tag("tag:b"): Owners{}, // empty owners but exists + }, + want: TagOwners{ + Tag("tag:a"): nil, + Tag("tag:b"): nil, + }, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := flattenTagOwners(tt.input) + if tt.wantErr != "" { + if err == nil { + t.Fatalf("flattenTagOwners() expected error %q, got nil", tt.wantErr) + } + + if err.Error() != tt.wantErr { + t.Fatalf("flattenTagOwners() expected error %q, got %q", tt.wantErr, err.Error()) + } + + return + } + + if err != nil { + t.Fatalf("flattenTagOwners() unexpected error: %v", err) + } + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("flattenTagOwners() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/policy/v2/utils.go b/hscontrol/policy/v2/utils.go new file mode 100644 index 00000000..a4367775 --- /dev/null +++ b/hscontrol/policy/v2/utils.go @@ -0,0 +1,99 @@ +package v2 + +import ( + "errors" + "slices" + "strconv" + "strings" + + "tailscale.com/tailcfg" +) + +// splitDestinationAndPort takes an input string and returns the destination and port as a tuple, or an error if the input is invalid. +func splitDestinationAndPort(input string) (string, string, error) { + // Find the last occurrence of the colon character + lastColonIndex := strings.LastIndex(input, ":") + + // Check if the colon character is present and not at the beginning or end of the string + if lastColonIndex == -1 { + return "", "", errors.New("input must contain a colon character separating destination and port") + } + if lastColonIndex == 0 { + return "", "", errors.New("input cannot start with a colon character") + } + if lastColonIndex == len(input)-1 { + return "", "", errors.New("input cannot end with a colon character") + } + + // Split the string into destination and port based on the last colon + destination := input[:lastColonIndex] + port := input[lastColonIndex+1:] + + return destination, port, nil +} + +// parsePortRange parses a port definition string and returns a slice of PortRange structs. +func parsePortRange(portDef string) ([]tailcfg.PortRange, error) { + if portDef == "*" { + return []tailcfg.PortRange{tailcfg.PortRangeAny}, nil + } + + var portRanges []tailcfg.PortRange + + parts := strings.SplitSeq(portDef, ",") + + for part := range parts { + if strings.Contains(part, "-") { + rangeParts := strings.Split(part, "-") + rangeParts = slices.DeleteFunc(rangeParts, func(e string) bool { + return e == "" + }) + if len(rangeParts) != 2 { + return nil, errors.New("invalid port range format") + } + + first, err := parsePort(rangeParts[0]) + if err != nil { + return nil, err + } + + last, err := parsePort(rangeParts[1]) + if err != nil { + return nil, err + } + + if first > last { + return nil, errors.New("invalid port range: first port is greater than last port") + } + + portRanges = append(portRanges, tailcfg.PortRange{First: first, Last: last}) + } else { + port, err := parsePort(part) + if err != nil { + return nil, err + } + + if port < 1 { + return nil, errors.New("first port must be >0, or use '*' for wildcard") + } + + portRanges = append(portRanges, tailcfg.PortRange{First: port, Last: port}) + } + } + + return portRanges, nil +} + +// parsePort parses a single port number from a string. +func parsePort(portStr string) (uint16, error) { + port, err := strconv.Atoi(portStr) + if err != nil { + return 0, errors.New("invalid port number") + } + + if port < 0 || port > 65535 { + return 0, errors.New("port number out of range") + } + + return uint16(port), nil +} diff --git a/hscontrol/policy/v2/utils_test.go b/hscontrol/policy/v2/utils_test.go new file mode 100644 index 00000000..2084b22f --- /dev/null +++ b/hscontrol/policy/v2/utils_test.go @@ -0,0 +1,102 @@ +package v2 + +import ( + "errors" + "testing" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" +) + +// TestParseDestinationAndPort tests the parseDestinationAndPort function using table-driven tests. +func TestParseDestinationAndPort(t *testing.T) { + testCases := []struct { + input string + expectedDst string + expectedPort string + expectedErr error + }{ + {"git-server:*", "git-server", "*", nil}, + {"192.168.1.0/24:22", "192.168.1.0/24", "22", nil}, + {"fd7a:115c:a1e0::2:22", "fd7a:115c:a1e0::2", "22", nil}, + {"fd7a:115c:a1e0::2/128:22", "fd7a:115c:a1e0::2/128", "22", nil}, + {"tag:montreal-webserver:80,443", "tag:montreal-webserver", "80,443", nil}, + {"tag:api-server:443", "tag:api-server", "443", nil}, + {"example-host-1:*", "example-host-1", "*", nil}, + {"hostname:80-90", "hostname", "80-90", nil}, + {"invalidinput", "", "", errors.New("input must contain a colon character separating destination and port")}, + {":invalid", "", "", errors.New("input cannot start with a colon character")}, + {"invalid:", "", "", errors.New("input cannot end with a colon character")}, + } + + for _, testCase := range testCases { + dst, port, err := splitDestinationAndPort(testCase.input) + if dst != testCase.expectedDst || port != testCase.expectedPort || (err != nil && err.Error() != testCase.expectedErr.Error()) { + t.Errorf("parseDestinationAndPort(%q) = (%q, %q, %v), want (%q, %q, %v)", + testCase.input, dst, port, err, testCase.expectedDst, testCase.expectedPort, testCase.expectedErr) + } + } +} + +func TestParsePort(t *testing.T) { + tests := []struct { + input string + expected uint16 + err string + }{ + {"80", 80, ""}, + {"0", 0, ""}, + {"65535", 65535, ""}, + {"-1", 0, "port number out of range"}, + {"65536", 0, "port number out of range"}, + {"abc", 0, "invalid port number"}, + {"", 0, "invalid port number"}, + } + + for _, test := range tests { + result, err := parsePort(test.input) + if err != nil && err.Error() != test.err { + t.Errorf("parsePort(%q) error = %v, expected error = %v", test.input, err, test.err) + } + if err == nil && test.err != "" { + t.Errorf("parsePort(%q) expected error = %v, got nil", test.input, test.err) + } + if result != test.expected { + t.Errorf("parsePort(%q) = %v, expected %v", test.input, result, test.expected) + } + } +} + +func TestParsePortRange(t *testing.T) { + tests := []struct { + input string + expected []tailcfg.PortRange + err string + }{ + {"80", []tailcfg.PortRange{{First: 80, Last: 80}}, ""}, + {"80-90", []tailcfg.PortRange{{First: 80, Last: 90}}, ""}, + {"80,90", []tailcfg.PortRange{{First: 80, Last: 80}, {First: 90, Last: 90}}, ""}, + {"80-91,92,93-95", []tailcfg.PortRange{{First: 80, Last: 91}, {First: 92, Last: 92}, {First: 93, Last: 95}}, ""}, + {"*", []tailcfg.PortRange{tailcfg.PortRangeAny}, ""}, + {"80-", nil, "invalid port range format"}, + {"-90", nil, "invalid port range format"}, + {"80-90,", nil, "invalid port number"}, + {"80,90-", nil, "invalid port range format"}, + {"80-90,abc", nil, "invalid port number"}, + {"80-90,65536", nil, "port number out of range"}, + {"80-90,90-80", nil, "invalid port range: first port is greater than last port"}, + } + + for _, test := range tests { + result, err := parsePortRange(test.input) + if err != nil && err.Error() != test.err { + t.Errorf("parsePortRange(%q) error = %v, expected error = %v", test.input, err, test.err) + } + if err == nil && test.err != "" { + t.Errorf("parsePortRange(%q) expected error = %v, got nil", test.input, test.err) + } + if diff := cmp.Diff(result, test.expected); diff != "" { + t.Errorf("parsePortRange(%q) mismatch (-want +got):\n%s", test.input, diff) + } + } +} diff --git a/hscontrol/poll.go b/hscontrol/poll.go index a07fda08..02275751 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -2,569 +2,339 @@ package hscontrol import ( "context" + "encoding/binary" + "encoding/json" "fmt" + "math/rand/v2" "net/http" "time" - "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog" "github.com/rs/zerolog/log" - xslices "golang.org/x/exp/slices" + "github.com/sasha-s/go-deadlock" "tailscale.com/tailcfg" + "tailscale.com/util/zstdframe" ) const ( - keepAliveInterval = 60 * time.Second + keepAliveInterval = 50 * time.Second ) type contextKey string const nodeNameContextKey = contextKey("nodeName") -type UpdateNode func() +type mapSession struct { + h *Headscale + req tailcfg.MapRequest + ctx context.Context + capVer tailcfg.CapabilityVersion -func logPollFunc( - mapRequest tailcfg.MapRequest, - node *types.Node, -) (func(string), func(error, string)) { - return func(msg string) { - log.Info(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Str("node_key", node.NodeKey.ShortString()). - Str("node", node.Hostname). - Msg(msg) - }, - func(err error, msg string) { - log.Error(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Str("node_key", node.NodeKey.ShortString()). - Str("node", node.Hostname). - Err(err). - Msg(msg) - } + cancelChMu deadlock.Mutex + + ch chan *tailcfg.MapResponse + cancelCh chan struct{} + cancelChOpen bool + + keepAlive time.Duration + keepAliveTicker *time.Ticker + + node *types.Node + w http.ResponseWriter } -// handlePoll ensures the node gets the appropriate updates from either -// polling or immediate responses. -// -//nolint:gocyclo -func (h *Headscale) handlePoll( - writer http.ResponseWriter, +func (h *Headscale) newMapSession( ctx context.Context, + req tailcfg.MapRequest, + w http.ResponseWriter, node *types.Node, - mapRequest tailcfg.MapRequest, -) { - logInfo, logErr := logPollFunc(mapRequest, node) +) *mapSession { + ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond) + return &mapSession{ + h: h, + ctx: ctx, + req: req, + w: w, + node: node, + capVer: req.Version, + + ch: make(chan *tailcfg.MapResponse, h.cfg.Tuning.NodeMapSessionBufferedChanSize), + cancelCh: make(chan struct{}), + cancelChOpen: true, + + keepAlive: ka, + keepAliveTicker: nil, + } +} + +func (m *mapSession) isStreaming() bool { + return m.req.Stream +} + +func (m *mapSession) isEndpointUpdate() bool { + return !m.req.Stream && m.req.OmitPeers +} + +func (m *mapSession) resetKeepAlive() { + m.keepAliveTicker.Reset(m.keepAlive) +} + +func (m *mapSession) beforeServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Cancel(m.node.ID) + } +} + +// afterServeLongPoll is called when a long-polling session ends and the node +// is disconnected. +func (m *mapSession) afterServeLongPoll() { + if m.node.IsEphemeral() { + m.h.ephemeralGC.Schedule(m.node.ID, m.h.cfg.EphemeralNodeInactivityTimeout) + } +} + +// serve handles non-streaming requests. +func (m *mapSession) serve() { // This is the mechanism where the node gives us information about its // current configuration. // - // If OmitPeers is true, Stream is false, and ReadOnly is false, - // then te server will let clients update their endpoints without + // Process the MapRequest to update node state (endpoints, hostinfo, etc.) + c, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + httpError(m.w, err) + return + } + + m.h.Change(c) + + // If OmitPeers is true and Stream is false + // then the server will let clients update their endpoints without // breaking existing long-polling (Stream == true) connections. // In this case, the server can omit the entire response; the client // only checks the HTTP response status code. - // TODO(kradalby): remove ReadOnly when we only support capVer 68+ - if mapRequest.OmitPeers && !mapRequest.Stream && !mapRequest.ReadOnly { - log.Info(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Str("node_key", node.NodeKey.ShortString()). - Str("node", node.Hostname). - Int("cap_ver", int(mapRequest.Version)). - Msg("Received update") + // + // This is what Tailscale calls a Lite update, the client ignores + // the response and just wants a 200. + // !req.stream && req.OmitPeers + if m.isEndpointUpdate() { + m.w.WriteHeader(http.StatusOK) + mapResponseEndpointUpdates.WithLabelValues("ok").Inc() + } +} - change := node.PeerChangeFromMapRequest(mapRequest) +// serveLongPoll ensures the node gets the appropriate updates from either +// polling or immediate responses. +// +//nolint:gocyclo +func (m *mapSession) serveLongPoll() { + m.beforeServeLongPoll() - online := h.nodeNotifier.IsConnected(node.MachineKey) - change.Online = &online + log.Trace().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("Long poll session started because client connected") - node.ApplyPeerChange(&change) + // Clean up the session when the client disconnects + defer func() { + m.cancelChMu.Lock() + m.cancelChOpen = false + close(m.cancelCh) + m.cancelChMu.Unlock() - hostInfoChange := node.Hostinfo.Equal(mapRequest.Hostinfo) + _ = m.h.mapBatcher.RemoveNode(m.node.ID, m.ch) - logTracePeerChange(node.Hostname, hostInfoChange, &change) - - // Check if the Hostinfo of the node has changed. - // If it has changed, check if there has been a change tod - // the routable IPs of the host and update update them in - // the database. Then send a Changed update - // (containing the whole node object) to peers to inform about - // the route change. - // If the hostinfo has changed, but not the routes, just update - // hostinfo and let the function continue. - if !hostInfoChange { - oldRoutes := node.Hostinfo.RoutableIPs - newRoutes := mapRequest.Hostinfo.RoutableIPs - - oldServicesCount := len(node.Hostinfo.Services) - newServicesCount := len(mapRequest.Hostinfo.Services) - - node.Hostinfo = mapRequest.Hostinfo - - sendUpdate := false - - // Route changes come as part of Hostinfo, which means that - // when an update comes, the Node Route logic need to run. - // This will require a "change" in comparison to a "patch", - // which is more costly. - if !xslices.Equal(oldRoutes, newRoutes) { - var err error - sendUpdate, err = h.db.SaveNodeRoutes(node) - if err != nil { - logErr(err, "Error processing node routes") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - } - - // Services is mostly useful for discovery and not critical, - // except for peerapi, which is how nodes talk to eachother. - // If peerapi was not part of the initial mapresponse, we - // need to make sure its sent out later as it is needed for - // Taildrop. - // TODO(kradalby): Length comparison is a bit naive, replace. - if oldServicesCount != newServicesCount { - sendUpdate = true - } - - if sendUpdate { - if err := h.db.NodeSave(node); err != nil { - logErr(err, "Failed to persist/update node in the database") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from handlePoll -> update -> new hostinfo", - } - if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore( - stateUpdate, - node.MachineKey.String()) - } - - return - } - } - - if err := h.db.NodeSave(node); err != nil { - logErr(err, "Failed to persist/update node in the database") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{&change}, - } - if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore( - stateUpdate, - node.MachineKey.String()) - } - - writer.WriteHeader(http.StatusOK) - if f, ok := writer.(http.Flusher); ok { - f.Flush() - } - - return - } else if mapRequest.OmitPeers && !mapRequest.Stream && mapRequest.ReadOnly { - // ReadOnly is whether the client just wants to fetch the - // MapResponse, without updating their Endpoints. The - // Endpoints field will be ignored and LastSeen will not be - // updated and peers will not be notified of changes. + // When a node disconnects, it might rapidly reconnect (e.g. mobile clients, network weather). + // Instead of immediately marking the node as offline, we wait a few seconds to see if it reconnects. + // If it does reconnect, the existing mapSession will be replaced and the node remains online. + // If it doesn't reconnect within the timeout, we mark it as offline. // - // The intended use is for clients to discover the DERP map at - // start-up before their first real endpoint update. - } else if mapRequest.OmitPeers && !mapRequest.Stream && mapRequest.ReadOnly { - h.handleLiteRequest(writer, node, mapRequest) - - return - } else if mapRequest.OmitPeers && mapRequest.Stream { - logErr(nil, "Ignoring request, don't know how to handle it") - - return - } - - change := node.PeerChangeFromMapRequest(mapRequest) - - // A stream is being set up, the node is Online - online := true - change.Online = &online - - node.ApplyPeerChange(&change) - - // Only save HostInfo if changed, update routes if changed - // TODO(kradalby): Remove when capver is over 68 - if !node.Hostinfo.Equal(mapRequest.Hostinfo) { - oldRoutes := node.Hostinfo.RoutableIPs - newRoutes := mapRequest.Hostinfo.RoutableIPs - - node.Hostinfo = mapRequest.Hostinfo - - if !xslices.Equal(oldRoutes, newRoutes) { - _, err := h.db.SaveNodeRoutes(node) - if err != nil { - logErr(err, "Error processing node routes") - http.Error(writer, "", http.StatusInternalServerError) - - return + // This avoids flapping nodes in the UI and unnecessary churn in the network. + // This is not my favourite solution, but it kind of works in our eventually consistent world. + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + disconnected := true + // Wait up to 10 seconds for the node to reconnect. + // 10 seconds was arbitrary chosen as a reasonable time to reconnect. + for range 10 { + if m.h.mapBatcher.IsConnected(m.node.ID) { + disconnected = false + break } + <-ticker.C } - } - if err := h.db.NodeSave(node); err != nil { - logErr(err, "Failed to persist/update node in the database") - http.Error(writer, "", http.StatusInternalServerError) + if disconnected { + disconnectChanges, err := m.h.state.Disconnect(m.node.ID) + if err != nil { + m.errf(err, "Failed to disconnect node %s", m.node.Hostname) + } - return - } - - // When a node connects to control, list the peers it has at - // that given point, further updates are kept in memory in - // the Mapper, which lives for the duration of the polling - // session. - peers, err := h.db.ListPeers(node) - if err != nil { - logErr(err, "Failed to list peers when opening poller") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - - for _, peer := range peers { - online := h.nodeNotifier.IsConnected(peer.MachineKey) - peer.IsOnline = &online - } - - mapp := mapper.NewMapper( - node, - peers, - h.DERPMap, - h.cfg.BaseDomain, - h.cfg.DNSConfig, - h.cfg.LogTail.Enabled, - h.cfg.RandomizeClientPort, - ) - - // update ACLRules with peer informations (to update server tags if necessary) - if h.ACLPolicy != nil { - // update routes with peer information - err = h.db.EnableAutoApprovedRoutes(h.ACLPolicy, node) - if err != nil { - logErr(err, "Error running auto approved routes") + m.h.Change(disconnectChanges...) + m.afterServeLongPoll() + m.infof("node has disconnected, mapSession: %p, chan: %p", m, m.ch) } - } - - logInfo("Sending initial map") - - mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) - if err != nil { - logErr(err, "Failed to create MapResponse") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - - // Send the client an update to make sure we send an initial mapresponse - _, err = writer.Write(mapResp) - if err != nil { - logErr(err, "Could not write the map response") - - return - } - - if flusher, ok := writer.(http.Flusher); ok { - flusher.Flush() - } else { - return - } - - stateUpdate := types.StateUpdate{ - Type: types.StatePeerChanged, - ChangeNodes: types.Nodes{node}, - Message: "called from handlePoll -> new node added", - } - if stateUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore( - stateUpdate, - node.MachineKey.String()) - } + }() // Set up the client stream - h.pollNetMapStreamWG.Add(1) - defer h.pollNetMapStreamWG.Done() + m.h.clientStreamsOpen.Add(1) + defer m.h.clientStreamsOpen.Done() - // Use a buffered channel in case a node is not fully ready - // to receive a message to make sure we dont block the entire - // notifier. - // 12 is arbitrarily chosen. - updateChan := make(chan types.StateUpdate, 12) - defer closeChanWithLog(updateChan, node.Hostname, "updateChan") - - // Register the node's update channel - h.nodeNotifier.AddNode(node.MachineKey, updateChan) - defer h.nodeNotifier.RemoveNode(node.MachineKey) - - keepAliveTicker := time.NewTicker(keepAliveInterval) - - ctx = context.WithValue(ctx, nodeNameContextKey, node.Hostname) - - ctx, cancel := context.WithCancel(ctx) + ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname)) defer cancel() - if len(node.Routes) > 0 { - go h.db.EnsureFailoverRouteIsAvailable(node) + m.keepAliveTicker = time.NewTicker(m.keepAlive) + + // Process the initial MapRequest to update node state (endpoints, hostinfo, etc.) + // This must be done BEFORE calling Connect() to ensure routes are properly synchronized. + // When nodes reconnect, they send their hostinfo with announced routes in the MapRequest. + // We need this data in NodeStore before Connect() sets up the primary routes, because + // SubnetRoutes() calculates the intersection of announced and approved routes. If we + // call Connect() first, SubnetRoutes() returns empty (no announced routes yet), causing + // the node to be incorrectly removed from AvailableRoutes. + mapReqChange, err := m.h.state.UpdateNodeFromMapRequest(m.node.ID, m.req) + if err != nil { + m.errf(err, "failed to update node from initial MapRequest") + return } + // Connect the node after its state has been updated. + // We send two separate change notifications because these are distinct operations: + // 1. UpdateNodeFromMapRequest: processes the client's reported state (routes, endpoints, hostinfo) + // 2. Connect: marks the node online and recalculates primary routes based on the updated state + // While this results in two notifications, it ensures route data is synchronized before + // primary route selection occurs, which is critical for proper HA subnet router failover. + connectChanges := m.h.state.Connect(m.node.ID) + + m.infof("node has connected, mapSession: %p, chan: %p", m, m.ch) + + // TODO(kradalby): Redo the comments here + // Add node to batcher so it can receive updates, + // adding this before connecting it to the state ensure that + // it does not miss any updates that might be sent in the split + // time between the node connecting and the batcher being ready. + if err := m.h.mapBatcher.AddNode(m.node.ID, m.ch, m.capVer); err != nil { + m.errf(err, "failed to add node to batcher") + log.Error().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Err(err).Msg("AddNode failed in poll session") + return + } + log.Debug().Caller().Uint64("node.id", m.node.ID.Uint64()).Str("node.name", m.node.Hostname).Msg("AddNode succeeded in poll session because node added to batcher") + + m.h.Change(mapReqChange) + m.h.Change(connectChanges...) + + // Loop through updates and continuously send them to the + // client. for { - logInfo("Waiting for update on stream channel") + // consume channels with update, keep alives or "batch" blocking signals select { - case <-keepAliveTicker.C: - data, err := mapp.KeepAliveResponse(mapRequest, node) - if err != nil { - logErr(err, "Error generating the keep alive msg") - - return - } - _, err = writer.Write(data) - if err != nil { - logErr(err, "Cannot write keep alive message") - - return - } - if flusher, ok := writer.(http.Flusher); ok { - flusher.Flush() - } else { - log.Error().Msg("Failed to create http flusher") - - return - } - - // This goroutine is not ideal, but we have a potential issue here - // where it blocks too long and that holds up updates. - // One alternative is to split these different channels into - // goroutines, but then you might have a problem without a lock - // if a keepalive is written at the same time as an update. - go h.updateNodeOnlineStatus(true, node) - - case update := <-updateChan: - logInfo("Received update") - now := time.Now() - - var data []byte - var err error - - switch update.Type { - case types.StateFullUpdate: - logInfo("Sending Full MapResponse") - - data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) - case types.StatePeerChanged: - logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) - - for _, node := range update.ChangeNodes { - // If a node is not reported to be online, it might be - // because the value is outdated, check with the notifier. - // However, if it is set to Online, and not in the notifier, - // this might be because it has announced itself, but not - // reached the stage to actually create the notifier channel. - if node.IsOnline != nil && !*node.IsOnline { - isOnline := h.nodeNotifier.IsConnected(node.MachineKey) - node.IsOnline = &isOnline - } - } - - data, err = mapp.PeerChangedResponse(mapRequest, node, update.ChangeNodes, h.ACLPolicy, update.Message) - case types.StatePeerChangedPatch: - logInfo("Sending PeerChangedPatch MapResponse") - data, err = mapp.PeerChangedPatchResponse(mapRequest, node, update.ChangePatches, h.ACLPolicy) - case types.StatePeerRemoved: - logInfo("Sending PeerRemoved MapResponse") - data, err = mapp.PeerRemovedResponse(mapRequest, node, update.Removed) - case types.StateDERPUpdated: - logInfo("Sending DERPUpdate MapResponse") - data, err = mapp.DERPMapResponse(mapRequest, node, update.DERPMap) - } - - if err != nil { - logErr(err, "Could not get the create map update") - - return - } - - // Only send update if there is change - if data != nil { - _, err = writer.Write(data) - if err != nil { - logErr(err, "Could not write the map response") - - updateRequestsSentToNode.WithLabelValues(node.User.Name, node.Hostname, "failed"). - Inc() - - return - } - - if flusher, ok := writer.(http.Flusher); ok { - flusher.Flush() - } else { - log.Error().Msg("Failed to create http flusher") - - return - } - - log.Info(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Str("node_key", node.NodeKey.ShortString()). - Str("machine_key", node.MachineKey.ShortString()). - Str("node", node.Hostname). - TimeDiff("timeSpent", time.Now(), now). - Msg("update sent") - } + case <-m.cancelCh: + m.tracef("poll cancelled received") + mapResponseEnded.WithLabelValues("cancelled").Inc() + return case <-ctx.Done(): - logInfo("The client has closed the connection") - - go h.updateNodeOnlineStatus(false, node) - - // Failover the node's routes if any. - go h.db.FailoverNodeRoutesWithNotify(node) - - // The connection has been closed, so we can stop polling. + m.tracef("poll context done chan:%p", m.ch) + mapResponseEnded.WithLabelValues("done").Inc() return - case <-h.shutdownChan: - logInfo("The long-poll handler is shutting down") + // Consume updates sent to node + case update, ok := <-m.ch: + m.tracef("received update from channel, ok: %t", ok) + if !ok { + m.tracef("update channel closed, streaming session is likely being replaced") + return + } - return + if err := m.writeMap(update); err != nil { + m.errf(err, "cannot write update to client") + return + } + + m.tracef("update sent") + m.resetKeepAlive() + + case <-m.keepAliveTicker.C: + if err := m.writeMap(&keepAlive); err != nil { + m.errf(err, "cannot write keep alive") + return + } + + if debugHighCardinalityMetrics { + mapResponseLastSentSeconds.WithLabelValues("keepalive", m.node.ID.String()).Set(float64(time.Now().Unix())) + } + mapResponseSent.WithLabelValues("ok", "keepalive").Inc() + m.resetKeepAlive() } } } -// updateNodeOnlineStatus records the last seen status of a node and notifies peers -// about change in their online/offline status. -// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. -func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { - now := time.Now() - - node.LastSeen = &now - - statusUpdate := types.StateUpdate{ - Type: types.StatePeerChangedPatch, - ChangePatches: []*tailcfg.PeerChange{ - { - NodeID: tailcfg.NodeID(node.ID), - Online: &online, - LastSeen: &now, - }, - }, - } - if statusUpdate.Valid() { - h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String()) - } - - err := h.db.UpdateLastSeen(node) +// writeMap writes the map response to the client. +// It handles compression if requested and any headers that need to be set. +// It also handles flushing the response if the ResponseWriter +// implements http.Flusher. +func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error { + jsonBody, err := json.Marshal(msg) if err != nil { - log.Error().Err(err).Msg("Cannot update node LastSeen") - - return + return fmt.Errorf("marshalling map response: %w", err) + } + + if m.req.Compress == util.ZstdCompression { + jsonBody = zstdframe.AppendEncode(nil, jsonBody, zstdframe.FastestCompression) + } + + data := make([]byte, reservedResponseHeaderSize) + //nolint:gosec // G115: JSON response size will not exceed uint32 max + binary.LittleEndian.PutUint32(data, uint32(len(jsonBody))) + data = append(data, jsonBody...) + + startWrite := time.Now() + + _, err = m.w.Write(data) + if err != nil { + return err + } + + if m.isStreaming() { + if f, ok := m.w.(http.Flusher); ok { + f.Flush() + } else { + m.errf(nil, "ResponseWriter does not implement http.Flusher, cannot flush") + } } -} -func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) { log.Trace(). - Str("handler", "PollNetMap"). - Str("node", node). - Str("channel", "Done"). - Msg(fmt.Sprintf("Closing %s channel", name)) + Caller(). + Str("node.name", m.node.Hostname). + Uint64("node.id", m.node.ID.Uint64()). + Str("chan", fmt.Sprintf("%p", m.ch)). + TimeDiff("timeSpent", time.Now(), startWrite). + Str("machine.key", m.node.MachineKey.String()). + Bool("keepalive", msg.KeepAlive). + Msgf("finished writing mapresp to node chan(%p)", m.ch) - close(channel) + return nil } -func (h *Headscale) handleLiteRequest( - writer http.ResponseWriter, - node *types.Node, - mapRequest tailcfg.MapRequest, -) { - logInfo, logErr := logPollFunc(mapRequest, node) - - mapp := mapper.NewMapper( - node, - types.Nodes{}, - h.DERPMap, - h.cfg.BaseDomain, - h.cfg.DNSConfig, - h.cfg.LogTail.Enabled, - h.cfg.RandomizeClientPort, - ) - - logInfo("Client asked for a lite update, responding without peers") - - mapResp, err := mapp.LiteMapResponse(mapRequest, node, h.ACLPolicy) - if err != nil { - logErr(err, "Failed to create MapResponse") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - - writer.Header().Set("Content-Type", "application/json; charset=utf-8") - writer.WriteHeader(http.StatusOK) - _, err = writer.Write(mapResp) - if err != nil { - logErr(err, "Failed to write response") - } +var keepAlive = tailcfg.MapResponse{ + KeepAlive: true, } -func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { - trace := log.Trace().Str("node_id", change.NodeID.String()).Str("hostname", hostname) - - if change.Key != nil { - trace = trace.Str("node_key", change.Key.ShortString()) - } - - if change.DiscoKey != nil { - trace = trace.Str("disco_key", change.DiscoKey.ShortString()) - } - - if change.Online != nil { - trace = trace.Bool("online", *change.Online) - } - - if change.Endpoints != nil { - eps := make([]string, len(change.Endpoints)) - for idx, ep := range change.Endpoints { - eps[idx] = ep.String() - } - - trace = trace.Strs("endpoints", eps) - } - - if hostinfoChange { - trace = trace.Bool("hostinfo_changed", hostinfoChange) - } - - if change.DERPRegion != 0 { - trace = trace.Int("derp_region", change.DERPRegion) - } - - trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received") +// logf adds common mapSession context to a zerolog event. +func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event { + return event. + Bool("omitPeers", m.req.OmitPeers). + Bool("stream", m.req.Stream). + Uint64("node.id", m.node.ID.Uint64()). + Str("node.name", m.node.Hostname) +} + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) } + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) } + +//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf +func (m *mapSession) errf(err error, msg string, a ...any) { + m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...) } diff --git a/hscontrol/poll_noise.go b/hscontrol/poll_noise.go deleted file mode 100644 index ee1b67f9..00000000 --- a/hscontrol/poll_noise.go +++ /dev/null @@ -1,96 +0,0 @@ -package hscontrol - -import ( - "encoding/json" - "errors" - "io" - "net/http" - - "github.com/rs/zerolog/log" - "gorm.io/gorm" - "tailscale.com/tailcfg" - "tailscale.com/types/key" -) - -const ( - MinimumCapVersion tailcfg.CapabilityVersion = 36 -) - -// NoisePollNetMapHandler takes care of /machine/:id/map using the Noise protocol -// -// This is the busiest endpoint, as it keeps the HTTP long poll that updates -// the clients when something in the network changes. -// -// The clients POST stuff like HostInfo and their Endpoints here, but -// only after their first request (marked with the ReadOnly field). -// -// At this moment the updates are sent in a quite horrendous way, but they kinda work. -func (ns *noiseServer) NoisePollNetMapHandler( - writer http.ResponseWriter, - req *http.Request, -) { - log.Trace(). - Str("handler", "NoisePollNetMap"). - Msg("PollNetMapHandler called") - - log.Trace(). - Any("headers", req.Header). - Caller(). - Msg("Headers") - - body, _ := io.ReadAll(req.Body) - - mapRequest := tailcfg.MapRequest{} - if err := json.Unmarshal(body, &mapRequest); err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot parse MapRequest") - http.Error(writer, "Internal error", http.StatusInternalServerError) - - return - } - - // Reject unsupported versions - if mapRequest.Version < MinimumCapVersion { - log.Info(). - Caller(). - Int("min_version", int(MinimumCapVersion)). - Int("client_version", int(mapRequest.Version)). - Msg("unsupported client connected") - http.Error(writer, "Internal error", http.StatusBadRequest) - - return - } - - ns.nodeKey = mapRequest.NodeKey - - node, err := ns.headscale.db.GetNodeByAnyKey( - ns.conn.Peer(), - mapRequest.NodeKey, - key.NodePublic{}, - ) - if err != nil { - if errors.Is(err, gorm.ErrRecordNotFound) { - log.Warn(). - Str("handler", "NoisePollNetMap"). - Msgf("Ignoring request, cannot find node with key %s", mapRequest.NodeKey.String()) - http.Error(writer, "Internal error", http.StatusNotFound) - - return - } - log.Error(). - Str("handler", "NoisePollNetMap"). - Msgf("Failed to fetch node from the database with node key: %s", mapRequest.NodeKey.String()) - http.Error(writer, "Internal error", http.StatusInternalServerError) - - return - } - log.Debug(). - Str("handler", "NoisePollNetMap"). - Str("node", node.Hostname). - Int("cap_ver", int(mapRequest.Version)). - Msg("A node sending a MapRequest with Noise protocol") - - ns.headscale.handlePoll(writer, req.Context(), node, mapRequest) -} diff --git a/hscontrol/routes/primary.go b/hscontrol/routes/primary.go new file mode 100644 index 00000000..977dc7a9 --- /dev/null +++ b/hscontrol/routes/primary.go @@ -0,0 +1,307 @@ +package routes + +import ( + "fmt" + "net/netip" + "slices" + "sort" + "strings" + "sync" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + xmaps "golang.org/x/exp/maps" + "tailscale.com/net/tsaddr" + "tailscale.com/util/set" +) + +type PrimaryRoutes struct { + mu sync.Mutex + + // routes is a map of prefixes that are adverties and approved and available + // in the global headscale state. + routes map[types.NodeID]set.Set[netip.Prefix] + + // primaries is a map of prefixes to the node that is the primary for that prefix. + primaries map[netip.Prefix]types.NodeID + isPrimary map[types.NodeID]bool +} + +func New() *PrimaryRoutes { + return &PrimaryRoutes{ + routes: make(map[types.NodeID]set.Set[netip.Prefix]), + primaries: make(map[netip.Prefix]types.NodeID), + isPrimary: make(map[types.NodeID]bool), + } +} + +// updatePrimaryLocked recalculates the primary routes and updates the internal state. +// It returns true if the primary routes have changed. +// It is assumed that the caller holds the lock. +// The algorithm is as follows: +// 1. Reset the primaries map. +// 2. Iterate over the routes and count the number of times a prefix is advertised. +// 3. If a prefix is advertised by at least two nodes, it is a primary route. +// 4. If the primary routes have changed, update the internal state and return true. +// 5. Otherwise, return false. +func (pr *PrimaryRoutes) updatePrimaryLocked() bool { + log.Debug().Caller().Msg("updatePrimaryLocked starting") + + // reset the primaries map, as we are going to recalculate it. + allPrimaries := make(map[netip.Prefix][]types.NodeID) + pr.isPrimary = make(map[types.NodeID]bool) + changed := false + + // sort the node ids so we can iterate over them in a deterministic order. + // this is important so the same node is chosen two times in a row + // as the primary route. + ids := types.NodeIDs(xmaps.Keys(pr.routes)) + sort.Sort(ids) + + // Create a map of prefixes to nodes that serve them so we + // can determine the primary route for each prefix. + for _, id := range ids { + routes := pr.routes[id] + for route := range routes { + if _, ok := allPrimaries[route]; !ok { + allPrimaries[route] = []types.NodeID{id} + } else { + allPrimaries[route] = append(allPrimaries[route], id) + } + } + } + + // Go through all prefixes and determine the primary route for each. + // If the number of routes is below the minimum, remove the primary. + // If the current primary is still available, continue. + // If the current primary is not available, select a new one. + for prefix, nodes := range allPrimaries { + log.Debug(). + Caller(). + Str("prefix", prefix.String()). + Uints64("availableNodes", func() []uint64 { + ids := make([]uint64, len(nodes)) + for i, id := range nodes { + ids[i] = id.Uint64() + } + + return ids + }()). + Msg("Processing prefix for primary route selection") + + if node, ok := pr.primaries[prefix]; ok { + // If the current primary is still available, continue. + if slices.Contains(nodes, node) { + log.Debug(). + Caller(). + Str("prefix", prefix.String()). + Uint64("currentPrimary", node.Uint64()). + Msg("Current primary still available, keeping it") + + continue + } else { + log.Debug(). + Caller(). + Str("prefix", prefix.String()). + Uint64("oldPrimary", node.Uint64()). + Msg("Current primary no longer available") + } + } + if len(nodes) >= 1 { + pr.primaries[prefix] = nodes[0] + changed = true + log.Debug(). + Caller(). + Str("prefix", prefix.String()). + Uint64("newPrimary", nodes[0].Uint64()). + Msg("Selected new primary for prefix") + } + } + + // Clean up any remaining primaries that are no longer valid. + for prefix := range pr.primaries { + if _, ok := allPrimaries[prefix]; !ok { + log.Debug(). + Caller(). + Str("prefix", prefix.String()). + Msg("Cleaning up primary route that no longer has available nodes") + delete(pr.primaries, prefix) + changed = true + } + } + + // Populate the quick lookup index for primary routes + for _, nodeID := range pr.primaries { + pr.isPrimary[nodeID] = true + } + + log.Debug(). + Caller(). + Bool("changed", changed). + Str("finalState", pr.stringLocked()). + Msg("updatePrimaryLocked completed") + + return changed +} + +// SetRoutes sets the routes for a given Node ID and recalculates the primary routes +// of the headscale. +// It returns true if there was a change in primary routes. +// All exit routes are ignored as they are not used in primary route context. +func (pr *PrimaryRoutes) SetRoutes(node types.NodeID, prefixes ...netip.Prefix) bool { + pr.mu.Lock() + defer pr.mu.Unlock() + + log.Debug(). + Caller(). + Uint64("node.id", node.Uint64()). + Strs("prefixes", util.PrefixesToString(prefixes)). + Msg("PrimaryRoutes.SetRoutes called") + + // If no routes are being set, remove the node from the routes map. + if len(prefixes) == 0 { + wasPresent := false + if _, ok := pr.routes[node]; ok { + delete(pr.routes, node) + wasPresent = true + log.Debug(). + Caller(). + Uint64("node.id", node.Uint64()). + Msg("Removed node from primary routes (no prefixes)") + } + changed := pr.updatePrimaryLocked() + log.Debug(). + Caller(). + Uint64("node.id", node.Uint64()). + Bool("wasPresent", wasPresent). + Bool("changed", changed). + Str("newState", pr.stringLocked()). + Msg("SetRoutes completed (remove)") + + return changed + } + + rs := make(set.Set[netip.Prefix], len(prefixes)) + for _, prefix := range prefixes { + if !tsaddr.IsExitRoute(prefix) { + rs.Add(prefix) + } + } + + if rs.Len() != 0 { + pr.routes[node] = rs + log.Debug(). + Caller(). + Uint64("node.id", node.Uint64()). + Strs("routes", util.PrefixesToString(rs.Slice())). + Msg("Updated node routes in primary route manager") + } else { + delete(pr.routes, node) + log.Debug(). + Caller(). + Uint64("node.id", node.Uint64()). + Msg("Removed node from primary routes (only exit routes)") + } + + changed := pr.updatePrimaryLocked() + log.Debug(). + Caller(). + Uint64("node.id", node.Uint64()). + Bool("changed", changed). + Str("newState", pr.stringLocked()). + Msg("SetRoutes completed (update)") + + return changed +} + +func (pr *PrimaryRoutes) PrimaryRoutes(id types.NodeID) []netip.Prefix { + if pr == nil { + return nil + } + + pr.mu.Lock() + defer pr.mu.Unlock() + + // Short circuit if the node is not a primary for any route. + if _, ok := pr.isPrimary[id]; !ok { + return nil + } + + var routes []netip.Prefix + + for prefix, node := range pr.primaries { + if node == id { + routes = append(routes, prefix) + } + } + + tsaddr.SortPrefixes(routes) + + return routes +} + +func (pr *PrimaryRoutes) String() string { + pr.mu.Lock() + defer pr.mu.Unlock() + + return pr.stringLocked() +} + +func (pr *PrimaryRoutes) stringLocked() string { + var sb strings.Builder + + fmt.Fprintln(&sb, "Available routes:") + + ids := types.NodeIDs(xmaps.Keys(pr.routes)) + sort.Sort(ids) + for _, id := range ids { + prefixes := pr.routes[id] + fmt.Fprintf(&sb, "\nNode %d: %s", id, strings.Join(util.PrefixesToString(prefixes.Slice()), ", ")) + } + + fmt.Fprintln(&sb, "\n\nCurrent primary routes:") + for route, nodeID := range pr.primaries { + fmt.Fprintf(&sb, "\nRoute %s: %d", route, nodeID) + } + + return sb.String() +} + +// DebugRoutes represents the primary routes state in a structured format for JSON serialization. +type DebugRoutes struct { + // AvailableRoutes maps node IDs to their advertised routes + // In the context of primary routes, this represents the routes that are available + // for each node. A route will only be available if it is advertised by the node + // AND approved. + // Only routes by nodes currently connected to the headscale server are included. + AvailableRoutes map[types.NodeID][]netip.Prefix `json:"available_routes"` + + // PrimaryRoutes maps route prefixes to the primary node serving them + PrimaryRoutes map[string]types.NodeID `json:"primary_routes"` +} + +// DebugJSON returns a structured representation of the primary routes state suitable for JSON serialization. +func (pr *PrimaryRoutes) DebugJSON() DebugRoutes { + pr.mu.Lock() + defer pr.mu.Unlock() + + debug := DebugRoutes{ + AvailableRoutes: make(map[types.NodeID][]netip.Prefix), + PrimaryRoutes: make(map[string]types.NodeID), + } + + // Populate available routes + for nodeID, routes := range pr.routes { + prefixes := routes.Slice() + tsaddr.SortPrefixes(prefixes) + debug.AvailableRoutes[nodeID] = prefixes + } + + // Populate primary routes + for prefix, nodeID := range pr.primaries { + debug.PrimaryRoutes[prefix.String()] = nodeID + } + + return debug +} diff --git a/hscontrol/routes/primary_test.go b/hscontrol/routes/primary_test.go new file mode 100644 index 00000000..7a9767b2 --- /dev/null +++ b/hscontrol/routes/primary_test.go @@ -0,0 +1,468 @@ +package routes + +import ( + "net/netip" + "sync" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "tailscale.com/util/set" +) + +// mp is a helper function that wraps netip.MustParsePrefix. +func mp(prefix string) netip.Prefix { + return netip.MustParsePrefix(prefix) +} + +func TestPrimaryRoutes(t *testing.T) { + tests := []struct { + name string + operations func(pr *PrimaryRoutes) bool + expectedRoutes map[types.NodeID]set.Set[netip.Prefix] + expectedPrimaries map[netip.Prefix]types.NodeID + expectedIsPrimary map[types.NodeID]bool + expectedChange bool + + // primaries is a map of prefixes to the node that is the primary for that prefix. + primaries map[netip.Prefix]types.NodeID + isPrimary map[types.NodeID]bool + }{ + { + name: "single-node-registers-single-route", + operations: func(pr *PrimaryRoutes) bool { + return pr.SetRoutes(1, mp("192.168.1.0/24")) + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: true, + }, + { + name: "multiple-nodes-register-different-routes", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) + return pr.SetRoutes(2, mp("192.168.2.0/24")) + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.2.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + mp("192.168.2.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + expectedChange: true, + }, + { + name: "multiple-nodes-register-overlapping-routes", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // true + return pr.SetRoutes(2, mp("192.168.1.0/24")) // false + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, + }, + { + name: "node-deregisters-a-route", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) + return pr.SetRoutes(1) // Deregister by setting no routes + }, + expectedRoutes: nil, + expectedPrimaries: nil, + expectedIsPrimary: nil, + expectedChange: true, + }, + { + name: "node-deregisters-one-of-multiple-routes", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24"), mp("192.168.2.0/24")) + return pr.SetRoutes(1, mp("192.168.2.0/24")) // Deregister one route by setting the remaining route + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.2.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.2.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: true, + }, + { + name: "node-registers-and-deregisters-routes-in-sequence", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) + pr.SetRoutes(2, mp("192.168.2.0/24")) + pr.SetRoutes(1) // Deregister by setting no routes + return pr.SetRoutes(1, mp("192.168.3.0/24")) + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.3.0/24"): {}, + }, + 2: { + mp("192.168.2.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.2.0/24"): 2, + mp("192.168.3.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + expectedChange: true, + }, + { + name: "multiple-nodes-register-same-route", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true + return pr.SetRoutes(3, mp("192.168.1.0/24")) // false + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, + }, + { + name: "register-multiple-routes-shift-primary-check-primary", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + return pr.SetRoutes(1) // true, 2 primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 2: true, + }, + expectedChange: true, + }, + { + name: "primary-route-map-is-cleared-up-no-primary", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(2) // true, no primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 3, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 3: true, + }, + expectedChange: true, + }, + { + name: "primary-route-map-is-cleared-up-all-no-primary", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + pr.SetRoutes(1) // true, 2 primary + pr.SetRoutes(2) // true, no primary + + return pr.SetRoutes(3) // false, no primary + }, + expectedChange: true, + }, + { + name: "primary-route-map-is-cleared-up", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(2) // true, no primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 3, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 3: true, + }, + expectedChange: true, + }, + { + name: "primary-route-no-flake", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 2: true, + }, + expectedChange: false, + }, + { + name: "primary-route-no-flake-check-old-primary", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + pr.SetRoutes(1) // true, 2 primary + + return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 2 primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 2: true, + }, + expectedChange: false, + }, + { + name: "primary-route-no-flake-full-integration", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("192.168.1.0/24")) // false + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 1 primary + pr.SetRoutes(3, mp("192.168.1.0/24")) // false, 1 primary + pr.SetRoutes(1) // true, 2 primary + pr.SetRoutes(2) // true, 3 primary + pr.SetRoutes(1, mp("192.168.1.0/24")) // true, 3 primary + pr.SetRoutes(2, mp("192.168.1.0/24")) // true, 3 primary + pr.SetRoutes(1) // true, 3 primary + + return pr.SetRoutes(1, mp("192.168.1.0/24")) // false, 3 primary + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + 3: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 3, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 3: true, + }, + expectedChange: false, + }, + { + name: "multiple-nodes-register-same-route-and-exit", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("0.0.0.0/0"), mp("192.168.1.0/24")) + return pr.SetRoutes(2, mp("192.168.1.0/24")) + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.1.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, + }, + { + name: "deregister-non-existent-route", + operations: func(pr *PrimaryRoutes) bool { + return pr.SetRoutes(1) // Deregister by setting no routes + }, + expectedRoutes: nil, + expectedChange: false, + }, + { + name: "register-empty-prefix-list", + operations: func(pr *PrimaryRoutes) bool { + return pr.SetRoutes(1) + }, + expectedRoutes: nil, + expectedChange: false, + }, + { + name: "exit-nodes", + operations: func(pr *PrimaryRoutes) bool { + pr.SetRoutes(1, mp("10.0.0.0/16"), mp("0.0.0.0/0"), mp("::/0")) + pr.SetRoutes(3, mp("0.0.0.0/0"), mp("::/0")) + return pr.SetRoutes(2, mp("0.0.0.0/0"), mp("::/0")) + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("10.0.0.0/16"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("10.0.0.0/16"): 1, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + }, + expectedChange: false, + }, + { + name: "concurrent-access", + operations: func(pr *PrimaryRoutes) bool { + var wg sync.WaitGroup + wg.Add(2) + var change1, change2 bool + go func() { + defer wg.Done() + change1 = pr.SetRoutes(1, mp("192.168.1.0/24")) + }() + go func() { + defer wg.Done() + change2 = pr.SetRoutes(2, mp("192.168.2.0/24")) + }() + wg.Wait() + + return change1 || change2 + }, + expectedRoutes: map[types.NodeID]set.Set[netip.Prefix]{ + 1: { + mp("192.168.1.0/24"): {}, + }, + 2: { + mp("192.168.2.0/24"): {}, + }, + }, + expectedPrimaries: map[netip.Prefix]types.NodeID{ + mp("192.168.1.0/24"): 1, + mp("192.168.2.0/24"): 2, + }, + expectedIsPrimary: map[types.NodeID]bool{ + 1: true, + 2: true, + }, + expectedChange: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pr := New() + change := tt.operations(pr) + if change != tt.expectedChange { + t.Errorf("change = %v, want %v", change, tt.expectedChange) + } + comps := append(util.Comparers, cmpopts.EquateEmpty()) + if diff := cmp.Diff(tt.expectedRoutes, pr.routes, comps...); diff != "" { + t.Errorf("routes mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.expectedPrimaries, pr.primaries, comps...); diff != "" { + t.Errorf("primaries mismatch (-want +got):\n%s", diff) + } + if diff := cmp.Diff(tt.expectedIsPrimary, pr.isPrimary, comps...); diff != "" { + t.Errorf("isPrimary mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go new file mode 100644 index 00000000..3ed1d79f --- /dev/null +++ b/hscontrol/state/debug.go @@ -0,0 +1,376 @@ +package state + +import ( + "fmt" + "strings" + "time" + + hsdb "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/routes" + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +// DebugOverviewInfo represents the state overview information in a structured format. +type DebugOverviewInfo struct { + Nodes struct { + Total int `json:"total"` + Online int `json:"online"` + Expired int `json:"expired"` + Ephemeral int `json:"ephemeral"` + } `json:"nodes"` + Users map[string]int `json:"users"` // username -> node count + TotalUsers int `json:"total_users"` + Policy struct { + Mode string `json:"mode"` + Path string `json:"path,omitempty"` + } `json:"policy"` + DERP struct { + Configured bool `json:"configured"` + Regions int `json:"regions"` + } `json:"derp"` + PrimaryRoutes int `json:"primary_routes"` +} + +// DebugDERPInfo represents DERP map information in a structured format. +type DebugDERPInfo struct { + Configured bool `json:"configured"` + TotalRegions int `json:"total_regions"` + Regions map[int]*DebugDERPRegion `json:"regions,omitempty"` +} + +// DebugDERPRegion represents a single DERP region. +type DebugDERPRegion struct { + RegionID int `json:"region_id"` + RegionName string `json:"region_name"` + Nodes []*DebugDERPNode `json:"nodes"` +} + +// DebugDERPNode represents a single DERP node. +type DebugDERPNode struct { + Name string `json:"name"` + HostName string `json:"hostname"` + DERPPort int `json:"derp_port"` + STUNPort int `json:"stun_port,omitempty"` +} + +// DebugStringInfo wraps a debug string for JSON serialization. +type DebugStringInfo struct { + Content string `json:"content"` +} + +// DebugOverview returns a comprehensive overview of the current state for debugging. +func (s *State) DebugOverview() string { + allNodes := s.nodeStore.ListNodes() + users, _ := s.ListAllUsers() + + var sb strings.Builder + + sb.WriteString("=== Headscale State Overview ===\n\n") + + // Node statistics + sb.WriteString(fmt.Sprintf("Nodes: %d total\n", allNodes.Len())) + + userNodeCounts := make(map[string]int) + onlineCount := 0 + expiredCount := 0 + ephemeralCount := 0 + + now := time.Now() + for _, node := range allNodes.All() { + if node.Valid() { + userName := node.Owner().Name() + userNodeCounts[userName]++ + + if node.IsOnline().Valid() && node.IsOnline().Get() { + onlineCount++ + } + + if node.Expiry().Valid() && node.Expiry().Get().Before(now) { + expiredCount++ + } + + if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + ephemeralCount++ + } + } + } + + sb.WriteString(fmt.Sprintf(" - Online: %d\n", onlineCount)) + sb.WriteString(fmt.Sprintf(" - Expired: %d\n", expiredCount)) + sb.WriteString(fmt.Sprintf(" - Ephemeral: %d\n", ephemeralCount)) + sb.WriteString("\n") + + // User statistics + sb.WriteString(fmt.Sprintf("Users: %d total\n", len(users))) + for userName, nodeCount := range userNodeCounts { + sb.WriteString(fmt.Sprintf(" - %s: %d nodes\n", userName, nodeCount)) + } + sb.WriteString("\n") + + // Policy information + sb.WriteString("Policy:\n") + sb.WriteString(fmt.Sprintf(" - Mode: %s\n", s.cfg.Policy.Mode)) + if s.cfg.Policy.Mode == types.PolicyModeFile { + sb.WriteString(fmt.Sprintf(" - Path: %s\n", s.cfg.Policy.Path)) + } + sb.WriteString("\n") + + // DERP information + derpMap := s.derpMap.Load() + if derpMap != nil { + sb.WriteString(fmt.Sprintf("DERP: %d regions configured\n", len(derpMap.Regions))) + } else { + sb.WriteString("DERP: not configured\n") + } + sb.WriteString("\n") + + // Route information + routeCount := len(strings.Split(strings.TrimSpace(s.primaryRoutes.String()), "\n")) + if s.primaryRoutes.String() == "" { + routeCount = 0 + } + sb.WriteString(fmt.Sprintf("Primary Routes: %d active\n", routeCount)) + sb.WriteString("\n") + + // Registration cache + sb.WriteString("Registration Cache: active\n") + sb.WriteString("\n") + + return sb.String() +} + +// DebugNodeStore returns debug information about the NodeStore. +func (s *State) DebugNodeStore() string { + return s.nodeStore.DebugString() +} + +// DebugDERPMap returns debug information about the DERP map configuration. +func (s *State) DebugDERPMap() string { + derpMap := s.derpMap.Load() + if derpMap == nil { + return "DERP Map: not configured\n" + } + + var sb strings.Builder + + sb.WriteString("=== DERP Map Configuration ===\n\n") + + sb.WriteString(fmt.Sprintf("Total Regions: %d\n\n", len(derpMap.Regions))) + + for regionID, region := range derpMap.Regions { + sb.WriteString(fmt.Sprintf("Region %d: %s\n", regionID, region.RegionName)) + sb.WriteString(fmt.Sprintf(" - Nodes: %d\n", len(region.Nodes))) + + for _, node := range region.Nodes { + sb.WriteString(fmt.Sprintf(" - %s (%s:%d)\n", + node.Name, node.HostName, node.DERPPort)) + if node.STUNPort != 0 { + sb.WriteString(fmt.Sprintf(" STUN: %d\n", node.STUNPort)) + } + } + sb.WriteString("\n") + } + + return sb.String() +} + +// DebugSSHPolicies returns debug information about SSH policies for all nodes. +func (s *State) DebugSSHPolicies() map[string]*tailcfg.SSHPolicy { + nodes := s.nodeStore.ListNodes() + + sshPolicies := make(map[string]*tailcfg.SSHPolicy) + + for _, node := range nodes.All() { + if !node.Valid() { + continue + } + + pol, err := s.SSHPolicy(node) + if err != nil { + // Store the error information + continue + } + + key := fmt.Sprintf("id:%d hostname:%s givenname:%s", + node.ID(), node.Hostname(), node.GivenName()) + sshPolicies[key] = pol + } + + return sshPolicies +} + +// DebugRegistrationCache returns debug information about the registration cache. +func (s *State) DebugRegistrationCache() map[string]any { + // The cache doesn't expose internal statistics, so we provide basic info + result := map[string]any{ + "type": "zcache", + "expiration": registerCacheExpiration.String(), + "cleanup": registerCacheCleanup.String(), + "status": "active", + } + + return result +} + +// DebugConfig returns debug information about the current configuration. +func (s *State) DebugConfig() *types.Config { + return s.cfg +} + +// DebugPolicy returns the current policy data as a string. +func (s *State) DebugPolicy() (string, error) { + switch s.cfg.Policy.Mode { + case types.PolicyModeDB: + p, err := s.GetPolicy() + if err != nil { + return "", err + } + + return p.Data, nil + case types.PolicyModeFile: + pol, err := hsdb.PolicyBytes(s.db.DB, s.cfg) + if err != nil { + return "", err + } + + return string(pol), nil + default: + return "", fmt.Errorf("unsupported policy mode: %s", s.cfg.Policy.Mode) + } +} + +// DebugFilter returns the current filter rules and matchers. +func (s *State) DebugFilter() ([]tailcfg.FilterRule, error) { + filter, _ := s.Filter() + return filter, nil +} + +// DebugRoutes returns the current primary routes information as a structured object. +func (s *State) DebugRoutes() routes.DebugRoutes { + return s.primaryRoutes.DebugJSON() +} + +// DebugRoutesString returns the current primary routes information as a string. +func (s *State) DebugRoutesString() string { + return s.PrimaryRoutesString() +} + +// DebugPolicyManager returns the policy manager debug string. +func (s *State) DebugPolicyManager() string { + return s.PolicyDebugString() +} + +// PolicyDebugString returns a debug representation of the current policy. +func (s *State) PolicyDebugString() string { + return s.polMan.DebugString() +} + +// DebugOverviewJSON returns a structured overview of the current state for debugging. +func (s *State) DebugOverviewJSON() DebugOverviewInfo { + allNodes := s.nodeStore.ListNodes() + users, _ := s.ListAllUsers() + + info := DebugOverviewInfo{ + Users: make(map[string]int), + TotalUsers: len(users), + } + + // Node statistics + info.Nodes.Total = allNodes.Len() + now := time.Now() + + for _, node := range allNodes.All() { + if node.Valid() { + userName := node.Owner().Name() + info.Users[userName]++ + + if node.IsOnline().Valid() && node.IsOnline().Get() { + info.Nodes.Online++ + } + + if node.Expiry().Valid() && node.Expiry().Get().Before(now) { + info.Nodes.Expired++ + } + + if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + info.Nodes.Ephemeral++ + } + } + } + + // Policy information + info.Policy.Mode = string(s.cfg.Policy.Mode) + if s.cfg.Policy.Mode == types.PolicyModeFile { + info.Policy.Path = s.cfg.Policy.Path + } + + derpMap := s.derpMap.Load() + if derpMap != nil { + info.DERP.Configured = true + info.DERP.Regions = len(derpMap.Regions) + } else { + info.DERP.Configured = false + info.DERP.Regions = 0 + } + + // Route information + routeCount := len(strings.Split(strings.TrimSpace(s.primaryRoutes.String()), "\n")) + if s.primaryRoutes.String() == "" { + routeCount = 0 + } + info.PrimaryRoutes = routeCount + + return info +} + +// DebugDERPJSON returns structured debug information about the DERP map configuration. +func (s *State) DebugDERPJSON() DebugDERPInfo { + derpMap := s.derpMap.Load() + + info := DebugDERPInfo{ + Configured: derpMap != nil, + Regions: make(map[int]*DebugDERPRegion), + } + + if derpMap == nil { + return info + } + + info.TotalRegions = len(derpMap.Regions) + + for regionID, region := range derpMap.Regions { + debugRegion := &DebugDERPRegion{ + RegionID: regionID, + RegionName: region.RegionName, + Nodes: make([]*DebugDERPNode, 0, len(region.Nodes)), + } + + for _, node := range region.Nodes { + debugNode := &DebugDERPNode{ + Name: node.Name, + HostName: node.HostName, + DERPPort: node.DERPPort, + STUNPort: node.STUNPort, + } + debugRegion.Nodes = append(debugRegion.Nodes, debugNode) + } + + info.Regions[regionID] = debugRegion + } + + return info +} + +// DebugNodeStoreJSON returns the actual nodes map from the current NodeStore snapshot. +func (s *State) DebugNodeStoreJSON() map[types.NodeID]types.Node { + snapshot := s.nodeStore.data.Load() + return snapshot.nodesByID +} + +// DebugPolicyManagerJSON returns structured debug information about the policy manager. +func (s *State) DebugPolicyManagerJSON() DebugStringInfo { + return DebugStringInfo{ + Content: s.polMan.DebugString(), + } +} diff --git a/hscontrol/state/debug_test.go b/hscontrol/state/debug_test.go new file mode 100644 index 00000000..6fd528a8 --- /dev/null +++ b/hscontrol/state/debug_test.go @@ -0,0 +1,78 @@ +package state + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestNodeStoreDebugString(t *testing.T) { + tests := []struct { + name string + setupFn func() *NodeStore + contains []string + }{ + { + name: "empty nodestore", + setupFn: func() *NodeStore { + return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + contains: []string{ + "=== NodeStore Debug Information ===", + "Total Nodes: 0", + "Users with Nodes: 0", + "NodeKey Index: 0 entries", + }, + }, + { + name: "nodestore with data", + setupFn: func() *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 2, "user2", "node2") + + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + + _ = store.PutNode(node1) + _ = store.PutNode(node2) + + return store + }, + contains: []string{ + "Total Nodes: 2", + "Users with Nodes: 2", + "Peer Relationships:", + "NodeKey Index: 2 entries", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := tt.setupFn() + if store.writeQueue != nil { + defer store.Stop() + } + + debugStr := store.DebugString() + + for _, expected := range tt.contains { + assert.Contains(t, debugStr, expected, + "Debug string should contain: %s\nActual debug:\n%s", expected, debugStr) + } + }) + } +} + +func TestDebugRegistrationCache(t *testing.T) { + // Create a minimal NodeStore for testing debug methods + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + + debugStr := store.DebugString() + + // Should contain basic debug information + assert.Contains(t, debugStr, "=== NodeStore Debug Information ===") + assert.Contains(t, debugStr, "Total Nodes: 0") + assert.Contains(t, debugStr, "Users with Nodes: 0") + assert.Contains(t, debugStr, "NodeKey Index: 0 entries") +} diff --git a/hscontrol/state/endpoint_test.go b/hscontrol/state/endpoint_test.go new file mode 100644 index 00000000..b8905ab7 --- /dev/null +++ b/hscontrol/state/endpoint_test.go @@ -0,0 +1,113 @@ +package state + +import ( + "net/netip" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +// TestEndpointStorageInNodeStore verifies that endpoints sent in MapRequest via ApplyPeerChange +// are correctly stored in the NodeStore and can be retrieved for sending to peers. +// This test reproduces the issue reported in https://github.com/juanfont/headscale/issues/2846 +func TestEndpointStorageInNodeStore(t *testing.T) { + // Create two test nodes + node1 := createTestNode(1, 1, "test-user", "node1") + node2 := createTestNode(2, 1, "test-user", "node2") + + // Create NodeStore with allow-all peers function + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + + store.Start() + defer store.Stop() + + // Add both nodes to NodeStore + store.PutNode(node1) + store.PutNode(node2) + + // Create a MapRequest with endpoints for node1 + endpoints := []netip.AddrPort{ + netip.MustParseAddrPort("192.168.1.1:41641"), + netip.MustParseAddrPort("10.0.0.1:41641"), + } + + mapReq := tailcfg.MapRequest{ + NodeKey: node1.NodeKey, + DiscoKey: node1.DiscoKey, + Endpoints: endpoints, + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "node1", + }, + } + + // Simulate what UpdateNodeFromMapRequest does: create PeerChange and apply it + peerChange := node1.PeerChangeFromMapRequest(mapReq) + + // Verify PeerChange has endpoints + require.NotNil(t, peerChange.Endpoints, "PeerChange should contain endpoints") + assert.Len(t, peerChange.Endpoints, len(endpoints), + "PeerChange should have same number of endpoints as MapRequest") + + // Apply the PeerChange via NodeStore.UpdateNode + updatedNode, ok := store.UpdateNode(node1.ID, func(n *types.Node) { + n.ApplyPeerChange(&peerChange) + }) + require.True(t, ok, "UpdateNode should succeed") + require.True(t, updatedNode.Valid(), "Updated node should be valid") + + // Verify endpoints are in the updated node view + storedEndpoints := updatedNode.Endpoints().AsSlice() + assert.Len(t, storedEndpoints, len(endpoints), + "NodeStore should have same number of endpoints as sent") + + if len(storedEndpoints) == len(endpoints) { + for i, ep := range endpoints { + assert.Equal(t, ep, storedEndpoints[i], + "Endpoint %d should match", i) + } + } + + // Verify we can retrieve the node again and endpoints are still there + retrievedNode, found := store.GetNode(node1.ID) + require.True(t, found, "node1 should exist in NodeStore") + + retrievedEndpoints := retrievedNode.Endpoints().AsSlice() + assert.Len(t, retrievedEndpoints, len(endpoints), + "Retrieved node should have same number of endpoints") + + // Verify that when we get node1 as a peer of node2, it has endpoints + // This is the critical part that was failing in the bug report + peers := store.ListPeers(node2.ID) + require.Positive(t, peers.Len(), "node2 should have at least one peer") + + // Find node1 in the peer list + var node1Peer types.NodeView + + foundPeer := false + + for _, peer := range peers.All() { + if peer.ID() == node1.ID { + node1Peer = peer + foundPeer = true + + break + } + } + + require.True(t, foundPeer, "node1 should be in node2's peer list") + + // Check that node1's endpoints are available in the peer view + peerEndpoints := node1Peer.Endpoints().AsSlice() + assert.Len(t, peerEndpoints, len(endpoints), + "Peer view should have same number of endpoints as sent") + + if len(peerEndpoints) == len(endpoints) { + for i, ep := range endpoints { + assert.Equal(t, ep, peerEndpoints[i], + "Peer endpoint %d should match", i) + } + } +} diff --git a/hscontrol/state/ephemeral_test.go b/hscontrol/state/ephemeral_test.go new file mode 100644 index 00000000..632af13c --- /dev/null +++ b/hscontrol/state/ephemeral_test.go @@ -0,0 +1,449 @@ +package state + +import ( + "net/netip" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/ptr" +) + +// TestEphemeralNodeDeleteWithConcurrentUpdate tests the race condition where UpdateNode and DeleteNode +// are called concurrently and may be batched together. This reproduces the issue where ephemeral nodes +// are not properly deleted during logout because UpdateNodeFromMapRequest returns a stale node view +// after the node has been deleted from the NodeStore. +func TestEphemeralNodeDeleteWithConcurrentUpdate(t *testing.T) { + // Create a simple test node + node := createTestNode(1, 1, "test-user", "test-node") + + // Create NodeStore + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put the node in the store + resultNode := store.PutNode(node) + require.True(t, resultNode.Valid(), "initial PutNode should return valid node") + + // Verify node exists + retrievedNode, found := store.GetNode(node.ID) + require.True(t, found) + require.Equal(t, node.ID, retrievedNode.ID()) + + // Test scenario: UpdateNode is called, returns a node view from the batch, + // but in the same batch a DeleteNode removes the node. + // This simulates what happens when: + // 1. UpdateNodeFromMapRequest calls UpdateNode and gets back updatedNode + // 2. At the same time, handleLogout calls DeleteNode + // 3. They get batched together: [UPDATE, DELETE] + // 4. UPDATE modifies the node, DELETE removes it + // 5. UpdateNode returns a node view based on the state AFTER both operations + // 6. If DELETE came after UPDATE, the returned node should be invalid + + done := make(chan bool, 2) + var updatedNode types.NodeView + var updateOk bool + + // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { + updatedNode, updateOk = store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + done <- true + }() + + // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) + go func() { + store.DeleteNode(node.ID) + done <- true + }() + + // Wait for both operations + <-done + <-done + + // Verify node is eventually deleted + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found = store.GetNode(node.ID) + assert.False(c, found, "node should be deleted from NodeStore") + }, 1*time.Second, 10*time.Millisecond, "waiting for node to be deleted") + + // If the update happened before delete in the batch, the returned node might be invalid + if updateOk { + t.Logf("UpdateNode returned ok=true, valid=%v", updatedNode.Valid()) + // This is the bug scenario - UpdateNode thinks it succeeded but node is gone + if updatedNode.Valid() { + t.Logf("WARNING: UpdateNode returned valid node but node was deleted - this indicates the race condition bug") + } + } else { + t.Logf("UpdateNode correctly returned ok=false (node deleted in same batch)") + } +} + +// TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch specifically tests that when +// UpdateNode and DeleteNode are in the same batch with DELETE after UPDATE, +// the UpdateNode should return an invalid node view. +func TestUpdateNodeReturnsInvalidWhenDeletedInSameBatch(t *testing.T) { + node := createTestNode(2, 1, "test-user", "test-node-2") + + // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together + store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(node) + + // Queue UpdateNode and DeleteNode - with batch size of 2, they will batch together + resultChan := make(chan struct { + node types.NodeView + ok bool + }) + + // Start UpdateNode in goroutine - it will queue and wait for batch + go func() { + node, ok := store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + resultChan <- struct { + node types.NodeView + ok bool + }{node, ok} + }() + + // Start DeleteNode in goroutine - it will queue and trigger batch processing + // Since batch size is 2, both operations will be processed together + go func() { + store.DeleteNode(node.ID) + }() + + // Get the result from UpdateNode + result := <-resultChan + + // Node should be deleted + _, found := store.GetNode(node.ID) + assert.False(t, found, "node should be deleted") + + // The critical check: what did UpdateNode return? + // After the commit c6b09289988f34398eb3157e31ba092eb8721a9f, + // UpdateNode returns the node state from the batch. + // If DELETE came after UPDATE in the batch, the node doesn't exist anymore, + // so UpdateNode should return (invalid, false) + t.Logf("UpdateNode returned: ok=%v, valid=%v", result.ok, result.node.Valid()) + + // This is the expected behavior - if node was deleted in same batch, + // UpdateNode should return invalid node + if result.ok && result.node.Valid() { + t.Error("BUG: UpdateNode returned valid node even though it was deleted in same batch") + } +} + +// TestPersistNodeToDBPreventsRaceCondition tests that persistNodeToDB correctly handles +// the race condition where a node is deleted after UpdateNode returns but before +// persistNodeToDB is called. This reproduces the ephemeral node deletion bug. +func TestPersistNodeToDBPreventsRaceCondition(t *testing.T) { + node := createTestNode(3, 1, "test-user", "test-node-3") + + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(node) + + // Simulate UpdateNode being called + updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + require.True(t, ok, "UpdateNode should succeed") + require.True(t, updatedNode.Valid(), "UpdateNode should return valid node") + + // Now delete the node (simulating ephemeral logout happening concurrently) + store.DeleteNode(node.ID) + + // Verify node is eventually deleted + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := store.GetNode(node.ID) + assert.False(c, found, "node should be deleted") + }, 1*time.Second, 10*time.Millisecond, "waiting for node to be deleted") + + // Now try to use the updatedNode from before the deletion + // In the old code, this would re-insert the node into the database + // With our fix, GetNode check in persistNodeToDB should prevent this + + // Simulate what persistNodeToDB does - check if node still exists + _, exists := store.GetNode(updatedNode.ID()) + if !exists { + t.Log("SUCCESS: persistNodeToDB check would prevent re-insertion of deleted node") + } else { + t.Error("BUG: Node still exists in NodeStore after deletion") + } + + // The key assertion: after deletion, attempting to persist the old updatedNode + // should fail because the node no longer exists in NodeStore + assert.False(t, exists, "persistNodeToDB should detect node was deleted and refuse to persist") +} + +// TestEphemeralNodeLogoutRaceCondition tests the specific race condition that occurs +// when an ephemeral node logs out. This reproduces the bug where: +// 1. UpdateNodeFromMapRequest calls UpdateNode and receives a node view +// 2. Concurrently, handleLogout is called for the ephemeral node and calls DeleteNode +// 3. UpdateNode and DeleteNode get batched together +// 4. If UpdateNode's result is used to call persistNodeToDB after the deletion, +// the node could be re-inserted into the database even though it was deleted +func TestEphemeralNodeLogoutRaceCondition(t *testing.T) { + ephemeralNode := createTestNode(4, 1, "test-user", "ephemeral-node") + ephemeralNode.AuthKey = &types.PreAuthKey{ + ID: 1, + Key: "test-key", + Ephemeral: true, + } + + // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together + store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put ephemeral node in store + _ = store.PutNode(ephemeralNode) + + // Simulate concurrent operations: + // 1. UpdateNode (from UpdateNodeFromMapRequest during polling) + // 2. DeleteNode (from handleLogout when client sends logout request) + + var updatedNode types.NodeView + var updateOk bool + done := make(chan bool, 2) + + // Goroutine 1: UpdateNode (simulates UpdateNodeFromMapRequest) + go func() { + updatedNode, updateOk = store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + done <- true + }() + + // Goroutine 2: DeleteNode (simulates handleLogout for ephemeral node) + go func() { + store.DeleteNode(ephemeralNode.ID) + done <- true + }() + + // Wait for both operations + <-done + <-done + + // Verify node is eventually deleted + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, found := store.GetNode(ephemeralNode.ID) + assert.False(c, found, "ephemeral node should be deleted from NodeStore") + }, 1*time.Second, 10*time.Millisecond, "waiting for ephemeral node to be deleted") + + // Critical assertion: if UpdateNode returned before DeleteNode completed, + // the updatedNode might be valid but the node is actually deleted. + // This is the bug - UpdateNodeFromMapRequest would get a valid node, + // then try to persist it, re-inserting the deleted ephemeral node. + if updateOk && updatedNode.Valid() { + t.Log("UpdateNode returned valid node, but node is deleted - this is the race condition") + + // In the real code, this would cause persistNodeToDB to be called with updatedNode + // The fix in persistNodeToDB checks if the node still exists: + _, stillExists := store.GetNode(updatedNode.ID()) + assert.False(t, stillExists, "persistNodeToDB should check NodeStore and find node deleted") + } else if !updateOk || !updatedNode.Valid() { + t.Log("UpdateNode correctly returned invalid/not-ok result (delete happened in same batch)") + } +} + +// TestUpdateNodeFromMapRequestEphemeralLogoutSequence tests the exact sequence +// that causes ephemeral node logout failures: +// 1. Client sends MapRequest with updated endpoint info +// 2. UpdateNodeFromMapRequest starts processing, calls UpdateNode +// 3. Client sends logout request (past expiry) +// 4. handleLogout calls DeleteNode for ephemeral node +// 5. UpdateNode and DeleteNode batch together +// 6. UpdateNode returns a valid node (from before delete in batch) +// 7. persistNodeToDB is called with the stale valid node +// 8. Node gets re-inserted into database instead of staying deleted +func TestUpdateNodeFromMapRequestEphemeralLogoutSequence(t *testing.T) { + ephemeralNode := createTestNode(5, 1, "test-user", "ephemeral-node-5") + ephemeralNode.AuthKey = &types.PreAuthKey{ + ID: 2, + Key: "test-key-2", + Ephemeral: true, + } + + // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together + // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together + store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put ephemeral node in store + _ = store.PutNode(ephemeralNode) + + // Step 1: UpdateNodeFromMapRequest calls UpdateNode + // (simulating client sending MapRequest with endpoint updates) + updateResult := make(chan struct { + node types.NodeView + ok bool + }) + + go func() { + node, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + endpoint := netip.MustParseAddrPort("10.0.0.1:41641") + n.Endpoints = []netip.AddrPort{endpoint} + }) + updateResult <- struct { + node types.NodeView + ok bool + }{node, ok} + }() + + // Step 2: Logout happens - handleLogout calls DeleteNode + // With batch size of 2, this will trigger batch processing with UpdateNode + go func() { + store.DeleteNode(ephemeralNode.ID) + }() + + // Step 3: Wait and verify node is eventually deleted + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, nodeExists := store.GetNode(ephemeralNode.ID) + assert.False(c, nodeExists, "ephemeral node must be deleted after logout") + }, 1*time.Second, 10*time.Millisecond, "waiting for ephemeral node to be deleted") + + // Step 4: Get the update result + result := <-updateResult + + // Simulate what happens if we try to persist the updatedNode + if result.ok && result.node.Valid() { + // This is the problematic path - UpdateNode returned a valid node + // but the node was deleted in the same batch + t.Log("UpdateNode returned valid node even though node was deleted") + + // The fix: persistNodeToDB must check NodeStore before persisting + _, checkExists := store.GetNode(result.node.ID()) + if checkExists { + t.Error("BUG: Node still exists in NodeStore after deletion - should be impossible") + } else { + t.Log("SUCCESS: persistNodeToDB would detect node is deleted and refuse to persist") + } + } else { + t.Log("UpdateNode correctly indicated node was deleted (returned invalid or not-ok)") + } + + // Final assertion: node must not exist + _, finalExists := store.GetNode(ephemeralNode.ID) + assert.False(t, finalExists, "ephemeral node must remain deleted") +} + +// TestUpdateNodeDeletedInSameBatchReturnsInvalid specifically tests that when +// UpdateNode and DeleteNode are batched together with DELETE after UPDATE, +// UpdateNode returns ok=false to indicate the node was deleted. +func TestUpdateNodeDeletedInSameBatchReturnsInvalid(t *testing.T) { + node := createTestNode(6, 1, "test-user", "test-node-6") + + // Use batch size of 2 to guarantee UpdateNode and DeleteNode batch together + store := NewNodeStore(nil, allowAllPeersFunc, 2, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put node in store + _ = store.PutNode(node) + + // Queue UpdateNode and DeleteNode - with batch size of 2, they will batch together + updateDone := make(chan struct { + node types.NodeView + ok bool + }) + + go func() { + updatedNode, ok := store.UpdateNode(node.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + updateDone <- struct { + node types.NodeView + ok bool + }{updatedNode, ok} + }() + + // Queue DeleteNode - with batch size of 2, this triggers batch processing + go func() { + store.DeleteNode(node.ID) + }() + + // Get UpdateNode result + result := <-updateDone + + // Node should be deleted + _, exists := store.GetNode(node.ID) + assert.False(t, exists, "node should be deleted from store") + + // UpdateNode should indicate the node was deleted + // After c6b09289988f34398eb3157e31ba092eb8721a9f, when UPDATE and DELETE + // are in the same batch with DELETE after UPDATE, UpdateNode returns + // the state after the batch is applied - which means the node doesn't exist + assert.False(t, result.ok, "UpdateNode should return ok=false when node deleted in same batch") + assert.False(t, result.node.Valid(), "UpdateNode should return invalid node when node deleted in same batch") +} + +// TestPersistNodeToDBChecksNodeStoreBeforePersist verifies that persistNodeToDB +// checks if the node still exists in NodeStore before persisting to database. +// This prevents the race condition where: +// 1. UpdateNodeFromMapRequest calls UpdateNode and gets a valid node +// 2. Ephemeral node logout calls DeleteNode +// 3. UpdateNode and DeleteNode batch together +// 4. UpdateNode returns a valid node (from before delete in batch) +// 5. UpdateNodeFromMapRequest calls persistNodeToDB with the stale node +// 6. persistNodeToDB must detect the node is deleted and refuse to persist +func TestPersistNodeToDBChecksNodeStoreBeforePersist(t *testing.T) { + ephemeralNode := createTestNode(7, 1, "test-user", "ephemeral-node-7") + ephemeralNode.AuthKey = &types.PreAuthKey{ + ID: 3, + Key: "test-key-3", + Ephemeral: true, + } + + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Put node + _ = store.PutNode(ephemeralNode) + + // UpdateNode returns a node + updatedNode, ok := store.UpdateNode(ephemeralNode.ID, func(n *types.Node) { + n.LastSeen = ptr.To(time.Now()) + }) + require.True(t, ok, "UpdateNode should succeed") + require.True(t, updatedNode.Valid(), "updated node should be valid") + + // Delete the node + store.DeleteNode(ephemeralNode.ID) + + // Verify node is eventually deleted + require.EventuallyWithT(t, func(c *assert.CollectT) { + _, exists := store.GetNode(ephemeralNode.ID) + assert.False(c, exists, "node should be deleted from NodeStore") + }, 1*time.Second, 10*time.Millisecond, "waiting for node to be deleted") + + // 4. Simulate what persistNodeToDB does - check if node still exists + // The fix in persistNodeToDB checks NodeStore before persisting: + // if !exists { return error } + // This prevents re-inserting the deleted node into the database + + // Verify the node from UpdateNode is valid but node is gone from store + assert.True(t, updatedNode.Valid(), "UpdateNode returned a valid node view") + _, stillExists := store.GetNode(updatedNode.ID()) + assert.False(t, stillExists, "but node should be deleted from NodeStore") + + // This is the critical test: persistNodeToDB must check NodeStore + // and refuse to persist if the node doesn't exist anymore + // The actual persistNodeToDB implementation does: + // _, exists := s.nodeStore.GetNode(node.ID()) + // if !exists { return error } +} diff --git a/hscontrol/state/maprequest.go b/hscontrol/state/maprequest.go new file mode 100644 index 00000000..e7dfc11c --- /dev/null +++ b/hscontrol/state/maprequest.go @@ -0,0 +1,50 @@ +// Package state provides pure functions for processing MapRequest data. +// These functions are extracted from UpdateNodeFromMapRequest to improve +// testability and maintainability. + +package state + +import ( + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" + "tailscale.com/tailcfg" +) + +// netInfoFromMapRequest determines the correct NetInfo to use. +// Returns the NetInfo that should be used for this request. +func netInfoFromMapRequest( + nodeID types.NodeID, + currentHostinfo *tailcfg.Hostinfo, + reqHostinfo *tailcfg.Hostinfo, +) *tailcfg.NetInfo { + // If request has NetInfo, use it + if reqHostinfo != nil && reqHostinfo.NetInfo != nil { + return reqHostinfo.NetInfo + } + + // Otherwise, use current NetInfo if available + if currentHostinfo != nil && currentHostinfo.NetInfo != nil { + log.Debug(). + Caller(). + Uint64("node.id", nodeID.Uint64()). + Int("preferredDERP", currentHostinfo.NetInfo.PreferredDERP). + Msg("using NetInfo from previous Hostinfo in MapRequest") + return currentHostinfo.NetInfo + } + + // No NetInfo available anywhere - log for debugging + var hostname string + if reqHostinfo != nil { + hostname = reqHostinfo.Hostname + } else if currentHostinfo != nil { + hostname = currentHostinfo.Hostname + } + + log.Debug(). + Caller(). + Uint64("node.id", nodeID.Uint64()). + Str("node.hostname", hostname). + Msg("node sent update but has no NetInfo in request or database") + + return nil +} diff --git a/hscontrol/state/maprequest_test.go b/hscontrol/state/maprequest_test.go new file mode 100644 index 00000000..99f781d4 --- /dev/null +++ b/hscontrol/state/maprequest_test.go @@ -0,0 +1,161 @@ +package state + +import ( + "net/netip" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +func TestNetInfoFromMapRequest(t *testing.T) { + nodeID := types.NodeID(1) + + tests := []struct { + name string + currentHostinfo *tailcfg.Hostinfo + reqHostinfo *tailcfg.Hostinfo + expectNetInfo *tailcfg.NetInfo + }{ + { + name: "no current NetInfo - return nil", + currentHostinfo: nil, + reqHostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + expectNetInfo: nil, + }, + { + name: "current has NetInfo, request has NetInfo - use request", + currentHostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{PreferredDERP: 1}, + }, + reqHostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + NetInfo: &tailcfg.NetInfo{PreferredDERP: 2}, + }, + expectNetInfo: &tailcfg.NetInfo{PreferredDERP: 2}, + }, + { + name: "current has NetInfo, request has no NetInfo - use current", + currentHostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{PreferredDERP: 3}, + }, + reqHostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + expectNetInfo: &tailcfg.NetInfo{PreferredDERP: 3}, + }, + { + name: "current has NetInfo, no request Hostinfo - use current", + currentHostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{PreferredDERP: 4}, + }, + reqHostinfo: nil, + expectNetInfo: &tailcfg.NetInfo{PreferredDERP: 4}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := netInfoFromMapRequest(nodeID, tt.currentHostinfo, tt.reqHostinfo) + + if tt.expectNetInfo == nil { + assert.Nil(t, result, "expected nil NetInfo") + } else { + require.NotNil(t, result, "expected non-nil NetInfo") + assert.Equal(t, tt.expectNetInfo.PreferredDERP, result.PreferredDERP, "DERP mismatch") + } + }) + } +} + +func TestNetInfoPreservationInRegistrationFlow(t *testing.T) { + nodeID := types.NodeID(1) + + // This test reproduces the bug in registration flows where NetInfo was lost + // because we used the wrong hostinfo reference when calling NetInfoFromMapRequest + t.Run("registration_flow_bug_reproduction", func(t *testing.T) { + // Simulate existing node with NetInfo (before re-registration) + existingNodeHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + NetInfo: &tailcfg.NetInfo{PreferredDERP: 5}, + } + + // Simulate new registration request (no NetInfo) + newRegistrationHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + // NetInfo is nil - this is what comes from the registration request + } + + // Simulate what was happening in the bug: we passed the "current node being modified" + // hostinfo (which has no NetInfo) instead of the existing node's hostinfo + nodeBeingModifiedHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + // NetInfo is nil because this node is being modified/reset + } + + // BUG: Using the node being modified (no NetInfo) instead of existing node (has NetInfo) + buggyResult := netInfoFromMapRequest(nodeID, nodeBeingModifiedHostinfo, newRegistrationHostinfo) + assert.Nil(t, buggyResult, "Bug: Should return nil when using wrong hostinfo reference") + + // CORRECT: Using the existing node's hostinfo (has NetInfo) + correctResult := netInfoFromMapRequest(nodeID, existingNodeHostinfo, newRegistrationHostinfo) + assert.NotNil(t, correctResult, "Fix: Should preserve NetInfo when using correct hostinfo reference") + assert.Equal(t, 5, correctResult.PreferredDERP, "Should preserve the DERP region from existing node") + }) + + t.Run("new_node_creation_for_different_user_should_preserve_netinfo", func(t *testing.T) { + // This test covers the scenario where: + // 1. A node exists for user1 with NetInfo + // 2. The same machine logs in as user2 (different user) + // 3. A NEW node is created for user2 (pre-auth key flow) + // 4. The new node should preserve NetInfo from the old node + + // Existing node for user1 with NetInfo + existingNodeUser1Hostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + NetInfo: &tailcfg.NetInfo{PreferredDERP: 7}, + } + + // New registration request for user2 (no NetInfo yet) + newNodeUser2Hostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + // NetInfo is nil - registration request doesn't include it + } + + // When creating a new node for user2, we should preserve NetInfo from user1's node + result := netInfoFromMapRequest(types.NodeID(2), existingNodeUser1Hostinfo, newNodeUser2Hostinfo) + assert.NotNil(t, result, "New node for user2 should preserve NetInfo from user1's node") + assert.Equal(t, 7, result.PreferredDERP, "Should preserve DERP region from existing node") + }) +} + +// Simple helper function for tests +func createTestNodeSimple(id types.NodeID) *types.Node { + user := types.User{ + Name: "test-user", + } + + machineKey := key.NewMachine() + nodeKey := key.NewNode() + + node := &types.Node{ + ID: id, + Hostname: "test-node", + UserID: ptr.To(uint(id)), + User: &user, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + IPv4: &netip.Addr{}, + IPv6: &netip.Addr{}, + } + + return node +} diff --git a/hscontrol/state/node_store.go b/hscontrol/state/node_store.go new file mode 100644 index 00000000..6327b46b --- /dev/null +++ b/hscontrol/state/node_store.go @@ -0,0 +1,605 @@ +package state + +import ( + "fmt" + "maps" + "strings" + "sync/atomic" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +const ( + put = 1 + del = 2 + update = 3 + rebuildPeerMaps = 4 +) + +const prometheusNamespace = "headscale" + +var ( + nodeStoreOperations = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operations_total", + Help: "Total number of NodeStore operations", + }, []string{"operation"}) + nodeStoreOperationDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_operation_duration_seconds", + Help: "Duration of NodeStore operations", + Buckets: prometheus.DefBuckets, + }, []string{"operation"}) + nodeStoreBatchSize = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_size", + Help: "Size of NodeStore write batches", + Buckets: []float64{1, 2, 5, 10, 20, 50, 100}, + }) + nodeStoreBatchDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_batch_duration_seconds", + Help: "Duration of NodeStore batch processing", + Buckets: prometheus.DefBuckets, + }) + nodeStoreSnapshotBuildDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_snapshot_build_duration_seconds", + Help: "Duration of NodeStore snapshot building from nodes", + Buckets: prometheus.DefBuckets, + }) + nodeStoreNodesCount = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_nodes_total", + Help: "Total number of nodes in the NodeStore", + }) + nodeStorePeersCalculationDuration = promauto.NewHistogram(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_peers_calculation_duration_seconds", + Help: "Duration of peers calculation in NodeStore", + Buckets: prometheus.DefBuckets, + }) + nodeStoreQueueDepth = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "nodestore_queue_depth", + Help: "Current depth of NodeStore write queue", + }) +) + +// NodeStore is a thread-safe store for nodes. +// It is a copy-on-write structure, replacing the "snapshot" +// when a change to the structure occurs. It is optimised for reads, +// and while batches are not fast, they are grouped together +// to do less of the expensive peer calculation if there are many +// changes rapidly. +// +// Writes will block until committed, while reads are never +// blocked. This means that the caller of a write operation +// is responsible for ensuring an update depending on a write +// is not issued before the write is complete. +type NodeStore struct { + data atomic.Pointer[Snapshot] + + peersFunc PeersFunc + writeQueue chan work + + batchSize int + batchTimeout time.Duration +} + +func NewNodeStore(allNodes types.Nodes, peersFunc PeersFunc, batchSize int, batchTimeout time.Duration) *NodeStore { + nodes := make(map[types.NodeID]types.Node, len(allNodes)) + for _, n := range allNodes { + nodes[n.ID] = *n + } + snap := snapshotFromNodes(nodes, peersFunc) + + store := &NodeStore{ + peersFunc: peersFunc, + batchSize: batchSize, + batchTimeout: batchTimeout, + } + store.data.Store(&snap) + + // Initialize node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + + return store +} + +// Snapshot is the representation of the current state of the NodeStore. +// It contains all nodes and their relationships. +// It is a copy-on-write structure, meaning that when a write occurs, +// a new Snapshot is created with the updated state, +// and replaces the old one atomically. +type Snapshot struct { + // nodesByID is the main source of truth for nodes. + nodesByID map[types.NodeID]types.Node + + // calculated from nodesByID + nodesByNodeKey map[key.NodePublic]types.NodeView + nodesByMachineKey map[key.MachinePublic]map[types.UserID]types.NodeView + peersByNode map[types.NodeID][]types.NodeView + nodesByUser map[types.UserID][]types.NodeView + allNodes []types.NodeView +} + +// PeersFunc is a function that takes a list of nodes and returns a map +// with the relationships between nodes and their peers. +// This will typically be used to calculate which nodes can see each other +// based on the current policy. +type PeersFunc func(nodes []types.NodeView) map[types.NodeID][]types.NodeView + +// work represents a single operation to be performed on the NodeStore. +type work struct { + op int + nodeID types.NodeID + node types.Node + updateFn UpdateNodeFunc + result chan struct{} + nodeResult chan types.NodeView // Channel to return the resulting node after batch application + // For rebuildPeerMaps operation + rebuildResult chan struct{} +} + +// PutNode adds or updates a node in the store. +// If the node already exists, it will be replaced. +// If the node does not exist, it will be added. +// This is a blocking operation that waits for the write to complete. +// Returns the resulting node after all modifications in the batch have been applied. +func (s *NodeStore) PutNode(n types.Node) types.NodeView { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("put")) + defer timer.ObserveDuration() + + work := work{ + op: put, + nodeID: n.ID, + node: n, + result: make(chan struct{}), + nodeResult: make(chan types.NodeView, 1), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("put").Inc() + + return resultNode +} + +// UpdateNodeFunc is a function type that takes a pointer to a Node and modifies it. +type UpdateNodeFunc func(n *types.Node) + +// UpdateNode applies a function to modify a specific node in the store. +// This is a blocking operation that waits for the write to complete. +// This is analogous to a database "transaction", or, the caller should +// rather collect all data they want to change, and then call this function. +// Fewer calls are better. +// Returns the resulting node after all modifications in the batch have been applied. +// +// TODO(kradalby): Technically we could have a version of this that modifies the node +// in the current snapshot if _we know_ that the change will not affect the peer relationships. +// This is because the main nodesByID map contains the struct, and every other map is using a +// pointer to the underlying struct. The gotcha with this is that we will need to introduce +// a lock around the nodesByID map to ensure that no other writes are happening +// while we are modifying the node. Which mean we would need to implement read-write locks +// on all read operations. +func (s *NodeStore) UpdateNode(nodeID types.NodeID, updateFn func(n *types.Node)) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("update")) + defer timer.ObserveDuration() + + work := work{ + op: update, + nodeID: nodeID, + updateFn: updateFn, + result: make(chan struct{}), + nodeResult: make(chan types.NodeView, 1), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + resultNode := <-work.nodeResult + nodeStoreOperations.WithLabelValues("update").Inc() + + // Return the node and whether it exists (is valid) + return resultNode, resultNode.Valid() +} + +// DeleteNode removes a node from the store by its ID. +// This is a blocking operation that waits for the write to complete. +func (s *NodeStore) DeleteNode(id types.NodeID) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("delete")) + defer timer.ObserveDuration() + + work := work{ + op: del, + nodeID: id, + result: make(chan struct{}), + } + + nodeStoreQueueDepth.Inc() + s.writeQueue <- work + <-work.result + nodeStoreQueueDepth.Dec() + + nodeStoreOperations.WithLabelValues("delete").Inc() +} + +// Start initializes the NodeStore and starts processing the write queue. +func (s *NodeStore) Start() { + s.writeQueue = make(chan work) + go s.processWrite() +} + +// Stop stops the NodeStore. +func (s *NodeStore) Stop() { + close(s.writeQueue) +} + +// processWrite processes the write queue in batches. +func (s *NodeStore) processWrite() { + c := time.NewTicker(s.batchTimeout) + defer c.Stop() + + batch := make([]work, 0, s.batchSize) + + for { + select { + case w, ok := <-s.writeQueue: + if !ok { + // Channel closed, apply any remaining batch and exit + if len(batch) != 0 { + s.applyBatch(batch) + } + return + } + batch = append(batch, w) + if len(batch) >= s.batchSize { + s.applyBatch(batch) + batch = batch[:0] + + c.Reset(s.batchTimeout) + } + case <-c.C: + if len(batch) != 0 { + s.applyBatch(batch) + batch = batch[:0] + } + + c.Reset(s.batchTimeout) + } + } +} + +// applyBatch applies a batch of work to the node store. +// This means that it takes a copy of the current nodes, +// then applies the batch of operations to that copy, +// runs any precomputation needed (like calculating peers), +// and finally replaces the snapshot in the store with the new one. +// The replacement of the snapshot is atomic, ensuring that reads +// are never blocked by writes. +// Each write item is blocked until the batch is applied to ensure +// the caller knows the operation is complete and do not send any +// updates that are dependent on a read that is yet to be written. +func (s *NodeStore) applyBatch(batch []work) { + timer := prometheus.NewTimer(nodeStoreBatchDuration) + defer timer.ObserveDuration() + + nodeStoreBatchSize.Observe(float64(len(batch))) + + nodes := make(map[types.NodeID]types.Node) + maps.Copy(nodes, s.data.Load().nodesByID) + + // Track which work items need node results + nodeResultRequests := make(map[types.NodeID][]*work) + + // Track rebuildPeerMaps operations + var rebuildOps []*work + + for i := range batch { + w := &batch[i] + switch w.op { + case put: + nodes[w.nodeID] = w.node + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } + case update: + // Update the specific node identified by nodeID + if n, exists := nodes[w.nodeID]; exists { + w.updateFn(&n) + nodes[w.nodeID] = n + } + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } + case del: + delete(nodes, w.nodeID) + // For delete operations, send an invalid NodeView if requested + if w.nodeResult != nil { + nodeResultRequests[w.nodeID] = append(nodeResultRequests[w.nodeID], w) + } + case rebuildPeerMaps: + // rebuildPeerMaps doesn't modify nodes, it just forces the snapshot rebuild + // below to recalculate peer relationships using the current peersFunc + rebuildOps = append(rebuildOps, w) + } + } + + newSnap := snapshotFromNodes(nodes, s.peersFunc) + s.data.Store(&newSnap) + + // Update node count gauge + nodeStoreNodesCount.Set(float64(len(nodes))) + + // Send the resulting nodes to all work items that requested them + for nodeID, workItems := range nodeResultRequests { + if node, exists := nodes[nodeID]; exists { + nodeView := node.View() + for _, w := range workItems { + w.nodeResult <- nodeView + close(w.nodeResult) + } + } else { + // Node was deleted or doesn't exist + for _, w := range workItems { + w.nodeResult <- types.NodeView{} // Send invalid view + close(w.nodeResult) + } + } + } + + // Signal completion for rebuildPeerMaps operations + for _, w := range rebuildOps { + close(w.rebuildResult) + } + + // Signal completion for all other work items + for _, w := range batch { + if w.op != rebuildPeerMaps { + close(w.result) + } + } +} + +// snapshotFromNodes creates a new Snapshot from the provided nodes. +// It builds a lot of "indexes" to make lookups fast for datasets we +// that is used frequently, like nodesByNodeKey, peersByNode, and nodesByUser. +// This is not a fast operation, it is the "slow" part of our copy-on-write +// structure, but it allows us to have fast reads and efficient lookups. +func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) Snapshot { + timer := prometheus.NewTimer(nodeStoreSnapshotBuildDuration) + defer timer.ObserveDuration() + + allNodes := make([]types.NodeView, 0, len(nodes)) + for _, n := range nodes { + allNodes = append(allNodes, n.View()) + } + + newSnap := Snapshot{ + nodesByID: nodes, + allNodes: allNodes, + nodesByNodeKey: make(map[key.NodePublic]types.NodeView), + nodesByMachineKey: make(map[key.MachinePublic]map[types.UserID]types.NodeView), + + // peersByNode is most likely the most expensive operation, + // it will use the list of all nodes, combined with the + // current policy to precalculate which nodes are peers and + // can see each other. + peersByNode: func() map[types.NodeID][]types.NodeView { + peersTimer := prometheus.NewTimer(nodeStorePeersCalculationDuration) + defer peersTimer.ObserveDuration() + return peersFunc(allNodes) + }(), + nodesByUser: make(map[types.UserID][]types.NodeView), + } + + // Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps + for _, n := range nodes { + nodeView := n.View() + userID := n.TypedUserID() + + newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView) + newSnap.nodesByNodeKey[n.NodeKey] = nodeView + + // Build machine key index + if newSnap.nodesByMachineKey[n.MachineKey] == nil { + newSnap.nodesByMachineKey[n.MachineKey] = make(map[types.UserID]types.NodeView) + } + newSnap.nodesByMachineKey[n.MachineKey][userID] = nodeView + } + + return newSnap +} + +// GetNode retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *NodeStore) GetNode(id types.NodeID) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get").Inc() + + n, exists := s.data.Load().nodesByID[id] + if !exists { + return types.NodeView{}, false + } + + return n.View(), true +} + +// GetNodeByNodeKey retrieves a node by its NodeKey. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *NodeStore) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_key")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get_by_key").Inc() + + nodeView, exists := s.data.Load().nodesByNodeKey[nodeKey] + + return nodeView, exists +} + +// GetNodeByMachineKey returns a node by its machine key and user ID. The bool indicates if the node exists. +func (s *NodeStore) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get_by_machine_key").Inc() + + snapshot := s.data.Load() + if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists { + if node, exists := userMap[userID]; exists { + return node, true + } + } + + return types.NodeView{}, false +} + +// GetNodeByMachineKeyAnyUser returns the first node with the given machine key, +// regardless of which user it belongs to. This is useful for scenarios like +// transferring a node to a different user when re-authenticating with a +// different user's auth key. +// If multiple nodes exist with the same machine key (different users), the +// first one found is returned (order is not guaranteed). +func (s *NodeStore) GetNodeByMachineKeyAnyUser(machineKey key.MachinePublic) (types.NodeView, bool) { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("get_by_machine_key_any_user")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("get_by_machine_key_any_user").Inc() + + snapshot := s.data.Load() + if userMap, exists := snapshot.nodesByMachineKey[machineKey]; exists { + // Return the first node found (order not guaranteed due to map iteration) + for _, node := range userMap { + return node, true + } + } + + return types.NodeView{}, false +} + +// DebugString returns debug information about the NodeStore. +func (s *NodeStore) DebugString() string { + snapshot := s.data.Load() + + var sb strings.Builder + + sb.WriteString("=== NodeStore Debug Information ===\n\n") + + // Basic counts + sb.WriteString(fmt.Sprintf("Total Nodes: %d\n", len(snapshot.nodesByID))) + sb.WriteString(fmt.Sprintf("Users with Nodes: %d\n", len(snapshot.nodesByUser))) + sb.WriteString("\n") + + // User distribution (shows internal UserID tracking, not display owner) + sb.WriteString("Nodes by Internal User ID:\n") + for userID, nodes := range snapshot.nodesByUser { + if len(nodes) > 0 { + userName := "unknown" + taggedCount := 0 + if len(nodes) > 0 && nodes[0].Valid() { + userName = nodes[0].User().Name() + // Count tagged nodes (which have UserID set but are owned by "tagged-devices") + for _, n := range nodes { + if n.IsTagged() { + taggedCount++ + } + } + } + + if taggedCount > 0 { + sb.WriteString(fmt.Sprintf(" - User %d (%s): %d nodes (%d tagged)\n", userID, userName, len(nodes), taggedCount)) + } else { + sb.WriteString(fmt.Sprintf(" - User %d (%s): %d nodes\n", userID, userName, len(nodes))) + } + } + } + sb.WriteString("\n") + + // Peer relationships summary + sb.WriteString("Peer Relationships:\n") + totalPeers := 0 + for nodeID, peers := range snapshot.peersByNode { + peerCount := len(peers) + totalPeers += peerCount + if node, exists := snapshot.nodesByID[nodeID]; exists { + sb.WriteString(fmt.Sprintf(" - Node %d (%s): %d peers\n", + nodeID, node.Hostname, peerCount)) + } + } + if len(snapshot.peersByNode) > 0 { + avgPeers := float64(totalPeers) / float64(len(snapshot.peersByNode)) + sb.WriteString(fmt.Sprintf(" - Average peers per node: %.1f\n", avgPeers)) + } + sb.WriteString("\n") + + // Node key index + sb.WriteString(fmt.Sprintf("NodeKey Index: %d entries\n", len(snapshot.nodesByNodeKey))) + sb.WriteString("\n") + + return sb.String() +} + +// ListNodes returns a slice of all nodes in the store. +func (s *NodeStore) ListNodes() views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list").Inc() + + return views.SliceOf(s.data.Load().allNodes) +} + +// ListPeers returns a slice of all peers for a given node ID. +func (s *NodeStore) ListPeers(id types.NodeID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_peers")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_peers").Inc() + + return views.SliceOf(s.data.Load().peersByNode[id]) +} + +// RebuildPeerMaps rebuilds the peer relationship map using the current peersFunc. +// This must be called after policy changes because peersFunc uses PolicyManager's +// filters to determine which nodes can see each other. Without rebuilding, the +// peer map would use stale filter data until the next node add/delete. +func (s *NodeStore) RebuildPeerMaps() { + result := make(chan struct{}) + + w := work{ + op: rebuildPeerMaps, + rebuildResult: result, + } + + s.writeQueue <- w + <-result +} + +// ListNodesByUser returns a slice of all nodes for a given user ID. +func (s *NodeStore) ListNodesByUser(uid types.UserID) views.Slice[types.NodeView] { + timer := prometheus.NewTimer(nodeStoreOperationDuration.WithLabelValues("list_by_user")) + defer timer.ObserveDuration() + + nodeStoreOperations.WithLabelValues("list_by_user").Inc() + + return views.SliceOf(s.data.Load().nodesByUser[uid]) +} diff --git a/hscontrol/state/node_store_test.go b/hscontrol/state/node_store_test.go new file mode 100644 index 00000000..3d6184ba --- /dev/null +++ b/hscontrol/state/node_store_test.go @@ -0,0 +1,1243 @@ +package state + +import ( + "context" + "fmt" + "net/netip" + "runtime" + "sync" + "testing" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +func TestSnapshotFromNodes(t *testing.T) { + tests := []struct { + name string + setupFunc func() (map[types.NodeID]types.Node, PeersFunc) + validate func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) + }{ + { + name: "empty nodes", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := make(map[types.NodeID]types.Node) + peersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + return make(map[types.NodeID][]types.NodeView) + } + + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "single node", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + } + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, nodes[1].ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no other nodes, so no peers + assert.Len(t, snapshot.nodesByUser[1], 1) + assert.Equal(t, types.NodeID(1), snapshot.nodesByUser[1][0].ID()) + }, + }, + { + name: "multiple nodes same user", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 1, "user1", "node2"), + } + + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Each node sees the other as peer (but not itself) + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "multiple nodes different users", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 1, "user1", "node3"), + } + + return nodes, allowAllPeersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // Each node should have 2 peers (all others, but not itself) + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,3 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 2 + }, + }, + { + name: "odd-even peers filtering", + setupFunc: func() (map[types.NodeID]types.Node, PeersFunc) { + nodes := map[types.NodeID]types.Node{ + 1: createTestNode(1, 1, "user1", "node1"), + 2: createTestNode(2, 2, "user2", "node2"), + 3: createTestNode(3, 3, "user3", "node3"), + 4: createTestNode(4, 4, "user4", "node4"), + } + peersFunc := oddEvenPeersFunc + + return nodes, peersFunc + }, + validate: func(t *testing.T, nodes map[types.NodeID]types.Node, snapshot Snapshot) { + assert.Len(t, snapshot.nodesByID, 4) + assert.Len(t, snapshot.allNodes, 4) + assert.Len(t, snapshot.peersByNode, 4) + assert.Len(t, snapshot.nodesByUser, 4) + + // Odd nodes should only see other odd nodes as peers + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // Even nodes should only see other even nodes as peers + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + nodes, peersFunc := tt.setupFunc() + snapshot := snapshotFromNodes(nodes, peersFunc) + tt.validate(t, nodes, snapshot) + }) + } +} + +// Helper functions + +func createTestNode(nodeID types.NodeID, userID uint, username, hostname string) types.Node { + now := time.Now() + machineKey := key.NewMachine() + nodeKey := key.NewNode() + discoKey := key.NewDisco() + + ipv4 := netip.MustParseAddr("100.64.0.1") + ipv6 := netip.MustParseAddr("fd7a:115c:a1e0::1") + + return types.Node{ + ID: nodeID, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + DiscoKey: discoKey.Public(), + Hostname: hostname, + GivenName: hostname, + UserID: ptr.To(userID), + User: &types.User{ + Name: username, + DisplayName: username, + }, + RegisterMethod: "test", + IPv4: &ipv4, + IPv6: &ipv6, + CreatedAt: now, + UpdatedAt: now, + } +} + +// Peer functions + +func allowAllPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + for _, n := range nodes { + if n.ID() != node.ID() { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + + return ret +} + +func oddEvenPeersFunc(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + for _, node := range nodes { + var peers []types.NodeView + nodeIsOdd := node.ID()%2 == 1 + + for _, n := range nodes { + if n.ID() == node.ID() { + continue + } + + peerIsOdd := n.ID()%2 == 1 + + // Only add peer if both are odd or both are even + if nodeIsOdd == peerIsOdd { + peers = append(peers, n) + } + } + ret[node.ID()] = peers + } + + return ret +} + +func TestNodeStoreOperations(t *testing.T) { + tests := []struct { + name string + setupFunc func(t *testing.T) *NodeStore + steps []testStep + }{ + { + name: "create empty store and add single node", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "verify empty store", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + { + name: "add first node", + action: func(store *NodeStore) { + node := createTestNode(1, 1, "user1", "node1") + resultNode := store.PutNode(node) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, node.ID, resultNode.ID()) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + + require.Contains(t, snapshot.nodesByID, types.NodeID(1)) + assert.Equal(t, node.ID, snapshot.nodesByID[1].ID) + assert.Empty(t, snapshot.peersByNode[1]) // no peers yet + assert.Len(t, snapshot.nodesByUser[1], 1) + }, + }, + }, + }, + { + name: "create store with initial node and add more", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + initialNodes := types.Nodes{&node1} + + return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 1) + assert.Len(t, snapshot.allNodes, 1) + assert.Len(t, snapshot.peersByNode, 1) + assert.Len(t, snapshot.nodesByUser, 1) + assert.Empty(t, snapshot.peersByNode[1]) + }, + }, + { + name: "add second node same user", + action: func(store *NodeStore) { + node2 := createTestNode(2, 1, "user1", "node2") + resultNode := store.PutNode(node2) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, types.NodeID(2), resultNode.ID()) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 1) + + // Now both nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + assert.Len(t, snapshot.nodesByUser[1], 2) + }, + }, + { + name: "add third node different user", + action: func(store *NodeStore) { + node3 := createTestNode(3, 2, "user2", "node3") + resultNode := store.PutNode(node3) + assert.True(t, resultNode.Valid(), "PutNode should return valid node") + assert.Equal(t, types.NodeID(3), resultNode.ID()) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + + // All nodes should see the other 2 as peers + assert.Len(t, snapshot.peersByNode[1], 2) + assert.Len(t, snapshot.peersByNode[2], 2) + assert.Len(t, snapshot.peersByNode[3], 2) + + // User groupings + assert.Len(t, snapshot.nodesByUser[1], 2) // user1 has nodes 1,2 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 has node 3 + }, + }, + }, + }, + { + name: "test node deletion", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + node3 := createTestNode(3, 2, "user2", "node3") + initialNodes := types.Nodes{&node1, &node2, &node3} + + return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "verify initial 3 nodes", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Len(t, snapshot.allNodes, 3) + assert.Len(t, snapshot.peersByNode, 3) + assert.Len(t, snapshot.nodesByUser, 2) + }, + }, + { + name: "delete middle node", + action: func(store *NodeStore) { + store.DeleteNode(2) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Len(t, snapshot.allNodes, 2) + assert.Len(t, snapshot.peersByNode, 2) + assert.Len(t, snapshot.nodesByUser, 2) + + // Node 2 should be gone + assert.NotContains(t, snapshot.nodesByID, types.NodeID(2)) + + // Remaining nodes should see each other as peers + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + assert.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + // User groupings updated + assert.Len(t, snapshot.nodesByUser[1], 1) // user1 now has only node 1 + assert.Len(t, snapshot.nodesByUser[2], 1) // user2 still has node 3 + }, + }, + { + name: "delete all remaining nodes", + action: func(store *NodeStore) { + store.DeleteNode(1) + store.DeleteNode(3) + + snapshot := store.data.Load() + assert.Empty(t, snapshot.nodesByID) + assert.Empty(t, snapshot.allNodes) + assert.Empty(t, snapshot.peersByNode) + assert.Empty(t, snapshot.nodesByUser) + }, + }, + }, + }, + { + name: "test node updates", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + + return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "verify initial hostnames", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "update node hostname", + action: func(store *NodeStore) { + resultNode, ok := store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "updated-node1" + n.GivenName = "updated-node1" + }) + assert.True(t, ok, "UpdateNode should return true for existing node") + assert.True(t, resultNode.Valid(), "Result node should be valid") + assert.Equal(t, "updated-node1", resultNode.Hostname()) + assert.Equal(t, "updated-node1", resultNode.GivenName()) + + snapshot := store.data.Load() + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "updated-node1", snapshot.nodesByID[1].GivenName) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) // unchanged + + // Peers should still work correctly + assert.Len(t, snapshot.peersByNode[1], 1) + assert.Len(t, snapshot.peersByNode[2], 1) + }, + }, + }, + }, + { + name: "test with odd-even peers filtering", + setupFunc: func(t *testing.T) *NodeStore { + return NewNodeStore(nil, oddEvenPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "add nodes with odd-even filtering", + action: func(store *NodeStore) { + // Add nodes in sequence + n1 := store.PutNode(createTestNode(1, 1, "user1", "node1")) + assert.True(t, n1.Valid()) + n2 := store.PutNode(createTestNode(2, 2, "user2", "node2")) + assert.True(t, n2.Valid()) + n3 := store.PutNode(createTestNode(3, 3, "user3", "node3")) + assert.True(t, n3.Valid()) + n4 := store.PutNode(createTestNode(4, 4, "user4", "node4")) + assert.True(t, n4.Valid()) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 4) + + // Verify odd-even peer relationships + require.Len(t, snapshot.peersByNode[1], 1) + assert.Equal(t, types.NodeID(3), snapshot.peersByNode[1][0].ID()) + + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + + require.Len(t, snapshot.peersByNode[3], 1) + assert.Equal(t, types.NodeID(1), snapshot.peersByNode[3][0].ID()) + + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + { + name: "delete odd node and verify even nodes unaffected", + action: func(store *NodeStore) { + store.DeleteNode(1) + + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + + // Node 3 (odd) should now have no peers + assert.Empty(t, snapshot.peersByNode[3]) + + // Even nodes should still see each other + require.Len(t, snapshot.peersByNode[2], 1) + assert.Equal(t, types.NodeID(4), snapshot.peersByNode[2][0].ID()) + require.Len(t, snapshot.peersByNode[4], 1) + assert.Equal(t, types.NodeID(2), snapshot.peersByNode[4][0].ID()) + }, + }, + }, + }, + { + name: "test batch modifications return correct node state", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + + return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "verify initial state", + action: func(store *NodeStore) { + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 2) + assert.Equal(t, "node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "node2", snapshot.nodesByID[2].Hostname) + }, + }, + { + name: "concurrent updates should reflect all batch changes", + action: func(store *NodeStore) { + // Start multiple updates that will be batched together + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var resultNode1, resultNode2 types.NodeView + var newNode3 types.NodeView + var ok1, ok2 bool + + // These should all be processed in the same batch + go func() { + resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "batch-updated-node1" + n.GivenName = "batch-given-1" + }) + close(done1) + }() + + go func() { + resultNode2, ok2 = store.UpdateNode(2, func(n *types.Node) { + n.Hostname = "batch-updated-node2" + n.GivenName = "batch-given-2" + }) + close(done2) + }() + + go func() { + node3 := createTestNode(3, 1, "user1", "node3") + newNode3 = store.PutNode(node3) + close(done3) + }() + + // Wait for all operations to complete + <-done1 + <-done2 + <-done3 + + // Verify the returned nodes reflect the batch state + assert.True(t, ok1, "UpdateNode should succeed for node 1") + assert.True(t, ok2, "UpdateNode should succeed for node 2") + assert.True(t, resultNode1.Valid()) + assert.True(t, resultNode2.Valid()) + assert.True(t, newNode3.Valid()) + + // Check that returned nodes have the updated values + assert.Equal(t, "batch-updated-node1", resultNode1.Hostname()) + assert.Equal(t, "batch-given-1", resultNode1.GivenName()) + assert.Equal(t, "batch-updated-node2", resultNode2.Hostname()) + assert.Equal(t, "batch-given-2", resultNode2.GivenName()) + assert.Equal(t, "node3", newNode3.Hostname()) + + // Verify the snapshot also reflects all changes + snapshot := store.data.Load() + assert.Len(t, snapshot.nodesByID, 3) + assert.Equal(t, "batch-updated-node1", snapshot.nodesByID[1].Hostname) + assert.Equal(t, "batch-updated-node2", snapshot.nodesByID[2].Hostname) + assert.Equal(t, "node3", snapshot.nodesByID[3].Hostname) + + // Verify peer relationships are updated correctly with new node + assert.Len(t, snapshot.peersByNode[1], 2) // sees nodes 2 and 3 + assert.Len(t, snapshot.peersByNode[2], 2) // sees nodes 1 and 3 + assert.Len(t, snapshot.peersByNode[3], 2) // sees nodes 1 and 2 + }, + }, + { + name: "update non-existent node returns invalid view", + action: func(store *NodeStore) { + resultNode, ok := store.UpdateNode(999, func(n *types.Node) { + n.Hostname = "should-not-exist" + }) + + assert.False(t, ok, "UpdateNode should return false for non-existent node") + assert.False(t, resultNode.Valid(), "Result should be invalid NodeView") + }, + }, + { + name: "multiple updates to same node in batch all see final state", + action: func(store *NodeStore) { + // This test verifies that when multiple updates to the same node + // are batched together, each returned node reflects ALL changes + // in the batch, not just the individual update's changes. + + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var resultNode1, resultNode2, resultNode3 types.NodeView + var ok1, ok2, ok3 bool + + // These updates all modify node 1 and should be batched together + // The final state should have all three modifications applied + go func() { + resultNode1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "multi-update-hostname" + }) + close(done1) + }() + + go func() { + resultNode2, ok2 = store.UpdateNode(1, func(n *types.Node) { + n.GivenName = "multi-update-givenname" + }) + close(done2) + }() + + go func() { + resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) { + n.Tags = []string{"tag1", "tag2"} + }) + close(done3) + }() + + // Wait for all operations to complete + <-done1 + <-done2 + <-done3 + + // All updates should succeed + assert.True(t, ok1, "First update should succeed") + assert.True(t, ok2, "Second update should succeed") + assert.True(t, ok3, "Third update should succeed") + + // CRITICAL: Each returned node should reflect ALL changes from the batch + // not just the change from its specific update call + + // resultNode1 (from hostname update) should also have the givenname and tags changes + assert.Equal(t, "multi-update-hostname", resultNode1.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode1.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.Tags().AsSlice()) + + // resultNode2 (from givenname update) should also have the hostname and tags changes + assert.Equal(t, "multi-update-hostname", resultNode2.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode2.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.Tags().AsSlice()) + + // resultNode3 (from tags update) should also have the hostname and givenname changes + assert.Equal(t, "multi-update-hostname", resultNode3.Hostname()) + assert.Equal(t, "multi-update-givenname", resultNode3.GivenName()) + assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.Tags().AsSlice()) + + // Verify the snapshot also has all changes + snapshot := store.data.Load() + finalNode := snapshot.nodesByID[1] + assert.Equal(t, "multi-update-hostname", finalNode.Hostname) + assert.Equal(t, "multi-update-givenname", finalNode.GivenName) + assert.Equal(t, []string{"tag1", "tag2"}, finalNode.Tags) + }, + }, + }, + }, + { + name: "test UpdateNode result is immutable for database save", + setupFunc: func(t *testing.T) *NodeStore { + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 1, "user1", "node2") + initialNodes := types.Nodes{&node1, &node2} + + return NewNodeStore(initialNodes, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + }, + steps: []testStep{ + { + name: "verify returned node is complete and consistent", + action: func(store *NodeStore) { + // Update a node and verify the returned view is complete + resultNode, ok := store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "db-save-hostname" + n.GivenName = "db-save-given" + n.Tags = []string{"db-tag1", "db-tag2"} + }) + + assert.True(t, ok, "UpdateNode should succeed") + assert.True(t, resultNode.Valid(), "Result should be valid") + + // Verify the returned node has all expected values + assert.Equal(t, "db-save-hostname", resultNode.Hostname()) + assert.Equal(t, "db-save-given", resultNode.GivenName()) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.Tags().AsSlice()) + + // Convert to struct as would be done for database save + nodePtr := resultNode.AsStruct() + assert.NotNil(t, nodePtr) + assert.Equal(t, "db-save-hostname", nodePtr.Hostname) + assert.Equal(t, "db-save-given", nodePtr.GivenName) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.Tags) + + // Verify the snapshot also reflects the same state + snapshot := store.data.Load() + storedNode := snapshot.nodesByID[1] + assert.Equal(t, "db-save-hostname", storedNode.Hostname) + assert.Equal(t, "db-save-given", storedNode.GivenName) + assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.Tags) + }, + }, + { + name: "concurrent updates all return consistent final state for DB save", + action: func(store *NodeStore) { + // Multiple goroutines updating the same node + // All should receive the final batch state suitable for DB save + done1 := make(chan struct{}) + done2 := make(chan struct{}) + done3 := make(chan struct{}) + + var result1, result2, result3 types.NodeView + var ok1, ok2, ok3 bool + + // Start concurrent updates + go func() { + result1, ok1 = store.UpdateNode(1, func(n *types.Node) { + n.Hostname = "concurrent-db-hostname" + }) + close(done1) + }() + + go func() { + result2, ok2 = store.UpdateNode(1, func(n *types.Node) { + n.GivenName = "concurrent-db-given" + }) + close(done2) + }() + + go func() { + result3, ok3 = store.UpdateNode(1, func(n *types.Node) { + n.Tags = []string{"concurrent-tag"} + }) + close(done3) + }() + + // Wait for all to complete + <-done1 + <-done2 + <-done3 + + assert.True(t, ok1 && ok2 && ok3, "All updates should succeed") + + // All results should be valid and suitable for database save + assert.True(t, result1.Valid()) + assert.True(t, result2.Valid()) + assert.True(t, result3.Valid()) + + // Convert each to struct as would be done for DB save + nodePtr1 := result1.AsStruct() + nodePtr2 := result2.AsStruct() + nodePtr3 := result3.AsStruct() + + // All should have the complete final state + assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.Tags) + + assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.Tags) + + assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname) + assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName) + assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.Tags) + + // Verify consistency with stored state + snapshot := store.data.Load() + storedNode := snapshot.nodesByID[1] + assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname) + assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName) + assert.Equal(t, nodePtr1.Tags, storedNode.Tags) + }, + }, + { + name: "verify returned node preserves all fields for DB save", + action: func(store *NodeStore) { + // Get initial state + snapshot := store.data.Load() + originalNode := snapshot.nodesByID[2] + originalIPv4 := originalNode.IPv4 + originalIPv6 := originalNode.IPv6 + originalCreatedAt := originalNode.CreatedAt + originalUser := originalNode.User + + // Update only hostname + resultNode, ok := store.UpdateNode(2, func(n *types.Node) { + n.Hostname = "preserve-test-hostname" + }) + + assert.True(t, ok, "Update should succeed") + + // Convert to struct for DB save + nodeForDB := resultNode.AsStruct() + + // Verify all fields are preserved + assert.Equal(t, "preserve-test-hostname", nodeForDB.Hostname) + assert.Equal(t, originalIPv4, nodeForDB.IPv4) + assert.Equal(t, originalIPv6, nodeForDB.IPv6) + assert.Equal(t, originalCreatedAt, nodeForDB.CreatedAt) + assert.Equal(t, originalUser.Name, nodeForDB.User.Name) + assert.Equal(t, types.NodeID(2), nodeForDB.ID) + + // These fields should be suitable for direct database save + assert.NotNil(t, nodeForDB.IPv4) + assert.NotNil(t, nodeForDB.IPv6) + assert.False(t, nodeForDB.CreatedAt.IsZero()) + }, + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + store := tt.setupFunc(t) + store.Start() + defer store.Stop() + + for _, step := range tt.steps { + t.Run(step.name, func(t *testing.T) { + step.action(store) + }) + } + }) + } +} + +type testStep struct { + name string + action func(store *NodeStore) +} + +// --- Additional NodeStore concurrency, batching, race, resource, timeout, and allocation tests --- + +// Helper for concurrent test nodes +func createConcurrentTestNode(id types.NodeID, hostname string) types.Node { + machineKey := key.NewMachine() + nodeKey := key.NewNode() + return types.Node{ + ID: id, + Hostname: hostname, + MachineKey: machineKey.Public(), + NodeKey: nodeKey.Public(), + UserID: ptr.To(uint(1)), + User: &types.User{ + Name: "concurrent-test-user", + }, + } +} + +// --- Concurrency: concurrent PutNode operations --- +func TestNodeStoreConcurrentPutNode(t *testing.T) { + const concurrentOps = 20 + + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + var wg sync.WaitGroup + results := make(chan bool, concurrentOps) + for i := range concurrentOps { + wg.Add(1) + go func(nodeID int) { + defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "concurrent-node") + resultNode := store.PutNode(node) + results <- resultNode.Valid() + }(i + 1) + } + wg.Wait() + close(results) + + successCount := 0 + for success := range results { + if success { + successCount++ + } + } + require.Equal(t, concurrentOps, successCount, "All concurrent PutNode operations should succeed") +} + +// --- Batching: concurrent ops fit in one batch --- +func TestNodeStoreBatchingEfficiency(t *testing.T) { + const batchSize = 10 + const ops = 15 // more than batchSize + + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + var wg sync.WaitGroup + results := make(chan bool, ops) + for i := range ops { + wg.Add(1) + go func(nodeID int) { + defer wg.Done() + node := createConcurrentTestNode(types.NodeID(nodeID), "batch-node") + resultNode := store.PutNode(node) + results <- resultNode.Valid() + }(i + 1) + } + wg.Wait() + close(results) + + successCount := 0 + for success := range results { + if success { + successCount++ + } + } + require.Equal(t, ops, successCount, "All batch PutNode operations should succeed") +} + +// --- Race conditions: many goroutines on same node --- +func TestNodeStoreRaceConditions(t *testing.T) { + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + nodeID := types.NodeID(1) + node := createConcurrentTestNode(nodeID, "race-node") + resultNode := store.PutNode(node) + require.True(t, resultNode.Valid()) + + const numGoroutines = 30 + const opsPerGoroutine = 10 + var wg sync.WaitGroup + errors := make(chan error, numGoroutines*opsPerGoroutine) + + for i := range numGoroutines { + wg.Add(1) + go func(gid int) { + defer wg.Done() + + for j := range opsPerGoroutine { + switch j % 3 { + case 0: + resultNode, _ := store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "race-updated" + }) + if !resultNode.Valid() { + errors <- fmt.Errorf("UpdateNode failed in goroutine %d, op %d", gid, j) + } + case 1: + retrieved, found := store.GetNode(nodeID) + if !found || !retrieved.Valid() { + errors <- fmt.Errorf("GetNode failed in goroutine %d, op %d", gid, j) + } + case 2: + newNode := createConcurrentTestNode(nodeID, "race-put") + resultNode := store.PutNode(newNode) + if !resultNode.Valid() { + errors <- fmt.Errorf("PutNode failed in goroutine %d, op %d", gid, j) + } + } + } + }(i) + } + wg.Wait() + close(errors) + + errorCount := 0 + for err := range errors { + t.Error(err) + errorCount++ + } + if errorCount > 0 { + t.Fatalf("Race condition test failed with %d errors", errorCount) + } +} + +// --- Resource cleanup: goroutine leak detection --- +func TestNodeStoreResourceCleanup(t *testing.T) { + // initialGoroutines := runtime.NumGoroutine() + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + // Wait for store to be ready + var afterStartGoroutines int + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + afterStartGoroutines = runtime.NumGoroutine() + assert.Positive(c, afterStartGoroutines) // Just ensure we have a valid count + }, time.Second, 10*time.Millisecond, "store should be running") + + const ops = 100 + for i := range ops { + nodeID := types.NodeID(i + 1) + node := createConcurrentTestNode(nodeID, "cleanup-node") + resultNode := store.PutNode(node) + assert.True(t, resultNode.Valid()) + store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "cleanup-updated" + }) + retrieved, found := store.GetNode(nodeID) + assert.True(t, found && retrieved.Valid()) + if i%10 == 9 { + store.DeleteNode(nodeID) + } + } + runtime.GC() + + // Wait for goroutines to settle and check for leaks + assert.EventuallyWithT(t, func(c *assert.CollectT) { + finalGoroutines := runtime.NumGoroutine() + assert.LessOrEqual(c, finalGoroutines, afterStartGoroutines+2, + "Potential goroutine leak: started with %d, ended with %d", afterStartGoroutines, finalGoroutines) + }, time.Second, 10*time.Millisecond, "goroutines should not leak") +} + +// --- Timeout/deadlock: operations complete within reasonable time --- +func TestNodeStoreOperationTimeout(t *testing.T) { + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + const ops = 30 + var wg sync.WaitGroup + putResults := make([]error, ops) + updateResults := make([]error, ops) + + // Launch all PutNode operations concurrently + for i := 1; i <= ops; i++ { + nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { + defer wg.Done() + startPut := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) starting\n", startPut.Format("15:04:05.000"), id) + node := createConcurrentTestNode(id, "timeout-node") + resultNode := store.PutNode(node) + endPut := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: PutNode(%d) finished, valid=%v, duration=%v\n", endPut.Format("15:04:05.000"), id, resultNode.Valid(), endPut.Sub(startPut)) + if !resultNode.Valid() { + putResults[idx-1] = fmt.Errorf("PutNode failed for node %d", id) + } + }(i, nodeID) + } + wg.Wait() + + // Launch all UpdateNode operations concurrently + wg = sync.WaitGroup{} + for i := 1; i <= ops; i++ { + nodeID := types.NodeID(i) + wg.Add(1) + go func(idx int, id types.NodeID) { + defer wg.Done() + startUpdate := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) starting\n", startUpdate.Format("15:04:05.000"), id) + resultNode, ok := store.UpdateNode(id, func(n *types.Node) { + n.Hostname = "timeout-updated" + }) + endUpdate := time.Now() + fmt.Printf("[TestNodeStoreOperationTimeout] %s: UpdateNode(%d) finished, valid=%v, ok=%v, duration=%v\n", endUpdate.Format("15:04:05.000"), id, resultNode.Valid(), ok, endUpdate.Sub(startUpdate)) + if !ok || !resultNode.Valid() { + updateResults[idx-1] = fmt.Errorf("UpdateNode failed for node %d", id) + } + }(i, nodeID) + } + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + errorCount := 0 + for _, err := range putResults { + if err != nil { + t.Error(err) + errorCount++ + } + } + for _, err := range updateResults { + if err != nil { + t.Error(err) + errorCount++ + } + } + if errorCount == 0 { + t.Log("All concurrent operations completed successfully within timeout") + } else { + t.Fatalf("Some concurrent operations failed: %d errors", errorCount) + } + case <-ctx.Done(): + fmt.Println("[TestNodeStoreOperationTimeout] Timeout reached, test failed") + t.Fatal("Operations timed out - potential deadlock or resource issue") + } +} + +// --- Edge case: update non-existent node --- +func TestNodeStoreUpdateNonExistentNode(t *testing.T) { + for i := range 10 { + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + nonExistentID := types.NodeID(999 + i) + updateCallCount := 0 + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) starting\n", nonExistentID) + resultNode, ok := store.UpdateNode(nonExistentID, func(n *types.Node) { + updateCallCount++ + n.Hostname = "should-never-be-called" + }) + fmt.Printf("[TestNodeStoreUpdateNonExistentNode] UpdateNode(%d) finished, valid=%v, ok=%v, updateCallCount=%d\n", nonExistentID, resultNode.Valid(), ok, updateCallCount) + assert.False(t, ok, "UpdateNode should return false for non-existent node") + assert.False(t, resultNode.Valid(), "UpdateNode should return invalid node for non-existent node") + assert.Equal(t, 0, updateCallCount, "UpdateFn should not be called for non-existent node") + store.Stop() + } +} + +// --- Allocation benchmark --- +func BenchmarkNodeStoreAllocations(b *testing.B) { + store := NewNodeStore(nil, allowAllPeersFunc, TestBatchSize, TestBatchTimeout) + store.Start() + defer store.Stop() + + for i := 0; b.Loop(); i++ { + nodeID := types.NodeID(i + 1) + node := createConcurrentTestNode(nodeID, "bench-node") + store.PutNode(node) + store.UpdateNode(nodeID, func(n *types.Node) { + n.Hostname = "bench-updated" + }) + store.GetNode(nodeID) + if i%10 == 9 { + store.DeleteNode(nodeID) + } + } +} + +func TestNodeStoreAllocationStats(t *testing.T) { + res := testing.Benchmark(BenchmarkNodeStoreAllocations) + allocs := res.AllocsPerOp() + t.Logf("NodeStore allocations per op: %.2f", float64(allocs)) +} + +// TestRebuildPeerMapsWithChangedPeersFunc tests that RebuildPeerMaps correctly +// rebuilds the peer map when the peersFunc behavior changes. +// This simulates what happens when SetNodeTags changes node tags and the +// PolicyManager's matchers are updated, requiring the peer map to be rebuilt. +func TestRebuildPeerMapsWithChangedPeersFunc(t *testing.T) { + // Create a peersFunc that can be controlled via a channel + // Initially it returns all nodes as peers, then we change it to return no peers + allowPeers := true + + // This simulates how PolicyManager.BuildPeerMap works - it reads state + // that can change between calls + dynamicPeersFunc := func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + ret := make(map[types.NodeID][]types.NodeView, len(nodes)) + if allowPeers { + // Allow all peers + for _, node := range nodes { + var peers []types.NodeView + + for _, n := range nodes { + if n.ID() != node.ID() { + peers = append(peers, n) + } + } + + ret[node.ID()] = peers + } + } else { + // Allow no peers + for _, node := range nodes { + ret[node.ID()] = []types.NodeView{} + } + } + + return ret + } + + // Create nodes + node1 := createTestNode(1, 1, "user1", "node1") + node2 := createTestNode(2, 2, "user2", "node2") + initialNodes := types.Nodes{&node1, &node2} + + // Create store with dynamic peersFunc + store := NewNodeStore(initialNodes, dynamicPeersFunc, TestBatchSize, TestBatchTimeout) + + store.Start() + defer store.Stop() + + // Initially, nodes should see each other as peers + snapshot := store.data.Load() + require.Len(t, snapshot.peersByNode[1], 1, "node1 should have 1 peer initially") + require.Len(t, snapshot.peersByNode[2], 1, "node2 should have 1 peer initially") + require.Equal(t, types.NodeID(2), snapshot.peersByNode[1][0].ID()) + require.Equal(t, types.NodeID(1), snapshot.peersByNode[2][0].ID()) + + // Now "change the policy" by disabling peers + allowPeers = false + + // Call RebuildPeerMaps to rebuild with the new behavior + store.RebuildPeerMaps() + + // After rebuild, nodes should have no peers + snapshot = store.data.Load() + assert.Empty(t, snapshot.peersByNode[1], "node1 should have no peers after rebuild") + assert.Empty(t, snapshot.peersByNode[2], "node2 should have no peers after rebuild") + + // Verify that ListPeers returns the correct result + peers1 := store.ListPeers(1) + peers2 := store.ListPeers(2) + + assert.Equal(t, 0, peers1.Len(), "ListPeers for node1 should return empty") + assert.Equal(t, 0, peers2.Len(), "ListPeers for node2 should return empty") + + // Now re-enable peers and rebuild again + allowPeers = true + + store.RebuildPeerMaps() + + // Nodes should see each other again + snapshot = store.data.Load() + require.Len(t, snapshot.peersByNode[1], 1, "node1 should have 1 peer after re-enabling") + require.Len(t, snapshot.peersByNode[2], 1, "node2 should have 1 peer after re-enabling") + + peers1 = store.ListPeers(1) + peers2 = store.ListPeers(2) + + assert.Equal(t, 1, peers1.Len(), "ListPeers for node1 should return 1") + assert.Equal(t, 1, peers2.Len(), "ListPeers for node2 should return 1") +} diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go new file mode 100644 index 00000000..d1401ef0 --- /dev/null +++ b/hscontrol/state/state.go @@ -0,0 +1,2170 @@ +// Package state provides core state management for Headscale, coordinating +// between subsystems like database, IP allocation, policy management, and DERP routing. + +package state + +import ( + "cmp" + "context" + "errors" + "fmt" + "net/netip" + "slices" + "strings" + "sync" + "sync/atomic" + "time" + + hsdb "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/routes" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/types/change" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" + "golang.org/x/sync/errgroup" + "gorm.io/gorm" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" + "tailscale.com/types/views" + zcache "zgo.at/zcache/v2" +) + +const ( + // registerCacheExpiration defines how long node registration entries remain in cache. + registerCacheExpiration = time.Minute * 15 + + // registerCacheCleanup defines the interval for cleaning up expired cache entries. + registerCacheCleanup = time.Minute * 20 + + // defaultNodeStoreBatchSize is the default number of write operations to batch + // before rebuilding the in-memory node snapshot. + defaultNodeStoreBatchSize = 100 + + // defaultNodeStoreBatchTimeout is the default maximum time to wait before + // processing a partial batch of node operations. + defaultNodeStoreBatchTimeout = 500 * time.Millisecond +) + +// ErrUnsupportedPolicyMode is returned for invalid policy modes. Valid modes are "file" and "db". +var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode") + +// ErrNodeNotFound is returned when a node cannot be found by its ID. +var ErrNodeNotFound = errors.New("node not found") + +// ErrInvalidNodeView is returned when an invalid node view is provided. +var ErrInvalidNodeView = errors.New("invalid node view provided") + +// ErrNodeNotInNodeStore is returned when a node no longer exists in the NodeStore. +var ErrNodeNotInNodeStore = errors.New("node no longer exists in NodeStore") + +// ErrNodeNameNotUnique is returned when a node name is not unique. +var ErrNodeNameNotUnique = errors.New("node name is not unique") + +// State manages Headscale's core state, coordinating between database, policy management, +// IP allocation, and DERP routing. All methods are thread-safe. +type State struct { + // cfg holds the current Headscale configuration + cfg *types.Config + + // nodeStore provides an in-memory cache for nodes. + nodeStore *NodeStore + + // subsystem keeping state + // db provides persistent storage and database operations + db *hsdb.HSDatabase + // ipAlloc manages IP address allocation for nodes + ipAlloc *hsdb.IPAllocator + // derpMap contains the current DERP relay configuration + derpMap atomic.Pointer[tailcfg.DERPMap] + // polMan handles policy evaluation and management + polMan policy.PolicyManager + // registrationCache caches node registration data to reduce database load + registrationCache *zcache.Cache[types.RegistrationID, types.RegisterNode] + // primaryRoutes tracks primary route assignments for nodes + primaryRoutes *routes.PrimaryRoutes +} + +// NewState creates and initializes a new State instance, setting up the database, +// IP allocator, DERP map, policy manager, and loading existing users and nodes. +func NewState(cfg *types.Config) (*State, error) { + cacheExpiration := registerCacheExpiration + if cfg.Tuning.RegisterCacheExpiration != 0 { + cacheExpiration = cfg.Tuning.RegisterCacheExpiration + } + + cacheCleanup := registerCacheCleanup + if cfg.Tuning.RegisterCacheCleanup != 0 { + cacheCleanup = cfg.Tuning.RegisterCacheCleanup + } + + registrationCache := zcache.New[types.RegistrationID, types.RegisterNode]( + cacheExpiration, + cacheCleanup, + ) + + registrationCache.OnEvicted( + func(id types.RegistrationID, rn types.RegisterNode) { + rn.SendAndClose(nil) + }, + ) + + db, err := hsdb.NewHeadscaleDatabase( + cfg, + registrationCache, + ) + if err != nil { + return nil, fmt.Errorf("init database: %w", err) + } + + ipAlloc, err := hsdb.NewIPAllocator(db, cfg.PrefixV4, cfg.PrefixV6, cfg.IPAllocation) + if err != nil { + return nil, fmt.Errorf("init ip allocatior: %w", err) + } + + nodes, err := db.ListNodes() + if err != nil { + return nil, fmt.Errorf("loading nodes: %w", err) + } + + // On startup, all nodes should be marked as offline until they reconnect + // This ensures we don't have stale online status from previous runs + for _, node := range nodes { + node.IsOnline = ptr.To(false) + } + users, err := db.ListUsers() + if err != nil { + return nil, fmt.Errorf("loading users: %w", err) + } + + pol, err := hsdb.PolicyBytes(db.DB, cfg) + if err != nil { + return nil, fmt.Errorf("loading policy: %w", err) + } + + polMan, err := policy.NewPolicyManager(pol, users, nodes.ViewSlice()) + if err != nil { + return nil, fmt.Errorf("init policy manager: %w", err) + } + + // Apply defaults for NodeStore batch configuration if not set. + // This ensures tests that create Config directly (without viper) still work. + batchSize := cfg.Tuning.NodeStoreBatchSize + if batchSize == 0 { + batchSize = defaultNodeStoreBatchSize + } + batchTimeout := cfg.Tuning.NodeStoreBatchTimeout + if batchTimeout == 0 { + batchTimeout = defaultNodeStoreBatchTimeout + } + + // PolicyManager.BuildPeerMap handles both global and per-node filter complexity. + // This moves the complex peer relationship logic into the policy package where it belongs. + nodeStore := NewNodeStore( + nodes, + func(nodes []types.NodeView) map[types.NodeID][]types.NodeView { + return polMan.BuildPeerMap(views.SliceOf(nodes)) + }, + batchSize, + batchTimeout, + ) + nodeStore.Start() + + return &State{ + cfg: cfg, + + db: db, + ipAlloc: ipAlloc, + polMan: polMan, + registrationCache: registrationCache, + primaryRoutes: routes.New(), + nodeStore: nodeStore, + }, nil +} + +// Close gracefully shuts down the State instance and releases all resources. +func (s *State) Close() error { + s.nodeStore.Stop() + + if err := s.db.Close(); err != nil { + return fmt.Errorf("closing database: %w", err) + } + + return nil +} + +// SetDERPMap updates the DERP relay configuration. +func (s *State) SetDERPMap(dm *tailcfg.DERPMap) { + s.derpMap.Store(dm) +} + +// DERPMap returns the current DERP relay configuration for peer-to-peer connectivity. +func (s *State) DERPMap() tailcfg.DERPMapView { + return s.derpMap.Load().View() +} + +// ReloadPolicy reloads the access control policy and triggers auto-approval if changed. +// Returns true if the policy changed. +func (s *State) ReloadPolicy() ([]change.Change, error) { + pol, err := hsdb.PolicyBytes(s.db.DB, s.cfg) + if err != nil { + return nil, fmt.Errorf("loading policy: %w", err) + } + + policyChanged, err := s.polMan.SetPolicy(pol) + if err != nil { + return nil, fmt.Errorf("setting policy: %w", err) + } + + // Rebuild peer maps after policy changes because the peersFunc in NodeStore + // uses the PolicyManager's filters. Without this, nodes won't see newly allowed + // peers until a node is added/removed, causing autogroup:self policies to not + // propagate correctly when switching between policy types. + s.nodeStore.RebuildPeerMaps() + + cs := []change.Change{change.PolicyChange()} + + // Always call autoApproveNodes during policy reload, regardless of whether + // the policy content has changed. This ensures that routes are re-evaluated + // when they might have been manually disabled but could now be auto-approved + // with the current policy. + rcs, err := s.autoApproveNodes() + if err != nil { + return nil, fmt.Errorf("auto approving nodes: %w", err) + } + + // TODO(kradalby): These changes can probably be safely ignored. + // If the PolicyChange is happening, that will lead to a full update + // meaning that we do not need to send individual route changes. + cs = append(cs, rcs...) + + if len(rcs) > 0 || policyChanged { + log.Info(). + Bool("policy.changed", policyChanged). + Int("route.changes", len(rcs)). + Int("total.changes", len(cs)). + Msg("Policy reload completed with changes") + } + + return cs, nil +} + +// CreateUser creates a new user and updates the policy manager. +// Returns the created user, change set, and any error. +func (s *State) CreateUser(user types.User) (*types.User, change.Change, error) { + if err := s.db.DB.Save(&user).Error; err != nil { + return nil, change.Change{}, fmt.Errorf("creating user: %w", err) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerUsers() + if err != nil { + // Log the error but don't fail the user creation + return &user, change.Change{}, fmt.Errorf("failed to update policy manager after user creation: %w", err) + } + + // Even if the policy manager doesn't detect a filter change, SSH policies + // might now be resolvable when they weren't before. If there are existing + // nodes, we should send a policy change to ensure they get updated SSH policies. + // TODO(kradalby): detect this, or rebuild all SSH policies so we can determine + // this upstream. + if c.IsEmpty() { + c = change.PolicyChange() + } + + log.Info().Str("user.name", user.Name).Msg("User created") + + return &user, c, nil +} + +// UpdateUser modifies an existing user using the provided update function within a transaction. +// Returns the updated user, change set, and any error. +func (s *State) UpdateUser(userID types.UserID, updateFn func(*types.User) error) (*types.User, change.Change, error) { + user, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.User, error) { + user, err := hsdb.GetUserByID(tx, userID) + if err != nil { + return nil, err + } + + if err := updateFn(user); err != nil { + return nil, err + } + + // Use Updates() to only update modified fields, preserving unchanged values. + err = tx.Updates(user).Error + if err != nil { + return nil, fmt.Errorf("updating user: %w", err) + } + + return user, nil + }) + if err != nil { + return nil, change.Change{}, err + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerUsers() + if err != nil { + return user, change.Change{}, fmt.Errorf("failed to update policy manager after user update: %w", err) + } + + // TODO(kradalby): We might want to update nodestore with the user data + + return user, c, nil +} + +// DeleteUser permanently removes a user and all associated data (nodes, API keys, etc). +// This operation is irreversible. +// It also updates the policy manager to ensure ACL policies referencing the deleted +// user are re-evaluated immediately, fixing issue #2967. +func (s *State) DeleteUser(userID types.UserID) (change.Change, error) { + err := s.db.DestroyUser(userID) + if err != nil { + return change.Change{}, err + } + + // Update policy manager with the new user list (without the deleted user) + // This ensures that if the policy references the deleted user, it gets + // re-evaluated immediately rather than when some other operation triggers it. + c, err := s.updatePolicyManagerUsers() + if err != nil { + return change.Change{}, fmt.Errorf("updating policy after user deletion: %w", err) + } + + // If the policy manager doesn't detect changes, still return UserRemoved + // to ensure peer lists are refreshed + if c.IsEmpty() { + c = change.UserRemoved() + } + + return c, nil +} + +// RenameUser changes a user's name. The new name must be unique. +func (s *State) RenameUser(userID types.UserID, newName string) (*types.User, change.Change, error) { + return s.UpdateUser(userID, func(user *types.User) error { + user.Name = newName + return nil + }) +} + +// GetUserByID retrieves a user by ID. +func (s *State) GetUserByID(userID types.UserID) (*types.User, error) { + return s.db.GetUserByID(userID) +} + +// GetUserByName retrieves a user by name. +func (s *State) GetUserByName(name string) (*types.User, error) { + return s.db.GetUserByName(name) +} + +// GetUserByOIDCIdentifier retrieves a user by their OIDC identifier. +func (s *State) GetUserByOIDCIdentifier(id string) (*types.User, error) { + return s.db.GetUserByOIDCIdentifier(id) +} + +// ListUsersWithFilter retrieves users matching the specified filter criteria. +func (s *State) ListUsersWithFilter(filter *types.User) ([]types.User, error) { + return s.db.ListUsers(filter) +} + +// ListAllUsers retrieves all users in the system. +func (s *State) ListAllUsers() ([]types.User, error) { + return s.db.ListUsers() +} + +// persistNodeToDB saves the given node state to the database. +// This function must receive the exact node state to save to ensure consistency between +// NodeStore and the database. It verifies the node still exists in NodeStore to prevent +// race conditions where a node might be deleted between UpdateNode returning and +// persistNodeToDB being called. +func (s *State) persistNodeToDB(node types.NodeView) (types.NodeView, change.Change, error) { + if !node.Valid() { + return types.NodeView{}, change.Change{}, ErrInvalidNodeView + } + + // Verify the node still exists in NodeStore before persisting to database. + // Without this check, we could hit a race condition where UpdateNode returns a valid + // node from a batch update, then the node gets deleted (e.g., ephemeral node logout), + // and persistNodeToDB would incorrectly re-insert the deleted node into the database. + _, exists := s.nodeStore.GetNode(node.ID()) + if !exists { + log.Warn(). + Uint64("node.id", node.ID().Uint64()). + Str("node.name", node.Hostname()). + Bool("is_ephemeral", node.IsEphemeral()). + Msg("Node no longer exists in NodeStore, skipping database persist to prevent race condition") + + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, node.ID()) + } + + nodePtr := node.AsStruct() + + // Use Omit("expiry") to prevent overwriting expiry during MapRequest updates. + // Expiry should only be updated through explicit SetNodeExpiry calls or re-registration. + // See: https://github.com/juanfont/headscale/issues/2862 + err := s.db.DB.Omit("expiry").Updates(nodePtr).Error + if err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("saving node: %w", err) + } + + // Check if policy manager needs updating + c, err := s.updatePolicyManagerNodes() + if err != nil { + return nodePtr.View(), change.Change{}, fmt.Errorf("failed to update policy manager after node save: %w", err) + } + + if c.IsEmpty() { + c = change.NodeAdded(node.ID()) + } + + return node, c, nil +} + +func (s *State) SaveNode(node types.NodeView) (types.NodeView, change.Change, error) { + // Update NodeStore first + nodePtr := node.AsStruct() + + resultNode := s.nodeStore.PutNode(*nodePtr) + + // Then save to database using the result from PutNode + return s.persistNodeToDB(resultNode) +} + +// DeleteNode permanently removes a node and cleans up associated resources. +// Returns whether policies changed and any error. This operation is irreversible. +func (s *State) DeleteNode(node types.NodeView) (change.Change, error) { + s.nodeStore.DeleteNode(node.ID()) + + err := s.db.DeleteNode(node.AsStruct()) + if err != nil { + return change.Change{}, err + } + + s.ipAlloc.FreeIPs(node.IPs()) + + c := change.NodeRemoved(node.ID()) + + // Check if policy manager needs updating after node deletion + policyChange, err := s.updatePolicyManagerNodes() + if err != nil { + return change.Change{}, fmt.Errorf("failed to update policy manager after node deletion: %w", err) + } + + if !policyChange.IsEmpty() { + // Merge policy change with NodeRemoved to preserve PeersRemoved info + // This ensures the batcher cleans up the deleted node from its state + c = c.Merge(policyChange) + } + + return c, nil +} + +// Connect marks a node as connected and updates its primary routes in the state. +func (s *State) Connect(id types.NodeID) []change.Change { + // CRITICAL FIX: Update the online status in NodeStore BEFORE creating change notification + // This ensures that when the NodeCameOnline change is distributed and processed by other nodes, + // the NodeStore already reflects the correct online status for full map generation. + // now := time.Now() + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.IsOnline = ptr.To(true) + // n.LastSeen = ptr.To(now) + }) + if !ok { + return nil + } + + c := []change.Change{change.NodeOnlineFor(node)} + + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node connected") + + // Use the node's current routes for primary route update + // AllApprovedRoutes() returns only the intersection of announced AND approved routes + // We MUST use AllApprovedRoutes() to maintain the security model + routeChange := s.primaryRoutes.SetRoutes(id, node.AllApprovedRoutes()...) + + if routeChange { + c = append(c, change.NodeAdded(id)) + } + + return c +} + +// Disconnect marks a node as disconnected and updates its primary routes in the state. +func (s *State) Disconnect(id types.NodeID) ([]change.Change, error) { + now := time.Now() + + node, ok := s.nodeStore.UpdateNode(id, func(n *types.Node) { + n.LastSeen = ptr.To(now) + // NodeStore is the source of truth for all node state including online status. + n.IsOnline = ptr.To(false) + }) + + if !ok { + return nil, fmt.Errorf("node not found: %d", id) + } + + log.Info().Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Node disconnected") + + // Special error handling for disconnect - we log errors but continue + // because NodeStore is already updated and we need to notify peers + _, c, err := s.persistNodeToDB(node) + if err != nil { + // Log error but don't fail the disconnection - NodeStore is already updated + // and we need to send change notifications to peers + log.Error().Err(err).Uint64("node.id", id.Uint64()).Str("node.name", node.Hostname()).Msg("Failed to update last seen in database") + + c = change.Change{} + } + + // The node is disconnecting so make sure that none of the routes it + // announced are served to any nodes. + routeChange := s.primaryRoutes.SetRoutes(id) + + cs := []change.Change{change.NodeOfflineFor(node), c} + + // If we have a policy change or route change, return that as it's more comprehensive + // Otherwise, return the NodeOffline change to ensure nodes are notified + if c.IsFull() || routeChange { + cs = append(cs, change.PolicyChange()) + } + + return cs, nil +} + +// GetNodeByID retrieves a node by ID. +// GetNodeByID retrieves a node by its ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByID(nodeID types.NodeID) (types.NodeView, bool) { + return s.nodeStore.GetNode(nodeID) +} + +// GetNodeByNodeKey retrieves a node by its Tailscale public key. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByNodeKey(nodeKey key.NodePublic) (types.NodeView, bool) { + return s.nodeStore.GetNodeByNodeKey(nodeKey) +} + +// GetNodeByMachineKey retrieves a node by its machine key and user ID. +// The bool indicates if the node exists or is available (like "err not found"). +// The NodeView might be invalid, so it must be checked with .Valid(), which must be used to ensure +// it isn't an invalid node (this is more of a node error or node is broken). +func (s *State) GetNodeByMachineKey(machineKey key.MachinePublic, userID types.UserID) (types.NodeView, bool) { + return s.nodeStore.GetNodeByMachineKey(machineKey, userID) +} + +// ListNodes retrieves specific nodes by ID, or all nodes if no IDs provided. +func (s *State) ListNodes(nodeIDs ...types.NodeID) views.Slice[types.NodeView] { + if len(nodeIDs) == 0 { + return s.nodeStore.ListNodes() + } + + // Filter nodes by the requested IDs + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(nodeIDs)) + for _, id := range nodeIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes) +} + +// ListNodesByUser retrieves all nodes belonging to a specific user. +func (s *State) ListNodesByUser(userID types.UserID) views.Slice[types.NodeView] { + return s.nodeStore.ListNodesByUser(userID) +} + +// ListPeers retrieves nodes that can communicate with the specified node based on policy. +func (s *State) ListPeers(nodeID types.NodeID, peerIDs ...types.NodeID) views.Slice[types.NodeView] { + if len(peerIDs) == 0 { + return s.nodeStore.ListPeers(nodeID) + } + + // For specific peerIDs, filter from all nodes + allNodes := s.nodeStore.ListNodes() + nodeIDSet := make(map[types.NodeID]struct{}, len(peerIDs)) + for _, id := range peerIDs { + nodeIDSet[id] = struct{}{} + } + + var filteredNodes []types.NodeView + for _, node := range allNodes.All() { + if _, exists := nodeIDSet[node.ID()]; exists { + filteredNodes = append(filteredNodes, node) + } + } + + return views.SliceOf(filteredNodes) +} + +// ListEphemeralNodes retrieves all ephemeral (temporary) nodes in the system. +func (s *State) ListEphemeralNodes() views.Slice[types.NodeView] { + allNodes := s.nodeStore.ListNodes() + var ephemeralNodes []types.NodeView + + for _, node := range allNodes.All() { + // Check if node is ephemeral by checking its AuthKey + if node.AuthKey().Valid() && node.AuthKey().Ephemeral() { + ephemeralNodes = append(ephemeralNodes, node) + } + } + + return views.SliceOf(ephemeralNodes) +} + +// SetNodeExpiry updates the expiration time for a node. +func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.NodeView, change.Change, error) { + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. If the database update fails, the NodeStore change will + // remain, but since we return an error, no change notification will be sent to the + // batcher, preventing inconsistent state propagation. + expiryPtr := expiry + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.Expiry = &expiryPtr + }) + + if !ok { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) + } + + return s.persistNodeToDB(n) +} + +// SetNodeTags assigns tags to a node, making it a "tagged node". +// Once a node is tagged, it cannot be un-tagged (only tags can be changed). +// The UserID is preserved as "created by" information. +func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.Change, error) { + // CANNOT REMOVE ALL TAGS + if len(tags) == 0 { + return types.NodeView{}, change.Change{}, types.ErrCannotRemoveAllTags + } + + // Get node for validation + existingNode, exists := s.nodeStore.GetNode(nodeID) + if !exists { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID) + } + + // Validate tags: must have correct format and exist in policy + validatedTags := make([]string, 0, len(tags)) + invalidTags := make([]string, 0) + + for _, tag := range tags { + if !strings.HasPrefix(tag, "tag:") || !s.polMan.TagExists(tag) { + invalidTags = append(invalidTags, tag) + + continue + } + + validatedTags = append(validatedTags, tag) + } + + if len(invalidTags) > 0 { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, invalidTags) + } + + slices.Sort(validatedTags) + validatedTags = slices.Compact(validatedTags) + + // Log the operation + logTagOperation(existingNode, validatedTags) + + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.Tags = validatedTags + // UserID is preserved as "created by" - do NOT set to nil + }) + + if !ok { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) + } + + nodeView, c, err := s.persistNodeToDB(n) + if err != nil { + return nodeView, c, err + } + + // Set OriginNode so the mapper knows to include self info for this node. + // When tags change, persistNodeToDB returns PolicyChange which doesn't set OriginNode, + // so the mapper's self-update check fails and the node never sees its new tags. + // Setting OriginNode ensures the node gets a self-update with the new tags. + c.OriginNode = nodeID + + return nodeView, c, nil +} + +// SetApprovedRoutes sets the network routes that a node is approved to advertise. +func (s *State) SetApprovedRoutes(nodeID types.NodeID, routes []netip.Prefix) (types.NodeView, change.Change, error) { + // TODO(kradalby): In principle we should call the AutoApprove logic here + // because even if the CLI removes an auto-approved route, it will be added + // back automatically. + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.ApprovedRoutes = routes + }) + + if !ok { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) + } + + // Persist the node changes to the database + nodeView, c, err := s.persistNodeToDB(n) + if err != nil { + return types.NodeView{}, change.Change{}, err + } + + // Update primary routes table based on SubnetRoutes (intersection of announced and approved). + // The primary routes table is what the mapper uses to generate network maps, so updating it + // here ensures that route changes are distributed to peers. + routeChange := s.primaryRoutes.SetRoutes(nodeID, nodeView.AllApprovedRoutes()...) + + // If routes changed or the changeset isn't already a full update, trigger a policy change + // to ensure all nodes get updated network maps + if routeChange || !c.IsFull() { + c = change.PolicyChange() + } + + return nodeView, c, nil +} + +// RenameNode changes the display name of a node. +func (s *State) RenameNode(nodeID types.NodeID, newName string) (types.NodeView, change.Change, error) { + if err := util.ValidateHostname(newName); err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("renaming node: %w", err) + } + + // Check name uniqueness against NodeStore + allNodes := s.nodeStore.ListNodes() + for i := 0; i < allNodes.Len(); i++ { + node := allNodes.At(i) + if node.ID() != nodeID && node.AsStruct().GivenName == newName { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %s", ErrNodeNameNotUnique, newName) + } + } + + // Update NodeStore before database to ensure consistency. The NodeStore update is + // blocking and will be the source of truth for the batcher. The database update must + // make the exact same change. + n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) { + node.GivenName = newName + }) + + if !ok { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, nodeID) + } + + return s.persistNodeToDB(n) +} + +// BackfillNodeIPs assigns IP addresses to nodes that don't have them. +func (s *State) BackfillNodeIPs() ([]string, error) { + changes, err := s.db.BackfillNodeIPs(s.ipAlloc) + if err != nil { + return nil, err + } + + // Refresh NodeStore after IP changes to ensure consistency + if len(changes) > 0 { + nodes, err := s.db.ListNodes() + if err != nil { + return changes, fmt.Errorf("failed to refresh NodeStore after IP backfill: %w", err) + } + + for _, node := range nodes { + // Preserve online status and NetInfo when refreshing from database + existingNode, exists := s.nodeStore.GetNode(node.ID) + if exists && existingNode.Valid() { + node.IsOnline = ptr.To(existingNode.IsOnline().Get()) + + // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics + // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). + + // Preserve NetInfo from existing node to prevent loss during backfill + netInfo := netInfoFromMapRequest(node.ID, existingNode.Hostinfo().AsStruct(), node.Hostinfo) + node.Hostinfo = existingNode.Hostinfo().AsStruct() + node.Hostinfo.NetInfo = netInfo + } + // TODO(kradalby): This should just update the IP addresses, nothing else in the node store. + // We should avoid PutNode here. + _ = s.nodeStore.PutNode(*node) + } + } + + return changes, nil +} + +// ExpireExpiredNodes finds and processes expired nodes since the last check. +// Returns next check time, state update with expired nodes, and whether any were found. +func (s *State) ExpireExpiredNodes(lastCheck time.Time) (time.Time, []change.Change, bool) { + // Why capture start time: We need to ensure we don't miss nodes that expire + // while this function is running by using a consistent timestamp for the next check + started := time.Now() + + var updates []change.Change + + for _, node := range s.nodeStore.ListNodes().All() { + if !node.Valid() { + continue + } + + // Why check After(lastCheck): We only want to notify about nodes that + // expired since the last check to avoid duplicate notifications + if node.IsExpired() && node.Expiry().Valid() && node.Expiry().Get().After(lastCheck) { + updates = append(updates, change.KeyExpiryFor(node.ID(), node.Expiry().Get())) + } + } + + if len(updates) > 0 { + return started, updates, true + } + + return started, nil, false +} + +// SSHPolicy returns the SSH access policy for a node. +func (s *State) SSHPolicy(node types.NodeView) (*tailcfg.SSHPolicy, error) { + return s.polMan.SSHPolicy(node) +} + +// Filter returns the current network filter rules and matches. +func (s *State) Filter() ([]tailcfg.FilterRule, []matcher.Match) { + return s.polMan.Filter() +} + +// FilterForNode returns filter rules for a specific node, handling autogroup:self per-node. +func (s *State) FilterForNode(node types.NodeView) ([]tailcfg.FilterRule, error) { + return s.polMan.FilterForNode(node) +} + +// MatchersForNode returns matchers for peer relationship determination (unreduced). +func (s *State) MatchersForNode(node types.NodeView) ([]matcher.Match, error) { + return s.polMan.MatchersForNode(node) +} + +// NodeCanHaveTag checks if a node is allowed to have a specific tag. +func (s *State) NodeCanHaveTag(node types.NodeView, tag string) bool { + return s.polMan.NodeCanHaveTag(node, tag) +} + +// SetPolicy updates the policy configuration. +func (s *State) SetPolicy(pol []byte) (bool, error) { + return s.polMan.SetPolicy(pol) +} + +// AutoApproveRoutes checks if a node's routes should be auto-approved. +// AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them. +func (s *State) AutoApproveRoutes(nv types.NodeView) (change.Change, error) { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) + if changed { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.announced", util.PrefixesToString(nv.AnnouncedRoutes())). + Strs("routes.approved.old", util.PrefixesToString(nv.ApprovedRoutes().AsSlice())). + Strs("routes.approved.new", util.PrefixesToString(approved)). + Msg("Single node auto-approval detected route changes") + + // Persist the auto-approved routes to database and NodeStore via SetApprovedRoutes + // This ensures consistency between database and NodeStore + _, c, err := s.SetApprovedRoutes(nv.ID(), approved) + if err != nil { + log.Error(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Err(err). + Msg("Failed to persist auto-approved routes") + + return change.Change{}, err + } + + log.Info().Uint64("node.id", nv.ID().Uint64()).Str("node.name", nv.Hostname()).Strs("routes.approved", util.PrefixesToString(approved)).Msg("Routes approved") + + return c, nil + } + + return change.Change{}, nil +} + +// GetPolicy retrieves the current policy from the database. +func (s *State) GetPolicy() (*types.Policy, error) { + return s.db.GetPolicy() +} + +// SetPolicyInDB stores policy data in the database. +func (s *State) SetPolicyInDB(data string) (*types.Policy, error) { + return s.db.SetPolicy(data) +} + +// SetNodeRoutes sets the primary routes for a node. +func (s *State) SetNodeRoutes(nodeID types.NodeID, routes ...netip.Prefix) change.Change { + if s.primaryRoutes.SetRoutes(nodeID, routes...) { + // Route changes affect packet filters for all nodes, so trigger a policy change + // to ensure filters are regenerated across the entire network + return change.PolicyChange() + } + + return change.Change{} +} + +// GetNodePrimaryRoutes returns the primary routes for a node. +func (s *State) GetNodePrimaryRoutes(nodeID types.NodeID) []netip.Prefix { + return s.primaryRoutes.PrimaryRoutes(nodeID) +} + +// PrimaryRoutesString returns a string representation of all primary routes. +func (s *State) PrimaryRoutesString() string { + return s.primaryRoutes.String() +} + +// ValidateAPIKey checks if an API key is valid and active. +func (s *State) ValidateAPIKey(keyStr string) (bool, error) { + return s.db.ValidateAPIKey(keyStr) +} + +// CreateAPIKey generates a new API key with optional expiration. +func (s *State) CreateAPIKey(expiration *time.Time) (string, *types.APIKey, error) { + return s.db.CreateAPIKey(expiration) +} + +// GetAPIKey retrieves an API key by its prefix. +// Accepts both display format (hskey-api-{12chars}-***) and database format ({12chars}). +func (s *State) GetAPIKey(displayPrefix string) (*types.APIKey, error) { + // Parse the display prefix to extract the database prefix + prefix, err := hsdb.ParseAPIKeyPrefix(displayPrefix) + if err != nil { + return nil, err + } + + return s.db.GetAPIKey(prefix) +} + +// GetAPIKeyByID retrieves an API key by its database ID. +func (s *State) GetAPIKeyByID(id uint64) (*types.APIKey, error) { + return s.db.GetAPIKeyByID(id) +} + +// ExpireAPIKey marks an API key as expired. +func (s *State) ExpireAPIKey(key *types.APIKey) error { + return s.db.ExpireAPIKey(key) +} + +// ListAPIKeys returns all API keys in the system. +func (s *State) ListAPIKeys() ([]types.APIKey, error) { + return s.db.ListAPIKeys() +} + +// DestroyAPIKey permanently removes an API key. +func (s *State) DestroyAPIKey(key types.APIKey) error { + return s.db.DestroyAPIKey(key) +} + +// CreatePreAuthKey generates a new pre-authentication key for a user. +// The userID parameter is now optional (can be nil) for system-created tagged keys. +func (s *State) CreatePreAuthKey(userID *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) { + return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags) +} + +// Test helpers for the state layer + +// CreateUserForTest creates a test user. This is a convenience wrapper around the database layer. +func (s *State) CreateUserForTest(name ...string) *types.User { + return s.db.CreateUserForTest(name...) +} + +// CreateNodeForTest creates a test node. This is a convenience wrapper around the database layer. +func (s *State) CreateNodeForTest(user *types.User, hostname ...string) *types.Node { + return s.db.CreateNodeForTest(user, hostname...) +} + +// CreateRegisteredNodeForTest creates a test node with allocated IPs. This is a convenience wrapper around the database layer. +func (s *State) CreateRegisteredNodeForTest(user *types.User, hostname ...string) *types.Node { + return s.db.CreateRegisteredNodeForTest(user, hostname...) +} + +// CreateNodesForTest creates multiple test nodes. This is a convenience wrapper around the database layer. +func (s *State) CreateNodesForTest(user *types.User, count int, namePrefix ...string) []*types.Node { + return s.db.CreateNodesForTest(user, count, namePrefix...) +} + +// CreateUsersForTest creates multiple test users. This is a convenience wrapper around the database layer. +func (s *State) CreateUsersForTest(count int, namePrefix ...string) []*types.User { + return s.db.CreateUsersForTest(count, namePrefix...) +} + +// DB returns the underlying database for testing purposes. +func (s *State) DB() *hsdb.HSDatabase { + return s.db +} + +// GetPreAuthKey retrieves a pre-authentication key by ID. +func (s *State) GetPreAuthKey(id string) (*types.PreAuthKey, error) { + return s.db.GetPreAuthKey(id) +} + +// ListPreAuthKeys returns all pre-authentication keys for a user. +func (s *State) ListPreAuthKeys() ([]types.PreAuthKey, error) { + return s.db.ListPreAuthKeys() +} + +// ExpirePreAuthKey marks a pre-authentication key as expired. +func (s *State) ExpirePreAuthKey(id uint64) error { + return s.db.ExpirePreAuthKey(id) +} + +// DeletePreAuthKey permanently deletes a pre-authentication key. +func (s *State) DeletePreAuthKey(id uint64) error { + return s.db.DeletePreAuthKey(id) +} + +// GetRegistrationCacheEntry retrieves a node registration from cache. +func (s *State) GetRegistrationCacheEntry(id types.RegistrationID) (*types.RegisterNode, bool) { + entry, found := s.registrationCache.Get(id) + if !found { + return nil, false + } + + return &entry, true +} + +// SetRegistrationCacheEntry stores a node registration in cache. +func (s *State) SetRegistrationCacheEntry(id types.RegistrationID, entry types.RegisterNode) { + s.registrationCache.Set(id, entry) +} + +// logHostinfoValidation logs warnings when hostinfo is nil or has empty hostname. +func logHostinfoValidation(machineKey, nodeKey, username, hostname string, hostinfo *tailcfg.Hostinfo) { + if hostinfo == nil { + log.Warn(). + Caller(). + Str("machine.key", machineKey). + Str("node.key", nodeKey). + Str("user.name", username). + Str("generated.hostname", hostname). + Msg("Registration had nil hostinfo, generated default hostname") + } else if hostinfo.Hostname == "" { + log.Warn(). + Caller(). + Str("machine.key", machineKey). + Str("node.key", nodeKey). + Str("user.name", username). + Str("generated.hostname", hostname). + Msg("Registration had empty hostname, generated default") + } +} + +// preserveNetInfo preserves NetInfo from an existing node for faster DERP connectivity. +// If no existing node is provided, it creates new netinfo from the provided hostinfo. +func preserveNetInfo(existingNode types.NodeView, nodeID types.NodeID, validHostinfo *tailcfg.Hostinfo) *tailcfg.NetInfo { + var existingHostinfo *tailcfg.Hostinfo + if existingNode.Valid() { + existingHostinfo = existingNode.Hostinfo().AsStruct() + } + return netInfoFromMapRequest(nodeID, existingHostinfo, validHostinfo) +} + +// newNodeParams contains parameters for creating a new node. +type newNodeParams struct { + User types.User + MachineKey key.MachinePublic + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + Hostname string + Hostinfo *tailcfg.Hostinfo + Endpoints []netip.AddrPort + Expiry *time.Time + RegisterMethod string + + // Optional: Pre-auth key specific fields + PreAuthKey *types.PreAuthKey + + // Optional: Existing node for netinfo preservation + ExistingNodeForNetinfo types.NodeView +} + +// createAndSaveNewNode creates a new node, allocates IPs, saves to DB, and adds to NodeStore. +// It preserves netinfo from an existing node if one is provided (for faster DERP connectivity). +func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, error) { + // Preserve NetInfo from existing node if available + if params.Hostinfo != nil { + params.Hostinfo.NetInfo = preserveNetInfo( + params.ExistingNodeForNetinfo, + types.NodeID(0), + params.Hostinfo, + ) + } + + // Prepare the node for registration + nodeToRegister := types.Node{ + Hostname: params.Hostname, + MachineKey: params.MachineKey, + NodeKey: params.NodeKey, + DiscoKey: params.DiscoKey, + Hostinfo: params.Hostinfo, + Endpoints: params.Endpoints, + LastSeen: ptr.To(time.Now()), + RegisterMethod: params.RegisterMethod, + Expiry: params.Expiry, + } + + // Assign ownership based on PreAuthKey + if params.PreAuthKey != nil { + if params.PreAuthKey.IsTagged() { + // TAGGED NODE + // Tags from PreAuthKey are assigned ONLY during initial authentication + nodeToRegister.Tags = params.PreAuthKey.Proto().GetAclTags() + + // Set UserID to track "created by" (who created the PreAuthKey) + if params.PreAuthKey.UserID != nil { + nodeToRegister.UserID = params.PreAuthKey.UserID + nodeToRegister.User = params.PreAuthKey.User + } + // If PreAuthKey.UserID is nil, the node is "orphaned" (system-created) + + // Tagged nodes have key expiry disabled. + nodeToRegister.Expiry = nil + } else { + // USER-OWNED NODE + nodeToRegister.UserID = ¶ms.PreAuthKey.User.ID + nodeToRegister.User = params.PreAuthKey.User + nodeToRegister.Tags = nil + } + nodeToRegister.AuthKey = params.PreAuthKey + nodeToRegister.AuthKeyID = ¶ms.PreAuthKey.ID + } else { + // Non-PreAuthKey registration (OIDC, CLI) - always user-owned + nodeToRegister.UserID = ¶ms.User.ID + nodeToRegister.User = ¶ms.User + nodeToRegister.Tags = nil + } + + // Reject advertise-tags for PreAuthKey registrations early, before any resource allocation. + // PreAuthKey nodes get their tags from the key itself, not from client requests. + if params.PreAuthKey != nil && params.Hostinfo != nil && len(params.Hostinfo.RequestTags) > 0 { + return types.NodeView{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, params.Hostinfo.RequestTags) + } + + // Process RequestTags (from tailscale up --advertise-tags) ONLY for non-PreAuthKey registrations. + // Validate early before IP allocation to avoid resource leaks on failure. + if params.PreAuthKey == nil && params.Hostinfo != nil && len(params.Hostinfo.RequestTags) > 0 { + var approvedTags, rejectedTags []string + + for _, tag := range params.Hostinfo.RequestTags { + if s.polMan.NodeCanHaveTag(nodeToRegister.View(), tag) { + approvedTags = append(approvedTags, tag) + } else { + rejectedTags = append(rejectedTags, tag) + } + } + + // Reject registration if any requested tags are unauthorized + if len(rejectedTags) > 0 { + return types.NodeView{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, rejectedTags) + } + + if len(approvedTags) > 0 { + nodeToRegister.Tags = approvedTags + slices.Sort(nodeToRegister.Tags) + nodeToRegister.Tags = slices.Compact(nodeToRegister.Tags) + + // Tagged nodes have key expiry disabled. + nodeToRegister.Expiry = nil + + log.Info(). + Str("node.name", nodeToRegister.Hostname). + Strs("tags", nodeToRegister.Tags). + Msg("approved advertise-tags during registration") + } + } + + // Validate before saving + err := validateNodeOwnership(&nodeToRegister) + if err != nil { + return types.NodeView{}, err + } + + // Allocate new IPs + ipv4, ipv6, err := s.ipAlloc.Next() + if err != nil { + return types.NodeView{}, fmt.Errorf("allocating IPs: %w", err) + } + + nodeToRegister.IPv4 = ipv4 + nodeToRegister.IPv6 = ipv6 + + // Ensure unique given name if not set + if nodeToRegister.GivenName == "" { + givenName, err := hsdb.EnsureUniqueGivenName(s.db.DB, nodeToRegister.Hostname) + if err != nil { + return types.NodeView{}, fmt.Errorf("failed to ensure unique given name: %w", err) + } + nodeToRegister.GivenName = givenName + } + + // New node - database first to get ID, then NodeStore + savedNode, err := hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + if err := tx.Save(&nodeToRegister).Error; err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if params.PreAuthKey != nil && !params.PreAuthKey.Reusable { + err := hsdb.UsePreAuthKey(tx, params.PreAuthKey) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return &nodeToRegister, nil + }) + if err != nil { + return types.NodeView{}, err + } + + // Add to NodeStore after database creates the ID + return s.nodeStore.PutNode(*savedNode), nil +} + +// processReauthTags handles tag changes during node re-authentication. +// It processes RequestTags from the client and updates node tags accordingly. +// Returns rejected tags (if any) for post-validation error handling. +func (s *State) processReauthTags( + node *types.Node, + requestTags []string, + user *types.User, + oldTags []string, +) []string { + wasAuthKeyTagged := node.AuthKey != nil && node.AuthKey.IsTagged() + + logEvent := log.Debug(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("request.tags", requestTags). + Strs("current.tags", node.Tags). + Bool("is.tagged", node.IsTagged()). + Bool("was.authkey.tagged", wasAuthKeyTagged) + logEvent.Msg("Processing RequestTags during reauth") + + // Empty RequestTags means untag node (transition to user-owned) + if len(requestTags) == 0 { + if node.IsTagged() { + log.Info(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("removed.tags", node.Tags). + Str("user.name", user.Name). + Bool("was.authkey.tagged", wasAuthKeyTagged). + Msg("Reauth: removing all tags, returning node ownership to user") + + node.Tags = []string{} + node.UserID = &user.ID + } + + return nil + } + + // Non-empty RequestTags: validate and apply + var approvedTags, rejectedTags []string + + for _, tag := range requestTags { + if s.polMan.NodeCanHaveTag(node.View(), tag) { + approvedTags = append(approvedTags, tag) + } else { + rejectedTags = append(rejectedTags, tag) + } + } + + if len(rejectedTags) > 0 { + log.Warn(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("rejected.tags", rejectedTags). + Msg("Reauth: requested tags are not permitted") + + return rejectedTags + } + + if len(approvedTags) > 0 { + slices.Sort(approvedTags) + approvedTags = slices.Compact(approvedTags) + + wasTagged := node.IsTagged() + node.Tags = approvedTags + + // Note: UserID is preserved as "created by" tracking, consistent with SetNodeTags + if !wasTagged { + log.Info(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("new.tags", approvedTags). + Str("old.user", user.Name). + Msg("Reauth: applying tags, transferring node to tagged-devices") + } else { + log.Info(). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). + Strs("old.tags", oldTags). + Strs("new.tags", approvedTags). + Msg("Reauth: updating tags on already-tagged node") + } + } + + return nil +} + +// HandleNodeFromAuthPath handles node registration through authentication flow (like OIDC). +func (s *State) HandleNodeFromAuthPath( + registrationID types.RegistrationID, + userID types.UserID, + expiry *time.Time, + registrationMethod string, +) (types.NodeView, change.Change, error) { + // Get the registration entry from cache + regEntry, ok := s.GetRegistrationCacheEntry(registrationID) + if !ok { + return types.NodeView{}, change.Change{}, hsdb.ErrNodeNotFoundRegistrationCache + } + + // Get the user + user, err := s.db.GetUserByID(userID) + if err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("failed to find user: %w", err) + } + + // Ensure we have a valid hostname from the registration cache entry + hostname := util.EnsureHostname( + regEntry.Node.Hostinfo, + regEntry.Node.MachineKey.String(), + regEntry.Node.NodeKey.String(), + ) + + // Ensure we have valid hostinfo + validHostinfo := cmp.Or(regEntry.Node.Hostinfo, &tailcfg.Hostinfo{}) + validHostinfo.Hostname = hostname + + logHostinfoValidation( + regEntry.Node.MachineKey.ShortString(), + regEntry.Node.NodeKey.String(), + user.Name, + hostname, + regEntry.Node.Hostinfo, + ) + + var finalNode types.NodeView + + // Check if node already exists with same machine key for this user + existingNodeSameUser, existsSameUser := s.nodeStore.GetNodeByMachineKey(regEntry.Node.MachineKey, types.UserID(user.ID)) + + // If this node exists for this user, update the node in place. + if existsSameUser && existingNodeSameUser.Valid() { + log.Info(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Name). + Str("registrationMethod", registrationMethod). + Str("node.name", existingNodeSameUser.Hostname()). + Uint64("node.id", existingNodeSameUser.ID().Uint64()). + Interface("hostinfo", regEntry.Node.Hostinfo). + Msg("Updating existing node registration via reauth") + + // Process RequestTags during reauth (#2979) + // Due to json:",omitempty", we treat empty/nil as "clear tags" + var requestTags []string + if regEntry.Node.Hostinfo != nil { + requestTags = regEntry.Node.Hostinfo.RequestTags + } + + oldTags := existingNodeSameUser.Tags().AsSlice() + + var rejectedTags []string + + // Update existing node - NodeStore first, then database + updatedNodeView, ok := s.nodeStore.UpdateNode(existingNodeSameUser.ID(), func(node *types.Node) { + node.NodeKey = regEntry.Node.NodeKey + node.DiscoKey = regEntry.Node.DiscoKey + node.Hostname = hostname + + // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics + // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). + + // Preserve NetInfo from existing node when re-registering + node.Hostinfo = validHostinfo + node.Hostinfo.NetInfo = preserveNetInfo(existingNodeSameUser, existingNodeSameUser.ID(), validHostinfo) + + node.Endpoints = regEntry.Node.Endpoints + node.RegisterMethod = regEntry.Node.RegisterMethod + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + + // Tagged nodes keep their existing expiry (disabled). + // User-owned nodes update expiry from the provided value or registration entry. + if !node.IsTagged() { + if expiry != nil { + node.Expiry = expiry + } else { + node.Expiry = regEntry.Node.Expiry + } + } + + rejectedTags = s.processReauthTags(node, requestTags, user, oldTags) + }) + + if !ok { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID()) + } + + if len(rejectedTags) > 0 { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w %v are invalid or not permitted", ErrRequestedTagsInvalidOrNotPermitted, rejectedTags) + } + + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + // Use Updates() to preserve fields not modified by UpdateNode. + err := tx.Updates(updatedNodeView.AsStruct()).Error + if err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + return nil, nil + }) + if err != nil { + return types.NodeView{}, change.Change{}, err + } + + log.Trace(). + Caller(). + Str("node.name", updatedNodeView.Hostname()). + Uint64("node.id", updatedNodeView.ID().Uint64()). + Str("machine.key", regEntry.Node.MachineKey.ShortString()). + Str("node.key", updatedNodeView.NodeKey().ShortString()). + Str("user.name", user.Name). + Msg("Node re-authorized") + + finalNode = updatedNodeView + } else { + // Node does not exist for this user with this machine key + // Check if node exists with this machine key for a different user (for netinfo preservation) + existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey) + + if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != user.ID { + // Node exists but belongs to a different user + // Create a NEW node for the new user (do not transfer) + // This allows the same machine to have separate node identities per user + oldUser := existingNodeAnyUser.User() + log.Info(). + Caller(). + Str("existing.node.name", existingNodeAnyUser.Hostname()). + Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). + Str("machine.key", regEntry.Node.MachineKey.ShortString()). + Str("old.user", oldUser.Name()). + Str("new.user", user.Name). + Str("method", registrationMethod). + Msg("Creating new node for different user (same machine key exists for another user)") + } + + // Create a completely new node + log.Debug(). + Caller(). + Str("registration_id", registrationID.String()). + Str("user.name", user.Name). + Str("registrationMethod", registrationMethod). + Str("expiresAt", fmt.Sprintf("%v", expiry)). + Msg("Registering new node from auth callback") + + // Create and save new node + var err error + finalNode, err = s.createAndSaveNewNode(newNodeParams{ + User: *user, + MachineKey: regEntry.Node.MachineKey, + NodeKey: regEntry.Node.NodeKey, + DiscoKey: regEntry.Node.DiscoKey, + Hostname: hostname, + Hostinfo: validHostinfo, + Endpoints: regEntry.Node.Endpoints, + Expiry: cmp.Or(expiry, regEntry.Node.Expiry), + RegisterMethod: registrationMethod, + ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), + }) + if err != nil { + return types.NodeView{}, change.Change{}, err + } + } + + // Signal to waiting clients + regEntry.SendAndClose(finalNode.AsStruct()) + + // Delete from registration cache + s.registrationCache.Delete(registrationID) + + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() + if err != nil { + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager users: %w", err) + } + + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) + } + + var c change.Change + if !usersChange.IsEmpty() || !nodesChange.IsEmpty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(finalNode.ID()) + } + + return finalNode, c, nil +} + +// HandleNodeFromPreAuthKey handles node registration using a pre-authentication key. +func (s *State) HandleNodeFromPreAuthKey( + regReq tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (types.NodeView, change.Change, error) { + pak, err := s.GetPreAuthKey(regReq.Auth.AuthKey) + if err != nil { + return types.NodeView{}, change.Change{}, err + } + + // Helper to get username for logging (handles nil User for tags-only keys) + pakUsername := func() string { + if pak.User != nil { + return pak.User.Username() + } + + return types.TaggedDevices.Name + } + + // Check if node exists with same machine key before validating the key. + // For #2830: container restarts send the same pre-auth key which may be used/expired. + // Skip validation for existing nodes re-registering with the same NodeKey, as the + // key was only needed for initial authentication. NodeKey rotation requires validation. + // + // For tags-only keys (pak.User == nil), we skip the user-based lookup since there's + // no user to match against. These keys create tagged nodes without user ownership. + var existingNodeSameUser types.NodeView + + var existsSameUser bool + + if pak.User != nil { + existingNodeSameUser, existsSameUser = s.nodeStore.GetNodeByMachineKey(machineKey, types.UserID(pak.User.ID)) + } + + // For existing nodes, skip validation if: + // 1. MachineKey matches (cryptographic proof of machine identity) + // 2. User matches (from the PAK being used) + // 3. Not a NodeKey rotation (rotation requires fresh validation) + // + // Security: MachineKey is the cryptographic identity. If someone has the MachineKey, + // they control the machine. The PAK was only needed to authorize initial join. + // We don't check which specific PAK was used originally because: + // - Container restarts may use different PAKs (e.g., env var changed) + // - Original PAK may be deleted + // - MachineKey + User is sufficient to prove this is the same node + // + // Note: For tags-only keys, existsSameUser is always false, so we always validate. + isExistingNodeReregistering := existsSameUser && existingNodeSameUser.Valid() + + // Check if this is a NodeKey rotation (different NodeKey) + isNodeKeyRotation := existsSameUser && existingNodeSameUser.Valid() && + existingNodeSameUser.NodeKey() != regReq.NodeKey + + if isExistingNodeReregistering && !isNodeKeyRotation { + // Existing node re-registering with same NodeKey: skip validation. + // Pre-auth keys are only needed for initial authentication. Critical for + // containers that run "tailscale up --authkey=KEY" on every restart. + log.Debug(). + Caller(). + Uint64("node.id", existingNodeSameUser.ID().Uint64()). + Str("node.name", existingNodeSameUser.Hostname()). + Str("machine.key", machineKey.ShortString()). + Str("node.key.existing", existingNodeSameUser.NodeKey().ShortString()). + Str("node.key.request", regReq.NodeKey.ShortString()). + Uint64("authkey.id", pak.ID). + Bool("authkey.used", pak.Used). + Bool("authkey.expired", pak.Expiration != nil && pak.Expiration.Before(time.Now())). + Bool("authkey.reusable", pak.Reusable). + Bool("nodekey.rotation", isNodeKeyRotation). + Msg("Existing node re-registering with same NodeKey and auth key, skipping validation") + } else { + // New node or NodeKey rotation: require valid auth key. + err = pak.Validate() + if err != nil { + return types.NodeView{}, change.Change{}, err + } + } + + // Ensure we have a valid hostname - handle nil/empty cases + hostname := util.EnsureHostname( + regReq.Hostinfo, + machineKey.String(), + regReq.NodeKey.String(), + ) + + // Ensure we have valid hostinfo + validHostinfo := cmp.Or(regReq.Hostinfo, &tailcfg.Hostinfo{}) + validHostinfo.Hostname = hostname + + logHostinfoValidation( + machineKey.ShortString(), + regReq.NodeKey.ShortString(), + pakUsername(), + hostname, + regReq.Hostinfo, + ) + + log.Debug(). + Caller(). + Str("node.name", hostname). + Str("machine.key", machineKey.ShortString()). + Str("node.key", regReq.NodeKey.ShortString()). + Str("user.name", pakUsername()). + Msg("Registering node with pre-auth key") + + var finalNode types.NodeView + + // If this node exists for this user, update the node in place. + // Note: For tags-only keys (pak.User == nil), existsSameUser is always false. + if existsSameUser && existingNodeSameUser.Valid() { + log.Trace(). + Caller(). + Str("node.name", existingNodeSameUser.Hostname()). + Uint64("node.id", existingNodeSameUser.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", existingNodeSameUser.NodeKey().ShortString()). + Str("user.name", pakUsername()). + Msg("Node re-registering with existing machine key and user, updating in place") + + // Update existing node - NodeStore first, then database + updatedNodeView, ok := s.nodeStore.UpdateNode(existingNodeSameUser.ID(), func(node *types.Node) { + node.NodeKey = regReq.NodeKey + node.Hostname = hostname + + // TODO(kradalby): We should ensure we use the same hostinfo and node merge semantics + // when a node re-registers as we do when it sends a map request (UpdateNodeFromMapRequest). + + // Preserve NetInfo from existing node when re-registering + node.Hostinfo = validHostinfo + node.Hostinfo.NetInfo = preserveNetInfo(existingNodeSameUser, existingNodeSameUser.ID(), validHostinfo) + + node.RegisterMethod = util.RegisterMethodAuthKey + + // CRITICAL: Tags from PreAuthKey are ONLY applied during initial authentication + // On re-registration, we MUST NOT change tags or node ownership + // The node keeps whatever tags/user ownership it already has + // + // Only update AuthKey reference + node.AuthKey = pak + node.AuthKeyID = &pak.ID + node.IsOnline = ptr.To(false) + node.LastSeen = ptr.To(time.Now()) + + // Tagged nodes keep their existing expiry (disabled). + // User-owned nodes update expiry from the client request. + if !node.IsTagged() { + node.Expiry = ®Req.Expiry + } + }) + + if !ok { + return types.NodeView{}, change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, existingNodeSameUser.ID()) + } + + _, err = hsdb.Write(s.db.DB, func(tx *gorm.DB) (*types.Node, error) { + // Use Updates() to preserve fields not modified by UpdateNode. + err := tx.Updates(updatedNodeView.AsStruct()).Error + if err != nil { + return nil, fmt.Errorf("failed to save node: %w", err) + } + + if !pak.Reusable { + err = hsdb.UsePreAuthKey(tx, pak) + if err != nil { + return nil, fmt.Errorf("using pre auth key: %w", err) + } + } + + return nil, nil + }) + if err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("writing node to database: %w", err) + } + + log.Trace(). + Caller(). + Str("node.name", updatedNodeView.Hostname()). + Uint64("node.id", updatedNodeView.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("node.key", updatedNodeView.NodeKey().ShortString()). + Str("user.name", pakUsername()). + Msg("Node re-authorized") + + finalNode = updatedNodeView + } else { + // Node does not exist for this user with this machine key + // Check if node exists with this machine key for a different user + existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey) + + // For user-owned keys, check if node exists for a different user + // For tags-only keys (pak.User == nil), this check is skipped + if pak.User != nil && existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != pak.User.ID { + // Node exists but belongs to a different user + // Create a NEW node for the new user (do not transfer) + // This allows the same machine to have separate node identities per user + oldUser := existingNodeAnyUser.User() + log.Info(). + Caller(). + Str("existing.node.name", existingNodeAnyUser.Hostname()). + Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()). + Str("machine.key", machineKey.ShortString()). + Str("old.user", oldUser.Name()). + Str("new.user", pakUsername()). + Msg("Creating new node for different user (same machine key exists for another user)") + } + + // This is a new node - create it + // For user-owned keys: create for the user + // For tags-only keys: create as tagged node (createAndSaveNewNode handles this via PreAuthKey) + + // Create and save new node + // Note: For tags-only keys, User is empty but createAndSaveNewNode uses PreAuthKey for ownership + var pakUser types.User + if pak.User != nil { + pakUser = *pak.User + } + + var err error + finalNode, err = s.createAndSaveNewNode(newNodeParams{ + User: pakUser, + MachineKey: machineKey, + NodeKey: regReq.NodeKey, + DiscoKey: key.DiscoPublic{}, // DiscoKey not available in RegisterRequest + Hostname: hostname, + Hostinfo: validHostinfo, + Endpoints: nil, // Endpoints not available in RegisterRequest + Expiry: ®Req.Expiry, + RegisterMethod: util.RegisterMethodAuthKey, + PreAuthKey: pak, + ExistingNodeForNetinfo: cmp.Or(existingNodeAnyUser, types.NodeView{}), + }) + if err != nil { + return types.NodeView{}, change.Change{}, fmt.Errorf("creating new node: %w", err) + } + } + + // Update policy managers + usersChange, err := s.updatePolicyManagerUsers() + if err != nil { + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager users: %w", err) + } + + nodesChange, err := s.updatePolicyManagerNodes() + if err != nil { + return finalNode, change.NodeAdded(finalNode.ID()), fmt.Errorf("failed to update policy manager nodes: %w", err) + } + + var c change.Change + if !usersChange.IsEmpty() || !nodesChange.IsEmpty() { + c = change.PolicyChange() + } else { + c = change.NodeAdded(finalNode.ID()) + } + + return finalNode, c, nil +} + +// updatePolicyManagerUsers updates the policy manager with current users. +// Returns true if the policy changed and notifications should be sent. +// TODO(kradalby): This is a temporary stepping stone, ultimately we should +// have the list already available so it could go much quicker. Alternatively +// the policy manager could have a remove or add list for users. +// updatePolicyManagerUsers refreshes the policy manager with current user data. +func (s *State) updatePolicyManagerUsers() (change.Change, error) { + users, err := s.ListAllUsers() + if err != nil { + return change.Change{}, fmt.Errorf("listing users for policy update: %w", err) + } + + log.Debug().Caller().Int("user.count", len(users)).Msg("Policy manager user update initiated because user list modification detected") + + changed, err := s.polMan.SetUsers(users) + if err != nil { + return change.Change{}, fmt.Errorf("updating policy manager users: %w", err) + } + + log.Debug().Caller().Bool("policy.changed", changed).Msg("Policy manager user update completed because SetUsers operation finished") + + if changed { + return change.PolicyChange(), nil + } + + return change.Change{}, nil +} + +// UpdatePolicyManagerUsersForTest updates the policy manager's user cache. +// This is exposed for testing purposes to sync the policy manager after +// creating test users via CreateUserForTest(). +func (s *State) UpdatePolicyManagerUsersForTest() error { + _, err := s.updatePolicyManagerUsers() + return err +} + +// updatePolicyManagerNodes updates the policy manager with current nodes. +// Returns true if the policy changed and notifications should be sent. +// TODO(kradalby): This is a temporary stepping stone, ultimately we should +// have the list already available so it could go much quicker. Alternatively +// the policy manager could have a remove or add list for nodes. +// updatePolicyManagerNodes refreshes the policy manager with current node data. +func (s *State) updatePolicyManagerNodes() (change.Change, error) { + nodes := s.ListNodes() + + changed, err := s.polMan.SetNodes(nodes) + if err != nil { + return change.Change{}, fmt.Errorf("updating policy manager nodes: %w", err) + } + + if changed { + // Rebuild peer maps because policy-affecting node changes (tags, user, IPs) + // affect ACL visibility. Without this, cached peer relationships use stale data. + s.nodeStore.RebuildPeerMaps() + return change.PolicyChange(), nil + } + + return change.Change{}, nil +} + +// PingDB checks if the database connection is healthy. +func (s *State) PingDB(ctx context.Context) error { + return s.db.PingDB(ctx) +} + +// autoApproveNodes mass approves routes on all nodes. It is _only_ intended for +// use when the policy is replaced. It is not sending or reporting any changes +// or updates as we send full updates after replacing the policy. +// TODO(kradalby): This is kind of messy, maybe this is another +1 +// for an event bus. See example comments here. +// autoApproveNodes automatically approves nodes based on policy rules. +func (s *State) autoApproveNodes() ([]change.Change, error) { + nodes := s.ListNodes() + + // Approve routes concurrently, this should make it likely + // that the writes end in the same batch in the nodestore write. + var ( + errg errgroup.Group + cs []change.Change + mu sync.Mutex + ) + for _, nv := range nodes.All() { + errg.Go(func() error { + approved, changed := policy.ApproveRoutesWithPolicy(s.polMan, nv, nv.ApprovedRoutes().AsSlice(), nv.AnnouncedRoutes()) + if changed { + log.Debug(). + Uint64("node.id", nv.ID().Uint64()). + Str("node.name", nv.Hostname()). + Strs("routes.approved.old", util.PrefixesToString(nv.ApprovedRoutes().AsSlice())). + Strs("routes.approved.new", util.PrefixesToString(approved)). + Msg("Routes auto-approved by policy") + + _, c, err := s.SetApprovedRoutes(nv.ID(), approved) + if err != nil { + return err + } + + mu.Lock() + cs = append(cs, c) + mu.Unlock() + } + + return nil + }) + } + + err := errg.Wait() + if err != nil { + return nil, err + } + + return cs, nil +} + +// UpdateNodeFromMapRequest processes a MapRequest and updates the node. +// TODO(kradalby): This is essentially a patch update that could be sent directly to nodes, +// which means we could shortcut the whole change thing if there are no other important updates. +// When a field is added to this function, remember to also add it to: +// - node.PeerChangeFromMapRequest +// - node.ApplyPeerChange +// - logTracePeerChange in poll.go. +func (s *State) UpdateNodeFromMapRequest(id types.NodeID, req tailcfg.MapRequest) (change.Change, error) { + log.Trace(). + Caller(). + Uint64("node.id", id.Uint64()). + Interface("request", req). + Msg("Processing MapRequest for node") + + var ( + routeChange bool + hostinfoChanged bool + needsRouteApproval bool + autoApprovedRoutes []netip.Prefix + endpointChanged bool + derpChanged bool + ) + // We need to ensure we update the node as it is in the NodeStore at + // the time of the request. + updatedNode, ok := s.nodeStore.UpdateNode(id, func(currentNode *types.Node) { + peerChange := currentNode.PeerChangeFromMapRequest(req) + + // Track what specifically changed + endpointChanged = peerChange.Endpoints != nil + derpChanged = peerChange.DERPRegion != 0 + hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) + + // Get the correct NetInfo to use + netInfo := netInfoFromMapRequest(id, currentNode.Hostinfo, req.Hostinfo) + if req.Hostinfo != nil { + req.Hostinfo.NetInfo = netInfo + } else { + req.Hostinfo = &tailcfg.Hostinfo{NetInfo: netInfo} + } + + // Re-check hostinfoChanged after potential NetInfo preservation + hostinfoChanged = !hostinfoEqual(currentNode.View(), req.Hostinfo) + + // If there is no changes and nothing to save, + // return early. + if peerChangeEmpty(peerChange) && !hostinfoChanged { + return + } + + // Calculate route approval before NodeStore update to avoid calling View() inside callback + var hasNewRoutes bool + if hi := req.Hostinfo; hi != nil { + hasNewRoutes = len(hi.RoutableIPs) > 0 + } + needsRouteApproval = hostinfoChanged && (routesChanged(currentNode.View(), req.Hostinfo) || (hasNewRoutes && len(currentNode.ApprovedRoutes) == 0)) + if needsRouteApproval { + // Extract announced routes from request + var announcedRoutes []netip.Prefix + if req.Hostinfo != nil { + announcedRoutes = req.Hostinfo.RoutableIPs + } + + // Apply policy-based auto-approval if routes are announced + if len(announcedRoutes) > 0 { + autoApprovedRoutes, routeChange = policy.ApproveRoutesWithPolicy( + s.polMan, + currentNode.View(), + currentNode.ApprovedRoutes, + announcedRoutes, + ) + } + } + + // Log when routes change but approval doesn't + if hostinfoChanged && !routeChange { + if hi := req.Hostinfo; hi != nil { + if routesChanged(currentNode.View(), hi) { + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("oldAnnouncedRoutes", util.PrefixesToString(currentNode.AnnouncedRoutes())). + Strs("newAnnouncedRoutes", util.PrefixesToString(hi.RoutableIPs)). + Strs("approvedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Bool("routeChange", routeChange). + Msg("announced routes changed but approved routes did not") + } + } + } + + currentNode.ApplyPeerChange(&peerChange) + + if hostinfoChanged { + // The node might not set NetInfo if it has not changed and if + // the full HostInfo object is overwritten, the information is lost. + // If there is no NetInfo, keep the previous one. + // From 1.66 the client only sends it if changed: + // https://github.com/tailscale/tailscale/commit/e1011f138737286ecf5123ff887a7a5800d129a2 + // TODO(kradalby): evaluate if we need better comparing of hostinfo + // before we take the changes. + // NetInfo preservation has already been handled above before early return check + currentNode.Hostinfo = req.Hostinfo + currentNode.ApplyHostnameFromHostInfo(req.Hostinfo) + + if routeChange { + // Apply pre-calculated route approval + // Always apply the route approval result to ensure consistency, + // regardless of whether the policy evaluation detected changes. + // This fixes the bug where routes weren't properly cleared when + // auto-approvers were removed from the policy. + log.Info(). + Uint64("node.id", id.Uint64()). + Strs("oldApprovedRoutes", util.PrefixesToString(currentNode.ApprovedRoutes)). + Strs("newApprovedRoutes", util.PrefixesToString(autoApprovedRoutes)). + Bool("routeChanged", routeChange). + Msg("applying route approval results") + } + } + }) + + if !ok { + return change.Change{}, fmt.Errorf("%w: %d", ErrNodeNotInNodeStore, id) + } + + if routeChange { + log.Debug(). + Uint64("node.id", id.Uint64()). + Strs("autoApprovedRoutes", util.PrefixesToString(autoApprovedRoutes)). + Msg("Persisting auto-approved routes from MapRequest") + + // SetApprovedRoutes will update both database and PrimaryRoutes table + _, c, err := s.SetApprovedRoutes(id, autoApprovedRoutes) + if err != nil { + return change.Change{}, fmt.Errorf("persisting auto-approved routes: %w", err) + } + + // If SetApprovedRoutes resulted in a policy change, return it + if !c.IsEmpty() { + return c, nil + } + } // Continue with the rest of the processing using the updated node + + // Handle route changes after NodeStore update. + // Update routes if announced routes changed (even if approved routes stayed the same) + // because SubnetRoutes is the intersection of announced AND approved routes. + nodeRouteChange := s.maybeUpdateNodeRoutes(id, updatedNode, hostinfoChanged, needsRouteApproval, routeChange, req.Hostinfo) + + _, policyChange, err := s.persistNodeToDB(updatedNode) + if err != nil { + return change.Change{}, fmt.Errorf("saving to database: %w", err) + } + + if policyChange.IsFull() { + return policyChange, nil + } + + if !nodeRouteChange.IsEmpty() { + return nodeRouteChange, nil + } + + // Determine the most specific change type based on what actually changed. + // This allows us to send lightweight patch updates instead of full map responses. + return buildMapRequestChangeResponse(id, updatedNode, hostinfoChanged, endpointChanged, derpChanged) +} + +// buildMapRequestChangeResponse determines the appropriate response type for a MapRequest update. +// Hostinfo changes require a full update, while endpoint/DERP changes can use lightweight patches. +func buildMapRequestChangeResponse( + id types.NodeID, + node types.NodeView, + hostinfoChanged, endpointChanged, derpChanged bool, +) (change.Change, error) { + // Hostinfo changes require NodeAdded (full update) as they may affect many fields. + if hostinfoChanged { + return change.NodeAdded(id), nil + } + + // Return specific change types for endpoint and/or DERP updates. + if endpointChanged || derpChanged { + patch := &tailcfg.PeerChange{NodeID: id.NodeID()} + + if endpointChanged { + patch.Endpoints = node.Endpoints().AsSlice() + } + + if derpChanged { + if hi := node.Hostinfo(); hi.Valid() { + if ni := hi.NetInfo(); ni.Valid() { + patch.DERPRegion = ni.PreferredDERP() + } + } + } + + return change.EndpointOrDERPUpdate(id, patch), nil + } + + return change.NodeAdded(id), nil +} + +func hostinfoEqual(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool { + if !oldNode.Valid() && newHI == nil { + return true + } + + if !oldNode.Valid() || newHI == nil { + return false + } + old := oldNode.AsStruct().Hostinfo + + return old.Equal(newHI) +} + +func routesChanged(oldNode types.NodeView, newHI *tailcfg.Hostinfo) bool { + var oldRoutes []netip.Prefix + if oldNode.Valid() && oldNode.AsStruct().Hostinfo != nil { + oldRoutes = oldNode.AsStruct().Hostinfo.RoutableIPs + } + + newRoutes := newHI.RoutableIPs + if newRoutes == nil { + newRoutes = []netip.Prefix{} + } + + tsaddr.SortPrefixes(oldRoutes) + tsaddr.SortPrefixes(newRoutes) + + return !slices.Equal(oldRoutes, newRoutes) +} + +func peerChangeEmpty(peerChange tailcfg.PeerChange) bool { + return peerChange.Key == nil && + peerChange.DiscoKey == nil && + peerChange.Online == nil && + peerChange.Endpoints == nil && + peerChange.DERPRegion == 0 && + peerChange.LastSeen == nil && + peerChange.KeyExpiry == nil +} + +// maybeUpdateNodeRoutes updates node routes if announced routes changed but approved routes didn't. +// This is needed because SubnetRoutes is the intersection of announced AND approved routes. +func (s *State) maybeUpdateNodeRoutes( + id types.NodeID, + node types.NodeView, + hostinfoChanged, needsRouteApproval, routeChange bool, + hostinfo *tailcfg.Hostinfo, +) change.Change { + // Only update if announced routes changed without approval change + if !hostinfoChanged || !needsRouteApproval || routeChange || hostinfo == nil { + return change.Change{} + } + + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Msg("updating routes because announced routes changed but approved routes did not") + + // SetNodeRoutes sets the active/distributed routes using AllApprovedRoutes() + // which returns only the intersection of announced AND approved routes. + log.Debug(). + Caller(). + Uint64("node.id", id.Uint64()). + Strs("announcedRoutes", util.PrefixesToString(node.AnnouncedRoutes())). + Strs("approvedRoutes", util.PrefixesToString(node.ApprovedRoutes().AsSlice())). + Strs("allApprovedRoutes", util.PrefixesToString(node.AllApprovedRoutes())). + Msg("updating node routes for distribution") + + return s.SetNodeRoutes(id, node.AllApprovedRoutes()...) +} diff --git a/hscontrol/state/tags.go b/hscontrol/state/tags.go new file mode 100644 index 00000000..ef745241 --- /dev/null +++ b/hscontrol/state/tags.go @@ -0,0 +1,68 @@ +package state + +import ( + "errors" + "fmt" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/rs/zerolog/log" +) + +var ( + // ErrNodeMarkedTaggedButHasNoTags is returned when a node is marked as tagged but has no tags. + ErrNodeMarkedTaggedButHasNoTags = errors.New("node marked as tagged but has no tags") + + // ErrNodeHasNeitherUserNorTags is returned when a node has neither a user nor tags. + ErrNodeHasNeitherUserNorTags = errors.New("node has neither user nor tags - must be owned by user or tagged") + + // ErrRequestedTagsInvalidOrNotPermitted is returned when requested tags are invalid or not permitted. + // This message format matches Tailscale SaaS: "requested tags [tag:xxx] are invalid or not permitted". + ErrRequestedTagsInvalidOrNotPermitted = errors.New("requested tags") +) + +// validateNodeOwnership ensures proper node ownership model. +// A node must be EITHER user-owned OR tagged (mutually exclusive by behavior). +// Tagged nodes CAN have a UserID for "created by" tracking, but the tag is the owner. +func validateNodeOwnership(node *types.Node) error { + isTagged := node.IsTagged() + + // Tagged nodes: Must have tags, UserID is optional (just "created by") + if isTagged { + if len(node.Tags) == 0 { + return fmt.Errorf("%w: %q", ErrNodeMarkedTaggedButHasNoTags, node.Hostname) + } + // UserID can be set (created by) or nil (orphaned), both valid for tagged nodes + return nil + } + + // User-owned nodes: Must have UserID, must NOT have tags + if node.UserID == nil { + return fmt.Errorf("%w: %q", ErrNodeHasNeitherUserNorTags, node.Hostname) + } + + return nil +} + +// logTagOperation logs tag assignment operations for audit purposes. +func logTagOperation(existingNode types.NodeView, newTags []string) { + if existingNode.IsTagged() { + log.Info(). + Uint64("node.id", existingNode.ID().Uint64()). + Str("node.name", existingNode.Hostname()). + Strs("old.tags", existingNode.Tags().AsSlice()). + Strs("new.tags", newTags). + Msg("Updating tags on already-tagged node") + } else { + var userID uint + if existingNode.UserID().Valid() { + userID = existingNode.UserID().Get() + } + + log.Info(). + Uint64("node.id", existingNode.ID().Uint64()). + Str("node.name", existingNode.Hostname()). + Uint("created.by.user", userID). + Strs("new.tags", newTags). + Msg("Converting user-owned node to tagged node (irreversible)") + } +} diff --git a/hscontrol/state/test_helpers.go b/hscontrol/state/test_helpers.go new file mode 100644 index 00000000..95203106 --- /dev/null +++ b/hscontrol/state/test_helpers.go @@ -0,0 +1,12 @@ +package state + +import ( + "time" +) + +// Test configuration for NodeStore batching. +// These values are optimized for test speed rather than production use. +const ( + TestBatchSize = 5 + TestBatchTimeout = 5 * time.Millisecond +) diff --git a/hscontrol/suite_test.go b/hscontrol/suite_test.go deleted file mode 100644 index 82bdc797..00000000 --- a/hscontrol/suite_test.go +++ /dev/null @@ -1,58 +0,0 @@ -package hscontrol - -import ( - "net/netip" - "os" - "testing" - - "github.com/juanfont/headscale/hscontrol/types" - "gopkg.in/check.v1" -) - -func Test(t *testing.T) { - check.TestingT(t) -} - -var _ = check.Suite(&Suite{}) - -type Suite struct{} - -var ( - tmpDir string - app *Headscale -) - -func (s *Suite) SetUpTest(c *check.C) { - s.ResetDB(c) -} - -func (s *Suite) TearDownTest(c *check.C) { - os.RemoveAll(tmpDir) -} - -func (s *Suite) ResetDB(c *check.C) { - if len(tmpDir) != 0 { - os.RemoveAll(tmpDir) - } - var err error - tmpDir, err = os.MkdirTemp("", "autoygg-client-test2") - if err != nil { - c.Fatal(err) - } - cfg := types.Config{ - NoisePrivateKeyPath: tmpDir + "/noise_private.key", - DBtype: "sqlite3", - DBpath: tmpDir + "/headscale_test.db", - IPPrefixes: []netip.Prefix{ - netip.MustParsePrefix("10.27.0.0/23"), - }, - OIDC: types.OIDCConfig{ - StripEmaildomain: false, - }, - } - - app, err = NewHeadscale(&cfg) - if err != nil { - c.Fatal(err) - } -} diff --git a/hscontrol/tailsql.go b/hscontrol/tailsql.go new file mode 100644 index 00000000..1a949173 --- /dev/null +++ b/hscontrol/tailsql.go @@ -0,0 +1,101 @@ +package hscontrol + +import ( + "context" + "errors" + "fmt" + "net/http" + "os" + + "github.com/tailscale/tailsql/server/tailsql" + "tailscale.com/tsnet" + "tailscale.com/tsweb" + "tailscale.com/types/logger" +) + +func runTailSQLService(ctx context.Context, logf logger.Logf, stateDir, dbPath string) error { + opts := tailsql.Options{ + Hostname: "tailsql-headscale", + StateDir: stateDir, + Sources: []tailsql.DBSpec{ + { + Source: "headscale", + Label: "headscale - sqlite", + Driver: "sqlite", + URL: fmt.Sprintf("file:%s?mode=ro", dbPath), + Named: map[string]string{ + "schema": `select * from sqlite_schema`, + }, + }, + }, + } + + tsNode := &tsnet.Server{ + Dir: os.ExpandEnv(opts.StateDir), + Hostname: opts.Hostname, + Logf: logger.Discard, + } + // if *doDebugLog { + // tsNode.Logf = logf + // } + defer tsNode.Close() + + logf("Starting tailscale (hostname=%q)", opts.Hostname) + lc, err := tsNode.LocalClient() + if err != nil { + return fmt.Errorf("connect local client: %w", err) + } + opts.LocalClient = lc // for authentication + + // Make sure the Tailscale node starts up. It might not, if it is a new node + // and the user did not provide an auth key. + if st, err := tsNode.Up(ctx); err != nil { + return fmt.Errorf("starting tailscale: %w", err) + } else { + logf("tailscale started, node state %q", st.BackendState) + } + + // Reaching here, we have a running Tailscale node, now we can set up the + // HTTP and/or HTTPS plumbing for TailSQL itself. + tsql, err := tailsql.NewServer(opts) + if err != nil { + return fmt.Errorf("creating tailsql server: %w", err) + } + + lst, err := tsNode.Listen("tcp", ":80") + if err != nil { + return fmt.Errorf("listen port 80: %w", err) + } + + if opts.ServeHTTPS { + // When serving TLS, add a redirect from HTTP on port 80 to HTTPS on 443. + certDomains := tsNode.CertDomains() + if len(certDomains) == 0 { + return errors.New("no cert domains available for HTTPS") + } + base := "https://" + certDomains[0] + go http.Serve(lst, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + target := base + r.RequestURI + http.Redirect(w, r, target, http.StatusPermanentRedirect) + })) + // log.Printf("Redirecting HTTP to HTTPS at %q", base) + + // For the real service, start a separate listener. + // Note: Replaces the port 80 listener. + var err error + lst, err = tsNode.ListenTLS("tcp", ":443") + if err != nil { + return fmt.Errorf("listen TLS: %w", err) + } + logf("enabled serving via HTTPS") + } + + mux := tsql.NewMux() + tsweb.Debugger(mux) + go http.Serve(lst, mux) + logf("TailSQL started") + <-ctx.Done() + logf("TailSQL shutting down...") + + return tsNode.Close() +} diff --git a/hscontrol/templates/apple.go b/hscontrol/templates/apple.go new file mode 100644 index 00000000..3b120069 --- /dev/null +++ b/hscontrol/templates/apple.go @@ -0,0 +1,191 @@ +package templates + +import ( + "fmt" + + "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" + "github.com/chasefleming/elem-go/styles" +) + +func Apple(url string) *elem.Element { + return HtmlStructure( + elem.Title(nil, + elem.Text("headscale - Apple")), + mdTypesetBody( + headscaleLogo(), + H1(elem.Text("iOS configuration")), + H2(elem.Text("GUI")), + Ol( + elem.Li( + nil, + elem.Text("Install the official Tailscale iOS client from the "), + externalLink("https://apps.apple.com/app/tailscale/id1470499037", "App Store"), + ), + elem.Li( + nil, + elem.Text("Open the "), + elem.Strong(nil, elem.Text("Tailscale")), + elem.Text(" app"), + ), + elem.Li( + nil, + elem.Text("Click the account icon in the top-right corner and select "), + elem.Strong(nil, elem.Text("Log in…")), + ), + elem.Li( + nil, + elem.Text("Tap the top-right options menu button and select "), + elem.Strong(nil, elem.Text("Use custom coordination server")), + ), + elem.Li( + nil, + elem.Text("Enter your instance URL: "), + Code(elem.Text(url)), + ), + elem.Li( + nil, + elem.Text( + "Enter your credentials and log in. Headscale should now be working on your iOS device", + ), + ), + ), + H1(elem.Text("macOS configuration")), + H2(elem.Text("Command line")), + P( + elem.Text("Use Tailscale's login command to add your profile:"), + ), + Pre(PreCode("tailscale login --login-server "+url)), + H2(elem.Text("GUI")), + Ol( + elem.Li( + nil, + elem.Text("Option + Click the "), + elem.Strong(nil, elem.Text("Tailscale")), + elem.Text(" icon in the menu and hover over the "), + elem.Strong(nil, elem.Text("Debug")), + elem.Text(" menu"), + ), + elem.Li(nil, + elem.Text("Under "), + elem.Strong(nil, elem.Text("Custom Login Server")), + elem.Text(", select "), + elem.Strong(nil, elem.Text("Add Account...")), + ), + elem.Li( + nil, + elem.Text("Enter "), + Code(elem.Text(url)), + elem.Text(" of the headscale instance and press "), + elem.Strong(nil, elem.Text("Add Account")), + ), + elem.Li(nil, + elem.Text("Follow the login procedure in the browser"), + ), + ), + H2(elem.Text("Profiles")), + P( + elem.Text( + "Headscale can be set to the default server by installing a Headscale configuration profile:", + ), + ), + elem.Div(attrs.Props{attrs.Style: styles.Props{styles.MarginTop: spaceL, styles.MarginBottom: spaceL}.ToInline()}, + downloadButton("/apple/macos-app-store", "macOS AppStore profile"), + downloadButton("/apple/macos-standalone", "macOS Standalone profile"), + ), + Ol( + elem.Li( + nil, + elem.Text( + "Download the profile, then open it. When it has been opened, there should be a notification that a profile can be installed", + ), + ), + elem.Li(nil, + elem.Text("Open "), + elem.Strong(nil, elem.Text("System Preferences")), + elem.Text(" and go to "), + elem.Strong(nil, elem.Text("Profiles")), + ), + elem.Li(nil, + elem.Text("Find and install the "), + elem.Strong(nil, elem.Text("Headscale")), + elem.Text(" profile"), + ), + elem.Li(nil, + elem.Text("Restart "), + elem.Strong(nil, elem.Text("Tailscale.app")), + elem.Text(" and log in"), + ), + ), + orDivider(), + P( + elem.Text( + "Use your terminal to configure the default setting for Tailscale by issuing one of the following commands:", + ), + ), + P(elem.Text("For app store client:")), + Pre(PreCode("defaults write io.tailscale.ipn.macos ControlURL "+url)), + P(elem.Text("For standalone client:")), + Pre(PreCode("defaults write io.tailscale.ipn.macsys ControlURL "+url)), + P( + elem.Text("Restart "), + elem.Strong(nil, elem.Text("Tailscale.app")), + elem.Text(" and log in."), + ), + warningBox("Caution", "You should always download and inspect the profile before installing it."), + P(elem.Text("For app store client:")), + Pre(PreCode(fmt.Sprintf(`curl %s/apple/macos-app-store`, url))), + P(elem.Text("For standalone client:")), + Pre(PreCode(fmt.Sprintf(`curl %s/apple/macos-standalone`, url))), + H1(elem.Text("tvOS configuration")), + H2(elem.Text("GUI")), + Ol( + elem.Li( + nil, + elem.Text("Install the official Tailscale tvOS client from the "), + externalLink("https://apps.apple.com/app/tailscale/id1470499037", "App Store"), + ), + elem.Li( + nil, + elem.Text("Open "), + elem.Strong(nil, elem.Text("Settings")), + elem.Text(" (the Apple tvOS settings) > "), + elem.Strong(nil, elem.Text("Apps")), + elem.Text(" > "), + elem.Strong(nil, elem.Text("Tailscale")), + ), + elem.Li( + nil, + elem.Text("Enter "), + Code(elem.Text(url)), + elem.Text(" under "), + elem.Strong(nil, elem.Text("ALTERNATE COORDINATION SERVER URL")), + ), + elem.Li(nil, + elem.Text("Return to the tvOS "), + elem.Strong(nil, elem.Text("Home")), + elem.Text(" screen"), + ), + elem.Li(nil, + elem.Text("Open "), + elem.Strong(nil, elem.Text("Tailscale")), + ), + elem.Li(nil, + elem.Text("Select "), + elem.Strong(nil, elem.Text("Install VPN configuration")), + ), + elem.Li(nil, + elem.Text("Select "), + elem.Strong(nil, elem.Text("Allow")), + ), + elem.Li(nil, + elem.Text("Scan the QR code and follow the login procedure"), + ), + elem.Li(nil, + elem.Text("Headscale should now be working on your tvOS device"), + ), + ), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/apple.html b/hscontrol/templates/apple.html deleted file mode 100644 index 4064dced..00000000 --- a/hscontrol/templates/apple.html +++ /dev/null @@ -1,170 +0,0 @@ - - - - - - - headscale - Apple - - - - -

headscale: macOS configuration

-

Recent Tailscale versions (1.34.0 and higher)

-

- Tailscale added Fast User Switching in version 1.34 and you can now use - the new login command to connect to one or more headscale (and Tailscale) - servers. The previously used profiles does not have an effect anymore. -

-

Command line

-

Use Tailscale's login command to add your profile:

-
tailscale login --login-server {{.URL}}
-

GUI

-
    -
  1. - ALT + Click the Tailscale icon in the menu and hover over the Debug menu -
  2. -
  3. Under "Custom Login Server", select "Add Account..."
  4. -
  5. - Enter "{{.URL}}" of the headscale instance and press "Add Account" -
  6. -
  7. Follow the login procedure in the browser
  8. -
-

Apple configuration profiles (1.32.0 and lower)

-

- This page provides - configuration profiles - for the official Tailscale clients for -

- -

- The profiles will configure Tailscale.app to use {{.URL}} as - its control server. -

-

Caution

-

- You should always download and inspect the profile before installing it: -

-
    -
  • - for app store client: curl {{.URL}}/apple/macos-app-store -
  • -
  • - for standalone client: curl {{.URL}}/apple/macos-standalone -
  • -
-

Profiles

-

macOS

-

- Headscale can be set to the default server by installing a Headscale - configuration profile: -

-

- macOS AppStore profile - macOS Standalone profile -

-
    -
  1. - Download the profile, then open it. When it has been opened, there - should be a notification that a profile can be installed -
  2. -
  3. Open System Preferences and go to "Profiles"
  4. -
  5. Find and install the Headscale profile
  6. -
  7. Restart Tailscale.app and log in
  8. -
-

Or

-

- Use your terminal to configure the default setting for Tailscale by - issuing: -

-
    -
  • - for app store client: - defaults write io.tailscale.ipn.macos ControlURL {{.URL}} -
  • -
  • - for standalone client: - defaults write io.tailscale.ipn.macsys ControlURL {{.URL}} -
  • -
-

Restart Tailscale.app and log in.

-

headscale: iOS configuration

-

Recent Tailscale versions (1.38.1 and higher)

-

- Tailscale 1.38.1 on - iOS - added a configuration option to allow user to set an "Alternate - Coordination server". This can be used to connect to your headscale - server. -

-

GUI

-
    -
  1. - Install the official Tailscale iOS client from the - App store -
  2. -
  3. - Open Tailscale and make sure you are not logged in to any account -
  4. -
  5. Open Settings on the iOS device
  6. -
  7. - Scroll down to the "third party apps" section, under "Game Center" or - "TV Provider" -
  8. -
  9. - Find Tailscale and select it -
      -
    • - If the iOS device was previously logged into Tailscale, switch the - "Reset Keychain" toggle to "on" -
    • -
    -
  10. -
  11. Enter "{{.URL}}" under "Alternate Coordination Server URL"
  12. -
  13. - Restart the app by closing it from the iOS app switcher, open the app - and select the regular sign in option (non-SSO). It should open - up to the headscale authentication page. -
  14. -
  15. - Enter your credentials and log in. Headscale should now be working on - your iOS device -
  16. -
- - diff --git a/hscontrol/templates/design.go b/hscontrol/templates/design.go new file mode 100644 index 00000000..615c0e41 --- /dev/null +++ b/hscontrol/templates/design.go @@ -0,0 +1,482 @@ +package templates + +import ( + elem "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" + "github.com/chasefleming/elem-go/styles" +) + +// Design System Constants +// These constants define the visual language for all Headscale HTML templates. +// They ensure consistency across all pages and make it easy to maintain and update the design. + +// Color System +// EXTRACTED FROM: https://headscale.net/stable/assets/stylesheets/main.342714a4.min.css +// Material for MkDocs design system - exact values from official docs. +const ( + // Text colors - from --md-default-fg-color CSS variables. + colorTextPrimary = "#000000de" //nolint:unused // rgba(0,0,0,0.87) - Body text + colorTextSecondary = "#0000008a" //nolint:unused // rgba(0,0,0,0.54) - Headings (--md-default-fg-color--light) + colorTextTertiary = "#00000052" //nolint:unused // rgba(0,0,0,0.32) - Lighter text + colorTextLightest = "#00000012" //nolint:unused // rgba(0,0,0,0.07) - Lightest text + + // Code colors - from --md-code-* CSS variables. + colorCodeFg = "#36464e" //nolint:unused // Code text color (--md-code-fg-color) + colorCodeBg = "#f5f5f5" //nolint:unused // Code background (--md-code-bg-color) + + // Border colors. + colorBorderLight = "#e5e7eb" //nolint:unused // Light borders + colorBorderMedium = "#d1d5db" //nolint:unused // Medium borders + + // Background colors. + colorBackgroundPage = "#ffffff" //nolint:unused // Page background + colorBackgroundCard = "#ffffff" //nolint:unused // Card/content background + + // Accent colors - from --md-primary/accent-fg-color. + colorPrimaryAccent = "#4051b5" //nolint:unused // Primary accent (links) + colorAccent = "#526cfe" //nolint:unused // Secondary accent + + // Success colors. + colorSuccess = "#059669" //nolint:unused // Success states + colorSuccessLight = "#d1fae5" //nolint:unused // Success backgrounds +) + +// Spacing System +// Based on 4px/8px base unit for consistent rhythm. +// Uses rem units for scalability with user font size preferences. +const ( + spaceXS = "0.25rem" //nolint:unused // 4px - Tight spacing + spaceS = "0.5rem" //nolint:unused // 8px - Small spacing + spaceM = "1rem" //nolint:unused // 16px - Medium spacing (base) + spaceL = "1.5rem" //nolint:unused // 24px - Large spacing + spaceXL = "2rem" //nolint:unused // 32px - Extra large spacing + space2XL = "3rem" //nolint:unused // 48px - 2x extra large spacing + space3XL = "4rem" //nolint:unused // 64px - 3x extra large spacing +) + +// Typography System +// EXTRACTED FROM: https://headscale.net/stable/assets/stylesheets/main.342714a4.min.css +// Material for MkDocs typography - exact values from .md-typeset CSS. +const ( + // Font families - from CSS custom properties. + fontFamilySystem = `"Roboto", -apple-system, BlinkMacSystemFont, "Segoe UI", "Helvetica Neue", Arial, sans-serif` //nolint:unused + fontFamilyCode = `"Roboto Mono", "SF Mono", Monaco, "Cascadia Code", Consolas, "Courier New", monospace` //nolint:unused + + // Font sizes - from .md-typeset CSS rules. + fontSizeBase = "0.8rem" //nolint:unused // 12.8px - Base text (.md-typeset) + fontSizeH1 = "2em" //nolint:unused // 2x base - Main headings + fontSizeH2 = "1.5625em" //nolint:unused // 1.5625x base - Section headings + fontSizeH3 = "1.25em" //nolint:unused // 1.25x base - Subsection headings + fontSizeSmall = "0.8em" //nolint:unused // 0.8x base - Small text + fontSizeCode = "0.85em" //nolint:unused // 0.85x base - Inline code + + // Line heights - from .md-typeset CSS rules. + lineHeightBase = "1.6" //nolint:unused // Body text (.md-typeset) + lineHeightH1 = "1.3" //nolint:unused // H1 headings + lineHeightH2 = "1.4" //nolint:unused // H2 headings + lineHeightH3 = "1.5" //nolint:unused // H3 headings + lineHeightCode = "1.4" //nolint:unused // Code blocks (pre) +) + +// Responsive Container Component +// Creates a centered container with responsive padding and max-width. +// Mobile-first approach: starts at 100% width with padding, constrains on larger screens. +// +//nolint:unused // Reserved for future use in Phase 4. +func responsiveContainer(children ...elem.Node) *elem.Element { + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Width: "100%", + styles.MaxWidth: "min(800px, 90vw)", // Responsive: 90% of viewport or 800px max + styles.Margin: "0 auto", // Center horizontally + styles.Padding: "clamp(1rem, 5vw, 2.5rem)", // Fluid padding: 16px to 40px + }.ToInline(), + }, children...) +} + +// Card Component +// Reusable card for grouping related content with visual separation. +// Parameters: +// - title: Optional title for the card (empty string for no title) +// - children: Content elements to display in the card +// +//nolint:unused // Reserved for future use in Phase 4. +func card(title string, children ...elem.Node) *elem.Element { + cardContent := children + if title != "" { + // Prepend title as H3 if provided + cardContent = append([]elem.Node{ + elem.H3(attrs.Props{ + attrs.Style: styles.Props{ + styles.MarginTop: "0", + styles.MarginBottom: spaceM, + styles.FontSize: fontSizeH3, + styles.LineHeight: lineHeightH3, // 1.5 - H3 line height + styles.Color: colorTextSecondary, + }.ToInline(), + }, elem.Text(title)), + }, children...) + } + + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Background: colorBackgroundCard, + styles.Border: "1px solid " + colorBorderLight, + styles.BorderRadius: "0.5rem", // 8px rounded corners + styles.Padding: "clamp(1rem, 3vw, 1.5rem)", // Responsive padding + styles.MarginBottom: spaceL, + styles.BoxShadow: "0 1px 3px rgba(0,0,0,0.1)", // Subtle shadow + }.ToInline(), + }, cardContent...) +} + +// Code Block Component +// EXTRACTED FROM: .md-typeset pre CSS rules +// Exact styling from Material for MkDocs documentation. +// +//nolint:unused // Used across apple.go, windows.go, register_web.go templates. +func codeBlock(code string) *elem.Element { + return elem.Pre(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "block", + styles.Padding: "0.77em 1.18em", // From .md-typeset pre + styles.Border: "none", // No border in original + styles.BorderRadius: "0.1rem", // From .md-typeset code + styles.BackgroundColor: colorCodeBg, // #f5f5f5 + styles.FontFamily: fontFamilyCode, // Roboto Mono + styles.FontSize: fontSizeCode, // 0.85em + styles.LineHeight: lineHeightCode, // 1.4 + styles.OverflowX: "auto", // Horizontal scroll + "overflow-wrap": "break-word", // Word wrapping + "word-wrap": "break-word", // Legacy support + styles.WhiteSpace: "pre-wrap", // Preserve whitespace + styles.MarginTop: spaceM, // 1em + styles.MarginBottom: spaceM, // 1em + styles.Color: colorCodeFg, // #36464e + styles.BoxShadow: "none", // No shadow in original + }.ToInline(), + }, + elem.Code(nil, elem.Text(code)), + ) +} + +// Base Typeset Styles +// Returns inline styles for the main content container that matches .md-typeset. +// EXTRACTED FROM: .md-typeset CSS rule from Material for MkDocs. +// +//nolint:unused // Used in general.go for mdTypesetBody. +func baseTypesetStyles() styles.Props { + return styles.Props{ + styles.FontSize: fontSizeBase, // 0.8rem + styles.LineHeight: lineHeightBase, // 1.6 + styles.Color: colorTextPrimary, + styles.FontFamily: fontFamilySystem, + "overflow-wrap": "break-word", + styles.TextAlign: "left", + } +} + +// H1 Styles +// Returns inline styles for H1 headings that match .md-typeset h1. +// EXTRACTED FROM: .md-typeset h1 CSS rule from Material for MkDocs. +// +//nolint:unused // Used across templates for main headings. +func h1Styles() styles.Props { + return styles.Props{ + styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54) + styles.FontSize: fontSizeH1, // 2em + styles.LineHeight: lineHeightH1, // 1.3 + styles.Margin: "0 0 1.25em", + styles.FontWeight: "300", + "letter-spacing": "-0.01em", + styles.FontFamily: fontFamilySystem, // Roboto + "overflow-wrap": "break-word", + } +} + +// H2 Styles +// Returns inline styles for H2 headings that match .md-typeset h2. +// EXTRACTED FROM: .md-typeset h2 CSS rule from Material for MkDocs. +// +//nolint:unused // Used across templates for section headings. +func h2Styles() styles.Props { + return styles.Props{ + styles.FontSize: fontSizeH2, // 1.5625em + styles.LineHeight: lineHeightH2, // 1.4 + styles.Margin: "1.6em 0 0.64em", + styles.FontWeight: "300", + "letter-spacing": "-0.01em", + styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54) + styles.FontFamily: fontFamilySystem, // Roboto + "overflow-wrap": "break-word", + } +} + +// H3 Styles +// Returns inline styles for H3 headings that match .md-typeset h3. +// EXTRACTED FROM: .md-typeset h3 CSS rule from Material for MkDocs. +// +//nolint:unused // Used across templates for subsection headings. +func h3Styles() styles.Props { + return styles.Props{ + styles.FontSize: fontSizeH3, // 1.25em + styles.LineHeight: lineHeightH3, // 1.5 + styles.Margin: "1.6em 0 0.8em", + styles.FontWeight: "400", + "letter-spacing": "-0.01em", + styles.Color: colorTextSecondary, // rgba(0, 0, 0, 0.54) + styles.FontFamily: fontFamilySystem, // Roboto + "overflow-wrap": "break-word", + } +} + +// Paragraph Styles +// Returns inline styles for paragraphs that match .md-typeset p. +// EXTRACTED FROM: .md-typeset p CSS rule from Material for MkDocs. +// +//nolint:unused // Used for consistent paragraph spacing. +func paragraphStyles() styles.Props { + return styles.Props{ + styles.Margin: "1em 0", + styles.FontFamily: fontFamilySystem, // Roboto + styles.FontSize: fontSizeBase, // 0.8rem - inherited from .md-typeset + styles.LineHeight: lineHeightBase, // 1.6 - inherited from .md-typeset + styles.Color: colorTextPrimary, // rgba(0, 0, 0, 0.87) + "overflow-wrap": "break-word", + } +} + +// Ordered List Styles +// Returns inline styles for ordered lists that match .md-typeset ol. +// EXTRACTED FROM: .md-typeset ol CSS rule from Material for MkDocs. +// +//nolint:unused // Used for numbered instruction lists. +func orderedListStyles() styles.Props { + return styles.Props{ + styles.MarginBottom: "1em", + styles.MarginTop: "1em", + styles.PaddingLeft: "2em", + styles.FontFamily: fontFamilySystem, // Roboto - inherited from .md-typeset + styles.FontSize: fontSizeBase, // 0.8rem - inherited from .md-typeset + styles.LineHeight: lineHeightBase, // 1.6 - inherited from .md-typeset + styles.Color: colorTextPrimary, // rgba(0, 0, 0, 0.87) - inherited from .md-typeset + "overflow-wrap": "break-word", + } +} + +// Unordered List Styles +// Returns inline styles for unordered lists that match .md-typeset ul. +// EXTRACTED FROM: .md-typeset ul CSS rule from Material for MkDocs. +// +//nolint:unused // Used for bullet point lists. +func unorderedListStyles() styles.Props { + return styles.Props{ + styles.MarginBottom: "1em", + styles.MarginTop: "1em", + styles.PaddingLeft: "2em", + styles.FontFamily: fontFamilySystem, // Roboto - inherited from .md-typeset + styles.FontSize: fontSizeBase, // 0.8rem - inherited from .md-typeset + styles.LineHeight: lineHeightBase, // 1.6 - inherited from .md-typeset + styles.Color: colorTextPrimary, // rgba(0, 0, 0, 0.87) - inherited from .md-typeset + "overflow-wrap": "break-word", + } +} + +// Link Styles +// Returns inline styles for links that match .md-typeset a. +// EXTRACTED FROM: .md-typeset a CSS rule from Material for MkDocs. +// Note: Hover states cannot be implemented with inline styles. +// +//nolint:unused // Used for text links. +func linkStyles() styles.Props { + return styles.Props{ + styles.Color: colorPrimaryAccent, // #4051b5 - var(--md-primary-fg-color) + styles.TextDecoration: "none", + "word-break": "break-word", + styles.FontFamily: fontFamilySystem, // Roboto - inherited from .md-typeset + } +} + +// Inline Code Styles (updated) +// Returns inline styles for inline code that matches .md-typeset code. +// EXTRACTED FROM: .md-typeset code CSS rule from Material for MkDocs. +// +//nolint:unused // Used for inline code snippets. +func inlineCodeStyles() styles.Props { + return styles.Props{ + styles.BackgroundColor: colorCodeBg, // #f5f5f5 + styles.Color: colorCodeFg, // #36464e + styles.BorderRadius: "0.1rem", + styles.FontSize: fontSizeCode, // 0.85em + styles.FontFamily: fontFamilyCode, // Roboto Mono + styles.Padding: "0 0.2941176471em", + "word-break": "break-word", + } +} + +// Inline Code Component +// For inline code snippets within text. +// +//nolint:unused // Reserved for future inline code usage. +func inlineCode(code string) *elem.Element { + return elem.Code(attrs.Props{ + attrs.Style: inlineCodeStyles().ToInline(), + }, elem.Text(code)) +} + +// orDivider creates a visual "or" divider between sections. +// Styled with lines on either side for better visual separation. +// +//nolint:unused // Used in apple.go template. +func orDivider() *elem.Element { + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "flex", + styles.AlignItems: "center", + styles.Gap: spaceM, + styles.MarginTop: space2XL, + styles.MarginBottom: space2XL, + styles.Width: "100%", + }.ToInline(), + }, + elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Flex: "1", + styles.Height: "1px", + styles.BackgroundColor: colorBorderLight, + }.ToInline(), + }), + elem.Strong(attrs.Props{ + attrs.Style: styles.Props{ + styles.Color: colorTextSecondary, + styles.FontSize: fontSizeBase, + styles.FontWeight: "500", + "text-transform": "uppercase", + "letter-spacing": "0.05em", + }.ToInline(), + }, elem.Text("or")), + elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Flex: "1", + styles.Height: "1px", + styles.BackgroundColor: colorBorderLight, + }.ToInline(), + }), + ) +} + +// warningBox creates a warning message box with icon and content. +// +//nolint:unused // Used in apple.go template. +func warningBox(title, message string) *elem.Element { + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "flex", + styles.AlignItems: "flex-start", + styles.Gap: spaceM, + styles.Padding: spaceL, + styles.BackgroundColor: "#fef3c7", // yellow-100 + styles.Border: "1px solid #f59e0b", // yellow-500 + styles.BorderRadius: "0.5rem", + styles.MarginTop: spaceL, + styles.MarginBottom: spaceL, + }.ToInline(), + }, + elem.Raw(``), + elem.Div(nil, + elem.Strong(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "block", + styles.Color: "#92400e", // yellow-800 + styles.FontSize: fontSizeH3, + styles.MarginBottom: spaceXS, + }.ToInline(), + }, elem.Text(title)), + elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Color: colorTextPrimary, + styles.FontSize: fontSizeBase, + }.ToInline(), + }, elem.Text(message)), + ), + ) +} + +// downloadButton creates a nice button-style link for downloads. +// +//nolint:unused // Used in apple.go template. +func downloadButton(href, text string) *elem.Element { + return elem.A(attrs.Props{ + attrs.Href: href, + attrs.Download: "headscale_macos.mobileconfig", + attrs.Style: styles.Props{ + styles.Display: "inline-block", + styles.Padding: "0.75rem 1.5rem", + styles.BackgroundColor: "#3b82f6", // blue-500 + styles.Color: "#ffffff", + styles.TextDecoration: "none", + styles.BorderRadius: "0.5rem", + styles.FontWeight: "500", + styles.Transition: "background-color 0.2s", + styles.MarginRight: spaceM, + styles.MarginBottom: spaceM, + }.ToInline(), + }, elem.Text(text)) +} + +// External Link Component +// Creates a link with proper security attributes for external URLs. +// Automatically adds rel="noreferrer noopener" and target="_blank". +// +//nolint:unused // Used in apple.go, oidc_callback.go templates. +func externalLink(href, text string) *elem.Element { + return elem.A(attrs.Props{ + attrs.Href: href, + attrs.Rel: "noreferrer noopener", + attrs.Target: "_blank", + attrs.Style: styles.Props{ + styles.Color: colorPrimaryAccent, // #4051b5 - base link color + styles.TextDecoration: "none", + }.ToInline(), + }, elem.Text(text)) +} + +// Instruction Step Component +// For numbered instruction lists with consistent formatting. +// +//nolint:unused // Reserved for future use in Phase 4. +func instructionStep(_ int, text string) *elem.Element { + return elem.Li(attrs.Props{ + attrs.Style: styles.Props{ + styles.MarginBottom: spaceS, + styles.LineHeight: lineHeightBase, + }.ToInline(), + }, elem.Text(text)) +} + +// Status Message Component +// For displaying success/error/info messages with appropriate styling. +// +//nolint:unused // Reserved for future use in Phase 4. +func statusMessage(message string, isSuccess bool) *elem.Element { + bgColor := colorSuccessLight + textColor := colorSuccess + + if !isSuccess { + bgColor = "#fee2e2" // red-100 + textColor = "#dc2626" // red-600 + } + + return elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Padding: spaceM, + styles.BackgroundColor: bgColor, + styles.Color: textColor, + styles.BorderRadius: "0.5rem", + styles.Border: "1px solid " + textColor, + styles.MarginBottom: spaceL, + styles.FontSize: fontSizeBase, + styles.LineHeight: lineHeightBase, + }.ToInline(), + }, elem.Text(message)) +} diff --git a/hscontrol/templates/general.go b/hscontrol/templates/general.go new file mode 100644 index 00000000..ccc5a360 --- /dev/null +++ b/hscontrol/templates/general.go @@ -0,0 +1,216 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" + "github.com/chasefleming/elem-go/styles" + "github.com/juanfont/headscale/hscontrol/assets" +) + +// mdTypesetBody creates a body element with md-typeset styling +// that matches the official Headscale documentation design. +// Uses CSS classes with styles defined in assets.CSS. +func mdTypesetBody(children ...elem.Node) *elem.Element { + return elem.Body(attrs.Props{ + attrs.Style: styles.Props{ + styles.MinHeight: "100vh", + styles.Display: "flex", + styles.FlexDirection: "column", + styles.AlignItems: "center", + styles.BackgroundColor: "#ffffff", + styles.Padding: "3rem 1.5rem", + }.ToInline(), + "translate": "no", + }, + elem.Div(attrs.Props{ + attrs.Class: "md-typeset", + attrs.Style: styles.Props{ + styles.MaxWidth: "min(800px, 90vw)", + styles.Width: "100%", + }.ToInline(), + }, children...), + ) +} + +// Styled Element Wrappers +// These functions wrap elem-go elements using CSS classes. +// Styling is handled by the CSS in assets.CSS. + +// H1 creates a H1 element styled by .md-typeset h1 +func H1(children ...elem.Node) *elem.Element { + return elem.H1(nil, children...) +} + +// H2 creates a H2 element styled by .md-typeset h2 +func H2(children ...elem.Node) *elem.Element { + return elem.H2(nil, children...) +} + +// H3 creates a H3 element styled by .md-typeset h3 +func H3(children ...elem.Node) *elem.Element { + return elem.H3(nil, children...) +} + +// P creates a paragraph element styled by .md-typeset p +func P(children ...elem.Node) *elem.Element { + return elem.P(nil, children...) +} + +// Ol creates an ordered list element styled by .md-typeset ol +func Ol(children ...elem.Node) *elem.Element { + return elem.Ol(nil, children...) +} + +// Ul creates an unordered list element styled by .md-typeset ul +func Ul(children ...elem.Node) *elem.Element { + return elem.Ul(nil, children...) +} + +// A creates a link element styled by .md-typeset a +func A(href string, children ...elem.Node) *elem.Element { + return elem.A(attrs.Props{attrs.Href: href}, children...) +} + +// Code creates an inline code element styled by .md-typeset code +func Code(children ...elem.Node) *elem.Element { + return elem.Code(nil, children...) +} + +// Pre creates a preformatted text block styled by .md-typeset pre +func Pre(children ...elem.Node) *elem.Element { + return elem.Pre(nil, children...) +} + +// PreCode creates a code block inside Pre styled by .md-typeset pre > code +func PreCode(code string) *elem.Element { + return elem.Code(nil, elem.Text(code)) +} + +// Deprecated: use H1, H2, H3 instead +func headerOne(text string) *elem.Element { + return H1(elem.Text(text)) +} + +// Deprecated: use H1, H2, H3 instead +func headerTwo(text string) *elem.Element { + return H2(elem.Text(text)) +} + +// Deprecated: use H1, H2, H3 instead +func headerThree(text string) *elem.Element { + return H3(elem.Text(text)) +} + +// contentContainer wraps page content with proper width. +// Content inside is left-aligned by default. +func contentContainer(children ...elem.Node) *elem.Element { + containerStyle := styles.Props{ + styles.MaxWidth: "720px", + styles.Width: "100%", + styles.Display: "flex", + styles.FlexDirection: "column", + styles.AlignItems: "flex-start", // Left-align all children + } + + return elem.Div(attrs.Props{attrs.Style: containerStyle.ToInline()}, children...) +} + +// headscaleLogo returns the Headscale SVG logo for consistent branding across all pages. +// The logo is styled by the .headscale-logo CSS class. +func headscaleLogo() elem.Node { + // Return the embedded SVG as-is + return elem.Raw(assets.SVG) +} + +// pageFooter creates a consistent footer for all pages. +func pageFooter() *elem.Element { + footerStyle := styles.Props{ + styles.MarginTop: space3XL, + styles.TextAlign: "center", + styles.FontSize: fontSizeSmall, + styles.Color: colorTextSecondary, + styles.LineHeight: lineHeightBase, + } + + linkStyle := styles.Props{ + styles.Color: colorTextSecondary, + styles.TextDecoration: "underline", + } + + return elem.Div(attrs.Props{attrs.Style: footerStyle.ToInline()}, + elem.Text("Powered by "), + elem.A(attrs.Props{ + attrs.Href: "https://github.com/juanfont/headscale", + attrs.Rel: "noreferrer noopener", + attrs.Target: "_blank", + attrs.Style: linkStyle.ToInline(), + }, elem.Text("Headscale")), + ) +} + +// listStyle provides consistent styling for ordered and unordered lists +// EXTRACTED FROM: .md-typeset ol, .md-typeset ul CSS rules +var listStyle = styles.Props{ + styles.LineHeight: lineHeightBase, // 1.6 - From .md-typeset + styles.MarginTop: "1em", // From CSS: margin-top: 1em + styles.MarginBottom: "1em", // From CSS: margin-bottom: 1em + styles.PaddingLeft: "clamp(1.5rem, 5vw, 2.5rem)", // Responsive indentation +} + +// HtmlStructure creates a complete HTML document structure with proper meta tags +// and semantic HTML5 structure. The head and body elements are passed as parameters +// to allow for customization of each page. +// Styling is provided via a CSS stylesheet (Material for MkDocs design system) with +// minimal inline styles for layout and positioning. +func HtmlStructure(head, body *elem.Element) *elem.Element { + return elem.Html(attrs.Props{attrs.Lang: "en"}, + elem.Head(nil, + elem.Meta(attrs.Props{ + attrs.Charset: "UTF-8", + }), + elem.Meta(attrs.Props{ + attrs.HTTPequiv: "X-UA-Compatible", + attrs.Content: "IE=edge", + }), + elem.Meta(attrs.Props{ + attrs.Name: "viewport", + attrs.Content: "width=device-width, initial-scale=1.0", + }), + elem.Link(attrs.Props{ + attrs.Rel: "icon", + attrs.Href: "/favicon.ico", + }), + // Google Fonts for Roboto and Roboto Mono + elem.Link(attrs.Props{ + attrs.Rel: "preconnect", + attrs.Href: "https://fonts.gstatic.com", + "crossorigin": "", + }), + elem.Link(attrs.Props{ + attrs.Rel: "stylesheet", + attrs.Href: "https://fonts.googleapis.com/css2?family=Roboto:wght@300;400;500;700&family=Roboto+Mono:wght@400;700&display=swap", + }), + // Material for MkDocs CSS styles + elem.Style(attrs.Props{attrs.Type: "text/css"}, elem.Raw(assets.CSS)), + head, + ), + body, + ) +} + +// BlankPage creates a minimal blank HTML page with favicon. +// Used for endpoints that need to return a valid HTML page with no content. +func BlankPage() *elem.Element { + return elem.Html(attrs.Props{attrs.Lang: "en"}, + elem.Head(nil, + elem.Meta(attrs.Props{ + attrs.Charset: "UTF-8", + }), + elem.Link(attrs.Props{ + attrs.Rel: "icon", + attrs.Href: "/favicon.ico", + }), + ), + elem.Body(nil), + ) +} diff --git a/hscontrol/templates/oidc_callback.go b/hscontrol/templates/oidc_callback.go new file mode 100644 index 00000000..16c08fde --- /dev/null +++ b/hscontrol/templates/oidc_callback.go @@ -0,0 +1,69 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" + "github.com/chasefleming/elem-go/attrs" + "github.com/chasefleming/elem-go/styles" +) + +// checkboxIcon returns the success checkbox SVG icon as raw HTML. +func checkboxIcon() elem.Node { + return elem.Raw(``) +} + +// OIDCCallback renders the OIDC authentication success callback page. +func OIDCCallback(user, verb string) *elem.Element { + // Success message box + successBox := elem.Div(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "flex", + styles.AlignItems: "center", + styles.Gap: spaceM, + styles.Padding: spaceL, + styles.BackgroundColor: colorSuccessLight, + styles.Border: "1px solid " + colorSuccess, + styles.BorderRadius: "0.5rem", + styles.MarginBottom: spaceXL, + }.ToInline(), + }, + checkboxIcon(), + elem.Div(nil, + elem.Strong(attrs.Props{ + attrs.Style: styles.Props{ + styles.Display: "block", + styles.Color: colorSuccess, + styles.FontSize: fontSizeH3, + styles.MarginBottom: spaceXS, + }.ToInline(), + }, elem.Text("Signed in successfully")), + elem.P(attrs.Props{ + attrs.Style: styles.Props{ + styles.Margin: "0", + styles.Color: colorTextPrimary, + styles.FontSize: fontSizeBase, + }.ToInline(), + }, elem.Text(verb), elem.Text(" as "), elem.Strong(nil, elem.Text(user)), elem.Text(". You can now close this window.")), + ), + ) + + return HtmlStructure( + elem.Title(nil, elem.Text("Headscale Authentication Succeeded")), + mdTypesetBody( + headscaleLogo(), + successBox, + H2(elem.Text("Getting started")), + P(elem.Text("Check out the documentation to learn more about headscale and Tailscale:")), + Ul( + elem.Li(nil, + externalLink("https://headscale.net/stable/", "Headscale documentation"), + ), + elem.Li(nil, + externalLink("https://tailscale.com/kb/", "Tailscale knowledge base"), + ), + ), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/register_web.go b/hscontrol/templates/register_web.go new file mode 100644 index 00000000..829af7fb --- /dev/null +++ b/hscontrol/templates/register_web.go @@ -0,0 +1,21 @@ +package templates + +import ( + "fmt" + + "github.com/chasefleming/elem-go" + "github.com/juanfont/headscale/hscontrol/types" +) + +func RegisterWeb(registrationID types.RegistrationID) *elem.Element { + return HtmlStructure( + elem.Title(nil, elem.Text("Registration - Headscale")), + mdTypesetBody( + headscaleLogo(), + H1(elem.Text("Machine registration")), + P(elem.Text("Run the command below in the headscale server to add this machine to your network:")), + Pre(PreCode(fmt.Sprintf("headscale nodes register --key %s --user USERNAME", registrationID.String()))), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/windows.go b/hscontrol/templates/windows.go new file mode 100644 index 00000000..f649509a --- /dev/null +++ b/hscontrol/templates/windows.go @@ -0,0 +1,27 @@ +package templates + +import ( + "github.com/chasefleming/elem-go" +) + +func Windows(url string) *elem.Element { + return HtmlStructure( + elem.Title(nil, + elem.Text("headscale - Windows"), + ), + mdTypesetBody( + headscaleLogo(), + H1(elem.Text("Windows configuration")), + P( + elem.Text("Download "), + externalLink("https://tailscale.com/download/windows", "Tailscale for Windows"), + elem.Text(" and install it."), + ), + P( + elem.Text("Open a Command Prompt or PowerShell and use Tailscale's login command to connect with headscale:"), + ), + Pre(PreCode("tailscale login --login-server "+url)), + pageFooter(), + ), + ) +} diff --git a/hscontrol/templates/windows.html b/hscontrol/templates/windows.html deleted file mode 100644 index c590494f..00000000 --- a/hscontrol/templates/windows.html +++ /dev/null @@ -1,99 +0,0 @@ - - - - - - - headscale - Windows - - - - -

headscale: Windows configuration

-

Recent Tailscale versions (1.34.0 and higher)

-

- Tailscale added Fast User Switching in version 1.34 and you can now use - the new login command to connect to one or more headscale (and Tailscale) - servers. The previously used profiles does not have an effect anymore. -

-

Use Tailscale's login command to add your profile:

-
tailscale login --login-server {{.URL}}
- -

Windows registry configuration (1.32.0 and lower)

-

- This page provides Windows registry information for the official Windows - Tailscale client. -

- -

-

- The registry file will configure Tailscale to use {{.URL}} as - its control server. -

- -

-

Caution

-

- You should always download and inspect the registry file before installing - it: -

-
curl {{.URL}}/windows/tailscale.reg
- -

Installation

-

- Headscale can be set to the default server by running the registry file: -

- -

- Windows registry file -

- -
    -
  1. Download the registry file, then run it
  2. -
  3. Follow the prompts
  4. -
  5. Install and run the official windows Tailscale client
  6. -
  7. - When the installation has finished, start Tailscale, and log in by - clicking the icon in the system tray -
  8. -
-

Or using REG:

-

- Open command prompt with Administrator rights. Issue the following - commands to add the required registry entries: -

-
-    REG ADD "HKLM\Software\Tailscale IPN" /v UnattendedMode /t REG_SZ /d always
-      REG ADD "HKLM\Software\Tailscale IPN" /v LoginURL /t REG_SZ /d "{{.URL}}"
-  
-

Or using Powershell

-

- Open Powershell with Administrator rights. Issue the following commands to - add the required registry entries: -

-
-    New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name UnattendedMode -PropertyType String -Value always
-      New-ItemProperty -Path 'HKLM:\Software\Tailscale IPN' -Name LoginURL -PropertyType String -Value "{{.URL}}"
-  
-

Finally, restart Tailscale and log in.

- -

- - diff --git a/hscontrol/templates_consistency_test.go b/hscontrol/templates_consistency_test.go new file mode 100644 index 00000000..369639cc --- /dev/null +++ b/hscontrol/templates_consistency_test.go @@ -0,0 +1,213 @@ +package hscontrol + +import ( + "strings" + "testing" + + "github.com/juanfont/headscale/hscontrol/templates" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" +) + +func TestTemplateHTMLConsistency(t *testing.T) { + // Test all templates produce consistent modern HTML + testCases := []struct { + name string + html string + }{ + { + name: "OIDC Callback", + html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + }, + { + name: "Register Web", + html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + }, + { + name: "Windows Config", + html: templates.Windows("https://example.com").Render(), + }, + { + name: "Apple Config", + html: templates.Apple("https://example.com").Render(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Check DOCTYPE + assert.True(t, strings.HasPrefix(tc.html, ""), + "%s should start with ", tc.name) + + // Check HTML5 lang attribute + assert.Contains(t, tc.html, ``, + "%s should have html lang=\"en\"", tc.name) + + // Check UTF-8 charset + assert.Contains(t, tc.html, `charset="UTF-8"`, + "%s should have UTF-8 charset", tc.name) + + // Check viewport meta tag + assert.Contains(t, tc.html, `name="viewport"`, + "%s should have viewport meta tag", tc.name) + + // Check IE compatibility meta tag + assert.Contains(t, tc.html, `X-UA-Compatible`, + "%s should have X-UA-Compatible meta tag", tc.name) + + // Check closing tags + assert.Contains(t, tc.html, "", + "%s should have closing html tag", tc.name) + assert.Contains(t, tc.html, "", + "%s should have closing head tag", tc.name) + assert.Contains(t, tc.html, "", + "%s should have closing body tag", tc.name) + }) + } +} + +func TestTemplateModernHTMLFeatures(t *testing.T) { + testCases := []struct { + name string + html string + }{ + { + name: "OIDC Callback", + html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + }, + { + name: "Register Web", + html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + }, + { + name: "Windows Config", + html: templates.Windows("https://example.com").Render(), + }, + { + name: "Apple Config", + html: templates.Apple("https://example.com").Render(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Check no deprecated tags + assert.NotContains(t, tc.html, " tag", tc.name) + assert.NotContains(t, tc.html, " tag", tc.name) + + // Check modern structure + assert.Contains(t, tc.html, "", + "%s should have section", tc.name) + assert.Contains(t, tc.html, " section", tc.name) + assert.Contains(t, tc.html, "", + "%s should have <title> tag", tc.name) + }) + } +} + +func TestTemplateExternalLinkSecurity(t *testing.T) { + // Test that all external links (http/https) have proper security attributes + testCases := []struct { + name string + html string + externalURLs []string // URLs that should have security attributes + }{ + { + name: "OIDC Callback", + html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + externalURLs: []string{ + "https://headscale.net/stable/", + "https://tailscale.com/kb/", + }, + }, + { + name: "Register Web", + html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + externalURLs: []string{}, // No external links + }, + { + name: "Windows Config", + html: templates.Windows("https://example.com").Render(), + externalURLs: []string{ + "https://tailscale.com/download/windows", + }, + }, + { + name: "Apple Config", + html: templates.Apple("https://example.com").Render(), + externalURLs: []string{ + "https://apps.apple.com/app/tailscale/id1470499037", + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + for _, url := range tc.externalURLs { + // Find the link tag containing this URL + if !strings.Contains(tc.html, url) { + t.Errorf("%s should contain external link %s", tc.name, url) + continue + } + + // Check for rel="noreferrer noopener" + // We look for the pattern: href="URL"...rel="noreferrer noopener" + // The attributes might be in any order, so we check within a reasonable window + idx := strings.Index(tc.html, url) + if idx == -1 { + continue + } + + // Look for the closing > of the <a> tag (within 200 chars should be safe) + endIdx := strings.Index(tc.html[idx:idx+200], ">") + if endIdx == -1 { + endIdx = 200 + } + + linkTag := tc.html[idx : idx+endIdx] + + assert.Contains(t, linkTag, `rel="noreferrer noopener"`, + "%s external link %s should have rel=\"noreferrer noopener\"", tc.name, url) + assert.Contains(t, linkTag, `target="_blank"`, + "%s external link %s should have target=\"_blank\"", tc.name, url) + } + }) + } +} + +func TestTemplateAccessibilityAttributes(t *testing.T) { + // Test that all templates have proper accessibility attributes + testCases := []struct { + name string + html string + }{ + { + name: "OIDC Callback", + html: templates.OIDCCallback("test@example.com", "Logged in").Render(), + }, + { + name: "Register Web", + html: templates.RegisterWeb(types.RegistrationID("test-key-123")).Render(), + }, + { + name: "Windows Config", + html: templates.Windows("https://example.com").Render(), + }, + { + name: "Apple Config", + html: templates.Apple("https://example.com").Render(), + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + // Check for translate="no" on body tag to prevent browser translation + // This is important for technical documentation with commands + assert.Contains(t, tc.html, `translate="no"`, + "%s should have translate=\"no\" attribute on body tag", tc.name) + }) + } +} diff --git a/hscontrol/types/api_key.go b/hscontrol/types/api_key.go index 8ca00044..b6a12b65 100644 --- a/hscontrol/types/api_key.go +++ b/hscontrol/types/api_key.go @@ -7,6 +7,13 @@ import ( "google.golang.org/protobuf/types/known/timestamppb" ) +const ( + // NewAPIKeyPrefixLength is the length of the prefix for new API keys. + NewAPIKeyPrefixLength = 12 + // LegacyAPIKeyPrefixLength is the length of the prefix for legacy API keys. + LegacyAPIKeyPrefixLength = 7 +) + // APIKey describes the datamodel for API keys used to remotely authenticate with // headscale. type APIKey struct { @@ -21,8 +28,16 @@ type APIKey struct { func (key *APIKey) Proto() *v1.ApiKey { protoKey := v1.ApiKey{ - Id: key.ID, - Prefix: key.Prefix, + Id: key.ID, + } + + // Show prefix format: distinguish between new (12-char) and legacy (7-char) keys + if len(key.Prefix) == NewAPIKeyPrefixLength { + // New format key (12-char prefix) + protoKey.Prefix = "hskey-api-" + key.Prefix + "-***" + } else { + // Legacy format key (7-char prefix) or fallback + protoKey.Prefix = key.Prefix + "***" } if key.Expiration != nil { diff --git a/hscontrol/types/change/change.go b/hscontrol/types/change/change.go new file mode 100644 index 00000000..a76fb7c4 --- /dev/null +++ b/hscontrol/types/change/change.go @@ -0,0 +1,457 @@ +package change + +import ( + "slices" + "time" + + "github.com/juanfont/headscale/hscontrol/types" + "tailscale.com/tailcfg" +) + +// Change declares what should be included in a MapResponse. +// The mapper uses this to build the response without guessing. +type Change struct { + // Reason is a human-readable description for logging/debugging. + Reason string + + // TargetNode, if set, means this response should only be sent to this node. + TargetNode types.NodeID + + // OriginNode is the node that triggered this change. + // Used for self-update detection and filtering. + OriginNode types.NodeID + + // Content flags - what to include in the MapResponse. + IncludeSelf bool + IncludeDERPMap bool + IncludeDNS bool + IncludeDomain bool + IncludePolicy bool // PacketFilters and SSHPolicy - always sent together + + // Peer changes. + PeersChanged []types.NodeID + PeersRemoved []types.NodeID + PeerPatches []*tailcfg.PeerChange + SendAllPeers bool + + // RequiresRuntimePeerComputation indicates that peer visibility + // must be computed at runtime per-node. Used for policy changes + // where each node may have different peer visibility. + RequiresRuntimePeerComputation bool +} + +// boolFieldNames returns all boolean field names for exhaustive testing. +// When adding a new boolean field to Change, add it here. +// Tests use reflection to verify this matches the struct. +func (r Change) boolFieldNames() []string { + return []string{ + "IncludeSelf", + "IncludeDERPMap", + "IncludeDNS", + "IncludeDomain", + "IncludePolicy", + "SendAllPeers", + "RequiresRuntimePeerComputation", + } +} + +func (r Change) Merge(other Change) Change { + merged := r + + merged.IncludeSelf = r.IncludeSelf || other.IncludeSelf + merged.IncludeDERPMap = r.IncludeDERPMap || other.IncludeDERPMap + merged.IncludeDNS = r.IncludeDNS || other.IncludeDNS + merged.IncludeDomain = r.IncludeDomain || other.IncludeDomain + merged.IncludePolicy = r.IncludePolicy || other.IncludePolicy + merged.SendAllPeers = r.SendAllPeers || other.SendAllPeers + merged.RequiresRuntimePeerComputation = r.RequiresRuntimePeerComputation || other.RequiresRuntimePeerComputation + + merged.PeersChanged = uniqueNodeIDs(append(r.PeersChanged, other.PeersChanged...)) + merged.PeersRemoved = uniqueNodeIDs(append(r.PeersRemoved, other.PeersRemoved...)) + merged.PeerPatches = append(r.PeerPatches, other.PeerPatches...) + + // Preserve OriginNode for self-update detection. + // If either change has OriginNode set, keep it so the mapper + // can detect self-updates and send the node its own changes. + if merged.OriginNode == 0 { + merged.OriginNode = other.OriginNode + } + + // Preserve TargetNode for targeted responses. + if merged.TargetNode == 0 { + merged.TargetNode = other.TargetNode + } + + if r.Reason != "" && other.Reason != "" && r.Reason != other.Reason { + merged.Reason = r.Reason + "; " + other.Reason + } else if other.Reason != "" { + merged.Reason = other.Reason + } + + return merged +} + +func (r Change) IsEmpty() bool { + if r.IncludeSelf || r.IncludeDERPMap || r.IncludeDNS || + r.IncludeDomain || r.IncludePolicy || r.SendAllPeers { + return false + } + + if r.RequiresRuntimePeerComputation { + return false + } + + return len(r.PeersChanged) == 0 && + len(r.PeersRemoved) == 0 && + len(r.PeerPatches) == 0 +} + +func (r Change) IsSelfOnly() bool { + if r.TargetNode == 0 || !r.IncludeSelf { + return false + } + + if r.SendAllPeers || len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || len(r.PeerPatches) > 0 { + return false + } + + return true +} + +// IsTargetedToNode returns true if this response should only be sent to TargetNode. +func (r Change) IsTargetedToNode() bool { + return r.TargetNode != 0 +} + +// IsFull reports whether this is a full update response. +func (r Change) IsFull() bool { + return r.SendAllPeers && r.IncludeSelf && r.IncludeDERPMap && + r.IncludeDNS && r.IncludeDomain && r.IncludePolicy +} + +// Type returns a categorized type string for metrics. +// This provides a bounded set of values suitable for Prometheus labels, +// unlike Reason which is free-form text for logging. +func (r Change) Type() string { + if r.IsFull() { + return "full" + } + + if r.IsSelfOnly() { + return "self" + } + + if r.RequiresRuntimePeerComputation { + return "policy" + } + + if len(r.PeerPatches) > 0 && len(r.PeersChanged) == 0 && len(r.PeersRemoved) == 0 && !r.SendAllPeers { + return "patch" + } + + if len(r.PeersChanged) > 0 || len(r.PeersRemoved) > 0 || r.SendAllPeers { + return "peers" + } + + if r.IncludeDERPMap || r.IncludeDNS || r.IncludeDomain || r.IncludePolicy { + return "config" + } + + return "unknown" +} + +// ShouldSendToNode determines if this response should be sent to nodeID. +// It handles self-only targeting and filtering out self-updates for non-origin nodes. +func (r Change) ShouldSendToNode(nodeID types.NodeID) bool { + // If targeted to a specific node, only send to that node + if r.TargetNode != 0 { + return r.TargetNode == nodeID + } + + return true +} + +// HasFull returns true if any response in the slice is a full update. +func HasFull(rs []Change) bool { + for _, r := range rs { + if r.IsFull() { + return true + } + } + + return false +} + +// SplitTargetedAndBroadcast separates responses into targeted (to specific node) and broadcast. +func SplitTargetedAndBroadcast(rs []Change) ([]Change, []Change) { + var broadcast, targeted []Change + + for _, r := range rs { + if r.IsTargetedToNode() { + targeted = append(targeted, r) + } else { + broadcast = append(broadcast, r) + } + } + + return broadcast, targeted +} + +// FilterForNode returns responses that should be sent to the given node. +func FilterForNode(nodeID types.NodeID, rs []Change) []Change { + var result []Change + + for _, r := range rs { + if r.ShouldSendToNode(nodeID) { + result = append(result, r) + } + } + + return result +} + +func uniqueNodeIDs(ids []types.NodeID) []types.NodeID { + if len(ids) == 0 { + return nil + } + + slices.Sort(ids) + + return slices.Compact(ids) +} + +// Constructor functions + +func FullUpdate() Change { + return Change{ + Reason: "full update", + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, + } +} + +// FullSelf returns a full update targeted at a specific node. +func FullSelf(nodeID types.NodeID) Change { + return Change{ + Reason: "full self update", + TargetNode: nodeID, + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, + } +} + +func SelfUpdate(nodeID types.NodeID) Change { + return Change{ + Reason: "self update", + TargetNode: nodeID, + IncludeSelf: true, + } +} + +func PolicyOnly() Change { + return Change{ + Reason: "policy update", + IncludePolicy: true, + } +} + +func PolicyAndPeers(changedPeers ...types.NodeID) Change { + return Change{ + Reason: "policy and peers update", + IncludePolicy: true, + PeersChanged: changedPeers, + } +} + +func VisibilityChange(reason string, added, removed []types.NodeID) Change { + return Change{ + Reason: reason, + IncludePolicy: true, + PeersChanged: added, + PeersRemoved: removed, + } +} + +func PeersChanged(reason string, peerIDs ...types.NodeID) Change { + return Change{ + Reason: reason, + PeersChanged: peerIDs, + } +} + +func PeersRemoved(peerIDs ...types.NodeID) Change { + return Change{ + Reason: "peers removed", + PeersRemoved: peerIDs, + } +} + +func PeerPatched(reason string, patches ...*tailcfg.PeerChange) Change { + return Change{ + Reason: reason, + PeerPatches: patches, + } +} + +func DERPMap() Change { + return Change{ + Reason: "DERP map update", + IncludeDERPMap: true, + } +} + +// PolicyChange creates a response for policy changes. +// Policy changes require runtime peer visibility computation. +func PolicyChange() Change { + return Change{ + Reason: "policy change", + IncludePolicy: true, + RequiresRuntimePeerComputation: true, + } +} + +// DNSConfig creates a response for DNS configuration updates. +func DNSConfig() Change { + return Change{ + Reason: "DNS config update", + IncludeDNS: true, + } +} + +// NodeOnline creates a patch response for a node coming online. +func NodeOnline(nodeID types.NodeID) Change { + return Change{ + Reason: "node online", + PeerPatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + Online: ptrTo(true), + }, + }, + } +} + +// NodeOffline creates a patch response for a node going offline. +func NodeOffline(nodeID types.NodeID) Change { + return Change{ + Reason: "node offline", + PeerPatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + Online: ptrTo(false), + }, + }, + } +} + +// KeyExpiry creates a patch response for a node's key expiry change. +func KeyExpiry(nodeID types.NodeID, expiry *time.Time) Change { + return Change{ + Reason: "key expiry", + PeerPatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + KeyExpiry: expiry, + }, + }, + } +} + +// ptrTo returns a pointer to the given value. +func ptrTo[T any](v T) *T { + return &v +} + +// High-level change constructors + +// NodeAdded returns a Change for when a node is added or updated. +// The OriginNode field enables self-update detection by the mapper. +func NodeAdded(id types.NodeID) Change { + c := PeersChanged("node added", id) + c.OriginNode = id + + return c +} + +// NodeRemoved returns a Change for when a node is removed. +func NodeRemoved(id types.NodeID) Change { + return PeersRemoved(id) +} + +// NodeOnlineFor returns a Change for when a node comes online. +// If the node is a subnet router, a full update is sent instead of a patch. +func NodeOnlineFor(node types.NodeView) Change { + if node.IsSubnetRouter() { + c := FullUpdate() + c.Reason = "subnet router online" + + return c + } + + return NodeOnline(node.ID()) +} + +// NodeOfflineFor returns a Change for when a node goes offline. +// If the node is a subnet router, a full update is sent instead of a patch. +func NodeOfflineFor(node types.NodeView) Change { + if node.IsSubnetRouter() { + c := FullUpdate() + c.Reason = "subnet router offline" + + return c + } + + return NodeOffline(node.ID()) +} + +// KeyExpiryFor returns a Change for when a node's key expiry changes. +// The OriginNode field enables self-update detection by the mapper. +func KeyExpiryFor(id types.NodeID, expiry time.Time) Change { + c := KeyExpiry(id, &expiry) + c.OriginNode = id + + return c +} + +// EndpointOrDERPUpdate returns a Change for when a node's endpoints or DERP region changes. +// The OriginNode field enables self-update detection by the mapper. +func EndpointOrDERPUpdate(id types.NodeID, patch *tailcfg.PeerChange) Change { + c := PeerPatched("endpoint/DERP update", patch) + c.OriginNode = id + + return c +} + +// UserAdded returns a Change for when a user is added or updated. +// A full update is sent to refresh user profiles on all nodes. +func UserAdded() Change { + c := FullUpdate() + c.Reason = "user added" + + return c +} + +// UserRemoved returns a Change for when a user is removed. +// A full update is sent to refresh user profiles on all nodes. +func UserRemoved() Change { + c := FullUpdate() + c.Reason = "user removed" + + return c +} + +// ExtraRecords returns a Change for when DNS extra records change. +func ExtraRecords() Change { + c := DNSConfig() + c.Reason = "extra records update" + + return c +} diff --git a/hscontrol/types/change/change_test.go b/hscontrol/types/change/change_test.go new file mode 100644 index 00000000..9f181dd6 --- /dev/null +++ b/hscontrol/types/change/change_test.go @@ -0,0 +1,479 @@ +package change + +import ( + "reflect" + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "tailscale.com/tailcfg" +) + +func TestChange_FieldSync(t *testing.T) { + r := Change{} + fieldNames := r.boolFieldNames() + + typ := reflect.TypeFor[Change]() + boolCount := 0 + + for i := range typ.NumField() { + if typ.Field(i).Type.Kind() == reflect.Bool { + boolCount++ + } + } + + if len(fieldNames) != boolCount { + t.Fatalf("boolFieldNames() returns %d fields but struct has %d bool fields; "+ + "update boolFieldNames() when adding new bool fields", len(fieldNames), boolCount) + } +} + +func TestChange_IsEmpty(t *testing.T) { + tests := []struct { + name string + response Change + want bool + }{ + { + name: "zero value is empty", + response: Change{}, + want: true, + }, + { + name: "only reason is still empty", + response: Change{Reason: "test"}, + want: true, + }, + { + name: "IncludeSelf not empty", + response: Change{IncludeSelf: true}, + want: false, + }, + { + name: "IncludeDERPMap not empty", + response: Change{IncludeDERPMap: true}, + want: false, + }, + { + name: "IncludeDNS not empty", + response: Change{IncludeDNS: true}, + want: false, + }, + { + name: "IncludeDomain not empty", + response: Change{IncludeDomain: true}, + want: false, + }, + { + name: "IncludePolicy not empty", + response: Change{IncludePolicy: true}, + want: false, + }, + { + name: "SendAllPeers not empty", + response: Change{SendAllPeers: true}, + want: false, + }, + { + name: "PeersChanged not empty", + response: Change{PeersChanged: []types.NodeID{1}}, + want: false, + }, + { + name: "PeersRemoved not empty", + response: Change{PeersRemoved: []types.NodeID{1}}, + want: false, + }, + { + name: "PeerPatches not empty", + response: Change{PeerPatches: []*tailcfg.PeerChange{{}}}, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.response.IsEmpty() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChange_IsSelfOnly(t *testing.T) { + tests := []struct { + name string + response Change + want bool + }{ + { + name: "empty is not self only", + response: Change{}, + want: false, + }, + { + name: "IncludeSelf without TargetNode is not self only", + response: Change{IncludeSelf: true}, + want: false, + }, + { + name: "TargetNode without IncludeSelf is not self only", + response: Change{TargetNode: 1}, + want: false, + }, + { + name: "TargetNode with IncludeSelf is self only", + response: Change{TargetNode: 1, IncludeSelf: true}, + want: true, + }, + { + name: "self only with SendAllPeers is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, SendAllPeers: true}, + want: false, + }, + { + name: "self only with PeersChanged is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, PeersChanged: []types.NodeID{2}}, + want: false, + }, + { + name: "self only with PeersRemoved is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, PeersRemoved: []types.NodeID{2}}, + want: false, + }, + { + name: "self only with PeerPatches is not self only", + response: Change{TargetNode: 1, IncludeSelf: true, PeerPatches: []*tailcfg.PeerChange{{}}}, + want: false, + }, + { + name: "self only with other include flags is still self only", + response: Change{ + TargetNode: 1, + IncludeSelf: true, + IncludePolicy: true, + IncludeDNS: true, + }, + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.response.IsSelfOnly() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChange_Merge(t *testing.T) { + tests := []struct { + name string + r1 Change + r2 Change + want Change + }{ + { + name: "empty merge", + r1: Change{}, + r2: Change{}, + want: Change{}, + }, + { + name: "bool fields OR together", + r1: Change{IncludeSelf: true, IncludePolicy: true}, + r2: Change{IncludeDERPMap: true, IncludePolicy: true}, + want: Change{IncludeSelf: true, IncludeDERPMap: true, IncludePolicy: true}, + }, + { + name: "all bool fields merge", + r1: Change{IncludeSelf: true, IncludeDNS: true, IncludePolicy: true}, + r2: Change{IncludeDERPMap: true, IncludeDomain: true, SendAllPeers: true}, + want: Change{ + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, + }, + }, + { + name: "peers deduplicated and sorted", + r1: Change{PeersChanged: []types.NodeID{3, 1}}, + r2: Change{PeersChanged: []types.NodeID{2, 1}}, + want: Change{PeersChanged: []types.NodeID{1, 2, 3}}, + }, + { + name: "peers removed deduplicated", + r1: Change{PeersRemoved: []types.NodeID{1, 2}}, + r2: Change{PeersRemoved: []types.NodeID{2, 3}}, + want: Change{PeersRemoved: []types.NodeID{1, 2, 3}}, + }, + { + name: "peer patches concatenated", + r1: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}}}, + r2: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 2}}}, + want: Change{PeerPatches: []*tailcfg.PeerChange{{NodeID: 1}, {NodeID: 2}}}, + }, + { + name: "reasons combined when different", + r1: Change{Reason: "route change"}, + r2: Change{Reason: "tag change"}, + want: Change{Reason: "route change; tag change"}, + }, + { + name: "same reason not duplicated", + r1: Change{Reason: "policy"}, + r2: Change{Reason: "policy"}, + want: Change{Reason: "policy"}, + }, + { + name: "empty reason takes other", + r1: Change{}, + r2: Change{Reason: "update"}, + want: Change{Reason: "update"}, + }, + { + name: "OriginNode preserved from first", + r1: Change{OriginNode: 42}, + r2: Change{IncludePolicy: true}, + want: Change{OriginNode: 42, IncludePolicy: true}, + }, + { + name: "OriginNode preserved from second when first is zero", + r1: Change{IncludePolicy: true}, + r2: Change{OriginNode: 42}, + want: Change{OriginNode: 42, IncludePolicy: true}, + }, + { + name: "OriginNode first wins when both set", + r1: Change{OriginNode: 1}, + r2: Change{OriginNode: 2}, + want: Change{OriginNode: 1}, + }, + { + name: "TargetNode preserved from first", + r1: Change{TargetNode: 42}, + r2: Change{IncludeSelf: true}, + want: Change{TargetNode: 42, IncludeSelf: true}, + }, + { + name: "TargetNode preserved from second when first is zero", + r1: Change{IncludeSelf: true}, + r2: Change{TargetNode: 42}, + want: Change{TargetNode: 42, IncludeSelf: true}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.r1.Merge(tt.r2) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestChange_Constructors(t *testing.T) { + tests := []struct { + name string + constructor func() Change + wantReason string + want Change + }{ + { + name: "FullUpdateResponse", + constructor: FullUpdate, + wantReason: "full update", + want: Change{ + Reason: "full update", + IncludeSelf: true, + IncludeDERPMap: true, + IncludeDNS: true, + IncludeDomain: true, + IncludePolicy: true, + SendAllPeers: true, + }, + }, + { + name: "PolicyOnlyResponse", + constructor: PolicyOnly, + wantReason: "policy update", + want: Change{ + Reason: "policy update", + IncludePolicy: true, + }, + }, + { + name: "DERPMapResponse", + constructor: DERPMap, + wantReason: "DERP map update", + want: Change{ + Reason: "DERP map update", + IncludeDERPMap: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + r := tt.constructor() + assert.Equal(t, tt.wantReason, r.Reason) + assert.Equal(t, tt.want, r) + }) + } +} + +func TestSelfUpdate(t *testing.T) { + r := SelfUpdate(42) + assert.Equal(t, "self update", r.Reason) + assert.Equal(t, types.NodeID(42), r.TargetNode) + assert.True(t, r.IncludeSelf) + assert.True(t, r.IsSelfOnly()) +} + +func TestPolicyAndPeers(t *testing.T) { + r := PolicyAndPeers(1, 2, 3) + assert.Equal(t, "policy and peers update", r.Reason) + assert.True(t, r.IncludePolicy) + assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersChanged) +} + +func TestVisibilityChange(t *testing.T) { + r := VisibilityChange("tag change", []types.NodeID{1}, []types.NodeID{2, 3}) + assert.Equal(t, "tag change", r.Reason) + assert.True(t, r.IncludePolicy) + assert.Equal(t, []types.NodeID{1}, r.PeersChanged) + assert.Equal(t, []types.NodeID{2, 3}, r.PeersRemoved) +} + +func TestPeersChanged(t *testing.T) { + r := PeersChanged("routes approved", 1, 2) + assert.Equal(t, "routes approved", r.Reason) + assert.Equal(t, []types.NodeID{1, 2}, r.PeersChanged) + assert.False(t, r.IncludePolicy) +} + +func TestPeersRemoved(t *testing.T) { + r := PeersRemoved(1, 2, 3) + assert.Equal(t, "peers removed", r.Reason) + assert.Equal(t, []types.NodeID{1, 2, 3}, r.PeersRemoved) +} + +func TestPeerPatched(t *testing.T) { + patch := &tailcfg.PeerChange{NodeID: 1} + r := PeerPatched("endpoint change", patch) + assert.Equal(t, "endpoint change", r.Reason) + assert.Equal(t, []*tailcfg.PeerChange{patch}, r.PeerPatches) +} + +func TestChange_Type(t *testing.T) { + tests := []struct { + name string + response Change + want string + }{ + { + name: "full update", + response: FullUpdate(), + want: "full", + }, + { + name: "self only", + response: SelfUpdate(1), + want: "self", + }, + { + name: "policy with runtime computation", + response: PolicyChange(), + want: "policy", + }, + { + name: "patch only", + response: PeerPatched("test", &tailcfg.PeerChange{NodeID: 1}), + want: "patch", + }, + { + name: "peers changed", + response: PeersChanged("test", 1, 2), + want: "peers", + }, + { + name: "peers removed", + response: PeersRemoved(1, 2), + want: "peers", + }, + { + name: "config - DERP map", + response: DERPMap(), + want: "config", + }, + { + name: "config - DNS", + response: DNSConfig(), + want: "config", + }, + { + name: "config - policy only (no runtime)", + response: PolicyOnly(), + want: "config", + }, + { + name: "empty is unknown", + response: Change{}, + want: "unknown", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.response.Type() + assert.Equal(t, tt.want, got) + }) + } +} + +func TestUniqueNodeIDs(t *testing.T) { + tests := []struct { + name string + input []types.NodeID + want []types.NodeID + }{ + { + name: "nil input", + input: nil, + want: nil, + }, + { + name: "empty input", + input: []types.NodeID{}, + want: nil, + }, + { + name: "single element", + input: []types.NodeID{1}, + want: []types.NodeID{1}, + }, + { + name: "no duplicates", + input: []types.NodeID{1, 2, 3}, + want: []types.NodeID{1, 2, 3}, + }, + { + name: "with duplicates", + input: []types.NodeID{3, 1, 2, 1, 3}, + want: []types.NodeID{1, 2, 3}, + }, + { + name: "all same", + input: []types.NodeID{5, 5, 5, 5}, + want: []types.NodeID{5}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := uniqueNodeIDs(tt.input) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 6e8bfff8..f4814519 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -1,87 +1,48 @@ +//go:generate go tool viewer --type=User,Node,PreAuthKey package types +//go:generate go run tailscale.com/cmd/viewer --type=User,Node,PreAuthKey + import ( - "database/sql/driver" - "encoding/json" "errors" "fmt" - "net/netip" + "runtime" + "sync/atomic" + "time" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" ) +const ( + SelfUpdateIdentifier = "self-update" + DatabasePostgres = "postgres" + DatabaseSqlite = "sqlite3" +) + var ErrCannotParsePrefix = errors.New("cannot parse prefix") -type IPPrefix netip.Prefix - -func (i *IPPrefix) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - prefix, err := netip.ParsePrefix(value) - if err != nil { - return err - } - *i = IPPrefix(prefix) - - return nil - default: - return fmt.Errorf("%w: unexpected data type %T", ErrCannotParsePrefix, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefix) Value() (driver.Value, error) { - prefixStr := netip.Prefix(i).String() - - return prefixStr, nil -} - -type IPPrefixes []netip.Prefix - -func (i *IPPrefixes) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i IPPrefixes) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err -} - -type StringList []string - -func (i *StringList) Scan(destination interface{}) error { - switch value := destination.(type) { - case []byte: - return json.Unmarshal(value, i) - - case string: - return json.Unmarshal([]byte(value), i) - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (i StringList) Value() (driver.Value, error) { - bytes, err := json.Marshal(i) - - return string(bytes), err -} - type StateUpdateType int +func (su StateUpdateType) String() string { + switch su { + case StateFullUpdate: + return "StateFullUpdate" + case StatePeerChanged: + return "StatePeerChanged" + case StatePeerChangedPatch: + return "StatePeerChangedPatch" + case StatePeerRemoved: + return "StatePeerRemoved" + case StateSelfUpdate: + return "StateSelfUpdate" + case StateDERPUpdated: + return "StateDERPUpdated" + } + + return "unknown state update type" +} + const ( StateFullUpdate StateUpdateType = iota // StatePeerChanged is used for updates that needs @@ -91,6 +52,12 @@ const ( StatePeerChanged StatePeerChangedPatch StatePeerRemoved + // StateSelfUpdate is used to indicate that the node + // has changed in control, and the client needs to be + // informed. + // The updated node is inside the ChangeNodes field + // which should have a length of one. + StateSelfUpdate StateDERPUpdated ) @@ -104,7 +71,7 @@ type StateUpdate struct { // ChangeNodes must be set when Type is StatePeerAdded // and StatePeerChanged and contains the full node // object for added nodes. - ChangeNodes Nodes + ChangeNodes []NodeID // ChangePatches must be set when Type is StatePeerChangedPatch // and contains a populated PeerChange object. @@ -113,7 +80,7 @@ type StateUpdate struct { // Removed must be set when Type is StatePeerRemoved and // contain a list of the nodes that has been removed from // the network. - Removed []tailcfg.NodeID + Removed []NodeID // DERPMap must be set when Type is StateDERPUpdated and // contain the new DERP Map. @@ -124,29 +91,141 @@ type StateUpdate struct { Message string } -// Valid reports if a StateUpdate is correctly filled and -// panics if the mandatory fields for a type is not -// filled. -// Reports true if valid. -func (su *StateUpdate) Valid() bool { +// Empty reports if there are any updates in the StateUpdate. +func (su *StateUpdate) Empty() bool { switch su.Type { case StatePeerChanged: - if su.ChangeNodes == nil { - panic("Mandatory field ChangeNodes is not set on StatePeerChanged update") - } + return len(su.ChangeNodes) == 0 case StatePeerChangedPatch: - if su.ChangePatches == nil { - panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update") - } + return len(su.ChangePatches) == 0 case StatePeerRemoved: - if su.Removed == nil { - panic("Mandatory field Removed is not set on StatePeerRemove update") - } - case StateDERPUpdated: - if su.DERPMap == nil { - panic("Mandatory field DERPMap is not set on StateDERPUpdated update") - } + return len(su.Removed) == 0 } - return true + return false +} + +func UpdateFull() StateUpdate { + return StateUpdate{ + Type: StateFullUpdate, + } +} + +func UpdateSelf(nodeID NodeID) StateUpdate { + return StateUpdate{ + Type: StateSelfUpdate, + ChangeNodes: []NodeID{nodeID}, + } +} + +func UpdatePeerChanged(nodeIDs ...NodeID) StateUpdate { + return StateUpdate{ + Type: StatePeerChanged, + ChangeNodes: nodeIDs, + } +} + +func UpdatePeerPatch(changes ...*tailcfg.PeerChange) StateUpdate { + return StateUpdate{ + Type: StatePeerChangedPatch, + ChangePatches: changes, + } +} + +func UpdatePeerRemoved(nodeIDs ...NodeID) StateUpdate { + return StateUpdate{ + Type: StatePeerRemoved, + Removed: nodeIDs, + } +} + +func UpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate { + return StateUpdate{ + Type: StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: nodeID.NodeID(), + KeyExpiry: &expiry, + }, + }, + } +} + +const RegistrationIDLength = 24 + +type RegistrationID string + +func NewRegistrationID() (RegistrationID, error) { + rid, err := util.GenerateRandomStringURLSafe(RegistrationIDLength) + if err != nil { + return "", err + } + + return RegistrationID(rid), nil +} + +func MustRegistrationID() RegistrationID { + rid, err := NewRegistrationID() + if err != nil { + panic(err) + } + + return rid +} + +func RegistrationIDFromString(str string) (RegistrationID, error) { + if len(str) != RegistrationIDLength { + return "", fmt.Errorf("registration ID must be %d characters long", RegistrationIDLength) + } + return RegistrationID(str), nil +} + +func (r RegistrationID) String() string { + return string(r) +} + +type RegisterNode struct { + Node Node + Registered chan *Node + closed *atomic.Bool +} + +func NewRegisterNode(node Node) RegisterNode { + return RegisterNode{ + Node: node, + Registered: make(chan *Node), + closed: &atomic.Bool{}, + } +} + +func (rn *RegisterNode) SendAndClose(node *Node) { + if rn.closed.Swap(true) { + return + } + + select { + case rn.Registered <- node: + default: + } + + close(rn.Registered) +} + +// DefaultBatcherWorkers returns the default number of batcher workers. +// Default to 3/4 of CPU cores, minimum 1, no maximum. +func DefaultBatcherWorkers() int { + return DefaultBatcherWorkersFor(runtime.NumCPU()) +} + +// DefaultBatcherWorkersFor returns the default number of batcher workers for a given CPU count. +// Default to 3/4 of CPU cores, minimum 1, no maximum. +func DefaultBatcherWorkersFor(cpuCount int) int { + const ( + workerNumerator = 3 + workerDenominator = 4 + ) + + defaultWorkers := max((cpuCount*workerNumerator)/workerDenominator, 1) + + return defaultWorkers } diff --git a/hscontrol/types/common_test.go b/hscontrol/types/common_test.go new file mode 100644 index 00000000..a443918b --- /dev/null +++ b/hscontrol/types/common_test.go @@ -0,0 +1,36 @@ +package types + +import ( + "testing" +) + +func TestDefaultBatcherWorkersFor(t *testing.T) { + tests := []struct { + cpuCount int + expected int + }{ + {1, 1}, // (1*3)/4 = 0, should be minimum 1 + {2, 1}, // (2*3)/4 = 1 + {4, 3}, // (4*3)/4 = 3 + {8, 6}, // (8*3)/4 = 6 + {12, 9}, // (12*3)/4 = 9 + {16, 12}, // (16*3)/4 = 12 + {20, 15}, // (20*3)/4 = 15 + {24, 18}, // (24*3)/4 = 18 + } + + for _, test := range tests { + result := DefaultBatcherWorkersFor(test.cpuCount) + if result != test.expected { + t.Errorf("DefaultBatcherWorkersFor(%d) = %d, expected %d", test.cpuCount, result, test.expected) + } + } +} + +func TestDefaultBatcherWorkers(t *testing.T) { + // Just verify it returns a valid value (>= 1) + result := DefaultBatcherWorkers() + if result < 1 { + t.Errorf("DefaultBatcherWorkers() = %d, expected value >= 1", result) + } +} diff --git a/hscontrol/types/config.go b/hscontrol/types/config.go index 4b29c4b7..4068d72e 100644 --- a/hscontrol/types/config.go +++ b/hscontrol/types/config.go @@ -20,15 +20,37 @@ import ( "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/dnstype" + "tailscale.com/util/set" ) const ( defaultOIDCExpiryTime = 180 * 24 * time.Hour // 180 Days maxDuration time.Duration = 1<<63 - 1 + PKCEMethodPlain string = "plain" + PKCEMethodS256 string = "S256" + + defaultNodeStoreBatchSize = 100 ) -var errOidcMutuallyExclusive = errors.New( - "oidc_client_secret and oidc_client_secret_path are mutually exclusive", +var ( + errOidcMutuallyExclusive = errors.New("oidc_client_secret and oidc_client_secret_path are mutually exclusive") + errServerURLSuffix = errors.New("server_url cannot be part of base_domain in a way that could make the DERP and headscale server unreachable") + errServerURLSame = errors.New("server_url cannot use the same domain as base_domain in a way that could make the DERP and headscale server unreachable") + errInvalidPKCEMethod = errors.New("pkce.method must be either 'plain' or 'S256'") +) + +type IPAllocationStrategy string + +const ( + IPAllocationStrategySequential IPAllocationStrategy = "sequential" + IPAllocationStrategyRandom IPAllocationStrategy = "random" +) + +type PolicyMode string + +const ( + PolicyModeDB = "database" + PolicyModeFile = "file" ) // Config contains the initial Headscale configuration. @@ -39,30 +61,31 @@ type Config struct { GRPCAddr string GRPCAllowInsecure bool EphemeralNodeInactivityTimeout time.Duration - NodeUpdateCheckInterval time.Duration - IPPrefixes []netip.Prefix + PrefixV4 *netip.Prefix + PrefixV6 *netip.Prefix + IPAllocation IPAllocationStrategy NoisePrivateKeyPath string BaseDomain string Log LogConfig DisableUpdateCheck bool - DERP DERPConfig + Database DatabaseConfig - DBtype string - DBpath string - DBhost string - DBport int - DBname string - DBuser string - DBpass string - DBssl string + DERP DERPConfig TLS TLSConfig ACMEURL string ACMEEmail string - DNSConfig *tailcfg.DNSConfig + // DNSConfig is the headscale representation of the DNS configuration. + // It is kept in the config update for some settings that are + // not directly converted into a tailcfg.DNSConfig. + DNSConfig DNSConfig + + // TailcfgDNSConfig is the tailcfg representation of the DNS configuration, + // it can be used directly when sending Netmaps to clients. + TailcfgDNSConfig *tailcfg.DNSConfig UnixSocket string UnixSocketPermission fs.FileMode @@ -71,10 +94,66 @@ type Config struct { LogTail LogTailConfig RandomizeClientPort bool + Taildrop TaildropConfig CLI CLIConfig - ACL ACLConfig + Policy PolicyConfig + + Tuning Tuning +} + +type DNSConfig struct { + MagicDNS bool `mapstructure:"magic_dns"` + BaseDomain string `mapstructure:"base_domain"` + OverrideLocalDNS bool `mapstructure:"override_local_dns"` + Nameservers Nameservers + SearchDomains []string `mapstructure:"search_domains"` + ExtraRecords []tailcfg.DNSRecord `mapstructure:"extra_records"` + ExtraRecordsPath string `mapstructure:"extra_records_path"` +} + +type Nameservers struct { + Global []string + Split map[string][]string +} + +type SqliteConfig struct { + Path string + WriteAheadLog bool + WALAutoCheckPoint int +} + +type PostgresConfig struct { + Host string + Port int + Name string + User string + Pass string + Ssl string + MaxOpenConnections int + MaxIdleConnections int + ConnMaxIdleTimeSecs int +} + +type GormConfig struct { + Debug bool + SlowThreshold time.Duration + SkipErrRecordNotFound bool + ParameterizedQueries bool + PrepareStmt bool +} + +type DatabaseConfig struct { + // Type sets the database type, either "sqlite3" or "postgres" + Type string + Debug bool + + // Type sets the gorm configuration + Gorm GormConfig + + Sqlite SqliteConfig + Postgres PostgresConfig } type TLSConfig struct { @@ -91,6 +170,11 @@ type LetsEncryptConfig struct { ChallengeType string } +type PKCEConfig struct { + Enabled bool + Method string +} + type OIDCConfig struct { OnlyStartIfOIDCIsAvailable bool Issuer string @@ -101,28 +185,38 @@ type OIDCConfig struct { AllowedDomains []string AllowedUsers []string AllowedGroups []string - StripEmaildomain bool + EmailVerifiedRequired bool Expiry time.Duration UseExpiryFromToken bool + PKCE PKCEConfig } type DERPConfig struct { - ServerEnabled bool - ServerRegionID int - ServerRegionCode string - ServerRegionName string - ServerPrivateKeyPath string - STUNAddr string - URLs []url.URL - Paths []string - AutoUpdate bool - UpdateFrequency time.Duration + ServerEnabled bool + AutomaticallyAddEmbeddedDerpRegion bool + ServerRegionID int + ServerRegionCode string + ServerRegionName string + ServerPrivateKeyPath string + ServerVerifyClients bool + STUNAddr string + URLs []url.URL + Paths []string + DERPMap *tailcfg.DERPMap + AutoUpdate bool + UpdateFrequency time.Duration + IPv4 string + IPv6 string } type LogTailConfig struct { Enabled bool } +type TaildropConfig struct { + Enabled bool +} + type CLIConfig struct { Address string APIKey string @@ -130,8 +224,13 @@ type CLIConfig struct { Insecure bool } -type ACLConfig struct { - PolicyPath string +type PolicyConfig struct { + Path string + Mode PolicyMode +} + +func (p *PolicyConfig) IsEmpty() bool { + return p.Mode == PolicyModeFile && p.Path == "" } type LogConfig struct { @@ -139,6 +238,89 @@ type LogConfig struct { Level zerolog.Level } +// Tuning contains advanced performance tuning parameters for Headscale. +// These settings control internal batching, timeouts, and resource allocation. +// The defaults are carefully chosen for typical deployments and should rarely +// need adjustment. Changes to these values can significantly impact performance +// and resource usage. +type Tuning struct { + // NotifierSendTimeout is the maximum time to wait when sending notifications + // to connected clients about network changes. + NotifierSendTimeout time.Duration + + // BatchChangeDelay controls how long to wait before sending batched updates + // to clients when multiple changes occur in rapid succession. + BatchChangeDelay time.Duration + + // NodeMapSessionBufferedChanSize sets the buffer size for the channel that + // queues map updates to be sent to connected clients. + NodeMapSessionBufferedChanSize int + + // BatcherWorkers controls the number of parallel workers processing map + // updates for connected clients. + BatcherWorkers int + + // RegisterCacheCleanup is the interval between cleanup operations for + // expired registration cache entries. + RegisterCacheCleanup time.Duration + + // RegisterCacheExpiration is how long registration cache entries remain + // valid before being eligible for cleanup. + RegisterCacheExpiration time.Duration + + // NodeStoreBatchSize controls how many write operations are accumulated + // before rebuilding the in-memory node snapshot. + // + // The NodeStore batches write operations (add/update/delete nodes) before + // rebuilding its in-memory data structures. Rebuilding involves recalculating + // peer relationships between all nodes based on the current ACL policy, which + // is computationally expensive and scales with the square of the number of nodes. + // + // By batching writes, Headscale can process N operations but only rebuild once, + // rather than rebuilding N times. This significantly reduces CPU usage during + // bulk operations like initial sync or policy updates. + // + // Trade-off: Higher values reduce CPU usage from rebuilds but increase latency + // for individual operations waiting for their batch to complete. + NodeStoreBatchSize int + + // NodeStoreBatchTimeout is the maximum time to wait before processing a + // partial batch of node operations. + // + // When NodeStoreBatchSize operations haven't accumulated, this timeout ensures + // writes don't wait indefinitely. The batch processes when either the size + // threshold is reached OR this timeout expires, whichever comes first. + // + // Trade-off: Lower values provide faster response for individual operations + // but trigger more frequent (expensive) peer map rebuilds. Higher values + // optimize for bulk throughput at the cost of individual operation latency. + NodeStoreBatchTimeout time.Duration +} + +func validatePKCEMethod(method string) error { + if method != PKCEMethodPlain && method != PKCEMethodS256 { + return errInvalidPKCEMethod + } + return nil +} + +// Domain returns the hostname/domain part of the ServerURL. +// If the ServerURL is not a valid URL, it returns the BaseDomain. +func (c *Config) Domain() string { + u, err := url.Parse(c.ServerURL) + if err != nil { + return c.BaseDomain + } + + return u.Hostname() +} + +// LoadConfig prepares and loads the Headscale configuration into Viper. +// This means it sets the default values, reads the configuration file and +// environment variables, and handles deprecated configuration options. +// It has to be called before LoadServerConfig and LoadCLIConfig. +// The configuration is not validated and the caller should check for errors +// using a validation function. func LoadConfig(path string, isFile bool) error { if isFile { viper.SetConfigFile(path) @@ -154,21 +336,31 @@ func LoadConfig(path string, isFile bool) error { } } - viper.SetEnvPrefix("headscale") + envPrefix := "headscale" + viper.SetEnvPrefix(envPrefix) viper.SetEnvKeyReplacer(strings.NewReplacer(".", "_")) viper.AutomaticEnv() + viper.SetDefault("policy.mode", "file") + viper.SetDefault("tls_letsencrypt_cache_dir", "/var/www/.cache") viper.SetDefault("tls_letsencrypt_challenge_type", HTTP01ChallengeType) viper.SetDefault("log.level", "info") viper.SetDefault("log.format", TextLogFormat) - viper.SetDefault("dns_config", nil) - viper.SetDefault("dns_config.override_local_dns", true) + viper.SetDefault("dns.magic_dns", true) + viper.SetDefault("dns.base_domain", "") + viper.SetDefault("dns.override_local_dns", true) + viper.SetDefault("dns.nameservers.global", []string{}) + viper.SetDefault("dns.nameservers.split", map[string]string{}) + viper.SetDefault("dns.search_domains", []string{}) viper.SetDefault("derp.server.enabled", false) + viper.SetDefault("derp.server.verify_clients", true) viper.SetDefault("derp.server.stun.enabled", true) + viper.SetDefault("derp.server.automatically_add_embedded_derp_region", true) + viper.SetDefault("derp.update_frequency", "3h") viper.SetDefault("unix_socket", "/var/run/headscale/headscale.sock") viper.SetDefault("unix_socket_permission", "0o770") @@ -179,31 +371,88 @@ func LoadConfig(path string, isFile bool) error { viper.SetDefault("cli.timeout", "5s") viper.SetDefault("cli.insecure", false) - viper.SetDefault("db_ssl", false) + viper.SetDefault("database.postgres.ssl", false) + viper.SetDefault("database.postgres.max_open_conns", 10) + viper.SetDefault("database.postgres.max_idle_conns", 10) + viper.SetDefault("database.postgres.conn_max_idle_time_secs", 3600) + + viper.SetDefault("database.sqlite.write_ahead_log", true) + viper.SetDefault("database.sqlite.wal_autocheckpoint", 1000) // SQLite default viper.SetDefault("oidc.scope", []string{oidc.ScopeOpenID, "profile", "email"}) - viper.SetDefault("oidc.strip_email_domain", true) viper.SetDefault("oidc.only_start_if_oidc_is_available", true) viper.SetDefault("oidc.expiry", "180d") viper.SetDefault("oidc.use_expiry_from_token", false) + viper.SetDefault("oidc.pkce.enabled", false) + viper.SetDefault("oidc.pkce.method", "S256") + viper.SetDefault("oidc.email_verified_required", true) viper.SetDefault("logtail.enabled", false) viper.SetDefault("randomize_client_port", false) + viper.SetDefault("taildrop.enabled", true) viper.SetDefault("ephemeral_node_inactivity_timeout", "120s") - viper.SetDefault("node_update_check_interval", "10s") + viper.SetDefault("tuning.notifier_send_timeout", "800ms") + viper.SetDefault("tuning.batch_change_delay", "800ms") + viper.SetDefault("tuning.node_mapsession_buffered_chan_size", 30) + viper.SetDefault("tuning.node_store_batch_size", defaultNodeStoreBatchSize) + viper.SetDefault("tuning.node_store_batch_timeout", "500ms") - if IsCLIConfigured() { - return nil - } + viper.SetDefault("prefixes.allocation", string(IPAllocationStrategySequential)) if err := viper.ReadInConfig(); err != nil { - log.Warn().Err(err).Msg("Failed to read configuration from disk") + if _, ok := err.(viper.ConfigFileNotFoundError); ok { + log.Warn().Msg("No config file found, using defaults") + return nil + } return fmt.Errorf("fatal error reading config file: %w", err) } + return nil +} + +func validateServerConfig() error { + depr := deprecator{ + warns: make(set.Set[string]), + fatals: make(set.Set[string]), + } + + // Register aliases for backward compatibility + // Has to be called _after_ viper.ReadInConfig() + // https://github.com/spf13/viper/issues/560 + + // Alias the old ACL Policy path with the new configuration option. + depr.fatalIfNewKeyIsNotUsed("policy.path", "acl_policy_path") + + // Move dns_config -> dns + depr.fatalIfNewKeyIsNotUsed("dns.magic_dns", "dns_config.magic_dns") + depr.fatalIfNewKeyIsNotUsed("dns.base_domain", "dns_config.base_domain") + depr.fatalIfNewKeyIsNotUsed("dns.override_local_dns", "dns_config.override_local_dns") + depr.fatalIfNewKeyIsNotUsed("dns.nameservers.global", "dns_config.nameservers") + depr.fatalIfNewKeyIsNotUsed("dns.nameservers.split", "dns_config.restricted_nameservers") + depr.fatalIfNewKeyIsNotUsed("dns.search_domains", "dns_config.domains") + depr.fatalIfNewKeyIsNotUsed("dns.extra_records", "dns_config.extra_records") + depr.fatal("dns.use_username_in_magic_dns") + depr.fatal("dns_config.use_username_in_magic_dns") + + // Removed since version v0.26.0 + depr.fatal("oidc.strip_email_domain") + depr.fatal("oidc.map_legacy_users") + + if viper.GetBool("oidc.enabled") { + if err := validatePKCEMethod(viper.GetString("oidc.pkce.method")); err != nil { + return err + } + } + + depr.Log() + + if viper.IsSet("dns.extra_records") && viper.IsSet("dns.extra_records_path") { + log.Fatal().Msg("Fatal config error: dns.extra_records and dns.extra_records_path are mutually exclusive. Please remove one of them from your config file") + } + // Collect any validation errors and return them all at once var errorText string if (viper.GetString("tls_letsencrypt_hostname") != "") && @@ -211,7 +460,7 @@ func LoadConfig(path string, isFile bool) error { errorText += "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both\n" } - if !viper.IsSet("noise") || viper.GetString("noise.private_key_path") == "" { + if viper.GetString("noise.private_key_path") == "" { errorText += "Fatal config error: headscale now requires a new `noise.private_key_path` field in the config file for the Tailscale v2 protocol\n" } @@ -244,24 +493,36 @@ func LoadConfig(path string, isFile bool) error { ) } - maxNodeUpdateCheckInterval, _ := time.ParseDuration("60s") - if viper.GetDuration("node_update_check_interval") > maxNodeUpdateCheckInterval { + if viper.GetBool("dns.override_local_dns") { + if global := viper.GetStringSlice("dns.nameservers.global"); len(global) == 0 { + errorText += "Fatal config error: dns.nameservers.global must be set when dns.override_local_dns is true\n" + } + } + + // Validate tuning parameters + if size := viper.GetInt("tuning.node_store_batch_size"); size <= 0 { errorText += fmt.Sprintf( - "Fatal config error: node_update_check_interval (%s) is set too high, must be less than %s", - viper.GetString("node_update_check_interval"), - maxNodeUpdateCheckInterval, + "Fatal config error: tuning.node_store_batch_size must be positive, got %d\n", + size, + ) + } + + if timeout := viper.GetDuration("tuning.node_store_batch_timeout"); timeout <= 0 { + errorText += fmt.Sprintf( + "Fatal config error: tuning.node_store_batch_timeout must be positive, got %s\n", + timeout, ) } if errorText != "" { - //nolint + // nolint return errors.New(strings.TrimSuffix(errorText, "\n")) - } else { - return nil } + + return nil } -func GetTLSConfig() TLSConfig { +func tlsConfig() TLSConfig { return TLSConfig{ LetsEncrypt: LetsEncryptConfig{ Hostname: viper.GetString("tls_letsencrypt_hostname"), @@ -280,14 +541,21 @@ func GetTLSConfig() TLSConfig { } } -func GetDERPConfig() DERPConfig { +func derpConfig() DERPConfig { serverEnabled := viper.GetBool("derp.server.enabled") serverRegionID := viper.GetInt("derp.server.region_id") serverRegionCode := viper.GetString("derp.server.region_code") serverRegionName := viper.GetString("derp.server.region_name") + serverVerifyClients := viper.GetBool("derp.server.verify_clients") stunAddr := viper.GetString("derp.server.stun_listen_addr") - privateKeyPath := util.AbsolutePathFromConfigPath(viper.GetString("derp.server.private_key_path")) - + privateKeyPath := util.AbsolutePathFromConfigPath( + viper.GetString("derp.server.private_key_path"), + ) + ipv4 := viper.GetString("derp.server.ipv4") + ipv6 := viper.GetString("derp.server.ipv6") + automaticallyAddEmbeddedDerpRegion := viper.GetBool( + "derp.server.automatically_add_embedded_derp_region", + ) if serverEnabled && stunAddr == "" { log.Fatal(). Msg("derp.server.stun_listen_addr must be set if derp.server.enabled is true") @@ -300,6 +568,7 @@ func GetDERPConfig() DERPConfig { urlAddr, err := url.Parse(urlStr) if err != nil { log.Error(). + Caller(). Str("url", urlStr). Err(err). Msg("Failed to parse url, ignoring...") @@ -310,24 +579,33 @@ func GetDERPConfig() DERPConfig { paths := viper.GetStringSlice("derp.paths") + if serverEnabled && !automaticallyAddEmbeddedDerpRegion && len(paths) == 0 { + log.Fatal(). + Msg("Disabling derp.server.automatically_add_embedded_derp_region requires to configure the derp server in derp.paths") + } + autoUpdate := viper.GetBool("derp.auto_update_enabled") updateFrequency := viper.GetDuration("derp.update_frequency") return DERPConfig{ - ServerEnabled: serverEnabled, - ServerRegionID: serverRegionID, - ServerRegionCode: serverRegionCode, - ServerRegionName: serverRegionName, - ServerPrivateKeyPath: privateKeyPath, - STUNAddr: stunAddr, - URLs: urls, - Paths: paths, - AutoUpdate: autoUpdate, - UpdateFrequency: updateFrequency, + ServerEnabled: serverEnabled, + ServerRegionID: serverRegionID, + ServerRegionCode: serverRegionCode, + ServerRegionName: serverRegionName, + ServerVerifyClients: serverVerifyClients, + ServerPrivateKeyPath: privateKeyPath, + STUNAddr: stunAddr, + URLs: urls, + Paths: paths, + AutoUpdate: autoUpdate, + UpdateFrequency: updateFrequency, + IPv4: ipv4, + IPv6: ipv6, + AutomaticallyAddEmbeddedDerpRegion: automaticallyAddEmbeddedDerpRegion, } } -func GetLogTailConfig() LogTailConfig { +func logtailConfig() LogTailConfig { enabled := viper.GetBool("logtail.enabled") return LogTailConfig{ @@ -335,15 +613,17 @@ func GetLogTailConfig() LogTailConfig { } } -func GetACLConfig() ACLConfig { - policyPath := viper.GetString("acl_policy_path") +func policyConfig() PolicyConfig { + policyPath := viper.GetString("policy.path") + policyMode := viper.GetString("policy.mode") - return ACLConfig{ - PolicyPath: policyPath, + return PolicyConfig{ + Path: policyPath, + Mode: PolicyMode(policyMode), } } -func GetLogConfig() LogConfig { +func logConfig() LogConfig { logLevelStr := viper.GetString("log.level") logLevel, err := zerolog.ParseLevel(logLevelStr) if err != nil { @@ -353,14 +633,15 @@ func GetLogConfig() LogConfig { logFormatOpt := viper.GetString("log.format") var logFormat string switch logFormatOpt { - case "json": + case JSONLogFormat: logFormat = JSONLogFormat - case "text": + case TextLogFormat: logFormat = TextLogFormat case "": logFormat = TextLogFormat default: log.Error(). + Caller(). Str("func", "GetLogConfig"). Msgf("Could not parse log format: %s. Valid choices are 'json' or 'text'", logFormatOpt) } @@ -371,197 +652,310 @@ func GetLogConfig() LogConfig { } } -func GetDNSConfig() (*tailcfg.DNSConfig, string) { - if viper.IsSet("dns_config") { - dnsConfig := &tailcfg.DNSConfig{} +func databaseConfig() DatabaseConfig { + debug := viper.GetBool("database.debug") - overrideLocalDNS := viper.GetBool("dns_config.override_local_dns") + type_ := viper.GetString("database.type") - if viper.IsSet("dns_config.nameservers") { - nameserversStr := viper.GetStringSlice("dns_config.nameservers") + skipErrRecordNotFound := viper.GetBool("database.gorm.skip_err_record_not_found") + slowThreshold := viper.GetDuration("database.gorm.slow_threshold") * time.Millisecond + parameterizedQueries := viper.GetBool("database.gorm.parameterized_queries") + prepareStmt := viper.GetBool("database.gorm.prepare_stmt") - nameservers := []netip.Addr{} - resolvers := []*dnstype.Resolver{} - - for _, nameserverStr := range nameserversStr { - // Search for explicit DNS-over-HTTPS resolvers - if strings.HasPrefix(nameserverStr, "https://") { - resolvers = append(resolvers, &dnstype.Resolver{ - Addr: nameserverStr, - }) - - // This nameserver can not be parsed as an IP address - continue - } - - // Parse nameserver as a regular IP - nameserver, err := netip.ParseAddr(nameserverStr) - if err != nil { - log.Error(). - Str("func", "getDNSConfig"). - Err(err). - Msgf("Could not parse nameserver IP: %s", nameserverStr) - } - - nameservers = append(nameservers, nameserver) - resolvers = append(resolvers, &dnstype.Resolver{ - Addr: nameserver.String(), - }) - } - - dnsConfig.Nameservers = nameservers - - if overrideLocalDNS { - dnsConfig.Resolvers = resolvers - } else { - dnsConfig.FallbackResolvers = resolvers - } - } - - if viper.IsSet("dns_config.restricted_nameservers") { - dnsConfig.Routes = make(map[string][]*dnstype.Resolver) - domains := []string{} - restrictedDNS := viper.GetStringMapStringSlice( - "dns_config.restricted_nameservers", - ) - for domain, restrictedNameservers := range restrictedDNS { - restrictedResolvers := make( - []*dnstype.Resolver, - len(restrictedNameservers), - ) - for index, nameserverStr := range restrictedNameservers { - nameserver, err := netip.ParseAddr(nameserverStr) - if err != nil { - log.Error(). - Str("func", "getDNSConfig"). - Err(err). - Msgf("Could not parse restricted nameserver IP: %s", nameserverStr) - } - restrictedResolvers[index] = &dnstype.Resolver{ - Addr: nameserver.String(), - } - } - dnsConfig.Routes[domain] = restrictedResolvers - domains = append(domains, domain) - } - dnsConfig.Domains = domains - } - - if viper.IsSet("dns_config.domains") { - domains := viper.GetStringSlice("dns_config.domains") - if len(dnsConfig.Resolvers) > 0 { - dnsConfig.Domains = domains - } else if domains != nil { - log.Warn(). - Msg("Warning: dns_config.domains is set, but no nameservers are configured. Ignoring domains.") - } - } - - if viper.IsSet("dns_config.extra_records") { - var extraRecords []tailcfg.DNSRecord - - err := viper.UnmarshalKey("dns_config.extra_records", &extraRecords) - if err != nil { - log.Error(). - Str("func", "getDNSConfig"). - Err(err). - Msgf("Could not parse dns_config.extra_records") - } - - dnsConfig.ExtraRecords = extraRecords - } - - if viper.IsSet("dns_config.magic_dns") { - dnsConfig.Proxied = viper.GetBool("dns_config.magic_dns") - } - - var baseDomain string - if viper.IsSet("dns_config.base_domain") { - baseDomain = viper.GetString("dns_config.base_domain") - } else { - baseDomain = "headscale.net" // does not really matter when MagicDNS is not enabled - } - - log.Trace().Interface("dns_config", dnsConfig).Msg("DNS configuration loaded") - - return dnsConfig, baseDomain + switch type_ { + case DatabaseSqlite, DatabasePostgres: + break + case "sqlite": + type_ = "sqlite3" + default: + log.Fatal(). + Msgf("invalid database type %q, must be sqlite, sqlite3 or postgres", type_) } - return nil, "" + return DatabaseConfig{ + Type: type_, + Debug: debug, + Gorm: GormConfig{ + Debug: debug, + SkipErrRecordNotFound: skipErrRecordNotFound, + SlowThreshold: slowThreshold, + ParameterizedQueries: parameterizedQueries, + PrepareStmt: prepareStmt, + }, + Sqlite: SqliteConfig{ + Path: util.AbsolutePathFromConfigPath( + viper.GetString("database.sqlite.path"), + ), + WriteAheadLog: viper.GetBool("database.sqlite.write_ahead_log"), + WALAutoCheckPoint: viper.GetInt("database.sqlite.wal_autocheckpoint"), + }, + Postgres: PostgresConfig{ + Host: viper.GetString("database.postgres.host"), + Port: viper.GetInt("database.postgres.port"), + Name: viper.GetString("database.postgres.name"), + User: viper.GetString("database.postgres.user"), + Pass: viper.GetString("database.postgres.pass"), + Ssl: viper.GetString("database.postgres.ssl"), + MaxOpenConnections: viper.GetInt("database.postgres.max_open_conns"), + MaxIdleConnections: viper.GetInt("database.postgres.max_idle_conns"), + ConnMaxIdleTimeSecs: viper.GetInt( + "database.postgres.conn_max_idle_time_secs", + ), + }, + } } -func GetHeadscaleConfig() (*Config, error) { - if IsCLIConfigured() { - return &Config{ - CLI: CLIConfig{ - Address: viper.GetString("cli.address"), - APIKey: viper.GetString("cli.api_key"), - Timeout: viper.GetDuration("cli.timeout"), - Insecure: viper.GetBool("cli.insecure"), - }, - }, nil - } +func dns() (DNSConfig, error) { + var dns DNSConfig - dnsConfig, baseDomain := GetDNSConfig() - derpConfig := GetDERPConfig() - logConfig := GetLogTailConfig() - randomizeClientPort := viper.GetBool("randomize_client_port") + // TODO: Use this instead of manually getting settings when + // UnmarshalKey is compatible with Environment Variables. + // err := viper.UnmarshalKey("dns", &dns) + // if err != nil { + // return DNSConfig{}, fmt.Errorf("unmarshalling dns config: %w", err) + // } - configuredPrefixes := viper.GetStringSlice("ip_prefixes") - parsedPrefixes := make([]netip.Prefix, 0, len(configuredPrefixes)+1) + dns.MagicDNS = viper.GetBool("dns.magic_dns") + dns.BaseDomain = viper.GetString("dns.base_domain") + dns.OverrideLocalDNS = viper.GetBool("dns.override_local_dns") + dns.Nameservers.Global = viper.GetStringSlice("dns.nameservers.global") + dns.Nameservers.Split = viper.GetStringMapStringSlice("dns.nameservers.split") + dns.SearchDomains = viper.GetStringSlice("dns.search_domains") + dns.ExtraRecordsPath = viper.GetString("dns.extra_records_path") - for i, prefixInConfig := range configuredPrefixes { - prefix, err := netip.ParsePrefix(prefixInConfig) + if viper.IsSet("dns.extra_records") { + var extraRecords []tailcfg.DNSRecord + + err := viper.UnmarshalKey("dns.extra_records", &extraRecords) if err != nil { - panic(fmt.Errorf("failed to parse ip_prefixes[%d]: %w", i, err)) + return DNSConfig{}, fmt.Errorf("unmarshalling dns extra records: %w", err) } - - if prefix.Addr().Is4() { - builder := netipx.IPSetBuilder{} - builder.AddPrefix(tsaddr.CGNATRange()) - ipSet, _ := builder.IPSet() - if !ipSet.ContainsPrefix(prefix) { - log.Warn(). - Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.", - prefixInConfig, tsaddr.CGNATRange()) - } - } - - if prefix.Addr().Is6() { - builder := netipx.IPSetBuilder{} - builder.AddPrefix(tsaddr.TailscaleULARange()) - ipSet, _ := builder.IPSet() - if !ipSet.ContainsPrefix(prefix) { - log.Warn(). - Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.", - prefixInConfig, tsaddr.TailscaleULARange()) - } - } - - parsedPrefixes = append(parsedPrefixes, prefix) + dns.ExtraRecords = extraRecords } - prefixes := make([]netip.Prefix, 0, len(parsedPrefixes)) - { - // dedup - normalizedPrefixes := make(map[string]int, len(parsedPrefixes)) - for i, p := range parsedPrefixes { - normalized, _ := netipx.RangeOfPrefix(p).Prefix() - normalizedPrefixes[normalized.String()] = i + return dns, nil +} + +// globalResolvers returns the global DNS resolvers +// defined in the config file. +// If a nameserver is a valid IP, it will be used as a regular resolver. +// If a nameserver is a valid URL, it will be used as a DoH resolver. +// If a nameserver is neither a valid URL nor a valid IP, it will be ignored. +func (d *DNSConfig) globalResolvers() []*dnstype.Resolver { + var resolvers []*dnstype.Resolver + + for _, nsStr := range d.Nameservers.Global { + warn := "" + if _, err := netip.ParseAddr(nsStr); err == nil { + resolvers = append(resolvers, &dnstype.Resolver{ + Addr: nsStr, + }) + + continue + } else { + warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err) } - // convert back to list - for _, i := range normalizedPrefixes { - prefixes = append(prefixes, parsedPrefixes[i]) + if _, err := url.Parse(nsStr); err == nil { + resolvers = append(resolvers, &dnstype.Resolver{ + Addr: nsStr, + }) + + continue + } else { + warn = fmt.Sprintf("Invalid global nameserver %q. Parsing error: %s ignoring", nsStr, err) + } + + if warn != "" { + log.Warn().Msg(warn) } } - if len(prefixes) < 1 { - prefixes = append(prefixes, netip.MustParsePrefix("100.64.0.0/10")) + return resolvers +} + +// splitResolvers returns a map of domain to DNS resolvers. +// If a nameserver is a valid IP, it will be used as a regular resolver. +// If a nameserver is a valid URL, it will be used as a DoH resolver. +// If a nameserver is neither a valid URL nor a valid IP, it will be ignored. +func (d *DNSConfig) splitResolvers() map[string][]*dnstype.Resolver { + routes := make(map[string][]*dnstype.Resolver) + for domain, nameservers := range d.Nameservers.Split { + var resolvers []*dnstype.Resolver + for _, nsStr := range nameservers { + warn := "" + if _, err := netip.ParseAddr(nsStr); err == nil { + resolvers = append(resolvers, &dnstype.Resolver{ + Addr: nsStr, + }) + + continue + } else { + warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err) + } + + if _, err := url.Parse(nsStr); err == nil { + resolvers = append(resolvers, &dnstype.Resolver{ + Addr: nsStr, + }) + + continue + } else { + warn = fmt.Sprintf("Invalid split dns nameserver %q. Parsing error: %s ignoring", nsStr, err) + } + + if warn != "" { + log.Warn().Msg(warn) + } + } + routes[domain] = resolvers + } + + return routes +} + +func dnsToTailcfgDNS(dns DNSConfig) *tailcfg.DNSConfig { + cfg := tailcfg.DNSConfig{} + + if dns.BaseDomain == "" && dns.MagicDNS { + log.Fatal().Msg("dns.base_domain must be set when using MagicDNS (dns.magic_dns)") + } + + cfg.Proxied = dns.MagicDNS + cfg.ExtraRecords = dns.ExtraRecords + if dns.OverrideLocalDNS { + cfg.Resolvers = dns.globalResolvers() + } else { + cfg.FallbackResolvers = dns.globalResolvers() + } + + routes := dns.splitResolvers() + cfg.Routes = routes + if dns.BaseDomain != "" { + cfg.Domains = []string{dns.BaseDomain} + } + cfg.Domains = append(cfg.Domains, dns.SearchDomains...) + + return &cfg +} + +func prefixV4() (*netip.Prefix, error) { + prefixV4Str := viper.GetString("prefixes.v4") + + if prefixV4Str == "" { + return nil, nil + } + + prefixV4, err := netip.ParsePrefix(prefixV4Str) + if err != nil { + return nil, fmt.Errorf("parsing IPv4 prefix from config: %w", err) + } + + builder := netipx.IPSetBuilder{} + builder.AddPrefix(tsaddr.CGNATRange()) + ipSet, _ := builder.IPSet() + if !ipSet.ContainsPrefix(prefixV4) { log.Warn(). - Msgf("'ip_prefixes' not configured, falling back to default: %v", prefixes) + Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.", + prefixV4Str, tsaddr.CGNATRange()) } + return &prefixV4, nil +} + +func prefixV6() (*netip.Prefix, error) { + prefixV6Str := viper.GetString("prefixes.v6") + + if prefixV6Str == "" { + return nil, nil + } + + prefixV6, err := netip.ParsePrefix(prefixV6Str) + if err != nil { + return nil, fmt.Errorf("parsing IPv6 prefix from config: %w", err) + } + + builder := netipx.IPSetBuilder{} + builder.AddPrefix(tsaddr.TailscaleULARange()) + ipSet, _ := builder.IPSet() + + if !ipSet.ContainsPrefix(prefixV6) { + log.Warn(). + Msgf("Prefix %s is not in the %s range. This is an unsupported configuration.", + prefixV6Str, tsaddr.TailscaleULARange()) + } + + return &prefixV6, nil +} + +// LoadCLIConfig returns the needed configuration for the CLI client +// of Headscale to connect to a Headscale server. +func LoadCLIConfig() (*Config, error) { + logConfig := logConfig() + zerolog.SetGlobalLevel(logConfig.Level) + + return &Config{ + DisableUpdateCheck: viper.GetBool("disable_check_updates"), + UnixSocket: viper.GetString("unix_socket"), + CLI: CLIConfig{ + Address: viper.GetString("cli.address"), + APIKey: viper.GetString("cli.api_key"), + Timeout: viper.GetDuration("cli.timeout"), + Insecure: viper.GetBool("cli.insecure"), + }, + Log: logConfig, + }, nil +} + +// LoadServerConfig returns the full Headscale configuration to +// host a Headscale server. This is called as part of `headscale serve`. +func LoadServerConfig() (*Config, error) { + if err := validateServerConfig(); err != nil { + return nil, err + } + + logConfig := logConfig() + zerolog.SetGlobalLevel(logConfig.Level) + + prefix4, err := prefixV4() + if err != nil { + return nil, err + } + + prefix6, err := prefixV6() + if err != nil { + return nil, err + } + + if prefix4 == nil && prefix6 == nil { + return nil, errors.New("no IPv4 or IPv6 prefix configured, minimum one prefix is required") + } + + allocStr := viper.GetString("prefixes.allocation") + var alloc IPAllocationStrategy + switch allocStr { + case string(IPAllocationStrategySequential): + alloc = IPAllocationStrategySequential + case string(IPAllocationStrategyRandom): + alloc = IPAllocationStrategyRandom + default: + return nil, fmt.Errorf( + "config error, prefixes.allocation is set to %s, which is not a valid strategy, allowed options: %s, %s", + allocStr, + IPAllocationStrategySequential, + IPAllocationStrategyRandom, + ) + } + + dnsConfig, err := dns() + if err != nil { + return nil, err + } + + derpConfig := derpConfig() + logTailConfig := logtailConfig() + randomizeClientPort := viper.GetBool("randomize_client_port") + oidcClientSecret := viper.GetString("oidc.client_secret") oidcClientSecretPath := viper.GetString("oidc.client_secret_path") if oidcClientSecretPath != "" && oidcClientSecret != "" { @@ -572,22 +966,40 @@ func GetHeadscaleConfig() (*Config, error) { if err != nil { return nil, err } - oidcClientSecret = string(secretBytes) + oidcClientSecret = strings.TrimSpace(string(secretBytes)) + } + + serverURL := viper.GetString("server_url") + + // BaseDomain cannot be the same as the server URL. + // This is because Tailscale takes over the domain in BaseDomain, + // causing the headscale server and DERP to be unreachable. + // For Tailscale upstream, the following is true: + // - DERP run on their own domains + // - Control plane runs on login.tailscale.com/controlplane.tailscale.com + // - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net) + if dnsConfig.BaseDomain != "" { + if err := isSafeServerURL(serverURL, dnsConfig.BaseDomain); err != nil { + return nil, err + } } return &Config{ - ServerURL: viper.GetString("server_url"), + ServerURL: serverURL, Addr: viper.GetString("listen_addr"), MetricsAddr: viper.GetString("metrics_listen_addr"), GRPCAddr: viper.GetString("grpc_listen_addr"), GRPCAllowInsecure: viper.GetBool("grpc_allow_insecure"), - DisableUpdateCheck: viper.GetBool("disable_check_updates"), + DisableUpdateCheck: false, + + PrefixV4: prefix4, + PrefixV6: prefix6, + IPAllocation: IPAllocationStrategy(alloc), - IPPrefixes: prefixes, NoisePrivateKeyPath: util.AbsolutePathFromConfigPath( viper.GetString("noise.private_key_path"), ), - BaseDomain: baseDomain, + BaseDomain: dnsConfig.BaseDomain, DERP: derpConfig, @@ -595,22 +1007,12 @@ func GetHeadscaleConfig() (*Config, error) { "ephemeral_node_inactivity_timeout", ), - NodeUpdateCheckInterval: viper.GetDuration( - "node_update_check_interval", - ), + Database: databaseConfig(), - DBtype: viper.GetString("db_type"), - DBpath: util.AbsolutePathFromConfigPath(viper.GetString("db_path")), - DBhost: viper.GetString("db_host"), - DBport: viper.GetInt("db_port"), - DBname: viper.GetString("db_name"), - DBuser: viper.GetString("db_user"), - DBpass: viper.GetString("db_pass"), - DBssl: viper.GetString("db_ssl"), + TLS: tlsConfig(), - TLS: GetTLSConfig(), - - DNSConfig: dnsConfig, + DNSConfig: dnsConfig, + TailcfgDNSConfig: dnsToTailcfgDNS(dnsConfig), ACMEEmail: viper.GetString("acme_email"), ACMEURL: viper.GetString("acme_url"), @@ -622,15 +1024,15 @@ func GetHeadscaleConfig() (*Config, error) { OnlyStartIfOIDCIsAvailable: viper.GetBool( "oidc.only_start_if_oidc_is_available", ), - Issuer: viper.GetString("oidc.issuer"), - ClientID: viper.GetString("oidc.client_id"), - ClientSecret: oidcClientSecret, - Scope: viper.GetStringSlice("oidc.scope"), - ExtraParams: viper.GetStringMapString("oidc.extra_params"), - AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), - AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), - AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), - StripEmaildomain: viper.GetBool("oidc.strip_email_domain"), + Issuer: viper.GetString("oidc.issuer"), + ClientID: viper.GetString("oidc.client_id"), + ClientSecret: oidcClientSecret, + Scope: viper.GetStringSlice("oidc.scope"), + ExtraParams: viper.GetStringMapString("oidc.extra_params"), + AllowedDomains: viper.GetStringSlice("oidc.allowed_domains"), + AllowedUsers: viper.GetStringSlice("oidc.allowed_users"), + AllowedGroups: viper.GetStringSlice("oidc.allowed_groups"), + EmailVerifiedRequired: viper.GetBool("oidc.email_verified_required"), Expiry: func() time.Duration { // if set to 0, we assume no expiry if value := viper.GetString("oidc.expiry"); value == "0" { @@ -647,12 +1049,19 @@ func GetHeadscaleConfig() (*Config, error) { } }(), UseExpiryFromToken: viper.GetBool("oidc.use_expiry_from_token"), + PKCE: PKCEConfig{ + Enabled: viper.GetBool("oidc.pkce.enabled"), + Method: viper.GetString("oidc.pkce.method"), + }, }, - LogTail: logConfig, + LogTail: logTailConfig, RandomizeClientPort: randomizeClientPort, + Taildrop: TaildropConfig{ + Enabled: viper.GetBool("taildrop.enabled"), + }, - ACL: GetACLConfig(), + Policy: policyConfig(), CLI: CLIConfig{ Address: viper.GetString("cli.address"), @@ -661,10 +1070,158 @@ func GetHeadscaleConfig() (*Config, error) { Insecure: viper.GetBool("cli.insecure"), }, - Log: GetLogConfig(), + Log: logConfig, + + Tuning: Tuning{ + NotifierSendTimeout: viper.GetDuration("tuning.notifier_send_timeout"), + BatchChangeDelay: viper.GetDuration("tuning.batch_change_delay"), + NodeMapSessionBufferedChanSize: viper.GetInt( + "tuning.node_mapsession_buffered_chan_size", + ), + BatcherWorkers: func() int { + if workers := viper.GetInt("tuning.batcher_workers"); workers > 0 { + return workers + } + return DefaultBatcherWorkers() + }(), + RegisterCacheCleanup: viper.GetDuration("tuning.register_cache_cleanup"), + RegisterCacheExpiration: viper.GetDuration("tuning.register_cache_expiration"), + NodeStoreBatchSize: viper.GetInt("tuning.node_store_batch_size"), + NodeStoreBatchTimeout: viper.GetDuration("tuning.node_store_batch_timeout"), + }, }, nil } -func IsCLIConfigured() bool { - return viper.GetString("cli.address") != "" && viper.GetString("cli.api_key") != "" +// BaseDomain cannot be a suffix of the server URL. +// This is because Tailscale takes over the domain in BaseDomain, +// causing the headscale server and DERP to be unreachable. +// For Tailscale upstream, the following is true: +// - DERP run on their own domains. +// - Control plane runs on login.tailscale.com/controlplane.tailscale.com. +// - MagicDNS (BaseDomain) for users is on a *.ts.net domain per tailnet (e.g. tail-scale.ts.net). +func isSafeServerURL(serverURL, baseDomain string) error { + server, err := url.Parse(serverURL) + if err != nil { + return err + } + + if server.Hostname() == baseDomain { + return errServerURLSame + } + + serverDomainParts := strings.Split(server.Host, ".") + baseDomainParts := strings.Split(baseDomain, ".") + + if len(serverDomainParts) <= len(baseDomainParts) { + return nil + } + + s := len(serverDomainParts) + b := len(baseDomainParts) + for i := range baseDomainParts { + if serverDomainParts[s-i-1] != baseDomainParts[b-i-1] { + return nil + } + } + + return errServerURLSuffix +} + +type deprecator struct { + warns set.Set[string] + fatals set.Set[string] +} + +// warnWithAlias will register an alias between the newKey and the oldKey, +// and log a deprecation warning if the oldKey is set. +func (d *deprecator) warnWithAlias(newKey, oldKey string) { + // NOTE: RegisterAlias is called with NEW KEY -> OLD KEY + viper.RegisterAlias(newKey, oldKey) + if viper.IsSet(oldKey) { + d.warns.Add( + fmt.Sprintf( + "The %q configuration key is deprecated. Please use %q instead. %q will be removed in the future.", + oldKey, + newKey, + oldKey, + ), + ) + } +} + +// fatal deprecates and adds an entry to the fatal list of options if the oldKey is set. +func (d *deprecator) fatal(oldKey string) { + if viper.IsSet(oldKey) { + d.fatals.Add( + fmt.Sprintf( + "The %q configuration key has been removed. Please see the changelog for more details.", + oldKey, + ), + ) + } +} + +// fatalIfNewKeyIsNotUsed deprecates and adds an entry to the fatal list of options if the oldKey is set and the new key is _not_ set. +// If the new key is set, a warning is emitted instead. +func (d *deprecator) fatalIfNewKeyIsNotUsed(newKey, oldKey string) { + if viper.IsSet(oldKey) && !viper.IsSet(newKey) { + d.fatals.Add( + fmt.Sprintf( + "The %q configuration key is deprecated. Please use %q instead. %q has been removed.", + oldKey, + newKey, + oldKey, + ), + ) + } else if viper.IsSet(oldKey) { + d.warns.Add(fmt.Sprintf("The %q configuration key is deprecated. Please use %q instead. %q has been removed.", oldKey, newKey, oldKey)) + } +} + +// warn deprecates and adds an option to log a warning if the oldKey is set. +func (d *deprecator) warnNoAlias(newKey, oldKey string) { + if viper.IsSet(oldKey) { + d.warns.Add( + fmt.Sprintf( + "The %q configuration key is deprecated. Please use %q instead. %q has been removed.", + oldKey, + newKey, + oldKey, + ), + ) + } +} + +// warn deprecates and adds an entry to the warn list of options if the oldKey is set. +func (d *deprecator) warn(oldKey string) { + if viper.IsSet(oldKey) { + d.warns.Add( + fmt.Sprintf( + "The %q configuration key is deprecated and has been removed. Please see the changelog for more details.", + oldKey, + ), + ) + } +} + +func (d *deprecator) String() string { + var b strings.Builder + + for _, w := range d.warns.Slice() { + fmt.Fprintf(&b, "WARN: %s\n", w) + } + + for _, f := range d.fatals.Slice() { + fmt.Fprintf(&b, "FATAL: %s\n", f) + } + + return b.String() +} + +func (d *deprecator) Log() { + if len(d.fatals) > 0 { + log.Fatal().Msg("\n" + d.String()) + } else if len(d.warns) > 0 { + log.Warn().Msg("\n" + d.String()) + } } diff --git a/hscontrol/types/config_test.go b/hscontrol/types/config_test.go new file mode 100644 index 00000000..6b9fc2ef --- /dev/null +++ b/hscontrol/types/config_test.go @@ -0,0 +1,469 @@ +package types + +import ( + "fmt" + "os" + "path/filepath" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + "github.com/spf13/viper" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/dnstype" +) + +func TestReadConfig(t *testing.T) { + tests := []struct { + name string + configPath string + setup func(*testing.T) (any, error) + want any + wantErr string + }{ + { + name: "unmarshal-dns-full-config", + configPath: "testdata/dns_full.yaml", + setup: func(t *testing.T) (any, error) { + dns, err := dns() + if err != nil { + return nil, err + } + + return dns, nil + }, + want: DNSConfig{ + MagicDNS: true, + BaseDomain: "example.com", + OverrideLocalDNS: false, + Nameservers: Nameservers{ + Global: []string{ + "1.1.1.1", + "1.0.0.1", + "2606:4700:4700::1111", + "2606:4700:4700::1001", + "https://dns.nextdns.io/abc123", + }, + Split: map[string][]string{ + "darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, + "foo.bar.com": {"1.1.1.1"}, + }, + }, + ExtraRecords: []tailcfg.DNSRecord{ + {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, + {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, + }, + SearchDomains: []string{"test.com", "bar.com"}, + }, + }, + { + name: "dns-to-tailcfg.DNSConfig", + configPath: "testdata/dns_full.yaml", + setup: func(t *testing.T) (any, error) { + dns, err := dns() + if err != nil { + return nil, err + } + + return dnsToTailcfgDNS(dns), nil + }, + want: &tailcfg.DNSConfig{ + Proxied: true, + Domains: []string{"example.com", "test.com", "bar.com"}, + FallbackResolvers: []*dnstype.Resolver{ + {Addr: "1.1.1.1"}, + {Addr: "1.0.0.1"}, + {Addr: "2606:4700:4700::1111"}, + {Addr: "2606:4700:4700::1001"}, + {Addr: "https://dns.nextdns.io/abc123"}, + }, + Routes: map[string][]*dnstype.Resolver{ + "darp.headscale.net": {{Addr: "1.1.1.1"}, {Addr: "8.8.8.8"}}, + "foo.bar.com": {{Addr: "1.1.1.1"}}, + }, + ExtraRecords: []tailcfg.DNSRecord{ + {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, + {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, + }, + }, + }, + { + name: "unmarshal-dns-full-no-magic", + configPath: "testdata/dns_full_no_magic.yaml", + setup: func(t *testing.T) (any, error) { + dns, err := dns() + if err != nil { + return nil, err + } + + return dns, nil + }, + want: DNSConfig{ + MagicDNS: false, + BaseDomain: "example.com", + OverrideLocalDNS: false, + Nameservers: Nameservers{ + Global: []string{ + "1.1.1.1", + "1.0.0.1", + "2606:4700:4700::1111", + "2606:4700:4700::1001", + "https://dns.nextdns.io/abc123", + }, + Split: map[string][]string{ + "darp.headscale.net": {"1.1.1.1", "8.8.8.8"}, + "foo.bar.com": {"1.1.1.1"}, + }, + }, + ExtraRecords: []tailcfg.DNSRecord{ + {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, + {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, + }, + SearchDomains: []string{"test.com", "bar.com"}, + }, + }, + { + name: "dns-to-tailcfg.DNSConfig", + configPath: "testdata/dns_full_no_magic.yaml", + setup: func(t *testing.T) (any, error) { + dns, err := dns() + if err != nil { + return nil, err + } + + return dnsToTailcfgDNS(dns), nil + }, + want: &tailcfg.DNSConfig{ + Proxied: false, + Domains: []string{"example.com", "test.com", "bar.com"}, + FallbackResolvers: []*dnstype.Resolver{ + {Addr: "1.1.1.1"}, + {Addr: "1.0.0.1"}, + {Addr: "2606:4700:4700::1111"}, + {Addr: "2606:4700:4700::1001"}, + {Addr: "https://dns.nextdns.io/abc123"}, + }, + Routes: map[string][]*dnstype.Resolver{ + "darp.headscale.net": {{Addr: "1.1.1.1"}, {Addr: "8.8.8.8"}}, + "foo.bar.com": {{Addr: "1.1.1.1"}}, + }, + ExtraRecords: []tailcfg.DNSRecord{ + {Name: "grafana.myvpn.example.com", Type: "A", Value: "100.64.0.3"}, + {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, + }, + }, + }, + { + name: "base-domain-in-server-url-err", + configPath: "testdata/base-domain-in-server-url.yaml", + setup: func(t *testing.T) (any, error) { + return LoadServerConfig() + }, + want: nil, + wantErr: errServerURLSuffix.Error(), + }, + { + name: "base-domain-not-in-server-url", + configPath: "testdata/base-domain-not-in-server-url.yaml", + setup: func(t *testing.T) (any, error) { + cfg, err := LoadServerConfig() + if err != nil { + return nil, err + } + + return map[string]string{ + "server_url": cfg.ServerURL, + "base_domain": cfg.BaseDomain, + }, err + }, + want: map[string]string{ + "server_url": "https://derp.no", + "base_domain": "clients.derp.no", + }, + wantErr: "", + }, + { + name: "dns-override-true-errors", + configPath: "testdata/dns-override-true-error.yaml", + setup: func(t *testing.T) (any, error) { + return LoadServerConfig() + }, + wantErr: "Fatal config error: dns.nameservers.global must be set when dns.override_local_dns is true", + }, + { + name: "dns-override-true", + configPath: "testdata/dns-override-true.yaml", + setup: func(t *testing.T) (any, error) { + _, err := LoadServerConfig() + if err != nil { + return nil, err + } + + dns, err := dns() + if err != nil { + return nil, err + } + + return dnsToTailcfgDNS(dns), nil + }, + want: &tailcfg.DNSConfig{ + Proxied: true, + Domains: []string{"derp2.no"}, + Routes: map[string][]*dnstype.Resolver{}, + Resolvers: []*dnstype.Resolver{ + {Addr: "1.1.1.1"}, + {Addr: "1.0.0.1"}, + }, + }, + }, + { + name: "policy-path-is-loaded", + configPath: "testdata/policy-path-is-loaded.yaml", + setup: func(t *testing.T) (any, error) { + cfg, err := LoadServerConfig() + if err != nil { + return nil, err + } + + return map[string]string{ + "policy.mode": string(cfg.Policy.Mode), + "policy.path": cfg.Policy.Path, + }, err + }, + want: map[string]string{ + "policy.mode": "file", + "policy.path": "/etc/policy.hujson", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + viper.Reset() + err := LoadConfig(tt.configPath, true) + require.NoError(t, err) + + conf, err := tt.setup(t) + + if tt.wantErr != "" { + assert.Equal(t, tt.wantErr, err.Error()) + + return + } + + require.NoError(t, err) + + if diff := cmp.Diff(tt.want, conf); diff != "" { + t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestReadConfigFromEnv(t *testing.T) { + tests := []struct { + name string + configEnv map[string]string + setup func(*testing.T) (any, error) + want any + }{ + { + name: "test-random-base-settings-with-env", + configEnv: map[string]string{ + "HEADSCALE_LOG_LEVEL": "trace", + "HEADSCALE_DATABASE_SQLITE_WRITE_AHEAD_LOG": "false", + "HEADSCALE_PREFIXES_V4": "100.64.0.0/10", + }, + setup: func(t *testing.T) (any, error) { + t.Logf("all settings: %#v", viper.AllSettings()) + + assert.Equal(t, "trace", viper.GetString("log.level")) + assert.Equal(t, "100.64.0.0/10", viper.GetString("prefixes.v4")) + assert.False(t, viper.GetBool("database.sqlite.write_ahead_log")) + + return nil, nil + }, + want: nil, + }, + { + name: "unmarshal-dns-full-config", + configEnv: map[string]string{ + "HEADSCALE_DNS_MAGIC_DNS": "true", + "HEADSCALE_DNS_BASE_DOMAIN": "example.com", + "HEADSCALE_DNS_OVERRIDE_LOCAL_DNS": "false", + "HEADSCALE_DNS_NAMESERVERS_GLOBAL": `1.1.1.1 8.8.8.8`, + "HEADSCALE_DNS_SEARCH_DOMAINS": "test.com bar.com", + + // TODO(kradalby): Figure out how to pass these as env vars + // "HEADSCALE_DNS_NAMESERVERS_SPLIT": `{foo.bar.com: ["1.1.1.1"]}`, + // "HEADSCALE_DNS_EXTRA_RECORDS": `[{ name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" }]`, + }, + setup: func(t *testing.T) (any, error) { + t.Logf("all settings: %#v", viper.AllSettings()) + + dns, err := dns() + if err != nil { + return nil, err + } + + return dns, nil + }, + want: DNSConfig{ + MagicDNS: true, + BaseDomain: "example.com", + OverrideLocalDNS: false, + Nameservers: Nameservers{ + Global: []string{"1.1.1.1", "8.8.8.8"}, + Split: map[string][]string{ + // "foo.bar.com": {"1.1.1.1"}, + }, + }, + // ExtraRecords: []tailcfg.DNSRecord{ + // {Name: "prometheus.myvpn.example.com", Type: "A", Value: "100.64.0.4"}, + // }, + SearchDomains: []string{"test.com", "bar.com"}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + for k, v := range tt.configEnv { + t.Setenv(k, v) + } + + viper.Reset() + err := LoadConfig("testdata/minimal.yaml", true) + require.NoError(t, err) + + conf, err := tt.setup(t) + require.NoError(t, err) + + if diff := cmp.Diff(tt.want, conf, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("ReadConfig() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestTLSConfigValidation(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "headscale") + if err != nil { + t.Fatal(err) + } + // defer os.RemoveAll(tmpDir) + configYaml := []byte(`--- +tls_letsencrypt_hostname: example.com +tls_letsencrypt_challenge_type: "" +tls_cert_path: abc.pem +noise: + private_key_path: noise_private.key`) + + // Populate a custom config file + configFilePath := filepath.Join(tmpDir, "config.yaml") + err = os.WriteFile(configFilePath, configYaml, 0o600) + if err != nil { + t.Fatalf("Couldn't write file %s", configFilePath) + } + + // Check configuration validation errors (1) + err = LoadConfig(tmpDir, false) + require.NoError(t, err) + + err = validateServerConfig() + require.Error(t, err) + assert.Contains( + t, + err.Error(), + "Fatal config error: set either tls_letsencrypt_hostname or tls_cert_path/tls_key_path, not both", + ) + assert.Contains( + t, + err.Error(), + "Fatal config error: the only supported values for tls_letsencrypt_challenge_type are", + ) + assert.Contains( + t, + err.Error(), + "Fatal config error: server_url must start with https:// or http://", + ) + + // Check configuration validation errors (2) + configYaml = []byte(`--- +noise: + private_key_path: noise_private.key +server_url: http://127.0.0.1:8080 +tls_letsencrypt_hostname: example.com +tls_letsencrypt_challenge_type: TLS-ALPN-01 +`) + err = os.WriteFile(configFilePath, configYaml, 0o600) + if err != nil { + t.Fatalf("Couldn't write file %s", configFilePath) + } + err = LoadConfig(tmpDir, false) + require.NoError(t, err) +} + +// OK +// server_url: headscale.com, base: clients.headscale.com +// server_url: headscale.com, base: headscale.net +// +// NOT OK +// server_url: server.headscale.com, base: headscale.com. +func TestSafeServerURL(t *testing.T) { + tests := []struct { + serverURL, baseDomain, + wantErr string + }{ + { + serverURL: "https://example.com", + baseDomain: "example.org", + }, + { + serverURL: "https://headscale.com", + baseDomain: "headscale.com", + wantErr: errServerURLSame.Error(), + }, + { + serverURL: "https://headscale.com", + baseDomain: "clients.headscale.com", + }, + { + serverURL: "https://headscale.com", + baseDomain: "clients.subdomain.headscale.com", + }, + { + serverURL: "https://headscale.kristoffer.com", + baseDomain: "mybase", + }, + { + serverURL: "https://server.headscale.com", + baseDomain: "headscale.com", + wantErr: errServerURLSuffix.Error(), + }, + { + serverURL: "https://server.subdomain.headscale.com", + baseDomain: "headscale.com", + wantErr: errServerURLSuffix.Error(), + }, + { + serverURL: "http://foo\x00", + wantErr: `parse "http://foo\x00": net/url: invalid control character in URL`, + }, + } + + for _, tt := range tests { + testName := fmt.Sprintf("server=%s domain=%s", tt.serverURL, tt.baseDomain) + t.Run(testName, func(t *testing.T) { + err := isSafeServerURL(tt.serverURL, tt.baseDomain) + if tt.wantErr != "" { + assert.EqualError(t, err, tt.wantErr) + + return + } + assert.NoError(t, err) + }) + } +} diff --git a/hscontrol/types/const.go b/hscontrol/types/const.go index e718eb2e..019c14b6 100644 --- a/hscontrol/types/const.go +++ b/hscontrol/types/const.go @@ -3,7 +3,7 @@ package types import "time" const ( - HTTPReadTimeout = 30 * time.Second + HTTPTimeout = 30 * time.Second HTTPShutdownTimeout = 3 * time.Second TLSALPN01ChallengeType = "TLS-ALPN-01" HTTP01ChallengeType = "HTTP-01" diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 9b2ba769..41cd9759 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -1,23 +1,25 @@ package types import ( - "database/sql/driver" - "encoding/json" "errors" "fmt" "net/netip" - "sort" + "regexp" + "slices" + "strconv" "strings" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" "go4.org/netipx" "google.golang.org/protobuf/types/known/timestamppb" - "gorm.io/gorm" + "tailscale.com/net/tsaddr" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/views" ) var ( @@ -25,48 +27,69 @@ var ( ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars") ErrNodeHasNoGivenName = errors.New("node has no given name") ErrNodeUserHasNoName = errors.New("node user has no name") + ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node") + ErrInvalidNodeView = errors.New("cannot convert invalid NodeView to tailcfg.Node") + + invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") ) +// RouteFunc is a function that takes a node ID and returns a list of +// netip.Prefixes representing the primary routes for that node. +type RouteFunc func(id NodeID) []netip.Prefix + +type ( + NodeID uint64 + NodeIDs []NodeID +) + +func (n NodeIDs) Len() int { return len(n) } +func (n NodeIDs) Less(i, j int) bool { return n[i] < n[j] } +func (n NodeIDs) Swap(i, j int) { n[i], n[j] = n[j], n[i] } + +func (id NodeID) StableID() tailcfg.StableNodeID { + return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) +} + +func (id NodeID) NodeID() tailcfg.NodeID { + return tailcfg.NodeID(id) +} + +func (id NodeID) Uint64() uint64 { + return uint64(id) +} + +func (id NodeID) String() string { + return strconv.FormatUint(id.Uint64(), util.Base10) +} + +func ParseNodeID(s string) (NodeID, error) { + id, err := strconv.ParseUint(s, util.Base10, 64) + return NodeID(id), err +} + +func MustParseNodeID(s string) NodeID { + id, err := ParseNodeID(s) + if err != nil { + panic(err) + } + + return id +} + // Node is a Headscale client. type Node struct { - ID uint64 `gorm:"primary_key"` + ID NodeID `gorm:"primary_key"` - // MachineKeyDatabaseField is the string representation of MachineKey - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use MachineKey instead. - MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"` - MachineKey key.MachinePublic `gorm:"-"` + MachineKey key.MachinePublic `gorm:"serializer:text"` + NodeKey key.NodePublic `gorm:"serializer:text"` + DiscoKey key.DiscoPublic `gorm:"serializer:text"` - // NodeKeyDatabaseField is the string representation of NodeKey - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use NodeKey instead. - NodeKeyDatabaseField string `gorm:"column:node_key"` - NodeKey key.NodePublic `gorm:"-"` + Endpoints []netip.AddrPort `gorm:"serializer:json"` - // DiscoKeyDatabaseField is the string representation of DiscoKey - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use DiscoKey instead. - DiscoKeyDatabaseField string `gorm:"column:disco_key"` - DiscoKey key.DiscoPublic `gorm:"-"` + Hostinfo *tailcfg.Hostinfo `gorm:"column:host_info;serializer:json"` - // EndpointsDatabaseField is the string list representation of Endpoints - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use Endpoints instead. - EndpointsDatabaseField StringList `gorm:"column:endpoints"` - Endpoints []netip.AddrPort `gorm:"-"` - - // EndpointsDatabaseField is the string list representation of Endpoints - // it is _only_ used for reading and writing the key to the - // database and should not be used. - // Use Endpoints instead. - HostinfoDatabaseField string `gorm:"column:host_info"` - Hostinfo *tailcfg.Hostinfo `gorm:"-"` - - IPAddresses NodeAddresses + IPv4 *netip.Addr `gorm:"column:ipv4;serializer:text"` + IPv6 *netip.Addr `gorm:"column:ipv6;serializer:text"` // Hostname represents the name given by the Tailscale // client during registration @@ -79,21 +102,40 @@ type Node struct { // GivenName is the name used in all DNS related // parts of headscale. GivenName string `gorm:"type:varchar(63);unique_index"` - UserID uint - User User `gorm:"foreignKey:UserID"` + + // UserID is set for ALL nodes (tagged and user-owned) to track "created by". + // For tagged nodes, this is informational only - the tag is the owner. + // For user-owned nodes, this identifies the owner. + // Only nil for orphaned nodes (should not happen in normal operation). + UserID *uint + User *User `gorm:"constraint:OnDelete:CASCADE;"` RegisterMethod string - ForcedTags StringList + // Tags is the definitive owner for tagged nodes. + // When non-empty, the node is "tagged" and tags define its identity. + // Empty for user-owned nodes. + // Tags cannot be removed once set (one-way transition). + Tags []string `gorm:"column:tags;serializer:json"` - // TODO(kradalby): This seems like irrelevant information? - AuthKeyID uint + // When a node has been created with a PreAuthKey, we need to + // prevent the preauthkey from being deleted before the node. + // The preauthkey can define "tags" of the node so we need it + // around. + AuthKeyID *uint64 `sql:"DEFAULT:NULL"` AuthKey *PreAuthKey - LastSeen *time.Time - Expiry *time.Time + Expiry *time.Time - Routes []Route + // LastSeen is when the node was last in contact with + // headscale. It is best effort and not persisted. + LastSeen *time.Time `gorm:"column:last_seen"` + + // ApprovedRoutes is a list of routes that the node is allowed to announce + // as a subnet router. They are not necessarily the routes that the node + // announces at the moment. + // See [Node.Hostinfo] + ApprovedRoutes []netip.Prefix `gorm:"column:approved_routes;serializer:json"` CreatedAt time.Time UpdatedAt time.Time @@ -102,103 +144,35 @@ type Node struct { IsOnline *bool `gorm:"-"` } -type ( - Nodes []*Node -) +type Nodes []*Node -type NodeAddresses []netip.Addr - -func (na NodeAddresses) Sort() { - sort.Slice(na, func(index1, index2 int) bool { - if na[index1].Is4() && na[index2].Is6() { - return true - } - if na[index1].Is6() && na[index2].Is4() { - return false - } - - return na[index1].Compare(na[index2]) < 0 - }) -} - -func (na NodeAddresses) StringSlice() []string { - na.Sort() - strSlice := make([]string, 0, len(na)) - for _, addr := range na { - strSlice = append(strSlice, addr.String()) +func (ns Nodes) ViewSlice() views.Slice[NodeView] { + vs := make([]NodeView, len(ns)) + for i, n := range ns { + vs[i] = n.View() } - return strSlice + return views.SliceOf(vs) } -func (na NodeAddresses) Prefixes() []netip.Prefix { - addrs := []netip.Prefix{} - for _, nodeAddress := range na { - ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) - addrs = append(addrs, ip) - } - - return addrs -} - -func (na NodeAddresses) InIPSet(set *netipx.IPSet) bool { - for _, nodeAddr := range na { - if set.Contains(nodeAddr) { - return true - } - } - - return false -} - -// AppendToIPSet adds the individual ips in NodeAddresses to a -// given netipx.IPSetBuilder. -func (na NodeAddresses) AppendToIPSet(build *netipx.IPSetBuilder) { - for _, ip := range na { - build.Add(ip) - } -} - -func (na *NodeAddresses) Scan(destination interface{}) error { - switch value := destination.(type) { - case string: - addresses := strings.Split(value, ",") - *na = (*na)[:0] - for _, addr := range addresses { - if len(addr) < 1 { - continue - } - parsed, err := netip.ParseAddr(addr) - if err != nil { - return err - } - *na = append(*na, parsed) - } - - return nil - - default: - return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination) - } -} - -// Value return json value, implement driver.Valuer interface. -func (na NodeAddresses) Value() (driver.Value, error) { - addresses := strings.Join(na.StringSlice(), ",") - - return addresses, nil +// GivenNameHasBeenChanged returns whether the `givenName` can be automatically changed based on the `Hostname` of the node. +func (node *Node) GivenNameHasBeenChanged() bool { + // Strip invalid DNS characters for givenName comparison + normalised := strings.ToLower(node.Hostname) + normalised = invalidDNSRegex.ReplaceAllString(normalised, "") + return node.GivenName == normalised } // IsExpired returns whether the node registration has expired. func (node Node) IsExpired() bool { // If Expiry is not set, the client has not indicated that - // it wants an expiry time, it is therefor considered + // it wants an expiry time, it is therefore considered // to mean "not expired" if node.Expiry == nil || node.Expiry.IsZero() { return false } - return time.Now().UTC().After(*node.Expiry) + return time.Since(*node.Expiry) > 0 } // IsEphemeral returns if the node is registered as an Ephemeral node. @@ -207,16 +181,155 @@ func (node *Node) IsEphemeral() bool { return node.AuthKey != nil && node.AuthKey.Ephemeral } -func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { - for _, rule := range filter { - // TODO(kradalby): Cache or pregen this - matcher := matcher.MatchFromFilterRule(rule) +func (node *Node) IPs() []netip.Addr { + var ret []netip.Addr - if !matcher.SrcsContainsIPs([]netip.Addr(node.IPAddresses)) { + if node.IPv4 != nil { + ret = append(ret, *node.IPv4) + } + + if node.IPv6 != nil { + ret = append(ret, *node.IPv6) + } + + return ret +} + +// HasIP reports if a node has a given IP address. +func (node *Node) HasIP(i netip.Addr) bool { + for _, ip := range node.IPs() { + if ip.Compare(i) == 0 { + return true + } + } + + return false +} + +// IsTagged reports if a device is tagged and therefore should not be treated +// as a user-owned device. +// When a node has tags, the tags define its identity (not the user). +func (node *Node) IsTagged() bool { + return len(node.Tags) > 0 +} + +// IsUserOwned returns true if node is owned by a user (not tagged). +// Tagged nodes may have a UserID for "created by" tracking, but the tag is the owner. +func (node *Node) IsUserOwned() bool { + return !node.IsTagged() +} + +// HasTag reports if a node has a given tag. +func (node *Node) HasTag(tag string) bool { + return slices.Contains(node.Tags, tag) +} + +// TypedUserID returns the UserID as a typed UserID type. +// Returns 0 if UserID is nil. +func (node *Node) TypedUserID() UserID { + if node.UserID == nil { + return 0 + } + + return UserID(*node.UserID) +} + +func (node *Node) RequestTags() []string { + if node.Hostinfo == nil { + return []string{} + } + + return node.Hostinfo.RequestTags +} + +func (node *Node) Prefixes() []netip.Prefix { + var addrs []netip.Prefix + for _, nodeAddress := range node.IPs() { + ip := netip.PrefixFrom(nodeAddress, nodeAddress.BitLen()) + addrs = append(addrs, ip) + } + + return addrs +} + +// ExitRoutes returns a list of both exit routes if the +// node has any exit routes enabled. +// If none are enabled, it will return nil. +func (node *Node) ExitRoutes() []netip.Prefix { + var routes []netip.Prefix + + for _, route := range node.AnnouncedRoutes() { + if tsaddr.IsExitRoute(route) && slices.Contains(node.ApprovedRoutes, route) { + routes = append(routes, route) + } + } + + return routes +} + +func (node *Node) IsExitNode() bool { + return len(node.ExitRoutes()) > 0 +} + +func (node *Node) IPsAsString() []string { + var ret []string + + for _, ip := range node.IPs() { + ret = append(ret, ip.String()) + } + + return ret +} + +func (node *Node) InIPSet(set *netipx.IPSet) bool { + return slices.ContainsFunc(node.IPs(), set.Contains) +} + +// AppendToIPSet adds the individual ips in NodeAddresses to a +// given netipx.IPSetBuilder. +func (node *Node) AppendToIPSet(build *netipx.IPSetBuilder) { + for _, ip := range node.IPs() { + build.Add(ip) + } +} + +func (node *Node) CanAccess(matchers []matcher.Match, node2 *Node) bool { + src := node.IPs() + allowedIPs := node2.IPs() + + for _, matcher := range matchers { + if !matcher.SrcsContainsIPs(src...) { continue } - if matcher.DestsContainsIP([]netip.Addr(node2.IPAddresses)) { + if matcher.DestsContainsIP(allowedIPs...) { + return true + } + + // Check if the node has access to routes that might be part of a + // smaller subnet that is served from node2 as a subnet router. + if matcher.DestsOverlapsPrefixes(node2.SubnetRoutes()...) { + return true + } + + // If the dst is "the internet" and node2 is an exit node, allow access. + if matcher.DestsIsTheInternet() && node2.IsExitNode() { + return true + } + } + + return false +} + +func (node *Node) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool { + src := node.IPs() + + for _, matcher := range matchers { + if matcher.SrcsContainsIPs(src...) && matcher.DestsOverlapsPrefixes(route) { + return true + } + + if matcher.SrcsOverlapsPrefixes(route) && matcher.DestsContainsIP(src...) { return true } } @@ -225,109 +338,66 @@ func (node *Node) CanAccess(filter []tailcfg.FilterRule, node2 *Node) bool { } func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes { - found := make(Nodes, 0) + var found Nodes for _, node := range nodes { - for _, mIP := range node.IPAddresses { - if ip == mIP { - found = append(found, node) - } + if node.IPv4 != nil && ip == *node.IPv4 { + found = append(found, node) + continue + } + + if node.IPv6 != nil && ip == *node.IPv6 { + found = append(found, node) } } return found } -// BeforeSave is a hook that ensures that some values that -// cannot be directly marshalled into database values are stored -// correctly in the database. -// This currently means storing the keys as strings. -func (node *Node) BeforeSave(tx *gorm.DB) error { - node.MachineKeyDatabaseField = node.MachineKey.String() - node.NodeKeyDatabaseField = node.NodeKey.String() - node.DiscoKeyDatabaseField = node.DiscoKey.String() - - var endpoints StringList - for _, addrPort := range node.Endpoints { - endpoints = append(endpoints, addrPort.String()) - } - - node.EndpointsDatabaseField = endpoints - - hi, err := json.Marshal(node.Hostinfo) - if err != nil { - return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err) - } - node.HostinfoDatabaseField = string(hi) - - return nil -} - -// AfterFind is a hook that ensures that Node objects fields that -// has a different type in the database is unwrapped and populated -// correctly. -// This currently unmarshals all the keys, stored as strings, into -// the proper types. -func (node *Node) AfterFind(tx *gorm.DB) error { - var machineKey key.MachinePublic - if err := machineKey.UnmarshalText([]byte(node.MachineKeyDatabaseField)); err != nil { - return fmt.Errorf("failed to unmarshal machine key from db: %w", err) - } - node.MachineKey = machineKey - - var nodeKey key.NodePublic - if err := nodeKey.UnmarshalText([]byte(node.NodeKeyDatabaseField)); err != nil { - return fmt.Errorf("failed to unmarshal node key from db: %w", err) - } - node.NodeKey = nodeKey - - var discoKey key.DiscoPublic - if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil { - return fmt.Errorf("failed to unmarshal disco key from db: %w", err) - } - node.DiscoKey = discoKey - - endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField)) - for idx, ep := range node.EndpointsDatabaseField { - addrPort, err := netip.ParseAddrPort(ep) - if err != nil { - return fmt.Errorf("failed to parse endpoint from db: %w", err) +func (nodes Nodes) ContainsNodeKey(nodeKey key.NodePublic) bool { + for _, node := range nodes { + if node.NodeKey == nodeKey { + return true } - - endpoints[idx] = addrPort } - node.Endpoints = endpoints - var hi tailcfg.Hostinfo - if err := json.Unmarshal([]byte(node.HostinfoDatabaseField), &hi); err != nil { - log.Trace().Err(err).Msgf("Hostinfo content: %s", node.HostinfoDatabaseField) - - return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err) - } - node.Hostinfo = &hi - - return nil + return false } func (node *Node) Proto() *v1.Node { nodeProto := &v1.Node{ - Id: node.ID, + Id: uint64(node.ID), MachineKey: node.MachineKey.String(), - NodeKey: node.NodeKey.String(), - DiscoKey: node.DiscoKey.String(), - IpAddresses: node.IPAddresses.StringSlice(), + NodeKey: node.NodeKey.String(), + DiscoKey: node.DiscoKey.String(), + + // TODO(kradalby): replace list with v4, v6 field? + IpAddresses: node.IPsAsString(), Name: node.Hostname, GivenName: node.GivenName, - User: node.User.Proto(), - ForcedTags: node.ForcedTags, + User: nil, // Will be set below based on node type + Tags: node.Tags, + Online: node.IsOnline != nil && *node.IsOnline, - // TODO(kradalby): Implement register method enum converter - // RegisterMethod: , + // Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has + // to be populated manually with PrimaryRoute, to ensure it includes the + // routes that are actively served from the node. + ApprovedRoutes: util.PrefixesToString(node.ApprovedRoutes), + AvailableRoutes: util.PrefixesToString(node.AnnouncedRoutes()), + + RegisterMethod: node.RegisterMethodToV1Enum(), CreatedAt: timestamppb.New(node.CreatedAt), } + // Set User field based on node ownership + // Note: User will be set to TaggedDevices in the gRPC layer (grpcv1.go) + // for proper MapResponse formatting + if node.User != nil { + nodeProto.User = node.User.Proto() + } + if node.AuthKey != nil { nodeProto.PreAuthKey = node.AuthKey.Proto() } @@ -343,47 +413,86 @@ func (node *Node) Proto() *v1.Node { return nodeProto } -func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) { - var hostname string - if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS - if node.GivenName == "" { - return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName) - } +func (node *Node) GetFQDN(baseDomain string) (string, error) { + if node.GivenName == "" { + return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName) + } - if node.User.Name == "" { - return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName) - } + hostname := node.GivenName + if baseDomain != "" { hostname = fmt.Sprintf( - "%s.%s.%s", + "%s.%s.", node.GivenName, - node.User.Name, baseDomain, ) - if len(hostname) > MaxHostnameLength { - return "", fmt.Errorf( - "failed to create valid FQDN (%s): %w", - hostname, - ErrHostnameTooLong, - ) - } - } else { - hostname = node.GivenName + } + + if len(hostname) > MaxHostnameLength { + return "", fmt.Errorf( + "failed to create valid FQDN (%s): %w", + hostname, + ErrHostnameTooLong, + ) } return hostname, nil } -// func (node *Node) String() string { -// return node.Hostname -// } +// AnnouncedRoutes returns the list of routes that the node announces. +// It should be used instead of checking Hostinfo.RoutableIPs directly. +func (node *Node) AnnouncedRoutes() []netip.Prefix { + if node.Hostinfo == nil { + return nil + } + + return node.Hostinfo.RoutableIPs +} + +// SubnetRoutes returns the list of routes (excluding exit routes) that the node +// announces and are approved. +// +// IMPORTANT: This method is used for internal data structures and should NOT be +// used for the gRPC Proto conversion. For Proto, SubnetRoutes must be populated +// manually with PrimaryRoutes to ensure it includes only routes actively served +// by the node. See the comment in Proto() method and the implementation in +// grpcv1.go/nodesToProto. +func (node *Node) SubnetRoutes() []netip.Prefix { + var routes []netip.Prefix + + for _, route := range node.AnnouncedRoutes() { + if tsaddr.IsExitRoute(route) { + continue + } + + if slices.Contains(node.ApprovedRoutes, route) { + routes = append(routes, route) + } + } + + return routes +} + +// IsSubnetRouter reports if the node has any subnet routes. +func (node *Node) IsSubnetRouter() bool { + return len(node.SubnetRoutes()) > 0 +} + +// AllApprovedRoutes returns the combination of SubnetRoutes and ExitRoutes +func (node *Node) AllApprovedRoutes() []netip.Prefix { + return append(node.SubnetRoutes(), node.ExitRoutes()...) +} + +func (node *Node) String() string { + return node.Hostname +} // PeerChangeFromMapRequest takes a MapRequest and compares it to the node // to produce a PeerChange struct that can be used to updated the node and // inform peers about smaller changes to the node. // When a field is added to this function, remember to also add it to: // - node.ApplyPeerChange -// - logTracePeerChange in poll.go +// - logTracePeerChange in poll.go. func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { ret := tailcfg.PeerChange{ NodeID: tailcfg.NodeID(node.ID), @@ -420,8 +529,10 @@ func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC } } - // TODO(kradalby): Find a good way to compare updates - ret.Endpoints = req.Endpoints + // Compare endpoints using order-independent comparison + if EndpointsChanged(node.Endpoints, req.Endpoints) { + ret.Endpoints = req.Endpoints + } now := time.Now() ret.LastSeen = &now @@ -429,6 +540,88 @@ func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerC return ret } +// EndpointsChanged compares two endpoint slices and returns true if they differ. +// The comparison is order-independent - endpoints are sorted before comparison. +func EndpointsChanged(oldEndpoints, newEndpoints []netip.AddrPort) bool { + if len(oldEndpoints) != len(newEndpoints) { + return true + } + + if len(oldEndpoints) == 0 { + return false + } + + // Make copies to avoid modifying the original slices + oldCopy := slices.Clone(oldEndpoints) + newCopy := slices.Clone(newEndpoints) + + // Sort both slices to enable order-independent comparison + slices.SortFunc(oldCopy, func(a, b netip.AddrPort) int { + return a.Compare(b) + }) + slices.SortFunc(newCopy, func(a, b netip.AddrPort) int { + return a.Compare(b) + }) + + return !slices.Equal(oldCopy, newCopy) +} + +func (node *Node) RegisterMethodToV1Enum() v1.RegisterMethod { + switch node.RegisterMethod { + case "authkey": + return v1.RegisterMethod_REGISTER_METHOD_AUTH_KEY + case "oidc": + return v1.RegisterMethod_REGISTER_METHOD_OIDC + case "cli": + return v1.RegisterMethod_REGISTER_METHOD_CLI + default: + return v1.RegisterMethod_REGISTER_METHOD_UNSPECIFIED + } +} + +// ApplyHostnameFromHostInfo takes a Hostinfo struct and updates the node. +func (node *Node) ApplyHostnameFromHostInfo(hostInfo *tailcfg.Hostinfo) { + if hostInfo == nil { + return + } + + newHostname := strings.ToLower(hostInfo.Hostname) + if err := util.ValidateHostname(newHostname); err != nil { + log.Warn(). + Str("node.id", node.ID.String()). + Str("current_hostname", node.Hostname). + Str("rejected_hostname", hostInfo.Hostname). + Err(err). + Msg("Rejecting invalid hostname update from hostinfo") + return + } + + if node.Hostname != newHostname { + log.Trace(). + Str("node.id", node.ID.String()). + Str("old_hostname", node.Hostname). + Str("new_hostname", newHostname). + Str("old_given_name", node.GivenName). + Bool("given_name_changed", node.GivenNameHasBeenChanged()). + Msg("Updating hostname from hostinfo") + + if node.GivenNameHasBeenChanged() { + // Strip invalid DNS characters for givenName display + givenName := strings.ToLower(newHostname) + givenName = invalidDNSRegex.ReplaceAllString(givenName, "") + node.GivenName = givenName + } + + node.Hostname = newHostname + + log.Trace(). + Str("node.id", node.ID.String()). + Str("new_hostname", node.Hostname). + Str("new_given_name", node.GivenName). + Msg("Hostname updated") + } +} + // ApplyPeerChange takes a PeerChange struct and updates the node. func (node *Node) ApplyPeerChange(change *tailcfg.PeerChange) { if change.Key != nil { @@ -478,8 +671,8 @@ func (nodes Nodes) String() string { return fmt.Sprintf("[ %s ](%d)", strings.Join(temp, ", "), len(temp)) } -func (nodes Nodes) IDMap() map[uint64]*Node { - ret := map[uint64]*Node{} +func (nodes Nodes) IDMap() map[NodeID]*Node { + ret := map[NodeID]*Node{} for _, node := range nodes { ret[node.ID] = node @@ -487,3 +680,427 @@ func (nodes Nodes) IDMap() map[uint64]*Node { return ret } + +func (nodes Nodes) DebugString() string { + var sb strings.Builder + sb.WriteString("Nodes:\n") + for _, node := range nodes { + sb.WriteString(node.DebugString()) + sb.WriteString("\n") + } + + return sb.String() +} + +func (node Node) DebugString() string { + var sb strings.Builder + fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID) + + // Show ownership status + if node.IsTagged() { + fmt.Fprintf(&sb, "\tTagged: %v\n", node.Tags) + + if node.User != nil { + fmt.Fprintf(&sb, "\tCreated by: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } + } else if node.User != nil { + fmt.Fprintf(&sb, "\tUser-owned: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username()) + } else { + fmt.Fprintf(&sb, "\tOrphaned: no user or tags\n") + } + + fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs()) + fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes) + fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes()) + fmt.Fprintf(&sb, "\tSubnetRoutes: %v\n", node.SubnetRoutes()) + fmt.Fprintf(&sb, "\tExitRoutes: %v\n", node.ExitRoutes()) + sb.WriteString("\n") + + return sb.String() +} + +// Owner returns the owner for display purposes. +// For tagged nodes, returns TaggedDevices. For user-owned nodes, returns the user. +func (nv NodeView) Owner() UserView { + if nv.IsTagged() { + return TaggedDevices.View() + } + + return nv.User() +} + +func (nv NodeView) IPs() []netip.Addr { + if !nv.Valid() { + return nil + } + + return nv.ж.IPs() +} + +func (nv NodeView) InIPSet(set *netipx.IPSet) bool { + if !nv.Valid() { + return false + } + + return nv.ж.InIPSet(set) +} + +func (nv NodeView) CanAccess(matchers []matcher.Match, node2 NodeView) bool { + if !nv.Valid() { + return false + } + + return nv.ж.CanAccess(matchers, node2.AsStruct()) +} + +func (nv NodeView) CanAccessRoute(matchers []matcher.Match, route netip.Prefix) bool { + if !nv.Valid() { + return false + } + + return nv.ж.CanAccessRoute(matchers, route) +} + +func (nv NodeView) AnnouncedRoutes() []netip.Prefix { + if !nv.Valid() { + return nil + } + + return nv.ж.AnnouncedRoutes() +} + +func (nv NodeView) SubnetRoutes() []netip.Prefix { + if !nv.Valid() { + return nil + } + + return nv.ж.SubnetRoutes() +} + +func (nv NodeView) IsSubnetRouter() bool { + if !nv.Valid() { + return false + } + + return nv.ж.IsSubnetRouter() +} + +func (nv NodeView) AllApprovedRoutes() []netip.Prefix { + if !nv.Valid() { + return nil + } + + return nv.ж.AllApprovedRoutes() +} + +func (nv NodeView) AppendToIPSet(build *netipx.IPSetBuilder) { + if !nv.Valid() { + return + } + + nv.ж.AppendToIPSet(build) +} + +func (nv NodeView) RequestTagsSlice() views.Slice[string] { + if !nv.Valid() || !nv.Hostinfo().Valid() { + return views.Slice[string]{} + } + + return nv.Hostinfo().RequestTags() +} + +// IsTagged reports if a device is tagged +// and therefore should not be treated as a +// user owned device. +// Currently, this function only handles tags set +// via CLI ("forced tags" and preauthkeys). +func (nv NodeView) IsTagged() bool { + if !nv.Valid() { + return false + } + + return nv.ж.IsTagged() +} + +// IsExpired returns whether the node registration has expired. +func (nv NodeView) IsExpired() bool { + if !nv.Valid() { + return true + } + + return nv.ж.IsExpired() +} + +// IsEphemeral returns if the node is registered as an Ephemeral node. +// https://tailscale.com/kb/1111/ephemeral-nodes/ +func (nv NodeView) IsEphemeral() bool { + if !nv.Valid() { + return false + } + + return nv.ж.IsEphemeral() +} + +// PeerChangeFromMapRequest takes a MapRequest and compares it to the node +// to produce a PeerChange struct that can be used to updated the node and +// inform peers about smaller changes to the node. +func (nv NodeView) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { + if !nv.Valid() { + return tailcfg.PeerChange{} + } + + return nv.ж.PeerChangeFromMapRequest(req) +} + +// GetFQDN returns the fully qualified domain name for the node. +func (nv NodeView) GetFQDN(baseDomain string) (string, error) { + if !nv.Valid() { + return "", errors.New("failed to create valid FQDN: node view is invalid") + } + + return nv.ж.GetFQDN(baseDomain) +} + +// ExitRoutes returns a list of both exit routes if the +// node has any exit routes enabled. +// If none are enabled, it will return nil. +func (nv NodeView) ExitRoutes() []netip.Prefix { + if !nv.Valid() { + return nil + } + + return nv.ж.ExitRoutes() +} + +func (nv NodeView) IsExitNode() bool { + if !nv.Valid() { + return false + } + + return nv.ж.IsExitNode() +} + +// RequestTags returns the ACL tags that the node is requesting. +func (nv NodeView) RequestTags() []string { + if !nv.Valid() || !nv.Hostinfo().Valid() { + return []string{} + } + + return nv.Hostinfo().RequestTags().AsSlice() +} + +// Proto converts the NodeView to a protobuf representation. +func (nv NodeView) Proto() *v1.Node { + if !nv.Valid() { + return nil + } + + return nv.ж.Proto() +} + +// HasIP reports if a node has a given IP address. +func (nv NodeView) HasIP(i netip.Addr) bool { + if !nv.Valid() { + return false + } + + return nv.ж.HasIP(i) +} + +// HasTag reports if a node has a given tag. +func (nv NodeView) HasTag(tag string) bool { + if !nv.Valid() { + return false + } + + return nv.ж.HasTag(tag) +} + +// TypedUserID returns the UserID as a typed UserID type. +// Returns 0 if UserID is nil or node is invalid. +func (nv NodeView) TypedUserID() UserID { + if !nv.Valid() { + return 0 + } + + return nv.ж.TypedUserID() +} + +// TailscaleUserID returns the user ID to use in Tailscale protocol. +// Tagged nodes always return TaggedDevices.ID, user-owned nodes return their actual UserID. +func (nv NodeView) TailscaleUserID() tailcfg.UserID { + if !nv.Valid() { + return 0 + } + + if nv.IsTagged() { + //nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64 + return tailcfg.UserID(int64(TaggedDevices.ID)) + } + + //nolint:gosec // G115: UserID values are within int64 range + return tailcfg.UserID(int64(nv.UserID().Get())) +} + +// Prefixes returns the node IPs as netip.Prefix. +func (nv NodeView) Prefixes() []netip.Prefix { + if !nv.Valid() { + return nil + } + + return nv.ж.Prefixes() +} + +// IPsAsString returns the node IPs as strings. +func (nv NodeView) IPsAsString() []string { + if !nv.Valid() { + return nil + } + + return nv.ж.IPsAsString() +} + +// HasNetworkChanges checks if the node has network-related changes. +// Returns true if IPs, announced routes, or approved routes changed. +// This is primarily used for policy cache invalidation. +func (nv NodeView) HasNetworkChanges(other NodeView) bool { + if !slices.Equal(nv.IPs(), other.IPs()) { + return true + } + + if !slices.Equal(nv.AnnouncedRoutes(), other.AnnouncedRoutes()) { + return true + } + + if !slices.Equal(nv.SubnetRoutes(), other.SubnetRoutes()) { + return true + } + + return false +} + +// HasPolicyChange reports whether the node has changes that affect policy evaluation. +func (nv NodeView) HasPolicyChange(other NodeView) bool { + if nv.UserID() != other.UserID() { + return true + } + + if !views.SliceEqual(nv.Tags(), other.Tags()) { + return true + } + + if !slices.Equal(nv.IPs(), other.IPs()) { + return true + } + + return false +} + +// TailNodes converts a slice of NodeViews into Tailscale tailcfg.Nodes. +func TailNodes( + nodes views.Slice[NodeView], + capVer tailcfg.CapabilityVersion, + primaryRouteFunc RouteFunc, + cfg *Config, +) ([]*tailcfg.Node, error) { + tNodes := make([]*tailcfg.Node, 0, nodes.Len()) + + for _, node := range nodes.All() { + tNode, err := node.TailNode(capVer, primaryRouteFunc, cfg) + if err != nil { + return nil, err + } + + tNodes = append(tNodes, tNode) + } + + return tNodes, nil +} + +// TailNode converts a NodeView into a Tailscale tailcfg.Node. +func (nv NodeView) TailNode( + capVer tailcfg.CapabilityVersion, + primaryRouteFunc RouteFunc, + cfg *Config, +) (*tailcfg.Node, error) { + if !nv.Valid() { + return nil, ErrInvalidNodeView + } + + hostname, err := nv.GetFQDN(cfg.BaseDomain) + if err != nil { + return nil, err + } + + var derp int + // TODO(kradalby): legacyDERP was removed in tailscale/tailscale@2fc4455e6dd9ab7f879d4e2f7cffc2be81f14077 + // and should be removed after 111 is the minimum capver. + legacyDERP := "127.3.3.40:0" // Zero means disconnected or unknown. + if nv.Hostinfo().Valid() && nv.Hostinfo().NetInfo().Valid() { + legacyDERP = fmt.Sprintf("127.3.3.40:%d", nv.Hostinfo().NetInfo().PreferredDERP()) + derp = nv.Hostinfo().NetInfo().PreferredDERP() + } + + var keyExpiry time.Time + if nv.Expiry().Valid() { + keyExpiry = nv.Expiry().Get() + } + + primaryRoutes := primaryRouteFunc(nv.ID()) + allowedIPs := slices.Concat(nv.Prefixes(), primaryRoutes, nv.ExitRoutes()) + tsaddr.SortPrefixes(allowedIPs) + + capMap := tailcfg.NodeCapMap{ + tailcfg.CapabilityAdmin: []tailcfg.RawMessage{}, + tailcfg.CapabilitySSH: []tailcfg.RawMessage{}, + } + if cfg.RandomizeClientPort { + capMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} + } + + if cfg.Taildrop.Enabled { + capMap[tailcfg.CapabilityFileSharing] = []tailcfg.RawMessage{} + } + + tNode := tailcfg.Node{ + //nolint:gosec // G115: NodeID values are within int64 range + ID: tailcfg.NodeID(nv.ID()), + StableID: nv.ID().StableID(), + Name: hostname, + Cap: capVer, + CapMap: capMap, + + User: nv.TailscaleUserID(), + + Key: nv.NodeKey(), + KeyExpiry: keyExpiry.UTC(), + + Machine: nv.MachineKey(), + DiscoKey: nv.DiscoKey(), + Addresses: nv.Prefixes(), + PrimaryRoutes: primaryRoutes, + AllowedIPs: allowedIPs, + Endpoints: nv.Endpoints().AsSlice(), + HomeDERP: derp, + LegacyDERPString: legacyDERP, + Hostinfo: nv.Hostinfo(), + Created: nv.CreatedAt().UTC(), + + Online: nv.IsOnline().Clone(), + + Tags: nv.Tags().AsSlice(), + + MachineAuthorized: !nv.IsExpired(), + Expired: nv.IsExpired(), + } + + // Set LastSeen only for offline nodes to avoid confusing Tailscale clients + // during rapid reconnection cycles. Online nodes should not have LastSeen set + // as this can make clients interpret them as "not online" despite Online=true. + if nv.LastSeen().Valid() && nv.IsOnline().Valid() && !nv.IsOnline().Get() { + lastSeen := nv.LastSeen().Get() + tNode.LastSeen = &lastSeen + } + + return &tNode, nil +} diff --git a/hscontrol/types/node_tags_test.go b/hscontrol/types/node_tags_test.go new file mode 100644 index 00000000..72598b3c --- /dev/null +++ b/hscontrol/types/node_tags_test.go @@ -0,0 +1,295 @@ +package types + +import ( + "testing" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" + "gorm.io/gorm" + "tailscale.com/types/ptr" +) + +// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged. +func TestNodeIsTagged(t *testing.T) { + tests := []struct { + name string + node Node + want bool + }{ + { + name: "node with tags - is tagged", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + want: true, + }, + { + name: "node with single tag - is tagged", + node: Node{ + Tags: []string{"tag:web"}, + }, + want: true, + }, + { + name: "node with no tags - not tagged", + node: Node{ + Tags: []string{}, + }, + want: false, + }, + { + name: "node with nil tags - not tagged", + node: Node{ + Tags: nil, + }, + want: false, + }, + { + // Tags should be copied from AuthKey during registration, so a node + // with only AuthKey.Tags and no Tags would be invalid in practice. + // IsTagged() only checks node.Tags, not AuthKey.Tags. + name: "node registered with tagged authkey only - not tagged (tags should be copied)", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + want: false, + }, + { + name: "node with both tags and authkey tags - is tagged", + node: Node{ + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + want: true, + }, + { + name: "node with user and no tags - not tagged", + node: Node{ + UserID: ptr.To(uint(42)), + Tags: []string{}, + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.IsTagged() + assert.Equal(t, tt.want, got, "IsTagged() returned unexpected value") + }) + } +} + +// TestNodeViewIsTagged tests the IsTagged() method on NodeView. +func TestNodeViewIsTagged(t *testing.T) { + tests := []struct { + name string + node Node + want bool + }{ + { + name: "tagged node via Tags field", + node: Node{ + Tags: []string{"tag:server"}, + }, + want: true, + }, + { + // Tags should be copied from AuthKey during registration, so a node + // with only AuthKey.Tags and no Tags would be invalid in practice. + name: "node with only AuthKey tags - not tagged (tags should be copied)", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:web"}, + }, + }, + want: false, // IsTagged() only checks node.Tags + }, + { + name: "user-owned node", + node: Node{ + UserID: ptr.To(uint(1)), + }, + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + view := tt.node.View() + got := view.IsTagged() + assert.Equal(t, tt.want, got, "NodeView.IsTagged() returned unexpected value") + }) + } +} + +// TestNodeHasTag tests the HasTag() method for checking specific tag membership. +func TestNodeHasTag(t *testing.T) { + tests := []struct { + name string + node Node + tag string + want bool + }{ + { + name: "node has the tag", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + tag: "tag:server", + want: true, + }, + { + name: "node does not have the tag", + node: Node{ + Tags: []string{"tag:server", "tag:prod"}, + }, + tag: "tag:web", + want: false, + }, + { + // Tags should be copied from AuthKey during registration + // HasTag() only checks node.Tags, not AuthKey.Tags + name: "node has tag only in authkey - returns false", + node: Node{ + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + tag: "tag:database", + want: false, + }, + { + // node.Tags is what matters, not AuthKey.Tags + name: "node has tag in Tags but not in AuthKey", + node: Node{ + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + tag: "tag:server", + want: true, + }, + { + name: "invalid tag format still returns false", + node: Node{ + Tags: []string{"tag:server"}, + }, + tag: "invalid-tag", + want: false, + }, + { + name: "empty tag returns false", + node: Node{ + Tags: []string{"tag:server"}, + }, + tag: "", + want: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.HasTag(tt.tag) + assert.Equal(t, tt.want, got, "HasTag() returned unexpected value") + }) + } +} + +// TestNodeTagsImmutableAfterRegistration tests that tags can only be set during registration. +func TestNodeTagsImmutableAfterRegistration(t *testing.T) { + // Test that a node registered with tags keeps them + taggedNode := Node{ + ID: 1, + Tags: []string{"tag:server"}, + AuthKey: &PreAuthKey{ + Tags: []string{"tag:server"}, + }, + RegisterMethod: util.RegisterMethodAuthKey, + } + + // Node should be tagged + assert.True(t, taggedNode.IsTagged(), "Node registered with tags should be tagged") + + // Node should have the tag + has := taggedNode.HasTag("tag:server") + assert.True(t, has, "Node should have the tag it was registered with") + + // Test that a user-owned node is not tagged + userNode := Node{ + ID: 2, + UserID: ptr.To(uint(42)), + Tags: []string{}, + RegisterMethod: util.RegisterMethodOIDC, + } + + assert.False(t, userNode.IsTagged(), "User-owned node should not be tagged") +} + +// TestNodeOwnershipModel tests the tags-as-identity model. +func TestNodeOwnershipModel(t *testing.T) { + tests := []struct { + name string + node Node + wantIsTagged bool + description string + }{ + { + name: "tagged node has tags, UserID is informational", + node: Node{ + ID: 1, + UserID: ptr.To(uint(5)), // "created by" user 5 + Tags: []string{"tag:server"}, + }, + wantIsTagged: true, + description: "Tagged nodes may have UserID set for tracking, but ownership is defined by tags", + }, + { + name: "user-owned node has no tags", + node: Node{ + ID: 2, + UserID: ptr.To(uint(5)), + Tags: []string{}, + }, + wantIsTagged: false, + description: "User-owned nodes are owned by the user, not by tags", + }, + { + // Tags should be copied from AuthKey to Node during registration + // IsTagged() only checks node.Tags, not AuthKey.Tags + name: "node with only authkey tags - not tagged (tags should be copied)", + node: Node{ + ID: 3, + UserID: ptr.To(uint(5)), // "created by" user 5 + AuthKey: &PreAuthKey{ + Tags: []string{"tag:database"}, + }, + }, + wantIsTagged: false, + description: "IsTagged() only checks node.Tags; AuthKey.Tags should be copied during registration", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.IsTagged() + assert.Equal(t, tt.wantIsTagged, got, tt.description) + }) + } +} + +// TestUserTypedID tests the TypedID() helper method. +func TestUserTypedID(t *testing.T) { + user := User{ + Model: gorm.Model{ID: 42}, + } + + typedID := user.TypedID() + assert.NotNil(t, typedID, "TypedID() should return non-nil pointer") + assert.Equal(t, UserID(42), *typedID, "TypedID() should return correct UserID value") +} diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 7e6c9840..9518833f 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -1,16 +1,25 @@ package types import ( + "fmt" "net/netip" + "strings" "testing" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/policy/matcher" + "github.com/juanfont/headscale/hscontrol/util" "tailscale.com/tailcfg" "tailscale.com/types/key" ) func Test_NodeCanAccess(t *testing.T) { + iap := func(ipStr string) *netip.Addr { + ip := netip.MustParseAddr(ipStr) + return &ip + } tests := []struct { name string node1 Node @@ -21,10 +30,10 @@ func Test_NodeCanAccess(t *testing.T) { { name: "no-rules", node1: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + IPv4: iap("10.0.0.1"), }, node2: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + IPv4: iap("10.0.0.2"), }, rules: []tailcfg.FilterRule{}, want: false, @@ -32,10 +41,10 @@ func Test_NodeCanAccess(t *testing.T) { { name: "wildcard", node1: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.1")}, + IPv4: iap("10.0.0.1"), }, node2: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("10.0.0.2")}, + IPv4: iap("10.0.0.2"), }, rules: []tailcfg.FilterRule{ { @@ -53,10 +62,10 @@ func Test_NodeCanAccess(t *testing.T) { { name: "other-cant-access-src", node1: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.1")}, + IPv4: iap("100.64.0.1"), }, node2: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, + IPv4: iap("100.64.0.3"), }, rules: []tailcfg.FilterRule{ { @@ -71,10 +80,10 @@ func Test_NodeCanAccess(t *testing.T) { { name: "dest-cant-access-src", node1: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, + IPv4: iap("100.64.0.3"), }, node2: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, + IPv4: iap("100.64.0.2"), }, rules: []tailcfg.FilterRule{ { @@ -89,10 +98,10 @@ func Test_NodeCanAccess(t *testing.T) { { name: "src-can-access-dest", node1: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.2")}, + IPv4: iap("100.64.0.2"), }, node2: Node{ - IPAddresses: []netip.Addr{netip.MustParseAddr("100.64.0.3")}, + IPv4: iap("100.64.0.3"), }, rules: []tailcfg.FilterRule{ { @@ -108,7 +117,8 @@ func Test_NodeCanAccess(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.node1.CanAccess(tt.rules, &tt.node2) + matchers := matcher.MatchesFromFilterRules(tt.rules) + got := tt.node1.CanAccess(matchers, &tt.node2) if got != tt.want { t.Errorf("canAccess() failed: want (%t), got (%t)", tt.want, got) @@ -117,110 +127,72 @@ func Test_NodeCanAccess(t *testing.T) { } } -func TestNodeAddressesOrder(t *testing.T) { - machineAddresses := NodeAddresses{ - netip.MustParseAddr("2001:db8::2"), - netip.MustParseAddr("100.64.0.2"), - netip.MustParseAddr("2001:db8::1"), - netip.MustParseAddr("100.64.0.1"), - } - - strSlice := machineAddresses.StringSlice() - expected := []string{ - "100.64.0.1", - "100.64.0.2", - "2001:db8::1", - "2001:db8::2", - } - - if len(strSlice) != len(expected) { - t.Fatalf("unexpected slice length: got %v, want %v", len(strSlice), len(expected)) - } - for i, addr := range strSlice { - if addr != expected[i] { - t.Errorf("unexpected address at index %v: got %v, want %v", i, addr, expected[i]) - } - } -} - func TestNodeFQDN(t *testing.T) { tests := []struct { name string node Node - dns tailcfg.DNSConfig domain string want string wantErr string }{ { - name: "all-set", + name: "no-dnsconfig-with-username", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, - dns: tailcfg.DNSConfig{ - Proxied: true, + domain: "example.com", + want: "test.example.com.", + }, + { + name: "all-set", + node: Node{ + GivenName: "test", + User: &User{ + Name: "user", + }, }, domain: "example.com", - want: "test.user.example.com", + want: "test.example.com.", }, { name: "no-given-name", node: Node{ - User: User{ + User: &User{ Name: "user", }, }, - dns: tailcfg.DNSConfig{ - Proxied: true, - }, domain: "example.com", wantErr: "failed to create valid FQDN: node has no given name", }, { - name: "no-user-name", + name: "too-long-username", node: Node{ - GivenName: "test", - User: User{}, - }, - dns: tailcfg.DNSConfig{ - Proxied: true, + GivenName: strings.Repeat("a", 256), }, domain: "example.com", - wantErr: "failed to create valid FQDN: node user has no name", - }, - { - name: "no-magic-dns", - node: Node{ - GivenName: "test", - User: User{ - Name: "user", - }, - }, - dns: tailcfg.DNSConfig{ - Proxied: false, - }, - domain: "example.com", - want: "test", + wantErr: fmt.Sprintf("failed to create valid FQDN (%s.example.com.): hostname too long, cannot except 255 ASCII chars", strings.Repeat("a", 256)), }, { name: "no-dnsconfig", node: Node{ GivenName: "test", - User: User{ + User: &User{ Name: "user", }, }, domain: "example.com", - want: "test", + want: "test.example.com.", }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got, err := tc.node.GetFQDN(&tc.dns, tc.domain) + got, err := tc.node.GetFQDN(tc.domain) + + t.Logf("GOT: %q, %q", got, tc.domain) if (err != nil) && (err.Error() != tc.wantErr) { t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr) @@ -366,3 +338,634 @@ func TestPeerChangeFromMapRequest(t *testing.T) { }) } } + +func TestApplyHostnameFromHostInfo(t *testing.T) { + tests := []struct { + name string + nodeBefore Node + change *tailcfg.Hostinfo + want Node + }{ + { + name: "hostinfo-not-exists", + nodeBefore: Node{ + GivenName: "manual-test.local", + Hostname: "TestHost.Local", + }, + change: nil, + want: Node{ + GivenName: "manual-test.local", + Hostname: "TestHost.Local", + }, + }, + { + name: "hostinfo-exists-no-automatic-givenName", + nodeBefore: Node{ + GivenName: "manual-test.local", + Hostname: "TestHost.Local", + }, + change: &tailcfg.Hostinfo{ + Hostname: "NewHostName.Local", + }, + want: Node{ + GivenName: "manual-test.local", + Hostname: "newhostname.local", + }, + }, + { + name: "hostinfo-exists-automatic-givenName", + nodeBefore: Node{ + GivenName: "automaticname.test", + Hostname: "AutomaticName.Test", + }, + change: &tailcfg.Hostinfo{ + Hostname: "NewHostName.Local", + }, + want: Node{ + GivenName: "newhostname.local", + Hostname: "newhostname.local", + }, + }, + { + name: "invalid-hostname-with-emoji-rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "hostname-with-💩", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", // Should reject and keep old hostname + }, + }, + { + name: "invalid-hostname-with-unicode-rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "我的电脑", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", // Should keep old hostname + }, + }, + { + name: "invalid-hostname-with-special-chars-rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "node-with-special!@#$%", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", // Should reject and keep old hostname + }, + }, + { + name: "invalid-hostname-too-short-rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "a", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", // Should keep old hostname + }, + }, + { + name: "invalid-hostname-uppercase-accepted-lowercased", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "ValidHostName", + }, + want: Node{ + GivenName: "validhostname", // GivenName follows hostname when it changes + Hostname: "validhostname", // Uppercase is lowercased, not rejected + }, + }, + { + name: "uppercase_to_lowercase_accepted", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "User2-Host", + }, + want: Node{ + GivenName: "user2-host", + Hostname: "user2-host", + }, + }, + { + name: "at_sign_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "Test@Host", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "chinese_chars_with_dash_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "server-北京-01", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "chinese_only_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "我的电脑", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "emoji_with_text_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "laptop-🚀", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "mixed_chinese_emoji_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "测试💻机器", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "only_emojis_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "🎉🎊", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "only_at_signs_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "@@@", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "starts_with_dash_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "-test", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "ends_with_dash_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "test-", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "too_long_hostname_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: strings.Repeat("t", 65), + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + { + name: "underscore_rejected", + nodeBefore: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + change: &tailcfg.Hostinfo{ + Hostname: "test_node", + }, + want: Node{ + GivenName: "valid-hostname", + Hostname: "valid-hostname", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.nodeBefore.ApplyHostnameFromHostInfo(tt.change) + + if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" { + t.Errorf("Patch unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestApplyPeerChange(t *testing.T) { + tests := []struct { + name string + nodeBefore Node + change *tailcfg.PeerChange + want Node + }{ + { + name: "hostinfo-and-netinfo-not-exists", + nodeBefore: Node{}, + change: &tailcfg.PeerChange{ + DERPRegion: 1, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 1, + }, + }, + }, + }, + { + name: "hostinfo-netinfo-not-exists", + nodeBefore: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + }, + }, + change: &tailcfg.PeerChange{ + DERPRegion: 3, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 3, + }, + }, + }, + }, + { + name: "hostinfo-netinfo-exists-derp-set", + nodeBefore: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 999, + }, + }, + }, + change: &tailcfg.PeerChange{ + DERPRegion: 2, + }, + want: Node{ + Hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 2, + }, + }, + }, + }, + { + name: "endpoints-not-set", + nodeBefore: Node{}, + change: &tailcfg.PeerChange{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + want: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + }, + { + name: "endpoints-set", + nodeBefore: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("6.6.6.6:66"), + }, + }, + change: &tailcfg.PeerChange{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + want: Node{ + Endpoints: []netip.AddrPort{ + netip.MustParseAddrPort("8.8.8.8:88"), + }, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.nodeBefore.ApplyPeerChange(tt.change) + + if diff := cmp.Diff(tt.want, tt.nodeBefore, util.Comparers...); diff != "" { + t.Errorf("Patch unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestNodeRegisterMethodToV1Enum(t *testing.T) { + tests := []struct { + name string + node Node + want v1.RegisterMethod + }{ + { + name: "authkey", + node: Node{ + ID: 1, + RegisterMethod: util.RegisterMethodAuthKey, + }, + want: v1.RegisterMethod_REGISTER_METHOD_AUTH_KEY, + }, + { + name: "oidc", + node: Node{ + ID: 1, + RegisterMethod: util.RegisterMethodOIDC, + }, + want: v1.RegisterMethod_REGISTER_METHOD_OIDC, + }, + { + name: "cli", + node: Node{ + ID: 1, + RegisterMethod: util.RegisterMethodCLI, + }, + want: v1.RegisterMethod_REGISTER_METHOD_CLI, + }, + { + name: "unknown", + node: Node{ + ID: 0, + }, + want: v1.RegisterMethod_REGISTER_METHOD_UNSPECIFIED, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.node.RegisterMethodToV1Enum() + + if diff := cmp.Diff(tt.want, got); diff != "" { + t.Errorf("RegisterMethodToV1Enum() unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +// TestHasNetworkChanges tests the NodeView method for detecting +// when a node's network properties have changed. +func TestHasNetworkChanges(t *testing.T) { + mustIPPtr := func(s string) *netip.Addr { + ip := netip.MustParseAddr(s) + return &ip + } + + tests := []struct { + name string + old *Node + new *Node + changed bool + }{ + { + name: "no changes", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + changed: false, + }, + { + name: "IPv4 changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.2"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + }, + changed: true, + }, + { + name: "IPv6 changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::1"), + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + IPv6: mustIPPtr("fd7a:115c:a1e0::2"), + }, + changed: true, + }, + { + name: "RoutableIPs added", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + }, + changed: true, + }, + { + name: "RoutableIPs removed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{}, + }, + changed: true, + }, + { + name: "RoutableIPs changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + }, + changed: true, + }, + { + name: "SubnetRoutes added", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + changed: true, + }, + { + name: "SubnetRoutes removed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{}, + }, + changed: true, + }, + { + name: "SubnetRoutes changed", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")}, + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostinfo: &tailcfg.Hostinfo{RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24"), netip.MustParsePrefix("192.168.0.0/24")}}, + ApprovedRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, + }, + changed: true, + }, + { + name: "irrelevant property changed (Hostname)", + old: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostname: "old-name", + }, + new: &Node{ + ID: 1, + IPv4: mustIPPtr("100.64.0.1"), + Hostname: "new-name", + }, + changed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := tt.new.View().HasNetworkChanges(tt.old.View()) + if got != tt.changed { + t.Errorf("HasNetworkChanges() = %v, want %v", got, tt.changed) + } + }) + } +} diff --git a/hscontrol/types/policy.go b/hscontrol/types/policy.go new file mode 100644 index 00000000..a30bf640 --- /dev/null +++ b/hscontrol/types/policy.go @@ -0,0 +1,20 @@ +package types + +import ( + "errors" + + "gorm.io/gorm" +) + +var ( + ErrPolicyNotFound = errors.New("acl policy not found") + ErrPolicyUpdateIsDisabled = errors.New("update is disabled for modes other than 'database'") +) + +// Policy represents a policy in the database. +type Policy struct { + gorm.Model + + // Data contains the policy in HuJSON format. + Data string +} diff --git a/hscontrol/types/preauth_key.go b/hscontrol/types/preauth_key.go index 0d8c9cff..2ce02f02 100644 --- a/hscontrol/types/preauth_key.go +++ b/hscontrol/types/preauth_key.go @@ -1,45 +1,71 @@ package types import ( - "strconv" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" ) +type PAKError string + +func (e PAKError) Error() string { return string(e) } + // PreAuthKey describes a pre-authorization key usable in a particular user. type PreAuthKey struct { - ID uint64 `gorm:"primary_key"` - Key string - UserID uint - User User + ID uint64 `gorm:"primary_key"` + + // Legacy plaintext key (for backwards compatibility) + Key string + + // New bcrypt-based authentication + Prefix string + Hash []byte // bcrypt + + // For tagged keys: UserID tracks who created the key (informational) + // For user-owned keys: UserID tracks the node owner + // Can be nil for system-created tagged keys + UserID *uint + User *User `gorm:"constraint:OnDelete:SET NULL;"` + Reusable bool Ephemeral bool `gorm:"default:false"` Used bool `gorm:"default:false"` - ACLTags []PreAuthKeyACLTag + + // Tags to assign to nodes registered with this key. + // Tags are copied to the node during registration. + // If non-empty, this creates tagged nodes (not user-owned). + Tags []string `gorm:"serializer:json"` CreatedAt *time.Time Expiration *time.Time } -// PreAuthKeyACLTag describes an autmatic tag applied to a node when registered with the associated PreAuthKey. -type PreAuthKeyACLTag struct { - ID uint64 `gorm:"primary_key"` - PreAuthKeyID uint64 - Tag string +// PreAuthKeyNew is returned once when the key is created. +type PreAuthKeyNew struct { + ID uint64 `gorm:"primary_key"` + Key string + Reusable bool + Ephemeral bool + Tags []string + Expiration *time.Time + CreatedAt *time.Time + User *User // Can be nil for system-created tagged keys } -func (key *PreAuthKey) Proto() *v1.PreAuthKey { +func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey { protoKey := v1.PreAuthKey{ - User: key.User.Name, - Id: strconv.FormatUint(key.ID, util.Base10), + Id: key.ID, Key: key.Key, - Ephemeral: key.Ephemeral, + User: nil, // Will be set below if not nil Reusable: key.Reusable, - Used: key.Used, - AclTags: make([]string, len(key.ACLTags)), + Ephemeral: key.Ephemeral, + AclTags: key.Tags, + } + + if key.User != nil { + protoKey.User = key.User.Proto() } if key.Expiration != nil { @@ -50,9 +76,83 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey { protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) } - for idx := range key.ACLTags { - protoKey.AclTags[idx] = key.ACLTags[idx].Tag + return &protoKey +} + +func (key *PreAuthKey) Proto() *v1.PreAuthKey { + protoKey := v1.PreAuthKey{ + User: nil, // Will be set below if not nil + Id: key.ID, + Ephemeral: key.Ephemeral, + Reusable: key.Reusable, + Used: key.Used, + AclTags: key.Tags, + } + + if key.User != nil { + protoKey.User = key.User.Proto() + } + + // For new keys (with prefix/hash), show the prefix so users can identify the key + // For legacy keys (with plaintext key), show the full key for backwards compatibility + if key.Prefix != "" { + protoKey.Key = "hskey-auth-" + key.Prefix + "-***" + } else if key.Key != "" { + // Legacy key - show full key for backwards compatibility + // TODO: Consider hiding this in a future major version + protoKey.Key = key.Key + } + + if key.Expiration != nil { + protoKey.Expiration = timestamppb.New(*key.Expiration) + } + + if key.CreatedAt != nil { + protoKey.CreatedAt = timestamppb.New(*key.CreatedAt) } return &protoKey } + +// canUsePreAuthKey checks if a pre auth key can be used. +func (pak *PreAuthKey) Validate() error { + if pak == nil { + return PAKError("invalid authkey") + } + + log.Debug(). + Caller(). + Str("key", pak.Key). + Bool("hasExpiration", pak.Expiration != nil). + Time("expiration", func() time.Time { + if pak.Expiration != nil { + return *pak.Expiration + } + return time.Time{} + }()). + Time("now", time.Now()). + Bool("reusable", pak.Reusable). + Bool("used", pak.Used). + Msg("PreAuthKey.Validate: checking key") + + if pak.Expiration != nil && pak.Expiration.Before(time.Now()) { + return PAKError("authkey expired") + } + + // we don't need to check if has been used before + if pak.Reusable { + return nil + } + + if pak.Used { + return PAKError("authkey already used") + } + + return nil +} + +// IsTagged returns true if this PreAuthKey creates tagged nodes. +// When a PreAuthKey has tags, nodes registered with it will be tagged nodes. +func (pak *PreAuthKey) IsTagged() bool { + return len(pak.Tags) > 0 +} diff --git a/hscontrol/types/preauth_key_test.go b/hscontrol/types/preauth_key_test.go new file mode 100644 index 00000000..4ab1c717 --- /dev/null +++ b/hscontrol/types/preauth_key_test.go @@ -0,0 +1,130 @@ +package types + +import ( + "errors" + "testing" + "time" + + "github.com/google/go-cmp/cmp" +) + +func TestCanUsePreAuthKey(t *testing.T) { + now := time.Now() + past := now.Add(-time.Hour) + future := now.Add(time.Hour) + + tests := []struct { + name string + pak *PreAuthKey + wantErr bool + err PAKError + }{ + { + name: "valid reusable key", + pak: &PreAuthKey{ + Reusable: true, + Used: false, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "valid non-reusable key", + pak: &PreAuthKey{ + Reusable: false, + Used: false, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "expired key", + pak: &PreAuthKey{ + Reusable: false, + Used: false, + Expiration: &past, + }, + wantErr: true, + err: PAKError("authkey expired"), + }, + { + name: "used non-reusable key", + pak: &PreAuthKey{ + Reusable: false, + Used: true, + Expiration: &future, + }, + wantErr: true, + err: PAKError("authkey already used"), + }, + { + name: "used reusable key", + pak: &PreAuthKey{ + Reusable: true, + Used: true, + Expiration: &future, + }, + wantErr: false, + }, + { + name: "no expiration date", + pak: &PreAuthKey{ + Reusable: false, + Used: false, + Expiration: nil, + }, + wantErr: false, + }, + { + name: "nil preauth key", + pak: nil, + wantErr: true, + err: PAKError("invalid authkey"), + }, + { + name: "expired and used key", + pak: &PreAuthKey{ + Reusable: false, + Used: true, + Expiration: &past, + }, + wantErr: true, + err: PAKError("authkey expired"), + }, + { + name: "no expiration and used key", + pak: &PreAuthKey{ + Reusable: false, + Used: true, + Expiration: nil, + }, + wantErr: true, + err: PAKError("authkey already used"), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.pak.Validate() + if tt.wantErr { + if err == nil { + t.Errorf("expected error but got none") + } else { + var httpErr PAKError + ok := errors.As(err, &httpErr) + if !ok { + t.Errorf("expected HTTPError but got %T", err) + } else { + if diff := cmp.Diff(tt.err, httpErr); diff != "" { + t.Errorf("unexpected error (-want +got):\n%s", diff) + } + } + } + } else { + if err != nil { + t.Errorf("expected no error but got %v", err) + } + } + }) + } +} diff --git a/hscontrol/types/routes.go b/hscontrol/types/routes.go index 697cbc36..3ff56027 100644 --- a/hscontrol/types/routes.go +++ b/hscontrol/types/routes.go @@ -1,103 +1,31 @@ package types import ( - "fmt" "net/netip" - v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" ) -var ( - ExitRouteV4 = netip.MustParsePrefix("0.0.0.0/0") - ExitRouteV6 = netip.MustParsePrefix("::/0") -) - +// Deprecated: Approval of routes is denormalised onto the relevant node. +// Struct is kept for GORM migrations only. type Route struct { gorm.Model - NodeID uint64 - Node Node + NodeID uint64 `gorm:"not null"` + Node *Node - // TODO(kradalby): change this custom type to netip.Prefix - Prefix IPPrefix + Prefix netip.Prefix `gorm:"serializer:text"` + // Advertised is now only stored as part of [Node.Hostinfo]. Advertised bool - Enabled bool - IsPrimary bool + + // Enabled is stored directly on the node as ApprovedRoutes. + Enabled bool + + // IsPrimary is only determined in memory as it is only relevant + // when the server is up. + IsPrimary bool } +// Deprecated: Approval of routes is denormalised onto the relevant node. type Routes []Route - -func (r *Route) String() string { - return fmt.Sprintf("%s:%s", r.Node.Hostname, netip.Prefix(r.Prefix).String()) -} - -func (r *Route) IsExitRoute() bool { - return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 -} - -func (r *Route) IsAnnouncable() bool { - return r.Advertised && r.Enabled -} - -func (rs Routes) Prefixes() []netip.Prefix { - prefixes := make([]netip.Prefix, len(rs)) - for i, r := range rs { - prefixes[i] = netip.Prefix(r.Prefix) - } - - return prefixes -} - -// Primaries returns Primary routes from a list of routes. -func (rs Routes) Primaries() Routes { - res := make(Routes, 0) - for _, route := range rs { - if route.IsPrimary { - res = append(res, route) - } - } - - return res -} - -func (rs Routes) PrefixMap() map[IPPrefix][]Route { - res := map[IPPrefix][]Route{} - - for _, route := range rs { - if _, ok := res[route.Prefix]; ok { - res[route.Prefix] = append(res[route.Prefix], route) - } else { - res[route.Prefix] = []Route{route} - } - } - - return res -} - -func (rs Routes) Proto() []*v1.Route { - protoRoutes := []*v1.Route{} - - for _, route := range rs { - protoRoute := v1.Route{ - Id: uint64(route.ID), - Node: route.Node.Proto(), - Prefix: netip.Prefix(route.Prefix).String(), - Advertised: route.Advertised, - Enabled: route.Enabled, - IsPrimary: route.IsPrimary, - CreatedAt: timestamppb.New(route.CreatedAt), - UpdatedAt: timestamppb.New(route.UpdatedAt), - } - - if route.DeletedAt.Valid { - protoRoute.DeletedAt = timestamppb.New(route.DeletedAt.Time) - } - - protoRoutes = append(protoRoutes, &protoRoute) - } - - return protoRoutes -} diff --git a/hscontrol/types/routes_test.go b/hscontrol/types/routes_test.go deleted file mode 100644 index ead4c595..00000000 --- a/hscontrol/types/routes_test.go +++ /dev/null @@ -1,94 +0,0 @@ -package types - -import ( - "fmt" - "net/netip" - "testing" - - "github.com/google/go-cmp/cmp" - "github.com/juanfont/headscale/hscontrol/util" -) - -func TestPrefixMap(t *testing.T) { - ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) } - - // TODO(kradalby): Remove when we have gotten rid of IPPrefix type - prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool { - return x == y - }) - - tests := []struct { - rs Routes - want map[IPPrefix][]Route - }{ - { - rs: Routes{ - Route{ - Prefix: ipp("10.0.0.0/24"), - }, - }, - want: map[IPPrefix][]Route{ - ipp("10.0.0.0/24"): Routes{ - Route{ - Prefix: ipp("10.0.0.0/24"), - }, - }, - }, - }, - { - rs: Routes{ - Route{ - Prefix: ipp("10.0.0.0/24"), - }, - Route{ - Prefix: ipp("10.0.1.0/24"), - }, - }, - want: map[IPPrefix][]Route{ - ipp("10.0.0.0/24"): Routes{ - Route{ - Prefix: ipp("10.0.0.0/24"), - }, - }, - ipp("10.0.1.0/24"): Routes{ - Route{ - Prefix: ipp("10.0.1.0/24"), - }, - }, - }, - }, - { - rs: Routes{ - Route{ - Prefix: ipp("10.0.0.0/24"), - Enabled: true, - }, - Route{ - Prefix: ipp("10.0.0.0/24"), - Enabled: false, - }, - }, - want: map[IPPrefix][]Route{ - ipp("10.0.0.0/24"): Routes{ - Route{ - Prefix: ipp("10.0.0.0/24"), - Enabled: true, - }, - Route{ - Prefix: ipp("10.0.0.0/24"), - Enabled: false, - }, - }, - }, - }, - } - - for idx, tt := range tests { - t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) { - got := tt.rs.PrefixMap() - if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" { - t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff) - } - }) - } -} diff --git a/hscontrol/types/testdata/base-domain-in-server-url.yaml b/hscontrol/types/testdata/base-domain-in-server-url.yaml new file mode 100644 index 00000000..10a0b82a --- /dev/null +++ b/hscontrol/types/testdata/base-domain-in-server-url.yaml @@ -0,0 +1,16 @@ +noise: + private_key_path: "private_key.pem" + +prefixes: + v6: fd7a:115c:a1e0::/48 + v4: 100.64.0.0/10 + +database: + type: sqlite3 + +server_url: "https://server.derp.no" + +dns: + magic_dns: true + base_domain: derp.no + override_local_dns: false diff --git a/hscontrol/types/testdata/base-domain-not-in-server-url.yaml b/hscontrol/types/testdata/base-domain-not-in-server-url.yaml new file mode 100644 index 00000000..e78cd6f8 --- /dev/null +++ b/hscontrol/types/testdata/base-domain-not-in-server-url.yaml @@ -0,0 +1,16 @@ +noise: + private_key_path: "private_key.pem" + +prefixes: + v6: fd7a:115c:a1e0::/48 + v4: 100.64.0.0/10 + +database: + type: sqlite3 + +server_url: "https://derp.no" + +dns: + magic_dns: true + base_domain: clients.derp.no + override_local_dns: false diff --git a/hscontrol/types/testdata/dns-override-true-error.yaml b/hscontrol/types/testdata/dns-override-true-error.yaml new file mode 100644 index 00000000..c11e2fca --- /dev/null +++ b/hscontrol/types/testdata/dns-override-true-error.yaml @@ -0,0 +1,16 @@ +noise: + private_key_path: "private_key.pem" + +prefixes: + v6: fd7a:115c:a1e0::/48 + v4: 100.64.0.0/10 + +database: + type: sqlite3 + +server_url: "https://server.derp.no" + +dns: + magic_dns: true + base_domain: derp.no + override_local_dns: true diff --git a/hscontrol/types/testdata/dns-override-true.yaml b/hscontrol/types/testdata/dns-override-true.yaml new file mode 100644 index 00000000..359cea56 --- /dev/null +++ b/hscontrol/types/testdata/dns-override-true.yaml @@ -0,0 +1,20 @@ +noise: + private_key_path: "private_key.pem" + +prefixes: + v6: fd7a:115c:a1e0::/48 + v4: 100.64.0.0/10 + +database: + type: sqlite3 + +server_url: "https://server.derp.no" + +dns: + magic_dns: true + base_domain: derp2.no + override_local_dns: true + nameservers: + global: + - 1.1.1.1 + - 1.0.0.1 diff --git a/hscontrol/types/testdata/dns_full.yaml b/hscontrol/types/testdata/dns_full.yaml new file mode 100644 index 00000000..d27e0fee --- /dev/null +++ b/hscontrol/types/testdata/dns_full.yaml @@ -0,0 +1,36 @@ +# minimum to not fatal +noise: + private_key_path: "private_key.pem" +server_url: "https://derp.no" + +dns: + magic_dns: true + base_domain: example.com + + override_local_dns: false + nameservers: + global: + - 1.1.1.1 + - 1.0.0.1 + - 2606:4700:4700::1111 + - 2606:4700:4700::1001 + - https://dns.nextdns.io/abc123 + + split: + foo.bar.com: + - 1.1.1.1 + darp.headscale.net: + - 1.1.1.1 + - 8.8.8.8 + + search_domains: + - test.com + - bar.com + + extra_records: + - name: "grafana.myvpn.example.com" + type: "A" + value: "100.64.0.3" + + # you can also put it in one line + - { name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" } diff --git a/hscontrol/types/testdata/dns_full_no_magic.yaml b/hscontrol/types/testdata/dns_full_no_magic.yaml new file mode 100644 index 00000000..4fb25d65 --- /dev/null +++ b/hscontrol/types/testdata/dns_full_no_magic.yaml @@ -0,0 +1,36 @@ +# minimum to not fatal +noise: + private_key_path: "private_key.pem" +server_url: "https://derp.no" + +dns: + magic_dns: false + base_domain: example.com + + override_local_dns: false + nameservers: + global: + - 1.1.1.1 + - 1.0.0.1 + - 2606:4700:4700::1111 + - 2606:4700:4700::1001 + - https://dns.nextdns.io/abc123 + + split: + foo.bar.com: + - 1.1.1.1 + darp.headscale.net: + - 1.1.1.1 + - 8.8.8.8 + + search_domains: + - test.com + - bar.com + + extra_records: + - name: "grafana.myvpn.example.com" + type: "A" + value: "100.64.0.3" + + # you can also put it in one line + - { name: "prometheus.myvpn.example.com", type: "A", value: "100.64.0.4" } diff --git a/hscontrol/types/testdata/minimal.yaml b/hscontrol/types/testdata/minimal.yaml new file mode 100644 index 00000000..1d9b1e00 --- /dev/null +++ b/hscontrol/types/testdata/minimal.yaml @@ -0,0 +1,3 @@ +noise: + private_key_path: "private_key.pem" +server_url: "https://derp.no" diff --git a/hscontrol/types/testdata/policy-path-is-loaded.yaml b/hscontrol/types/testdata/policy-path-is-loaded.yaml new file mode 100644 index 00000000..94f60b74 --- /dev/null +++ b/hscontrol/types/testdata/policy-path-is-loaded.yaml @@ -0,0 +1,20 @@ +noise: + private_key_path: "private_key.pem" + +prefixes: + v6: fd7a:115c:a1e0::/48 + v4: 100.64.0.0/10 + +database: + type: sqlite3 + +server_url: "https://derp.no" + +acl_policy_path: "/etc/acl_policy.yaml" +policy: + type: file + path: "/etc/policy.hujson" + +dns: + magic_dns: false + override_local_dns: false diff --git a/hscontrol/types/types_clone.go b/hscontrol/types/types_clone.go new file mode 100644 index 00000000..4dfeedc2 --- /dev/null +++ b/hscontrol/types/types_clone.go @@ -0,0 +1,150 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by tailscale.com/cmd/cloner; DO NOT EDIT. + +package types + +import ( + "database/sql" + "net/netip" + "time" + + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/ptr" +) + +// Clone makes a deep copy of User. +// The result aliases no memory with the original. +func (src *User) Clone() *User { + if src == nil { + return nil + } + dst := new(User) + *dst = *src + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _UserCloneNeedsRegeneration = User(struct { + gorm.Model + Name string + DisplayName string + Email string + ProviderIdentifier sql.NullString + Provider string + ProfilePicURL string +}{}) + +// Clone makes a deep copy of Node. +// The result aliases no memory with the original. +func (src *Node) Clone() *Node { + if src == nil { + return nil + } + dst := new(Node) + *dst = *src + dst.Endpoints = append(src.Endpoints[:0:0], src.Endpoints...) + dst.Hostinfo = src.Hostinfo.Clone() + if dst.IPv4 != nil { + dst.IPv4 = ptr.To(*src.IPv4) + } + if dst.IPv6 != nil { + dst.IPv6 = ptr.To(*src.IPv6) + } + if dst.UserID != nil { + dst.UserID = ptr.To(*src.UserID) + } + if dst.User != nil { + dst.User = ptr.To(*src.User) + } + dst.Tags = append(src.Tags[:0:0], src.Tags...) + if dst.AuthKeyID != nil { + dst.AuthKeyID = ptr.To(*src.AuthKeyID) + } + dst.AuthKey = src.AuthKey.Clone() + if dst.Expiry != nil { + dst.Expiry = ptr.To(*src.Expiry) + } + if dst.LastSeen != nil { + dst.LastSeen = ptr.To(*src.LastSeen) + } + dst.ApprovedRoutes = append(src.ApprovedRoutes[:0:0], src.ApprovedRoutes...) + if dst.DeletedAt != nil { + dst.DeletedAt = ptr.To(*src.DeletedAt) + } + if dst.IsOnline != nil { + dst.IsOnline = ptr.To(*src.IsOnline) + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _NodeCloneNeedsRegeneration = Node(struct { + ID NodeID + MachineKey key.MachinePublic + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + Endpoints []netip.AddrPort + Hostinfo *tailcfg.Hostinfo + IPv4 *netip.Addr + IPv6 *netip.Addr + Hostname string + GivenName string + UserID *uint + User *User + RegisterMethod string + Tags []string + AuthKeyID *uint64 + AuthKey *PreAuthKey + Expiry *time.Time + LastSeen *time.Time + ApprovedRoutes []netip.Prefix + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time + IsOnline *bool +}{}) + +// Clone makes a deep copy of PreAuthKey. +// The result aliases no memory with the original. +func (src *PreAuthKey) Clone() *PreAuthKey { + if src == nil { + return nil + } + dst := new(PreAuthKey) + *dst = *src + dst.Hash = append(src.Hash[:0:0], src.Hash...) + if dst.UserID != nil { + dst.UserID = ptr.To(*src.UserID) + } + if dst.User != nil { + dst.User = ptr.To(*src.User) + } + dst.Tags = append(src.Tags[:0:0], src.Tags...) + if dst.CreatedAt != nil { + dst.CreatedAt = ptr.To(*src.CreatedAt) + } + if dst.Expiration != nil { + dst.Expiration = ptr.To(*src.Expiration) + } + return dst +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct { + ID uint64 + Key string + Prefix string + Hash []byte + UserID *uint + User *User + Reusable bool + Ephemeral bool + Used bool + Tags []string + CreatedAt *time.Time + Expiration *time.Time +}{}) diff --git a/hscontrol/types/types_view.go b/hscontrol/types/types_view.go new file mode 100644 index 00000000..e48dd029 --- /dev/null +++ b/hscontrol/types/types_view.go @@ -0,0 +1,406 @@ +// Copyright (c) Tailscale Inc & AUTHORS +// SPDX-License-Identifier: BSD-3-Clause + +// Code generated by tailscale/cmd/viewer; DO NOT EDIT. + +package types + +import ( + "database/sql" + jsonv1 "encoding/json" + "errors" + "net/netip" + "time" + + jsonv2 "github.com/go-json-experiment/json" + "github.com/go-json-experiment/json/jsontext" + "gorm.io/gorm" + "tailscale.com/tailcfg" + "tailscale.com/types/key" + "tailscale.com/types/views" +) + +//go:generate go run tailscale.com/cmd/cloner -clonefunc=false -type=User,Node,PreAuthKey + +// View returns a read-only view of User. +func (p *User) View() UserView { + return UserView{ж: p} +} + +// UserView provides a read-only view over User. +// +// Its methods should only be called if `Valid()` returns true. +type UserView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *User +} + +// Valid reports whether v's underlying value is non-nil. +func (v UserView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v UserView) AsStruct() *User { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v UserView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.ж) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v UserView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.ж) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *UserView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x User + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *UserView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.ж != nil { + return errors.New("already initialized") + } + var x User + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v UserView) Model() gorm.Model { return v.ж.Model } + +// Name (username) for the user, is used if email is empty +// Should not be used, please use Username(). +// It is unique if ProviderIdentifier is not set. +func (v UserView) Name() string { return v.ж.Name } + +// Typically the full name of the user +func (v UserView) DisplayName() string { return v.ж.DisplayName } + +// Email of the user +// Should not be used, please use Username(). +func (v UserView) Email() string { return v.ж.Email } + +// ProviderIdentifier is a unique or not set identifier of the +// user from OIDC. It is the combination of `iss` +// and `sub` claim in the OIDC token. +// It is unique if set. +// It is unique together with Name. +func (v UserView) ProviderIdentifier() sql.NullString { return v.ж.ProviderIdentifier } + +// Provider is the origin of the user account, +// same as RegistrationMethod, without authkey. +func (v UserView) Provider() string { return v.ж.Provider } +func (v UserView) ProfilePicURL() string { return v.ж.ProfilePicURL } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _UserViewNeedsRegeneration = User(struct { + gorm.Model + Name string + DisplayName string + Email string + ProviderIdentifier sql.NullString + Provider string + ProfilePicURL string +}{}) + +// View returns a read-only view of Node. +func (p *Node) View() NodeView { + return NodeView{ж: p} +} + +// NodeView provides a read-only view over Node. +// +// Its methods should only be called if `Valid()` returns true. +type NodeView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *Node +} + +// Valid reports whether v's underlying value is non-nil. +func (v NodeView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v NodeView) AsStruct() *Node { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v NodeView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.ж) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v NodeView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.ж) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *NodeView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x Node + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *NodeView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.ж != nil { + return errors.New("already initialized") + } + var x Node + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v NodeView) ID() NodeID { return v.ж.ID } +func (v NodeView) MachineKey() key.MachinePublic { return v.ж.MachineKey } +func (v NodeView) NodeKey() key.NodePublic { return v.ж.NodeKey } +func (v NodeView) DiscoKey() key.DiscoPublic { return v.ж.DiscoKey } +func (v NodeView) Endpoints() views.Slice[netip.AddrPort] { return views.SliceOf(v.ж.Endpoints) } +func (v NodeView) Hostinfo() tailcfg.HostinfoView { return v.ж.Hostinfo.View() } +func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv4) } + +func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) } + +// Hostname represents the name given by the Tailscale +// client during registration +func (v NodeView) Hostname() string { return v.ж.Hostname } + +// Givenname represents either: +// a DNS normalized version of Hostname +// a valid name set by the User +// +// GivenName is the name used in all DNS related +// parts of headscale. +func (v NodeView) GivenName() string { return v.ж.GivenName } + +// UserID is set for ALL nodes (tagged and user-owned) to track "created by". +// For tagged nodes, this is informational only - the tag is the owner. +// For user-owned nodes, this identifies the owner. +// Only nil for orphaned nodes (should not happen in normal operation). +func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) } + +func (v NodeView) User() UserView { return v.ж.User.View() } +func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod } + +// Tags is the definitive owner for tagged nodes. +// When non-empty, the node is "tagged" and tags define its identity. +// Empty for user-owned nodes. +// Tags cannot be removed once set (one-way transition). +func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } + +// When a node has been created with a PreAuthKey, we need to +// prevent the preauthkey from being deleted before the node. +// The preauthkey can define "tags" of the node so we need it +// around. +func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) } + +func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() } +func (v NodeView) Expiry() views.ValuePointer[time.Time] { return views.ValuePointerOf(v.ж.Expiry) } + +// LastSeen is when the node was last in contact with +// headscale. It is best effort and not persisted. +func (v NodeView) LastSeen() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.LastSeen) +} + +// ApprovedRoutes is a list of routes that the node is allowed to announce +// as a subnet router. They are not necessarily the routes that the node +// announces at the moment. +// See [Node.Hostinfo] +func (v NodeView) ApprovedRoutes() views.Slice[netip.Prefix] { + return views.SliceOf(v.ж.ApprovedRoutes) +} +func (v NodeView) CreatedAt() time.Time { return v.ж.CreatedAt } +func (v NodeView) UpdatedAt() time.Time { return v.ж.UpdatedAt } +func (v NodeView) DeletedAt() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.DeletedAt) +} + +func (v NodeView) IsOnline() views.ValuePointer[bool] { return views.ValuePointerOf(v.ж.IsOnline) } + +func (v NodeView) String() string { return v.ж.String() } + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _NodeViewNeedsRegeneration = Node(struct { + ID NodeID + MachineKey key.MachinePublic + NodeKey key.NodePublic + DiscoKey key.DiscoPublic + Endpoints []netip.AddrPort + Hostinfo *tailcfg.Hostinfo + IPv4 *netip.Addr + IPv6 *netip.Addr + Hostname string + GivenName string + UserID *uint + User *User + RegisterMethod string + Tags []string + AuthKeyID *uint64 + AuthKey *PreAuthKey + Expiry *time.Time + LastSeen *time.Time + ApprovedRoutes []netip.Prefix + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt *time.Time + IsOnline *bool +}{}) + +// View returns a read-only view of PreAuthKey. +func (p *PreAuthKey) View() PreAuthKeyView { + return PreAuthKeyView{ж: p} +} + +// PreAuthKeyView provides a read-only view over PreAuthKey. +// +// Its methods should only be called if `Valid()` returns true. +type PreAuthKeyView struct { + // ж is the underlying mutable value, named with a hard-to-type + // character that looks pointy like a pointer. + // It is named distinctively to make you think of how dangerous it is to escape + // to callers. You must not let callers be able to mutate it. + ж *PreAuthKey +} + +// Valid reports whether v's underlying value is non-nil. +func (v PreAuthKeyView) Valid() bool { return v.ж != nil } + +// AsStruct returns a clone of the underlying value which aliases no memory with +// the original. +func (v PreAuthKeyView) AsStruct() *PreAuthKey { + if v.ж == nil { + return nil + } + return v.ж.Clone() +} + +// MarshalJSON implements [jsonv1.Marshaler]. +func (v PreAuthKeyView) MarshalJSON() ([]byte, error) { + return jsonv1.Marshal(v.ж) +} + +// MarshalJSONTo implements [jsonv2.MarshalerTo]. +func (v PreAuthKeyView) MarshalJSONTo(enc *jsontext.Encoder) error { + return jsonv2.MarshalEncode(enc, v.ж) +} + +// UnmarshalJSON implements [jsonv1.Unmarshaler]. +func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error { + if v.ж != nil { + return errors.New("already initialized") + } + if len(b) == 0 { + return nil + } + var x PreAuthKey + if err := jsonv1.Unmarshal(b, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +// UnmarshalJSONFrom implements [jsonv2.UnmarshalerFrom]. +func (v *PreAuthKeyView) UnmarshalJSONFrom(dec *jsontext.Decoder) error { + if v.ж != nil { + return errors.New("already initialized") + } + var x PreAuthKey + if err := jsonv2.UnmarshalDecode(dec, &x); err != nil { + return err + } + v.ж = &x + return nil +} + +func (v PreAuthKeyView) ID() uint64 { return v.ж.ID } + +// Legacy plaintext key (for backwards compatibility) +func (v PreAuthKeyView) Key() string { return v.ж.Key } + +// New bcrypt-based authentication +func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix } + +// bcrypt +func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) } + +// For tagged keys: UserID tracks who created the key (informational) +// For user-owned keys: UserID tracks the node owner +// Can be nil for system-created tagged keys +func (v PreAuthKeyView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) } + +func (v PreAuthKeyView) User() UserView { return v.ж.User.View() } +func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable } +func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral } +func (v PreAuthKeyView) Used() bool { return v.ж.Used } + +// Tags to assign to nodes registered with this key. +// Tags are copied to the node during registration. +// If non-empty, this creates tagged nodes (not user-owned). +func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) } +func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.CreatedAt) +} + +func (v PreAuthKeyView) Expiration() views.ValuePointer[time.Time] { + return views.ValuePointerOf(v.ж.Expiration) +} + +// A compilation failure here means this code must be regenerated, with the command at the top of this file. +var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct { + ID uint64 + Key string + Prefix string + Hash []byte + UserID *uint + User *User + Reusable bool + Ephemeral bool + Used bool + Tags []string + CreatedAt *time.Time + Expiration *time.Time +}{}) diff --git a/hscontrol/types/users.go b/hscontrol/types/users.go index 7f6b40ed..ec40492b 100644 --- a/hscontrol/types/users.go +++ b/hscontrol/types/users.go @@ -1,53 +1,389 @@ package types import ( + "cmp" + "database/sql" + "encoding/json" + "fmt" + "net/mail" + "net/url" "strconv" - "time" + "strings" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" + "github.com/rs/zerolog/log" "google.golang.org/protobuf/types/known/timestamppb" "gorm.io/gorm" "tailscale.com/tailcfg" ) +type UserID uint64 + +type Users []User + +const ( + // TaggedDevicesUserID is the special user ID for tagged devices. + // This ID is used when rendering tagged nodes in the Tailscale protocol. + TaggedDevicesUserID = 2147455555 +) + +// TaggedDevices is a special user used in MapResponse for tagged nodes. +// Tagged nodes don't belong to a real user - the tag is their identity. +// This special user ID is used when rendering tagged nodes in the Tailscale protocol. +var TaggedDevices = User{ + Model: gorm.Model{ID: TaggedDevicesUserID}, + Name: "tagged-devices", + DisplayName: "Tagged Devices", +} + +func (u Users) String() string { + var sb strings.Builder + sb.WriteString("[ ") + for _, user := range u { + fmt.Fprintf(&sb, "%d: %s, ", user.ID, user.Name) + } + sb.WriteString(" ]") + + return sb.String() +} + // User is the way Headscale implements the concept of users in Tailscale // // At the end of the day, users in Tailscale are some kind of 'bubbles' or users // that contain our machines. type User struct { gorm.Model - Name string `gorm:"unique"` + // The index `idx_name_provider_identifier` is to enforce uniqueness + // between Name and ProviderIdentifier. This ensures that + // you can have multiple users with the same name in OIDC, + // but not if you only run with CLI users. + + // Name (username) for the user, is used if email is empty + // Should not be used, please use Username(). + // It is unique if ProviderIdentifier is not set. + Name string + + // Typically the full name of the user + DisplayName string + + // Email of the user + // Should not be used, please use Username(). + Email string + + // ProviderIdentifier is a unique or not set identifier of the + // user from OIDC. It is the combination of `iss` + // and `sub` claim in the OIDC token. + // It is unique if set. + // It is unique together with Name. + ProviderIdentifier sql.NullString + + // Provider is the origin of the user account, + // same as RegistrationMethod, without authkey. + Provider string + + ProfilePicURL string } -func (n *User) TailscaleUser() *tailcfg.User { - user := tailcfg.User{ - ID: tailcfg.UserID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, - ProfilePicURL: "", - Logins: []tailcfg.LoginID{}, - Created: time.Time{}, +func (u *User) StringID() string { + if u == nil { + return "" } - - return &user + return strconv.FormatUint(uint64(u.ID), 10) } -func (n *User) TailscaleLogin() *tailcfg.Login { - login := tailcfg.Login{ - ID: tailcfg.LoginID(n.ID), - LoginName: n.Name, - DisplayName: n.Name, - ProfilePicURL: "", +// TypedID returns a pointer to the user's ID as a UserID type. +// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)). +func (u *User) TypedID() *UserID { + uid := UserID(u.ID) + return &uid +} + +// Username is the main way to get the username of a user, +// it will return the email if it exists, the name if it exists, +// the OIDCIdentifier if it exists, and the ID if nothing else exists. +// Email and OIDCIdentifier will be set when the user has headscale +// enabled with OIDC, which means that there is a domain involved which +// should be used throughout headscale, in information returned to the +// user and the Policy engine. +func (u *User) Username() string { + return cmp.Or( + u.Email, + u.Name, + u.ProviderIdentifier.String, + u.StringID(), + ) +} + +// Display returns the DisplayName if it exists, otherwise +// it will return the Username. +func (u *User) Display() string { + return cmp.Or(u.DisplayName, u.Username()) +} + +// TODO(kradalby): See if we can fill in Gravatar here. +func (u *User) profilePicURL() string { + return u.ProfilePicURL +} + +func (u *User) TailscaleUser() tailcfg.User { + return tailcfg.User{ + ID: tailcfg.UserID(u.ID), + DisplayName: u.Display(), + ProfilePicURL: u.profilePicURL(), + Created: u.CreatedAt, } - - return &login } -func (n *User) Proto() *v1.User { +func (u UserView) TailscaleUser() tailcfg.User { + return u.ж.TailscaleUser() +} + +// ID returns the user's ID. +// This is a custom accessor because gorm.Model.ID is embedded +// and the viewer generator doesn't always produce it. +func (u UserView) ID() uint { + return u.ж.ID +} + +func (u *User) TailscaleLogin() tailcfg.Login { + return tailcfg.Login{ + ID: tailcfg.LoginID(u.ID), + Provider: u.Provider, + LoginName: u.Username(), + DisplayName: u.Display(), + ProfilePicURL: u.profilePicURL(), + } +} + +func (u UserView) TailscaleLogin() tailcfg.Login { + return u.ж.TailscaleLogin() +} + +func (u *User) TailscaleUserProfile() tailcfg.UserProfile { + return tailcfg.UserProfile{ + ID: tailcfg.UserID(u.ID), + LoginName: u.Username(), + DisplayName: u.Display(), + ProfilePicURL: u.profilePicURL(), + } +} + +func (u UserView) TailscaleUserProfile() tailcfg.UserProfile { + return u.ж.TailscaleUserProfile() +} + +func (u *User) Proto() *v1.User { + // Use Name if set, otherwise fall back to Username() which provides + // a display-friendly identifier (Email > ProviderIdentifier > ID). + // This ensures OIDC users (who typically have empty Name) display + // their email, while CLI users retain their original Name. + name := u.Name + if name == "" { + name = u.Username() + } return &v1.User{ - Id: strconv.FormatUint(uint64(n.ID), util.Base10), - Name: n.Name, - CreatedAt: timestamppb.New(n.CreatedAt), + Id: uint64(u.ID), + Name: name, + CreatedAt: timestamppb.New(u.CreatedAt), + DisplayName: u.DisplayName, + Email: u.Email, + ProviderId: u.ProviderIdentifier.String, + Provider: u.Provider, + ProfilePicUrl: u.ProfilePicURL, } } + +// JumpCloud returns a JSON where email_verified is returned as a +// string "true" or "false" instead of a boolean. +// This maps bool to a specific type with a custom unmarshaler to +// ensure we can decode it from a string. +// https://github.com/juanfont/headscale/issues/2293 +type FlexibleBoolean bool + +func (bit *FlexibleBoolean) UnmarshalJSON(data []byte) error { + var val any + err := json.Unmarshal(data, &val) + if err != nil { + return fmt.Errorf("could not unmarshal data: %w", err) + } + + switch v := val.(type) { + case bool: + *bit = FlexibleBoolean(v) + case string: + pv, err := strconv.ParseBool(v) + if err != nil { + return fmt.Errorf("could not parse %s as boolean: %w", v, err) + } + *bit = FlexibleBoolean(pv) + + default: + return fmt.Errorf("could not parse %v as boolean", v) + } + + return nil +} + +type OIDCClaims struct { + // Sub is the user's unique identifier at the provider. + Sub string `json:"sub"` + Iss string `json:"iss"` + + // Name is the user's full name. + Name string `json:"name,omitempty"` + Groups []string `json:"groups,omitempty"` + Email string `json:"email,omitempty"` + EmailVerified FlexibleBoolean `json:"email_verified,omitempty"` + ProfilePictureURL string `json:"picture,omitempty"` + Username string `json:"preferred_username,omitempty"` +} + +// Identifier returns a unique identifier string combining the Iss and Sub claims. +// The format depends on whether Iss is a URL or not: +// - For URLs: Joins the URL and sub path (e.g., "https://example.com/sub") +// - For non-URLs: Joins with a slash (e.g., "oidc/sub") +// - For empty Iss: Returns just "sub" +// - For empty Sub: Returns just the Issuer +// - For both empty: Returns empty string +// +// The result is cleaned using CleanIdentifier() to ensure consistent formatting. +func (c *OIDCClaims) Identifier() string { + // Handle empty components special cases + if c.Iss == "" && c.Sub == "" { + return "" + } + if c.Iss == "" { + return CleanIdentifier(c.Sub) + } + if c.Sub == "" { + return CleanIdentifier(c.Iss) + } + + // We'll use the raw values and let CleanIdentifier handle all the whitespace + issuer := c.Iss + subject := c.Sub + + var result string + // Try to parse as URL to handle URL joining correctly + if u, err := url.Parse(issuer); err == nil && u.Scheme != "" { + // For URLs, use proper URL path joining + if joined, err := url.JoinPath(issuer, subject); err == nil { + result = joined + } + } + + // If URL joining failed or issuer wasn't a URL, do simple string join + if result == "" { + // Default case: simple string joining with slash + issuer = strings.TrimSuffix(issuer, "/") + subject = strings.TrimPrefix(subject, "/") + result = issuer + "/" + subject + } + + // Clean the result and return it + return CleanIdentifier(result) +} + +// CleanIdentifier cleans a potentially malformed identifier by removing double slashes +// while preserving protocol specifications like http://. This function will: +// - Trim all whitespace from the beginning and end of the identifier +// - Remove whitespace within path segments +// - Preserve the scheme (http://, https://, etc.) for URLs +// - Remove any duplicate slashes in the path +// - Remove empty path segments +// - For non-URL identifiers, it joins non-empty segments with a single slash +// - Returns empty string for identifiers with only slashes +// - Normalize URL schemes to lowercase. +func CleanIdentifier(identifier string) string { + if identifier == "" { + return identifier + } + + // Trim leading/trailing whitespace + identifier = strings.TrimSpace(identifier) + + // Handle URLs with schemes + u, err := url.Parse(identifier) + if err == nil && u.Scheme != "" { + // Clean path by removing empty segments and whitespace within segments + parts := strings.FieldsFunc(u.Path, func(c rune) bool { return c == '/' }) + for i, part := range parts { + parts[i] = strings.TrimSpace(part) + } + // Remove empty parts after trimming + cleanParts := make([]string, 0, len(parts)) + for _, part := range parts { + if part != "" { + cleanParts = append(cleanParts, part) + } + } + + if len(cleanParts) == 0 { + u.Path = "" + } else { + u.Path = "/" + strings.Join(cleanParts, "/") + } + // Ensure scheme is lowercase + u.Scheme = strings.ToLower(u.Scheme) + + return u.String() + } + + // Handle non-URL identifiers + parts := strings.FieldsFunc(identifier, func(c rune) bool { return c == '/' }) + // Clean whitespace from each part + cleanParts := make([]string, 0, len(parts)) + for _, part := range parts { + trimmed := strings.TrimSpace(part) + if trimmed != "" { + cleanParts = append(cleanParts, trimmed) + } + } + if len(cleanParts) == 0 { + return "" + } + + return strings.Join(cleanParts, "/") +} + +type OIDCUserInfo struct { + Sub string `json:"sub"` + Name string `json:"name"` + GivenName string `json:"given_name"` + FamilyName string `json:"family_name"` + PreferredUsername string `json:"preferred_username"` + Email string `json:"email"` + EmailVerified FlexibleBoolean `json:"email_verified,omitempty"` + Groups []string `json:"groups"` + Picture string `json:"picture"` +} + +// FromClaim overrides a User from OIDC claims. +// All fields will be updated, except for the ID. +func (u *User) FromClaim(claims *OIDCClaims, emailVerifiedRequired bool) { + err := util.ValidateUsername(claims.Username) + if err == nil { + u.Name = claims.Username + } else { + log.Debug().Caller().Err(err).Msgf("Username %s is not valid", claims.Username) + } + + if claims.EmailVerified || !FlexibleBoolean(emailVerifiedRequired) { + _, err = mail.ParseAddress(claims.Email) + if err == nil { + u.Email = claims.Email + } + } + + // Get provider identifier + identifier := claims.Identifier() + // Ensure provider identifier always has a leading slash for backward compatibility + if claims.Iss == "" && !strings.HasPrefix(identifier, "/") { + identifier = "/" + identifier + } + u.ProviderIdentifier = sql.NullString{String: identifier, Valid: true} + u.DisplayName = claims.Name + u.ProfilePicURL = claims.ProfilePictureURL + u.Provider = util.RegisterMethodOIDC +} diff --git a/hscontrol/types/users_test.go b/hscontrol/types/users_test.go new file mode 100644 index 00000000..15386553 --- /dev/null +++ b/hscontrol/types/users_test.go @@ -0,0 +1,495 @@ +package types + +import ( + "database/sql" + "encoding/json" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" +) + +func TestUnmarshallOIDCClaims(t *testing.T) { + tests := []struct { + name string + jsonstr string + want OIDCClaims + }{ + { + name: "normal-bool", + jsonstr: ` +{ + "sub": "test", + "email": "test@test.no", + "email_verified": true +} + `, + want: OIDCClaims{ + Sub: "test", + Email: "test@test.no", + EmailVerified: true, + }, + }, + { + name: "string-bool-true", + jsonstr: ` +{ + "sub": "test2", + "email": "test2@test.no", + "email_verified": "true" +} + `, + want: OIDCClaims{ + Sub: "test2", + Email: "test2@test.no", + EmailVerified: true, + }, + }, + { + name: "string-bool-false", + jsonstr: ` +{ + "sub": "test3", + "email": "test3@test.no", + "email_verified": "false" +} + `, + want: OIDCClaims{ + Sub: "test3", + Email: "test3@test.no", + EmailVerified: false, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got OIDCClaims + if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + t.Errorf("UnmarshallOIDCClaims() error = %v", err) + return + } + if diff := cmp.Diff(got, tt.want); diff != "" { + t.Errorf("UnmarshallOIDCClaims() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOIDCClaimsIdentifier(t *testing.T) { + tests := []struct { + name string + iss string + sub string + expected string + }{ + { + name: "standard URL with trailing slash", + iss: "https://oidc.example.com/", + sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + }, + { + name: "standard URL without trailing slash", + iss: "https://oidc.example.com", + sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + }, + { + name: "standard URL with uppercase protocol", + iss: "HTTPS://oidc.example.com/", + sub: "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + expected: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + }, + { + name: "standard URL with path and trailing slash", + iss: "https://login.microsoftonline.com/v2.0/", + sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "standard URL with path without trailing slash", + iss: "https://login.microsoftonline.com/v2.0", + sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "non-URL identifier with slash", + iss: "oidc", + sub: "sub", + expected: "oidc/sub", + }, + { + name: "non-URL identifier with trailing slash", + iss: "oidc/", + sub: "sub", + expected: "oidc/sub", + }, + { + name: "subject with slash", + iss: "oidc/", + sub: "sub/", + expected: "oidc/sub", + }, + { + name: "whitespace", + iss: " oidc/ ", + sub: " sub ", + expected: "oidc/sub", + }, + { + name: "newline", + iss: "\noidc/\n", + sub: "\nsub\n", + expected: "oidc/sub", + }, + { + name: "tab", + iss: "\toidc/\t", + sub: "\tsub\t", + expected: "oidc/sub", + }, + { + name: "empty issuer", + iss: "", + sub: "sub", + expected: "sub", + }, + { + name: "empty subject", + iss: "https://oidc.example.com", + sub: "", + expected: "https://oidc.example.com", + }, + { + name: "both empty", + iss: "", + sub: "", + expected: "", + }, + { + name: "URL with double slash", + iss: "https://login.microsoftonline.com//v2.0", + sub: "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "FTP URL protocol", + iss: "ftp://example.com/directory", + sub: "resource", + expected: "ftp://example.com/directory/resource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + claims := OIDCClaims{ + Iss: tt.iss, + Sub: tt.sub, + } + result := claims.Identifier() + assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("Identifier() mismatch (-want +got):\n%s", diff) + } + + // Now clean the identifier and verify it's still the same + cleaned := CleanIdentifier(result) + + // Double-check with cmp.Diff for better error messages + if diff := cmp.Diff(tt.expected, cleaned); diff != "" { + t.Errorf("CleanIdentifier(Identifier()) mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestCleanIdentifier(t *testing.T) { + tests := []struct { + name string + identifier string + expected string + }{ + { + name: "empty identifier", + identifier: "", + expected: "", + }, + { + name: "simple identifier", + identifier: "oidc/sub", + expected: "oidc/sub", + }, + { + name: "double slashes in the middle", + identifier: "oidc//sub", + expected: "oidc/sub", + }, + { + name: "trailing slash", + identifier: "oidc/sub/", + expected: "oidc/sub", + }, + { + name: "multiple double slashes", + identifier: "oidc//sub///id//", + expected: "oidc/sub/id", + }, + { + name: "HTTP URL with proper scheme", + identifier: "http://example.com/path", + expected: "http://example.com/path", + }, + { + name: "HTTP URL with double slashes in path", + identifier: "http://example.com//path///resource", + expected: "http://example.com/path/resource", + }, + { + name: "HTTPS URL with empty segments", + identifier: "https://example.com///path//", + expected: "https://example.com/path", + }, + { + name: "URL with double slashes in domain", + identifier: "https://login.microsoftonline.com//v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + expected: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + }, + { + name: "FTP URL with double slashes", + identifier: "ftp://example.com//resource//", + expected: "ftp://example.com/resource", + }, + { + name: "Just slashes", + identifier: "///", + expected: "", + }, + { + name: "Leading slash without URL", + identifier: "/path//to///resource", + expected: "path/to/resource", + }, + { + name: "Non-standard protocol", + identifier: "ldap://example.org//path//to//resource", + expected: "ldap://example.org/path/to/resource", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := CleanIdentifier(tt.identifier) + assert.Equal(t, tt.expected, result) + if diff := cmp.Diff(tt.expected, result); diff != "" { + t.Errorf("CleanIdentifier() mismatch (-want +got):\n%s", diff) + } + }) + } +} + +func TestOIDCClaimsJSONToUser(t *testing.T) { + tests := []struct { + name string + jsonstr string + emailVerifiedRequired bool + want User + }{ + { + name: "normal-bool", + emailVerifiedRequired: true, + jsonstr: ` +{ + "sub": "test", + "email": "test@test.no", + "email_verified": true +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + Email: "test@test.no", + ProviderIdentifier: sql.NullString{ + String: "/test", + Valid: true, + }, + }, + }, + { + name: "string-bool-true", + emailVerifiedRequired: true, + jsonstr: ` +{ + "sub": "test2", + "email": "test2@test.no", + "email_verified": "true" +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + Email: "test2@test.no", + ProviderIdentifier: sql.NullString{ + String: "/test2", + Valid: true, + }, + }, + }, + { + name: "string-bool-false", + emailVerifiedRequired: true, + jsonstr: ` +{ + "sub": "test3", + "email": "test3@test.no", + "email_verified": "false" +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + ProviderIdentifier: sql.NullString{ + String: "/test3", + Valid: true, + }, + }, + }, + { + name: "allow-unverified-email", + emailVerifiedRequired: false, + jsonstr: ` +{ + "sub": "test4", + "email": "test4@test.no", + "email_verified": "false" +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + Email: "test4@test.no", + ProviderIdentifier: sql.NullString{ + String: "/test4", + Valid: true, + }, + }, + }, + { + // From https://github.com/juanfont/headscale/issues/2333 + name: "okta-oidc-claim-20250121", + emailVerifiedRequired: true, + jsonstr: ` +{ + "sub": "00u7dr4qp7XXXXXXXXXX", + "name": "Tim Horton", + "email": "tim.horton@company.com", + "ver": 1, + "iss": "https://sso.company.com/oauth2/default", + "aud": "0oa8neto4tXXXXXXXXXX", + "iat": 1737455152, + "exp": 1737458752, + "jti": "ID.zzJz93koTunMKv5Bq-XXXXXXXXXXXXXXXXXXXXXXXXX", + "amr": [ + "pwd" + ], + "idp": "00o42r3s2cXXXXXXXX", + "nonce": "nonce", + "preferred_username": "tim.horton@company.com", + "auth_time": 1000, + "at_hash": "preview_at_hash" +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + DisplayName: "Tim Horton", + Email: "", + Name: "tim.horton@company.com", + ProviderIdentifier: sql.NullString{ + String: "https://sso.company.com/oauth2/default/00u7dr4qp7XXXXXXXXXX", + Valid: true, + }, + }, + }, + { + // From https://github.com/juanfont/headscale/issues/2333 + name: "okta-oidc-claim-20250121", + emailVerifiedRequired: true, + jsonstr: ` +{ + "aud": "79xxxxxx-xxxx-xxxx-xxxx-892146xxxxxx", + "iss": "https://login.microsoftonline.com//v2.0", + "iat": 1737346441, + "nbf": 1737346441, + "exp": 1737350341, + "aio": "AWQAm/8ZAAAABKne9EWr6ygVO2DbcRmoPIpRM819qqlP/mmK41AAWv/C2tVkld4+znbG8DaXFdLQa9jRUzokvsT7rt9nAT6Fg7QC+/ecDWsF5U+QX11f9Ox7ZkK4UAIWFcIXpuZZvRS7", + "email": "user@domain.com", + "name": "XXXXXX XXXX", + "oid": "54c2323d-5052-4130-9588-ad751909003f", + "preferred_username": "user@domain.com", + "rh": "1.AXUAXdg0Rfc11UifLDJv67ChfSluoXmD9z1EmK-JIUYuSK9cAQl1AA.", + "sid": "5250a0a2-0b4e-4e68-8652-b4e97866411d", + "sub": "I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + "tid": "<redacted>", + "uti": "zAuXeEtMM0GwcTAcOsBZAA", + "ver": "2.0" +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + DisplayName: "XXXXXX XXXX", + Name: "user@domain.com", + Email: "", + ProviderIdentifier: sql.NullString{ + String: "https://login.microsoftonline.com/v2.0/I-70OQnj3TogrNSfkZQqB3f7dGwyBWSm1dolHNKrMzQ", + Valid: true, + }, + }, + }, + { + // From https://github.com/juanfont/headscale/issues/2333 + name: "casby-oidc-claim-20250513", + emailVerifiedRequired: true, + jsonstr: ` + { + "sub": "xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + "iss": "https://oidc.example.com/", + "aud": "xxxxxxxxxxxx", + "preferred_username": "user001", + "name": "User001", + "email": "user001@example.com", + "email_verified": true, + "picture": "https://cdn.casbin.org/img/casbin.svg", + "groups": [ + "org1/department1", + "org1/department2" + ] +} + `, + want: User{ + Provider: util.RegisterMethodOIDC, + Name: "user001", + DisplayName: "User001", + Email: "user001@example.com", + ProviderIdentifier: sql.NullString{ + String: "https://oidc.example.com/xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx", + Valid: true, + }, + ProfilePicURL: "https://cdn.casbin.org/img/casbin.svg", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var got OIDCClaims + if err := json.Unmarshal([]byte(tt.jsonstr), &got); err != nil { + t.Errorf("TestOIDCClaimsJSONToUser() error = %v", err) + return + } + + var user User + + user.FromClaim(&got, tt.emailVerifiedRequired) + if diff := cmp.Diff(user, tt.want); diff != "" { + t.Errorf("TestOIDCClaimsJSONToUser() mismatch (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/types/version.go b/hscontrol/types/version.go new file mode 100644 index 00000000..6676c92f --- /dev/null +++ b/hscontrol/types/version.go @@ -0,0 +1,81 @@ +package types + +import ( + "fmt" + "runtime" + "runtime/debug" + "strings" + "sync" +) + +type GoInfo struct { + Version string `json:"version"` + OS string `json:"os"` + Arch string `json:"arch"` +} + +type VersionInfo struct { + Version string `json:"version"` + Commit string `json:"commit"` + BuildTime string `json:"buildTime"` + Go GoInfo `json:"go"` + Dirty bool `json:"dirty"` +} + +func (v *VersionInfo) String() string { + var sb strings.Builder + + version := v.Version + if v.Dirty && !strings.Contains(version, "dirty") { + version += "-dirty" + } + + sb.WriteString(fmt.Sprintf("headscale version %s\n", version)) + sb.WriteString(fmt.Sprintf("commit: %s\n", v.Commit)) + sb.WriteString(fmt.Sprintf("build time: %s\n", v.BuildTime)) + sb.WriteString(fmt.Sprintf("built with: %s %s/%s\n", v.Go.Version, v.Go.OS, v.Go.Arch)) + + return sb.String() +} + +var buildInfo = sync.OnceValues(func() (*debug.BuildInfo, bool) { + return debug.ReadBuildInfo() +}) + +var GetVersionInfo = sync.OnceValue(func() *VersionInfo { + info := &VersionInfo{ + Version: "dev", + Commit: "unknown", + BuildTime: "unknown", + Go: GoInfo{ + Version: runtime.Version(), + OS: runtime.GOOS, + Arch: runtime.GOARCH, + }, + Dirty: false, + } + + buildInfo, ok := buildInfo() + if !ok { + return info + } + + // Extract version from module path or main version + if buildInfo.Main.Version != "" && buildInfo.Main.Version != "(devel)" { + info.Version = buildInfo.Main.Version + } + + // Extract build settings + for _, setting := range buildInfo.Settings { + switch setting.Key { + case "vcs.revision": + info.Commit = setting.Value + case "vcs.modified": + info.Dirty = setting.Value == "true" + case "vcs.time": + info.BuildTime = setting.Value + } + } + + return info +}) diff --git a/hscontrol/util/addr.go b/hscontrol/util/addr.go index 5c02c933..c91ef0ba 100644 --- a/hscontrol/util/addr.go +++ b/hscontrol/util/addr.go @@ -2,8 +2,8 @@ package util import ( "fmt" + "iter" "net/netip" - "reflect" "strings" "go4.org/netipx" @@ -104,7 +104,7 @@ func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { for index, prefixStr := range prefixes { prefix, err := netip.ParsePrefix(prefixStr) if err != nil { - return []netip.Prefix{}, err + return nil, err } result[index] = prefix @@ -113,12 +113,15 @@ func StringToIPPrefix(prefixes []string) ([]netip.Prefix, error) { return result, nil } -func StringOrPrefixListContains[T string | netip.Prefix](ts []T, t T) bool { - for _, v := range ts { - if reflect.DeepEqual(v, t) { - return true +// IPSetAddrIter returns a function that iterates over all the IPs in the IPSet. +func IPSetAddrIter(ipSet *netipx.IPSet) iter.Seq[netip.Addr] { + return func(yield func(netip.Addr) bool) { + for _, rng := range ipSet.Ranges() { + for ip := rng.From(); ip.Compare(rng.To()) <= 0; ip = ip.Next() { + if !yield(ip) { + return + } + } } } - - return false } diff --git a/hscontrol/util/dns.go b/hscontrol/util/dns.go index c6bd2b69..dcd58528 100644 --- a/hscontrol/util/dns.go +++ b/hscontrol/util/dns.go @@ -5,9 +5,10 @@ import ( "fmt" "net/netip" "regexp" + "strconv" "strings" + "unicode" - "github.com/spf13/viper" "go4.org/netipx" "tailscale.com/util/dnsname" ) @@ -21,68 +22,136 @@ const ( LabelHostnameLength = 63 ) -var invalidCharsInUserRegex = regexp.MustCompile("[^a-z0-9-.]+") +var invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+") -var ErrInvalidUserName = errors.New("invalid user name") +var ErrInvalidHostName = errors.New("invalid hostname") -func NormalizeToFQDNRulesConfigFromViper(name string) (string, error) { - strip := viper.GetBool("oidc.strip_email_domain") - - return NormalizeToFQDNRules(name, strip) -} - -// NormalizeToFQDNRules will replace forbidden chars in user -// it can also return an error if the user doesn't respect RFC 952 and 1123. -func NormalizeToFQDNRules(name string, stripEmailDomain bool) (string, error) { - name = strings.ToLower(name) - name = strings.ReplaceAll(name, "'", "") - atIdx := strings.Index(name, "@") - if stripEmailDomain && atIdx > 0 { - name = name[:atIdx] - } else { - name = strings.ReplaceAll(name, "@", ".") +// ValidateUsername checks if a username is valid. +// It must be at least 2 characters long, start with a letter, and contain +// only letters, numbers, hyphens, dots, and underscores. +// It cannot contain more than one '@'. +// It cannot contain invalid characters. +func ValidateUsername(username string) error { + // Ensure the username meets the minimum length requirement + if len(username) < 2 { + return errors.New("username must be at least 2 characters long") } - name = invalidCharsInUserRegex.ReplaceAllString(name, "-") - for _, elt := range strings.Split(name, ".") { - if len(elt) > LabelHostnameLength { - return "", fmt.Errorf( - "label %v is more than 63 chars: %w", - elt, - ErrInvalidUserName, - ) + // Ensure the username starts with a letter + if !unicode.IsLetter(rune(username[0])) { + return errors.New("username must start with a letter") + } + + atCount := 0 + + for _, char := range username { + switch { + case unicode.IsLetter(char), + unicode.IsDigit(char), + char == '-', + char == '.', + char == '_': + // Valid characters + case char == '@': + atCount++ + if atCount > 1 { + return errors.New("username cannot contain more than one '@'") + } + default: + return fmt.Errorf("username contains invalid character: '%c'", char) } } - return name, nil + return nil } -func CheckForFQDNRules(name string) error { +// ValidateHostname checks if a hostname meets DNS requirements. +// This function does NOT modify the input - it only validates. +// The hostname must already be lowercase and contain only valid characters. +func ValidateHostname(name string) error { + if len(name) < 2 { + return fmt.Errorf( + "hostname %q is too short, must be at least 2 characters", + name, + ) + } if len(name) > LabelHostnameLength { return fmt.Errorf( - "DNS segment must not be over 63 chars. %v doesn't comply with this rule: %w", + "hostname %q is too long, must not exceed 63 characters", name, - ErrInvalidUserName, ) } if strings.ToLower(name) != name { return fmt.Errorf( - "DNS segment should be lowercase. %v doesn't comply with this rule: %w", + "hostname %q must be lowercase (try %q)", name, - ErrInvalidUserName, + strings.ToLower(name), ) } - if invalidCharsInUserRegex.MatchString(name) { + + if strings.HasPrefix(name, "-") || strings.HasSuffix(name, "-") { return fmt.Errorf( - "DNS segment should only be composed of lowercase ASCII letters numbers, hyphen and dots. %v doesn't comply with theses rules: %w", + "hostname %q cannot start or end with a hyphen", + name, + ) + } + + if strings.HasPrefix(name, ".") || strings.HasSuffix(name, ".") { + return fmt.Errorf( + "hostname %q cannot start or end with a dot", + name, + ) + } + + if invalidDNSRegex.MatchString(name) { + return fmt.Errorf( + "hostname %q contains invalid characters, only lowercase letters, numbers, hyphens and dots are allowed", name, - ErrInvalidUserName, ) } return nil } +// NormaliseHostname transforms a string into a valid DNS hostname. +// Returns error if the transformation results in an invalid hostname. +// +// Transformations applied: +// - Converts to lowercase +// - Removes invalid DNS characters +// - Truncates to 63 characters if needed +// +// After transformation, validates the result. +func NormaliseHostname(name string) (string, error) { + // Early return if already valid + err := ValidateHostname(name) + if err == nil { + return name, nil + } + + // Transform to lowercase + name = strings.ToLower(name) + + // Strip invalid DNS characters + name = invalidDNSRegex.ReplaceAllString(name, "") + + // Truncate to DNS label limit + if len(name) > LabelHostnameLength { + name = name[:LabelHostnameLength] + } + + // Validate result after transformation + err = ValidateHostname(name) + if err != nil { + return "", fmt.Errorf( + "hostname invalid after normalisation: %w", + err, + ) + } + + return name, nil +} + // generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. // This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS // server (listening in 100.100.100.100 udp/53) should be used for. @@ -103,33 +172,7 @@ func CheckForFQDNRules(name string) error { // From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). // This allows us to then calculate the subnets included in the subsequent class block and generate the entries. -func GenerateMagicDNSRootDomains(ipPrefixes []netip.Prefix) []dnsname.FQDN { - fqdns := make([]dnsname.FQDN, 0, len(ipPrefixes)) - for _, ipPrefix := range ipPrefixes { - var generateDNSRoot func(netip.Prefix) []dnsname.FQDN - switch ipPrefix.Addr().BitLen() { - case ipv4AddressLength: - generateDNSRoot = generateIPv4DNSRootDomain - - case ipv6AddressLength: - generateDNSRoot = generateIPv6DNSRootDomain - - default: - panic( - fmt.Sprintf( - "unsupported IP version with address length %d", - ipPrefix.Addr().BitLen(), - ), - ) - } - - fqdns = append(fqdns, generateDNSRoot(ipPrefix)...) - } - - return fqdns -} - -func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { +func GenerateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // Conversion to the std lib net.IPnet, a bit easier to operate netRange := netipx.PrefixIPNet(ipPrefix) maskBits, _ := netRange.Mask.Size() @@ -148,7 +191,7 @@ func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // here we generate the base domain (e.g., 100.in-addr.arpa., 16.172.in-addr.arpa., etc.) rdnsSlice := []string{} for i := lastOctet - 1; i >= 0; i-- { - rdnsSlice = append(rdnsSlice, fmt.Sprintf("%d", netRange.IP[i])) + rdnsSlice = append(rdnsSlice, strconv.FormatUint(uint64(netRange.IP[i]), 10)) } rdnsSlice = append(rdnsSlice, "in-addr.arpa.") rdnsBase := strings.Join(rdnsSlice, ".") @@ -165,7 +208,27 @@ func generateIPv4DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { return fqdns } -func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { +// generateMagicDNSRootDomains generates a list of DNS entries to be included in `Routes` in `MapResponse`. +// This list of reverse DNS entries instructs the OS on what subnets and domains the Tailscale embedded DNS +// server (listening in 100.100.100.100 udp/53) should be used for. +// +// Tailscale.com includes in the list: +// - the `BaseDomain` of the user +// - the reverse DNS entry for IPv6 (0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa., see below more on IPv6) +// - the reverse DNS entries for the IPv4 subnets covered by the user's `IPPrefix`. +// In the public SaaS this is [64-127].100.in-addr.arpa. +// +// The main purpose of this function is then generating the list of IPv4 entries. For the 100.64.0.0/10, this +// is clear, and could be hardcoded. But we are allowing any range as `IPPrefix`, so we need to find out the +// subnets when we have 172.16.0.0/16 (i.e., [0-255].16.172.in-addr.arpa.), or any other subnet. +// +// How IN-ADDR.ARPA domains work is defined in RFC1035 (section 3.5). Tailscale.com seems to adhere to this, +// and do not make use of RFC2317 ("Classless IN-ADDR.ARPA delegation") - hence generating the entries for the next +// class block only. + +// From the netmask we can find out the wildcard bits (the bits that are not set in the netmask). +// This allows us to then calculate the subnets included in the subsequent class block and generate the entries. +func GenerateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { const nibbleLen = 4 maskBits, _ := netipx.PrefixIPNet(ipPrefix).Mask.Size() @@ -183,7 +246,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { // and from what I can see, the generateMagicDNSRootDomains // function is called only once over the lifetime of a server process. prefixConstantParts := []string{} - for i := 0; i < maskBits/nibbleLen; i++ { + for i := range maskBits / nibbleLen { prefixConstantParts = append( []string{string(nibbleStr[i])}, prefixConstantParts...) @@ -192,7 +255,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { makeDomain := func(variablePrefix ...string) (dnsname.FQDN, error) { prefix := strings.Join(append(variablePrefix, prefixConstantParts...), ".") - return dnsname.ToFQDN(fmt.Sprintf("%s.ip6.arpa", prefix)) + return dnsname.ToFQDN(prefix + ".ip6.arpa") } var fqdns []dnsname.FQDN @@ -202,7 +265,7 @@ func generateIPv6DNSRootDomain(ipPrefix netip.Prefix) []dnsname.FQDN { } else { domCount := 1 << (maskBits % nibbleLen) fqdns = make([]dnsname.FQDN, 0, domCount) - for i := 0; i < domCount; i++ { + for i := range domCount { varNibble := fmt.Sprintf("%x", i) dom, err := makeDomain(varNibble) if err != nil { diff --git a/hscontrol/util/dns_test.go b/hscontrol/util/dns_test.go index 9d9b08b3..b492e4d6 100644 --- a/hscontrol/util/dns_test.go +++ b/hscontrol/util/dns_test.go @@ -2,15 +2,17 @@ package util import ( "net/netip" + "strings" "testing" "github.com/stretchr/testify/assert" + "tailscale.com/util/dnsname" + "tailscale.com/util/must" ) -func TestNormalizeToFQDNRules(t *testing.T) { +func TestNormaliseHostname(t *testing.T) { type args struct { - name string - stripEmailDomain bool + name string } tests := []struct { name string @@ -19,214 +21,191 @@ func TestNormalizeToFQDNRules(t *testing.T) { wantErr bool }{ { - name: "normalize simple name", - args: args{ - name: "normalize-simple.name", - stripEmailDomain: false, - }, - want: "normalize-simple.name", + name: "valid: lowercase user", + args: args{name: "valid-user"}, + want: "valid-user", wantErr: false, }, { - name: "normalize an email", - args: args{ - name: "foo.bar@example.com", - stripEmailDomain: false, - }, - want: "foo.bar.example.com", + name: "normalise: capitalized user", + args: args{name: "Invalid-CapItaLIzed-user"}, + want: "invalid-capitalized-user", wantErr: false, }, { - name: "normalize an email domain should be removed", - args: args{ - name: "foo.bar@example.com", - stripEmailDomain: true, - }, - want: "foo.bar", + name: "normalise: email as user", + args: args{name: "foo.bar@example.com"}, + want: "foo.barexample.com", wantErr: false, }, { - name: "strip enabled no email passed as argument", - args: args{ - name: "not-email-and-strip-enabled", - stripEmailDomain: true, - }, - want: "not-email-and-strip-enabled", + name: "normalise: chars in user name", + args: args{name: "super-user+name"}, + want: "super-username", wantErr: false, }, { - name: "normalize complex email", + name: "invalid: too long name truncated leaves trailing hyphen", args: args{ - name: "foo.bar+complex-email@example.com", - stripEmailDomain: false, + name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", }, - want: "foo.bar-complex-email.example.com", + want: "", + wantErr: true, + }, + { + name: "invalid: emoji stripped leaves trailing hyphen", + args: args{name: "hostname-with-💩"}, + want: "", + wantErr: true, + }, + { + name: "normalise: multiple emojis stripped", + args: args{name: "node-🎉-🚀-test"}, + want: "node---test", wantErr: false, }, { - name: "user name with space", - args: args{ - name: "name space", - stripEmailDomain: false, - }, - want: "name-space", - wantErr: false, + name: "invalid: only emoji becomes empty", + args: args{name: "💩"}, + want: "", + wantErr: true, }, { - name: "user with quote", - args: args{ - name: "Jamie's iPhone 5", - stripEmailDomain: false, - }, - want: "jamies-iphone-5", - wantErr: false, + name: "invalid: emoji at start leaves leading hyphen", + args: args{name: "🚀-rocket-node"}, + want: "", + wantErr: true, + }, + { + name: "invalid: emoji at end leaves trailing hyphen", + args: args{name: "node-test-🎉"}, + want: "", + wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := NormalizeToFQDNRules(tt.args.name, tt.args.stripEmailDomain) + got, err := NormaliseHostname(tt.args.name) if (err != nil) != tt.wantErr { - t.Errorf( - "NormalizeToFQDNRules() error = %v, wantErr %v", - err, - tt.wantErr, - ) - + t.Errorf("NormaliseHostname() error = %v, wantErr %v", err, tt.wantErr) return } - if got != tt.want { - t.Errorf("NormalizeToFQDNRules() = %v, want %v", got, tt.want) + if !tt.wantErr && got != tt.want { + t.Errorf("NormaliseHostname() = %v, want %v", got, tt.want) } }) } } -func TestCheckForFQDNRules(t *testing.T) { - type args struct { - name string - } +func TestValidateHostname(t *testing.T) { tests := []struct { - name string - args args - wantErr bool + name string + hostname string + wantErr bool + errorContains string }{ { - name: "valid: user", - args: args{name: "valid-user"}, - wantErr: false, + name: "valid lowercase", + hostname: "valid-hostname", + wantErr: false, }, { - name: "invalid: capitalized user", - args: args{name: "Invalid-CapItaLIzed-user"}, - wantErr: true, + name: "uppercase rejected", + hostname: "MyHostname", + wantErr: true, + errorContains: "must be lowercase", }, { - name: "invalid: email as user", - args: args{name: "foo.bar@example.com"}, - wantErr: true, + name: "too short", + hostname: "a", + wantErr: true, + errorContains: "too short", }, { - name: "invalid: chars in user name", - args: args{name: "super-user+name"}, - wantErr: true, + name: "too long", + hostname: "a" + strings.Repeat("b", 63), + wantErr: true, + errorContains: "too long", }, { - name: "invalid: too long name for user", - args: args{ - name: "super-long-useruseruser-name-that-should-be-a-little-more-than-63-chars", - }, - wantErr: true, + name: "emoji rejected", + hostname: "hostname-💩", + wantErr: true, + errorContains: "invalid characters", + }, + { + name: "starts with hyphen", + hostname: "-hostname", + wantErr: true, + errorContains: "cannot start or end with a hyphen", + }, + { + name: "ends with hyphen", + hostname: "hostname-", + wantErr: true, + errorContains: "cannot start or end with a hyphen", + }, + { + name: "starts with dot", + hostname: ".hostname", + wantErr: true, + errorContains: "cannot start or end with a dot", + }, + { + name: "ends with dot", + hostname: "hostname.", + wantErr: true, + errorContains: "cannot start or end with a dot", + }, + { + name: "special characters", + hostname: "host!@#$name", + wantErr: true, + errorContains: "invalid characters", }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if err := CheckForFQDNRules(tt.args.name); (err != nil) != tt.wantErr { - t.Errorf("CheckForFQDNRules() error = %v, wantErr %v", err, tt.wantErr) + err := ValidateHostname(tt.hostname) + if (err != nil) != tt.wantErr { + t.Errorf("ValidateHostname() error = %v, wantErr %v", err, tt.wantErr) + return + } + if tt.wantErr && tt.errorContains != "" { + if err == nil || !strings.Contains(err.Error(), tt.errorContains) { + t.Errorf("ValidateHostname() error = %v, should contain %q", err, tt.errorContains) + } } }) } } func TestMagicDNSRootDomains100(t *testing.T) { - prefixes := []netip.Prefix{ - netip.MustParsePrefix("100.64.0.0/10"), - } - domains := GenerateMagicDNSRootDomains(prefixes) + domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("100.64.0.0/10")) - found := false - for _, domain := range domains { - if domain == "64.100.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) - - found = false - for _, domain := range domains { - if domain == "100.100.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) - - found = false - for _, domain := range domains { - if domain == "127.100.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("64.100.in-addr.arpa."))) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("100.100.in-addr.arpa."))) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("127.100.in-addr.arpa."))) } func TestMagicDNSRootDomains172(t *testing.T) { - prefixes := []netip.Prefix{ - netip.MustParsePrefix("172.16.0.0/16"), - } - domains := GenerateMagicDNSRootDomains(prefixes) + domains := GenerateIPv4DNSRootDomain(netip.MustParsePrefix("172.16.0.0/16")) - found := false - for _, domain := range domains { - if domain == "0.16.172.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) - - found = false - for _, domain := range domains { - if domain == "255.16.172.in-addr.arpa." { - found = true - - break - } - } - assert.True(t, found) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("0.16.172.in-addr.arpa."))) + assert.Contains(t, domains, must.Get(dnsname.ToFQDN("255.16.172.in-addr.arpa."))) } // Happens when netmask is a multiple of 4 bits (sounds likely). func TestMagicDNSRootDomainsIPv6Single(t *testing.T) { - prefixes := []netip.Prefix{ - netip.MustParsePrefix("fd7a:115c:a1e0::/48"), - } - domains := GenerateMagicDNSRootDomains(prefixes) + domains := GenerateIPv6DNSRootDomain(netip.MustParsePrefix("fd7a:115c:a1e0::/48")) assert.Len(t, domains, 1) assert.Equal(t, "0.e.1.a.c.5.1.1.a.7.d.f.ip6.arpa.", domains[0].WithTrailingDot()) } func TestMagicDNSRootDomainsIPv6SingleMultiple(t *testing.T) { - prefixes := []netip.Prefix{ - netip.MustParsePrefix("fd7a:115c:a1e0::/50"), - } - domains := GenerateMagicDNSRootDomains(prefixes) + domains := GenerateIPv6DNSRootDomain(netip.MustParsePrefix("fd7a:115c:a1e0::/50")) yieldsRoot := func(dom string) bool { for _, candidate := range domains { diff --git a/hscontrol/util/file.go b/hscontrol/util/file.go index 5b8656ff..86af636c 100644 --- a/hscontrol/util/file.go +++ b/hscontrol/util/file.go @@ -1,6 +1,8 @@ package util import ( + "errors" + "fmt" "io/fs" "os" "path/filepath" @@ -42,3 +44,21 @@ func GetFileMode(key string) fs.FileMode { return fs.FileMode(mode) } + +func EnsureDir(dir string) error { + if _, err := os.Stat(dir); os.IsNotExist(err) { + err := os.MkdirAll(dir, PermissionFallback) + if err != nil { + if errors.Is(err, os.ErrPermission) { + return fmt.Errorf( + "creating directory %s, failed with permission error, is it located somewhere Headscale can write?", + dir, + ) + } + + return fmt.Errorf("creating directory %s: %w", dir, err) + } + } + + return nil +} diff --git a/hscontrol/util/key.go b/hscontrol/util/key.go index 6501daca..ae107053 100644 --- a/hscontrol/util/key.go +++ b/hscontrol/util/key.go @@ -1,33 +1,10 @@ package util import ( - "encoding/json" "errors" - "regexp" - - "tailscale.com/types/key" ) var ( - NodePublicKeyRegex = regexp.MustCompile("nodekey:[a-fA-F0-9]+") ErrCannotDecryptResponse = errors.New("cannot decrypt response") ZstdCompression = "zstd" ) - -func DecodeAndUnmarshalNaCl( - msg []byte, - output interface{}, - pubKey *key.MachinePublic, - privKey *key.MachinePrivate, -) error { - decrypted, ok := privKey.OpenFrom(*pubKey, msg) - if !ok { - return ErrCannotDecryptResponse - } - - if err := json.Unmarshal(decrypted, output); err != nil { - return err - } - - return nil -} diff --git a/hscontrol/util/log.go b/hscontrol/util/log.go index ebbdb792..f28cd4a3 100644 --- a/hscontrol/util/log.go +++ b/hscontrol/util/log.go @@ -1,7 +1,91 @@ package util -import "github.com/rs/zerolog/log" +import ( + "context" + "errors" + "time" + + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" + "gorm.io/gorm" + gormLogger "gorm.io/gorm/logger" + "tailscale.com/types/logger" +) func LogErr(err error, msg string) { log.Error().Caller().Err(err).Msg(msg) } + +func TSLogfWrapper() logger.Logf { + return func(format string, args ...any) { + log.Debug().Caller().Msgf(format, args...) + } +} + +type DBLogWrapper struct { + Logger *zerolog.Logger + Level zerolog.Level + Event *zerolog.Event + SlowThreshold time.Duration + SkipErrRecordNotFound bool + ParameterizedQueries bool +} + +func NewDBLogWrapper(origin *zerolog.Logger, slowThreshold time.Duration, skipErrRecordNotFound bool, parameterizedQueries bool) *DBLogWrapper { + l := &DBLogWrapper{ + Logger: origin, + Level: origin.GetLevel(), + SlowThreshold: slowThreshold, + SkipErrRecordNotFound: skipErrRecordNotFound, + ParameterizedQueries: parameterizedQueries, + } + + return l +} + +type DBLogWrapperOption func(*DBLogWrapper) + +func (l *DBLogWrapper) LogMode(gormLogger.LogLevel) gormLogger.Interface { + return l +} + +func (l *DBLogWrapper) Info(ctx context.Context, msg string, data ...any) { + l.Logger.Info().Msgf(msg, data...) +} + +func (l *DBLogWrapper) Warn(ctx context.Context, msg string, data ...any) { + l.Logger.Warn().Msgf(msg, data...) +} + +func (l *DBLogWrapper) Error(ctx context.Context, msg string, data ...any) { + l.Logger.Error().Msgf(msg, data...) +} + +func (l *DBLogWrapper) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { + elapsed := time.Since(begin) + sql, rowsAffected := fc() + fields := map[string]any{ + "duration": elapsed, + "sql": sql, + "rowsAffected": rowsAffected, + } + + if err != nil && (!errors.Is(err, gorm.ErrRecordNotFound) || !l.SkipErrRecordNotFound) { + l.Logger.Error().Err(err).Fields(fields).Msgf("") + return + } + + if l.SlowThreshold != 0 && elapsed > l.SlowThreshold { + l.Logger.Warn().Fields(fields).Msgf("") + return + } + + l.Logger.Debug().Fields(fields).Msgf("") +} + +func (l *DBLogWrapper) ParamsFilter(ctx context.Context, sql string, params ...any) (string, []any) { + if l.ParameterizedQueries { + return sql, nil + } + return sql, params +} diff --git a/hscontrol/util/net.go b/hscontrol/util/net.go index b704c936..e28bb00b 100644 --- a/hscontrol/util/net.go +++ b/hscontrol/util/net.go @@ -3,6 +3,11 @@ package util import ( "context" "net" + "net/netip" + "sync" + + "go4.org/netipx" + "tailscale.com/net/tsaddr" ) func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { @@ -10,3 +15,49 @@ func GrpcSocketDialer(ctx context.Context, addr string) (net.Conn, error) { return d.DialContext(ctx, "unix", addr) } + +func PrefixesToString(prefixes []netip.Prefix) []string { + ret := make([]string, 0, len(prefixes)) + for _, prefix := range prefixes { + ret = append(ret, prefix.String()) + } + + return ret +} + +func MustStringsToPrefixes(strings []string) []netip.Prefix { + ret := make([]netip.Prefix, 0, len(strings)) + for _, str := range strings { + prefix := netip.MustParsePrefix(str) + ret = append(ret, prefix) + } + + return ret +} + +// TheInternet returns the IPSet for the Internet. +// https://www.youtube.com/watch?v=iDbyYGrswtg +var TheInternet = sync.OnceValue(func() *netipx.IPSet { + var internetBuilder netipx.IPSetBuilder + internetBuilder.AddPrefix(netip.MustParsePrefix("2000::/3")) + internetBuilder.AddPrefix(tsaddr.AllIPv4()) + + // Delete Private network addresses + // https://datatracker.ietf.org/doc/html/rfc1918 + internetBuilder.RemovePrefix(netip.MustParsePrefix("fc00::/7")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("10.0.0.0/8")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("172.16.0.0/12")) + internetBuilder.RemovePrefix(netip.MustParsePrefix("192.168.0.0/16")) + + // Delete Tailscale networks + internetBuilder.RemovePrefix(tsaddr.TailscaleULARange()) + internetBuilder.RemovePrefix(tsaddr.CGNATRange()) + + // Delete "can't find DHCP networks" + internetBuilder.RemovePrefix(netip.MustParsePrefix("fe80::/10")) // link-local + internetBuilder.RemovePrefix(netip.MustParsePrefix("169.254.0.0/16")) + + theInternetSet, _ := internetBuilder.IPSet() + + return theInternetSet +}) diff --git a/hscontrol/util/prompt.go b/hscontrol/util/prompt.go new file mode 100644 index 00000000..098f1979 --- /dev/null +++ b/hscontrol/util/prompt.go @@ -0,0 +1,24 @@ +package util + +import ( + "fmt" + "os" + "strings" +) + +// YesNo takes a question and prompts the user to answer the +// question with a yes or no. It appends a [y/n] to the message. +// The question is written to stderr so that content can be redirected +// without interfering with the prompt. +func YesNo(msg string) bool { + fmt.Fprint(os.Stderr, msg+" [y/n] ") + + var resp string + fmt.Scanln(&resp) + resp = strings.ToLower(resp) + switch resp { + case "y", "yes", "sure": + return true + } + return false +} diff --git a/hscontrol/util/prompt_test.go b/hscontrol/util/prompt_test.go new file mode 100644 index 00000000..d726ec60 --- /dev/null +++ b/hscontrol/util/prompt_test.go @@ -0,0 +1,209 @@ +package util + +import ( + "bytes" + "io" + "os" + "strings" + "testing" +) + +func TestYesNo(t *testing.T) { + tests := []struct { + name string + input string + expected bool + }{ + { + name: "y answer", + input: "y\n", + expected: true, + }, + { + name: "Y answer", + input: "Y\n", + expected: true, + }, + { + name: "yes answer", + input: "yes\n", + expected: true, + }, + { + name: "YES answer", + input: "YES\n", + expected: true, + }, + { + name: "sure answer", + input: "sure\n", + expected: true, + }, + { + name: "SURE answer", + input: "SURE\n", + expected: true, + }, + { + name: "n answer", + input: "n\n", + expected: false, + }, + { + name: "no answer", + input: "no\n", + expected: false, + }, + { + name: "empty answer", + input: "\n", + expected: false, + }, + { + name: "invalid answer", + input: "maybe\n", + expected: false, + }, + { + name: "random text", + input: "foobar\n", + expected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Capture stdin + oldStdin := os.Stdin + r, w, _ := os.Pipe() + os.Stdin = r + + // Capture stderr + oldStderr := os.Stderr + stderrR, stderrW, _ := os.Pipe() + os.Stderr = stderrW + + // Write test input + go func() { + defer w.Close() + w.WriteString(tt.input) + }() + + // Call the function + result := YesNo("Test question") + + // Restore stdin and stderr + os.Stdin = oldStdin + os.Stderr = oldStderr + stderrW.Close() + + // Check the result + if result != tt.expected { + t.Errorf("YesNo() = %v, want %v", result, tt.expected) + } + + // Check that the prompt was written to stderr + var stderrBuf bytes.Buffer + io.Copy(&stderrBuf, stderrR) + stderrR.Close() + + expectedPrompt := "Test question [y/n] " + actualPrompt := stderrBuf.String() + if actualPrompt != expectedPrompt { + t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) + } + }) + } +} + +func TestYesNoPromptMessage(t *testing.T) { + // Capture stdin + oldStdin := os.Stdin + r, w, _ := os.Pipe() + os.Stdin = r + + // Capture stderr + oldStderr := os.Stderr + stderrR, stderrW, _ := os.Pipe() + os.Stderr = stderrW + + // Write test input + go func() { + defer w.Close() + w.WriteString("n\n") + }() + + // Call the function with a custom message + customMessage := "Do you want to continue with this dangerous operation?" + YesNo(customMessage) + + // Restore stdin and stderr + os.Stdin = oldStdin + os.Stderr = oldStderr + stderrW.Close() + + // Check that the custom message was included in the prompt + var stderrBuf bytes.Buffer + io.Copy(&stderrBuf, stderrR) + stderrR.Close() + + expectedPrompt := customMessage + " [y/n] " + actualPrompt := stderrBuf.String() + if actualPrompt != expectedPrompt { + t.Errorf("Expected prompt %q, got %q", expectedPrompt, actualPrompt) + } +} + +func TestYesNoCaseInsensitive(t *testing.T) { + testCases := []struct { + input string + expected bool + }{ + {"y\n", true}, + {"Y\n", true}, + {"yes\n", true}, + {"Yes\n", true}, + {"YES\n", true}, + {"yEs\n", true}, + {"sure\n", true}, + {"Sure\n", true}, + {"SURE\n", true}, + {"SuRe\n", true}, + } + + for _, tc := range testCases { + t.Run("input_"+strings.TrimSpace(tc.input), func(t *testing.T) { + // Capture stdin + oldStdin := os.Stdin + r, w, _ := os.Pipe() + os.Stdin = r + + // Capture stderr to avoid output during tests + oldStderr := os.Stderr + stderrR, stderrW, _ := os.Pipe() + os.Stderr = stderrW + + // Write test input + go func() { + defer w.Close() + w.WriteString(tc.input) + }() + + // Call the function + result := YesNo("Test") + + // Restore stdin and stderr + os.Stdin = oldStdin + os.Stderr = oldStderr + stderrW.Close() + + // Drain stderr + io.Copy(io.Discard, stderrR) + stderrR.Close() + + if result != tc.expected { + t.Errorf("Input %q: expected %v, got %v", strings.TrimSpace(tc.input), tc.expected, result) + } + }) + } +} diff --git a/hscontrol/util/string.go b/hscontrol/util/string.go index 6f018aff..d1d7ece7 100644 --- a/hscontrol/util/string.go +++ b/hscontrol/util/string.go @@ -32,7 +32,8 @@ func GenerateRandomBytes(n int) ([]byte, error) { func GenerateRandomStringURLSafe(n int) (string, error) { b, err := GenerateRandomBytes(n) - return base64.RawURLEncoding.EncodeToString(b), err + uenc := base64.RawURLEncoding.EncodeToString(b) + return uenc[:n], err } // GenerateRandomStringDNSSafe returns a DNS-safe @@ -56,14 +57,18 @@ func GenerateRandomStringDNSSafe(size int) (string, error) { return str[:size], nil } -func IsStringInSlice(slice []string, str string) bool { - for _, s := range slice { - if s == str { - return true - } +func MustGenerateRandomStringDNSSafe(size int) string { + hash, err := GenerateRandomStringDNSSafe(size) + if err != nil { + panic(err) } - return false + return hash +} + +func InvalidString() string { + hash, _ := GenerateRandomStringDNSSafe(8) + return "invalid-" + hash } func TailNodesToString(nodes []*tailcfg.Node) string { @@ -83,3 +88,21 @@ func TailMapResponseToString(resp tailcfg.MapResponse) string { TailNodesToString(resp.Peers), ) } + +func TailcfgFilterRulesToString(rules []tailcfg.FilterRule) string { + var sb strings.Builder + + for index, rule := range rules { + sb.WriteString(fmt.Sprintf(` +{ + SrcIPs: %v + DstIPs: %v +} +`, rule.SrcIPs, rule.DstPorts)) + if index < len(rules)-1 { + sb.WriteString(", ") + } + } + + return fmt.Sprintf("[ %s ](%d)", sb.String(), len(rules)) +} diff --git a/hscontrol/util/string_test.go b/hscontrol/util/string_test.go index 87a8be1c..f0b4c558 100644 --- a/hscontrol/util/string_test.go +++ b/hscontrol/util/string_test.go @@ -4,12 +4,13 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestGenerateRandomStringDNSSafe(t *testing.T) { - for i := 0; i < 100000; i++ { + for range 100000 { str, err := GenerateRandomStringDNSSafe(8) - assert.Nil(t, err) + require.NoError(t, err) assert.Len(t, str, 8) } } diff --git a/hscontrol/util/test.go b/hscontrol/util/test.go index 6d465426..d93ae1f2 100644 --- a/hscontrol/util/test.go +++ b/hscontrol/util/test.go @@ -4,7 +4,9 @@ import ( "net/netip" "github.com/google/go-cmp/cmp" + "tailscale.com/types/ipproto" "tailscale.com/types/key" + "tailscale.com/types/views" ) var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool { @@ -15,6 +17,10 @@ var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool { return x.Compare(y) == 0 }) +var AddrPortComparer = cmp.Comparer(func(x, y netip.AddrPort) bool { + return x == y +}) + var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool { return x.String() == y.String() }) @@ -27,6 +33,8 @@ var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { return x.String() == y.String() }) +var ViewSliceIPProtoComparer = cmp.Comparer(func(a, b views.Slice[ipproto.Proto]) bool { return views.SliceEqual(a, b) }) + var Comparers []cmp.Option = []cmp.Option{ - IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer, + IPComparer, PrefixComparer, AddrPortComparer, MkeyComparer, NkeyComparer, DkeyComparer, ViewSliceIPProtoComparer, } diff --git a/hscontrol/util/util.go b/hscontrol/util/util.go new file mode 100644 index 00000000..4d828d02 --- /dev/null +++ b/hscontrol/util/util.go @@ -0,0 +1,313 @@ +package util + +import ( + "cmp" + "errors" + "fmt" + "net/netip" + "net/url" + "os" + "regexp" + "strconv" + "strings" + "time" + + "tailscale.com/tailcfg" + "tailscale.com/util/cmpver" +) + +func TailscaleVersionNewerOrEqual(minimum, toCheck string) bool { + if cmpver.Compare(minimum, toCheck) <= 0 || + toCheck == "unstable" || + toCheck == "head" { + return true + } + + return false +} + +// ParseLoginURLFromCLILogin parses the output of the tailscale up command to extract the login URL. +// It returns an error if not exactly one URL is found. +func ParseLoginURLFromCLILogin(output string) (*url.URL, error) { + lines := strings.Split(output, "\n") + var urlStr string + + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "http://") || strings.HasPrefix(line, "https://") { + if urlStr != "" { + return nil, fmt.Errorf("multiple URLs found: %s and %s", urlStr, line) + } + urlStr = line + } + } + + if urlStr == "" { + return nil, errors.New("no URL found") + } + + loginURL, err := url.Parse(urlStr) + if err != nil { + return nil, fmt.Errorf("failed to parse URL: %w", err) + } + + return loginURL, nil +} + +type TraceroutePath struct { + // Hop is the current jump in the total traceroute. + Hop int + + // Hostname is the resolved hostname or IP address identifying the jump + Hostname string + + // IP is the IP address of the jump + IP netip.Addr + + // Latencies is a list of the latencies for this jump + Latencies []time.Duration +} + +type Traceroute struct { + // Hostname is the resolved hostname or IP address identifying the target + Hostname string + + // IP is the IP address of the target + IP netip.Addr + + // Route is the path taken to reach the target if successful. The list is ordered by the path taken. + Route []TraceroutePath + + // Success indicates if the traceroute was successful. + Success bool + + // Err contains an error if the traceroute was not successful. + Err error +} + +// ParseTraceroute parses the output of the traceroute command and returns a Traceroute struct. +func ParseTraceroute(output string) (Traceroute, error) { + lines := strings.Split(strings.TrimSpace(output), "\n") + if len(lines) < 1 { + return Traceroute{}, errors.New("empty traceroute output") + } + + // Parse the header line - handle both 'traceroute' and 'tracert' (Windows) + headerRegex := regexp.MustCompile(`(?i)(?:traceroute|tracing route) to ([^ ]+) (?:\[([^\]]+)\]|\(([^)]+)\))`) + headerMatches := headerRegex.FindStringSubmatch(lines[0]) + if len(headerMatches) < 2 { + return Traceroute{}, fmt.Errorf("parsing traceroute header: %s", lines[0]) + } + + hostname := headerMatches[1] + // IP can be in either capture group 2 or 3 depending on format + ipStr := headerMatches[2] + if ipStr == "" { + ipStr = headerMatches[3] + } + ip, err := netip.ParseAddr(ipStr) + if err != nil { + return Traceroute{}, fmt.Errorf("parsing IP address %s: %w", ipStr, err) + } + + result := Traceroute{ + Hostname: hostname, + IP: ip, + Route: []TraceroutePath{}, + Success: false, + } + + // More flexible regex that handles various traceroute output formats + // Main pattern handles: "hostname (IP)", "hostname [IP]", "IP only", "* * *" + hopRegex := regexp.MustCompile(`^\s*(\d+)\s+(.*)$`) + // Patterns for parsing the hop details + hostIPRegex := regexp.MustCompile(`^([^ ]+) \(([^)]+)\)`) + hostIPBracketRegex := regexp.MustCompile(`^([^ ]+) \[([^\]]+)\]`) + // Pattern for latencies with flexible spacing and optional '<' + latencyRegex := regexp.MustCompile(`(<?\d+(?:\.\d+)?)\s*ms\b`) + + for i := 1; i < len(lines); i++ { + line := strings.TrimSpace(lines[i]) + if line == "" { + continue + } + + matches := hopRegex.FindStringSubmatch(line) + if len(matches) == 0 { + continue + } + + hop, err := strconv.Atoi(matches[1]) + if err != nil { + // Skip lines that don't start with a hop number + continue + } + + remainder := strings.TrimSpace(matches[2]) + var hopHostname string + var hopIP netip.Addr + var latencies []time.Duration + + // Check for Windows tracert format which has latencies before hostname + // Format: " 1 <1 ms <1 ms <1 ms router.local [192.168.1.1]" + latencyFirst := false + if strings.Contains(remainder, " ms ") && !strings.HasPrefix(remainder, "*") { + // Check if latencies appear before any hostname/IP + firstSpace := strings.Index(remainder, " ") + if firstSpace > 0 { + firstPart := remainder[:firstSpace] + if _, err := strconv.ParseFloat(strings.TrimPrefix(firstPart, "<"), 64); err == nil { + latencyFirst = true + } + } + } + + if latencyFirst { + // Windows format: extract latencies first + for { + latMatch := latencyRegex.FindStringSubmatchIndex(remainder) + if latMatch == nil || latMatch[0] > 0 { + break + } + // Extract and remove the latency from the beginning + latStr := strings.TrimPrefix(remainder[latMatch[2]:latMatch[3]], "<") + ms, err := strconv.ParseFloat(latStr, 64) + if err == nil { + // Round to nearest microsecond to avoid floating point precision issues + duration := time.Duration(ms * float64(time.Millisecond)) + latencies = append(latencies, duration.Round(time.Microsecond)) + } + remainder = strings.TrimSpace(remainder[latMatch[1]:]) + } + } + + // Now parse hostname/IP from remainder + if strings.HasPrefix(remainder, "*") { + // Timeout hop + hopHostname = "*" + // Skip any remaining asterisks + remainder = strings.TrimLeft(remainder, "* ") + } else if hostMatch := hostIPRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 { + // Format: hostname (IP) + hopHostname = hostMatch[1] + hopIP, _ = netip.ParseAddr(hostMatch[2]) + remainder = strings.TrimSpace(remainder[len(hostMatch[0]):]) + } else if hostMatch := hostIPBracketRegex.FindStringSubmatch(remainder); len(hostMatch) >= 3 { + // Format: hostname [IP] (Windows) + hopHostname = hostMatch[1] + hopIP, _ = netip.ParseAddr(hostMatch[2]) + remainder = strings.TrimSpace(remainder[len(hostMatch[0]):]) + } else { + // Try to parse as IP only or hostname only + parts := strings.Fields(remainder) + if len(parts) > 0 { + hopHostname = parts[0] + if ip, err := netip.ParseAddr(parts[0]); err == nil { + hopIP = ip + } + remainder = strings.TrimSpace(strings.Join(parts[1:], " ")) + } + } + + // Extract latencies from the remaining part (if not already done) + if !latencyFirst { + latencyMatches := latencyRegex.FindAllStringSubmatch(remainder, -1) + for _, match := range latencyMatches { + if len(match) > 1 { + // Remove '<' prefix if present (e.g., "<1 ms") + latStr := strings.TrimPrefix(match[1], "<") + ms, err := strconv.ParseFloat(latStr, 64) + if err == nil { + // Round to nearest microsecond to avoid floating point precision issues + duration := time.Duration(ms * float64(time.Millisecond)) + latencies = append(latencies, duration.Round(time.Microsecond)) + } + } + } + } + + path := TraceroutePath{ + Hop: hop, + Hostname: hopHostname, + IP: hopIP, + Latencies: latencies, + } + + result.Route = append(result.Route, path) + + // Check if we've reached the target + if hopIP == ip { + result.Success = true + } + } + + // If we didn't reach the target, it's unsuccessful + if !result.Success { + result.Err = errors.New("traceroute did not reach target") + } + + return result, nil +} + +func IsCI() bool { + if _, ok := os.LookupEnv("CI"); ok { + return true + } + + if _, ok := os.LookupEnv("GITHUB_RUN_ID"); ok { + return true + } + + return false +} + +// SafeHostname extracts a hostname from Hostinfo, providing sensible defaults +// if Hostinfo is nil or Hostname is empty. This prevents nil pointer dereferences +// and ensures nodes always have a valid hostname. +// The hostname is truncated to 63 characters to comply with DNS label length limits (RFC 1123). +// EnsureHostname guarantees a valid hostname for node registration. +// This function never fails - it always returns a valid hostname. +// +// Strategy: +// 1. If hostinfo is nil/empty → generate default from keys +// 2. If hostname is provided → normalise it +// 3. If normalisation fails → generate invalid-<random> replacement +// +// Returns the guaranteed-valid hostname to use. +func EnsureHostname(hostinfo *tailcfg.Hostinfo, machineKey, nodeKey string) string { + if hostinfo == nil || hostinfo.Hostname == "" { + key := cmp.Or(machineKey, nodeKey) + if key == "" { + return "unknown-node" + } + keyPrefix := key + if len(key) > 8 { + keyPrefix = key[:8] + } + return fmt.Sprintf("node-%s", keyPrefix) + } + + lowercased := strings.ToLower(hostinfo.Hostname) + if err := ValidateHostname(lowercased); err == nil { + return lowercased + } + + return InvalidString() +} + +// GenerateRegistrationKey generates a vanity key for tracking web authentication +// registration flows in logs. This key is NOT stored in the database and does NOT use bcrypt - +// it's purely for observability and correlating log entries during the registration process. +func GenerateRegistrationKey() (string, error) { + const ( + registerKeyPrefix = "hskey-reg-" //nolint:gosec // This is a vanity key for logging, not a credential + registerKeyLength = 64 + ) + + randomPart, err := GenerateRandomStringURLSafe(registerKeyLength) + if err != nil { + return "", fmt.Errorf("generating registration key: %w", err) + } + + return registerKeyPrefix + randomPart, nil +} diff --git a/hscontrol/util/util_test.go b/hscontrol/util/util_test.go new file mode 100644 index 00000000..33f27b7a --- /dev/null +++ b/hscontrol/util/util_test.go @@ -0,0 +1,1386 @@ +package util + +import ( + "errors" + "net/netip" + "strings" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "tailscale.com/tailcfg" +) + +func TestTailscaleVersionNewerOrEqual(t *testing.T) { + type args struct { + minimum string + toCheck string + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "is-equal", + args: args{ + minimum: "1.56", + toCheck: "1.56", + }, + want: true, + }, + { + name: "is-newer-head", + args: args{ + minimum: "1.56", + toCheck: "head", + }, + want: true, + }, + { + name: "is-newer-unstable", + args: args{ + minimum: "1.56", + toCheck: "unstable", + }, + want: true, + }, + { + name: "is-newer-patch", + args: args{ + minimum: "1.56.1", + toCheck: "1.56.1", + }, + want: true, + }, + { + name: "is-older-patch-same-minor", + args: args{ + minimum: "1.56.1", + toCheck: "1.56.0", + }, + want: false, + }, + { + name: "is-older-unstable", + args: args{ + minimum: "1.56", + toCheck: "1.55", + }, + want: false, + }, + { + name: "is-older-one-stable", + args: args{ + minimum: "1.56", + toCheck: "1.54", + }, + want: false, + }, + { + name: "is-older-five-stable", + args: args{ + minimum: "1.56", + toCheck: "1.46", + }, + want: false, + }, + { + name: "is-older-patch", + args: args{ + minimum: "1.56", + toCheck: "1.48.1", + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := TailscaleVersionNewerOrEqual(tt.args.minimum, tt.args.toCheck); got != tt.want { + t.Errorf("TailscaleVersionNewerThan() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestParseLoginURLFromCLILogin(t *testing.T) { + tests := []struct { + name string + output string + wantURL string + wantErr string + }{ + { + name: "valid https URL", + output: ` +To authenticate, visit: + + https://headscale.example.com/register/3oYCOZYA2zZmGB4PQ7aHBaMi + +Success.`, + wantURL: "https://headscale.example.com/register/3oYCOZYA2zZmGB4PQ7aHBaMi", + wantErr: "", + }, + { + name: "valid http URL", + output: ` +To authenticate, visit: + + http://headscale.example.com/register/3oYCOZYA2zZmGB4PQ7aHBaMi + +Success.`, + wantURL: "http://headscale.example.com/register/3oYCOZYA2zZmGB4PQ7aHBaMi", + wantErr: "", + }, + { + name: "no URL", + output: ` +To authenticate, visit: + +Success.`, + wantURL: "", + wantErr: "no URL found", + }, + { + name: "multiple URLs", + output: ` +To authenticate, visit: + + https://headscale.example.com/register/3oYCOZYA2zZmGB4PQ7aHBaMi + +To authenticate, visit: + + http://headscale.example.com/register/dv1l2k5FackOYl-7-V3mSd_E + +Success.`, + wantURL: "", + wantErr: "multiple URLs found: https://headscale.example.com/register/3oYCOZYA2zZmGB4PQ7aHBaMi and http://headscale.example.com/register/dv1l2k5FackOYl-7-V3mSd_E", + }, + { + name: "invalid URL", + output: ` +To authenticate, visit: + + invalid-url + +Success.`, + wantURL: "", + wantErr: "no URL found", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotURL, err := ParseLoginURLFromCLILogin(tt.output) + if tt.wantErr != "" { + if err == nil || err.Error() != tt.wantErr { + t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr) + } + } else { + if err != nil { + t.Errorf("ParseLoginURLFromCLILogin() error = %v, wantErr %v", err, tt.wantErr) + } + if gotURL.String() != tt.wantURL { + t.Errorf("ParseLoginURLFromCLILogin() = %v, want %v", gotURL, tt.wantURL) + } + } + }) + } +} + +func TestParseTraceroute(t *testing.T) { + tests := []struct { + name string + input string + want Traceroute + wantErr bool + }{ + { + name: "simple successful traceroute", + input: `traceroute to 172.24.0.3 (172.24.0.3), 30 hops max, 46 byte packets + 1 ts-head-hk0urr.headscale.net (100.64.0.1) 1.135 ms 0.922 ms 0.619 ms + 2 172.24.0.3 (172.24.0.3) 0.593 ms 0.549 ms 0.522 ms`, + want: Traceroute{ + Hostname: "172.24.0.3", + IP: netip.MustParseAddr("172.24.0.3"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "ts-head-hk0urr.headscale.net", + IP: netip.MustParseAddr("100.64.0.1"), + Latencies: []time.Duration{ + 1135 * time.Microsecond, + 922 * time.Microsecond, + 619 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "172.24.0.3", + IP: netip.MustParseAddr("172.24.0.3"), + Latencies: []time.Duration{ + 593 * time.Microsecond, + 549 * time.Microsecond, + 522 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "traceroute with timeouts", + input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets + 1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms + 2 * * * + 3 isp-gateway.net (10.0.0.1) 15.678 ms 14.789 ms 15.432 ms + 4 8.8.8.8 (8.8.8.8) 20.123 ms 19.876 ms 20.345 ms`, + want: Traceroute{ + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "router.local", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + 1121 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "*", + }, + { + Hop: 3, + Hostname: "isp-gateway.net", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 15678 * time.Microsecond, + 14789 * time.Microsecond, + 15432 * time.Microsecond, + }, + }, + { + Hop: 4, + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Latencies: []time.Duration{ + 20123 * time.Microsecond, + 19876 * time.Microsecond, + 20345 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "unsuccessful traceroute", + input: `traceroute to 10.0.0.99 (10.0.0.99), 5 hops max, 60 byte packets + 1 router.local (192.168.1.1) 1.234 ms 1.123 ms 1.121 ms + 2 * * * + 3 * * * + 4 * * * + 5 * * *`, + want: Traceroute{ + Hostname: "10.0.0.99", + IP: netip.MustParseAddr("10.0.0.99"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "router.local", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + 1121 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "*", + }, + { + Hop: 3, + Hostname: "*", + }, + { + Hop: 4, + Hostname: "*", + }, + { + Hop: 5, + Hostname: "*", + }, + }, + Success: false, + Err: errors.New("traceroute did not reach target"), + }, + wantErr: false, + }, + { + name: "empty input", + input: "", + want: Traceroute{}, + wantErr: true, + }, + { + name: "invalid header", + input: "not a valid traceroute output", + want: Traceroute{}, + wantErr: true, + }, + { + name: "windows tracert format", + input: `Tracing route to google.com [8.8.8.8] +over a maximum of 30 hops: + + 1 <1 ms <1 ms <1 ms router.local [192.168.1.1] + 2 5 ms 4 ms 5 ms 10.0.0.1 + 3 * * * Request timed out. + 4 20 ms 19 ms 21 ms 8.8.8.8`, + want: Traceroute{ + Hostname: "google.com", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "router.local", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1 * time.Millisecond, + 1 * time.Millisecond, + 1 * time.Millisecond, + }, + }, + { + Hop: 2, + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 5 * time.Millisecond, + 4 * time.Millisecond, + 5 * time.Millisecond, + }, + }, + { + Hop: 3, + Hostname: "*", + }, + { + Hop: 4, + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Latencies: []time.Duration{ + 20 * time.Millisecond, + 19 * time.Millisecond, + 21 * time.Millisecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "mixed latency formats", + input: `traceroute to 192.168.1.1 (192.168.1.1), 30 hops max, 60 byte packets + 1 gateway (192.168.1.1) 0.5 ms * 0.4 ms`, + want: Traceroute{ + Hostname: "192.168.1.1", + IP: netip.MustParseAddr("192.168.1.1"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "gateway", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 500 * time.Microsecond, + 400 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "only one latency value", + input: `traceroute to 10.0.0.1 (10.0.0.1), 30 hops max, 60 byte packets + 1 10.0.0.1 (10.0.0.1) 1.5 ms`, + want: Traceroute{ + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 1500 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "backward compatibility - original format with 3 latencies", + input: `traceroute to 172.24.0.3 (172.24.0.3), 30 hops max, 46 byte packets + 1 ts-head-hk0urr.headscale.net (100.64.0.1) 1.135 ms 0.922 ms 0.619 ms + 2 172.24.0.3 (172.24.0.3) 0.593 ms 0.549 ms 0.522 ms`, + want: Traceroute{ + Hostname: "172.24.0.3", + IP: netip.MustParseAddr("172.24.0.3"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "ts-head-hk0urr.headscale.net", + IP: netip.MustParseAddr("100.64.0.1"), + Latencies: []time.Duration{ + 1135 * time.Microsecond, + 922 * time.Microsecond, + 619 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "172.24.0.3", + IP: netip.MustParseAddr("172.24.0.3"), + Latencies: []time.Duration{ + 593 * time.Microsecond, + 549 * time.Microsecond, + 522 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "two latencies only - common on packet loss", + input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets + 1 gateway (192.168.1.1) 1.2 ms 1.1 ms`, + want: Traceroute{ + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "gateway", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1200 * time.Microsecond, + 1100 * time.Microsecond, + }, + }, + }, + Success: false, + Err: errors.New("traceroute did not reach target"), + }, + wantErr: false, + }, + { + name: "hostname without parentheses - some traceroute versions", + input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets + 1 192.168.1.1 1.2 ms 1.1 ms 1.0 ms + 2 8.8.8.8 20.1 ms 19.9 ms 20.2 ms`, + want: Traceroute{ + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "192.168.1.1", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1200 * time.Microsecond, + 1100 * time.Microsecond, + 1000 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Latencies: []time.Duration{ + 20100 * time.Microsecond, + 19900 * time.Microsecond, + 20200 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "ipv6 traceroute", + input: `traceroute to 2001:4860:4860::8888 (2001:4860:4860::8888), 30 hops max, 80 byte packets + 1 2001:db8::1 (2001:db8::1) 1.123 ms 1.045 ms 0.987 ms + 2 2001:4860:4860::8888 (2001:4860:4860::8888) 15.234 ms 14.876 ms 15.123 ms`, + want: Traceroute{ + Hostname: "2001:4860:4860::8888", + IP: netip.MustParseAddr("2001:4860:4860::8888"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "2001:db8::1", + IP: netip.MustParseAddr("2001:db8::1"), + Latencies: []time.Duration{ + 1123 * time.Microsecond, + 1045 * time.Microsecond, + 987 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "2001:4860:4860::8888", + IP: netip.MustParseAddr("2001:4860:4860::8888"), + Latencies: []time.Duration{ + 15234 * time.Microsecond, + 14876 * time.Microsecond, + 15123 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "macos traceroute with extra spacing", + input: `traceroute to google.com (8.8.8.8), 64 hops max, 52 byte packets + 1 router.home (192.168.1.1) 2.345 ms 1.234 ms 1.567 ms + 2 * * * + 3 isp-gw.net (10.1.1.1) 15.234 ms 14.567 ms 15.890 ms + 4 google.com (8.8.8.8) 20.123 ms 19.456 ms 20.789 ms`, + want: Traceroute{ + Hostname: "google.com", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "router.home", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 2345 * time.Microsecond, + 1234 * time.Microsecond, + 1567 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "*", + }, + { + Hop: 3, + Hostname: "isp-gw.net", + IP: netip.MustParseAddr("10.1.1.1"), + Latencies: []time.Duration{ + 15234 * time.Microsecond, + 14567 * time.Microsecond, + 15890 * time.Microsecond, + }, + }, + { + Hop: 4, + Hostname: "google.com", + IP: netip.MustParseAddr("8.8.8.8"), + Latencies: []time.Duration{ + 20123 * time.Microsecond, + 19456 * time.Microsecond, + 20789 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "busybox traceroute minimal format", + input: `traceroute to 10.0.0.1 (10.0.0.1), 30 hops max, 38 byte packets + 1 10.0.0.1 (10.0.0.1) 1.234 ms 1.123 ms 1.456 ms`, + want: Traceroute{ + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + 1456 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "linux traceroute with dns failure fallback to IP", + input: `traceroute to example.com (93.184.216.34), 30 hops max, 60 byte packets + 1 192.168.1.1 (192.168.1.1) 1.234 ms 1.123 ms 1.098 ms + 2 10.0.0.1 (10.0.0.1) 5.678 ms 5.432 ms 5.321 ms + 3 93.184.216.34 (93.184.216.34) 20.123 ms 19.876 ms 20.234 ms`, + want: Traceroute{ + Hostname: "example.com", + IP: netip.MustParseAddr("93.184.216.34"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "192.168.1.1", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + 1098 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 5678 * time.Microsecond, + 5432 * time.Microsecond, + 5321 * time.Microsecond, + }, + }, + { + Hop: 3, + Hostname: "93.184.216.34", + IP: netip.MustParseAddr("93.184.216.34"), + Latencies: []time.Duration{ + 20123 * time.Microsecond, + 19876 * time.Microsecond, + 20234 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "alpine linux traceroute with ms variations", + input: `traceroute to 1.1.1.1 (1.1.1.1), 30 hops max, 46 byte packets + 1 gateway (192.168.0.1) 0.456ms 0.389ms 0.412ms + 2 1.1.1.1 (1.1.1.1) 8.234ms 7.987ms 8.123ms`, + want: Traceroute{ + Hostname: "1.1.1.1", + IP: netip.MustParseAddr("1.1.1.1"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "gateway", + IP: netip.MustParseAddr("192.168.0.1"), + Latencies: []time.Duration{ + 456 * time.Microsecond, + 389 * time.Microsecond, + 412 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "1.1.1.1", + IP: netip.MustParseAddr("1.1.1.1"), + Latencies: []time.Duration{ + 8234 * time.Microsecond, + 7987 * time.Microsecond, + 8123 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + { + name: "mixed asterisk and latency values", + input: `traceroute to 8.8.8.8 (8.8.8.8), 30 hops max, 60 byte packets + 1 gateway (192.168.1.1) * 1.234 ms 1.123 ms + 2 10.0.0.1 (10.0.0.1) 5.678 ms * 5.432 ms + 3 8.8.8.8 (8.8.8.8) 20.123 ms 19.876 ms *`, + want: Traceroute{ + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Route: []TraceroutePath{ + { + Hop: 1, + Hostname: "gateway", + IP: netip.MustParseAddr("192.168.1.1"), + Latencies: []time.Duration{ + 1234 * time.Microsecond, + 1123 * time.Microsecond, + }, + }, + { + Hop: 2, + Hostname: "10.0.0.1", + IP: netip.MustParseAddr("10.0.0.1"), + Latencies: []time.Duration{ + 5678 * time.Microsecond, + 5432 * time.Microsecond, + }, + }, + { + Hop: 3, + Hostname: "8.8.8.8", + IP: netip.MustParseAddr("8.8.8.8"), + Latencies: []time.Duration{ + 20123 * time.Microsecond, + 19876 * time.Microsecond, + }, + }, + }, + Success: true, + Err: nil, + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := ParseTraceroute(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("ParseTraceroute() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if tt.wantErr { + return + } + + // Special handling for error field since it can't be directly compared with cmp.Diff + gotErr := got.Err + wantErr := tt.want.Err + got.Err = nil + tt.want.Err = nil + + if diff := cmp.Diff(tt.want, got, IPComparer); diff != "" { + t.Errorf("ParseTraceroute() mismatch (-want +got):\n%s", diff) + } + + // Now check error field separately + if (gotErr == nil) != (wantErr == nil) { + t.Errorf("Error field: got %v, want %v", gotErr, wantErr) + } else if gotErr != nil && wantErr != nil && gotErr.Error() != wantErr.Error() { + t.Errorf("Error message: got %q, want %q", gotErr.Error(), wantErr.Error()) + } + }) + } +} + +func TestEnsureHostname(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hostinfo *tailcfg.Hostinfo + machineKey string + nodeKey string + want string + }{ + { + name: "valid_hostname", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "test-node", + }, + { + name: "nil_hostinfo_with_machine_key", + hostinfo: nil, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "node-mkey1234", + }, + { + name: "nil_hostinfo_with_node_key_only", + hostinfo: nil, + machineKey: "", + nodeKey: "nkey12345678", + want: "node-nkey1234", + }, + { + name: "nil_hostinfo_no_keys", + hostinfo: nil, + machineKey: "", + nodeKey: "", + want: "unknown-node", + }, + { + name: "empty_hostname_with_machine_key", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "node-mkey1234", + }, + { + name: "empty_hostname_with_node_key_only", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "nkey12345678", + want: "node-nkey1234", + }, + { + name: "empty_hostname_no_keys", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "", + want: "unknown-node", + }, + { + name: "hostname_exactly_63_chars", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "123456789012345678901234567890123456789012345678901234567890123", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "123456789012345678901234567890123456789012345678901234567890123", + }, + { + name: "hostname_64_chars_truncated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "1234567890123456789012345678901234567890123456789012345678901234", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "hostname_very_long_truncated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters-and-should-be-truncated", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "hostname_with_special_chars", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "node-with-special!@#$%", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "hostname_with_unicode", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "node-ñoño-测试", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "short_machine_key", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "short", + nodeKey: "nkey12345678", + want: "node-short", + }, + { + name: "short_node_key", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "short", + want: "node-short", + }, + { + name: "hostname_with_emoji_replaced", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "hostname-with-💩", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "hostname_only_emoji_replaced", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "🚀", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "hostname_with_multiple_emojis_replaced", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "node-🎉-🚀-test", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "uppercase_to_lowercase", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "User2-Host", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "user2-host", + }, + { + name: "underscore_removed", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test_node", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "at_sign_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "Test@Host", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "chinese_chars_with_dash_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "server-北京-01", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "chinese_only_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "我的电脑", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "emoji_with_text_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "laptop-🚀", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "mixed_chinese_emoji_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "测试💻机器", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "only_emojis_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "🎉🎊", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "only_at_signs_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "@@@", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "starts_with_dash_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "-test", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "ends_with_dash_invalid", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + { + name: "very_long_hostname_truncated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: strings.Repeat("t", 70), + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + want: "invalid-", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + got := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + // For invalid hostnames, we just check the prefix since the random part varies + if strings.HasPrefix(tt.want, "invalid-") { + if !strings.HasPrefix(got, "invalid-") { + t.Errorf("EnsureHostname() = %v, want prefix %v", got, tt.want) + } + } else if got != tt.want { + t.Errorf("EnsureHostname() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEnsureHostnameWithHostinfo(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + hostinfo *tailcfg.Hostinfo + machineKey string + nodeKey string + wantHostname string + checkHostinfo func(*testing.T, *tailcfg.Hostinfo) + }{ + { + name: "valid_hostinfo_unchanged", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "test-node", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "test-node" { + t.Errorf("hostname = %v, want test-node", hi.Hostname) + } + if hi.OS != "linux" { + t.Errorf("OS = %v, want linux", hi.OS) + } + }, + }, + { + name: "nil_hostinfo_creates_default", + hostinfo: nil, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "node-mkey1234", + }, + { + name: "empty_hostname_updated", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + OS: "darwin", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "node-mkey1234", + }, + { + name: "long_hostname_rejected", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test-node-with-very-long-hostname-that-exceeds-dns-label-limits-of-63-characters", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "invalid-", + }, + { + name: "nil_hostinfo_node_key_only", + hostinfo: nil, + machineKey: "", + nodeKey: "nkey12345678", + wantHostname: "node-nkey1234", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "node-nkey1234" { + t.Errorf("hostname = %v, want node-nkey1234", hi.Hostname) + } + }, + }, + { + name: "nil_hostinfo_no_keys", + hostinfo: nil, + machineKey: "", + nodeKey: "", + wantHostname: "unknown-node", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "unknown-node" { + t.Errorf("hostname = %v, want unknown-node", hi.Hostname) + } + }, + }, + { + name: "empty_hostname_no_keys", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "", + }, + machineKey: "", + nodeKey: "", + wantHostname: "unknown-node", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "unknown-node" { + t.Errorf("hostname = %v, want unknown-node", hi.Hostname) + } + }, + }, + { + name: "preserves_other_fields", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "test", + OS: "windows", + OSVersion: "10.0.19044", + DeviceModel: "test-device", + BackendLogID: "log123", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "test", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if hi.Hostname != "test" { + t.Errorf("hostname = %v, want test", hi.Hostname) + } + if hi.OS != "windows" { + t.Errorf("OS = %v, want windows", hi.OS) + } + if hi.OSVersion != "10.0.19044" { + t.Errorf("OSVersion = %v, want 10.0.19044", hi.OSVersion) + } + if hi.DeviceModel != "test-device" { + t.Errorf("DeviceModel = %v, want test-device", hi.DeviceModel) + } + if hi.BackendLogID != "log123" { + t.Errorf("BackendLogID = %v, want log123", hi.BackendLogID) + } + }, + }, + { + name: "exactly_63_chars_unchanged", + hostinfo: &tailcfg.Hostinfo{ + Hostname: "123456789012345678901234567890123456789012345678901234567890123", + }, + machineKey: "mkey12345678", + nodeKey: "nkey12345678", + wantHostname: "123456789012345678901234567890123456789012345678901234567890123", + checkHostinfo: func(t *testing.T, hi *tailcfg.Hostinfo) { + if hi == nil { + t.Error("hostinfo should not be nil") + } + if len(hi.Hostname) != 63 { + t.Errorf("hostname length = %v, want 63", len(hi.Hostname)) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + gotHostname := EnsureHostname(tt.hostinfo, tt.machineKey, tt.nodeKey) + // For invalid hostnames, we just check the prefix since the random part varies + if strings.HasPrefix(tt.wantHostname, "invalid-") { + if !strings.HasPrefix(gotHostname, "invalid-") { + t.Errorf("EnsureHostname() = %v, want prefix %v", gotHostname, tt.wantHostname) + } + } else if gotHostname != tt.wantHostname { + t.Errorf("EnsureHostname() hostname = %v, want %v", gotHostname, tt.wantHostname) + } + }) + } +} + +func TestEnsureHostname_DNSLabelLimit(t *testing.T) { + t.Parallel() + + testCases := []string{ + "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa", + "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb", + "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc", + "dddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd", + } + + for i, hostname := range testCases { + t.Run(cmp.Diff("", ""), func(t *testing.T) { + hostinfo := &tailcfg.Hostinfo{Hostname: hostname} + result := EnsureHostname(hostinfo, "mkey", "nkey") + if len(result) > 63 { + t.Errorf("test case %d: hostname length = %d, want <= 63", i, len(result)) + } + }) + } +} + +func TestEnsureHostname_Idempotent(t *testing.T) { + t.Parallel() + + originalHostinfo := &tailcfg.Hostinfo{ + Hostname: "test-node", + OS: "linux", + } + + hostname1 := EnsureHostname(originalHostinfo, "mkey", "nkey") + hostname2 := EnsureHostname(originalHostinfo, "mkey", "nkey") + + if hostname1 != hostname2 { + t.Errorf("hostnames not equal: %v != %v", hostname1, hostname2) + } +} + +func TestGenerateRegistrationKey(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + test func(*testing.T) + }{ + { + name: "generates_key_with_correct_prefix", + test: func(t *testing.T) { + t.Helper() + + key, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + if !strings.HasPrefix(key, "hskey-reg-") { + t.Errorf("key does not have expected prefix: %s", key) + } + }, + }, + { + name: "generates_key_with_correct_length", + test: func(t *testing.T) { + t.Helper() + + key, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + // Expected format: hskey-reg-{64-char-random} + // Total length: 10 (prefix) + 64 (random) = 74 + if len(key) != 74 { + t.Errorf("key length = %d, want 74", len(key)) + } + }, + }, + { + name: "generates_unique_keys", + test: func(t *testing.T) { + t.Helper() + + key1, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + key2, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + if key1 == key2 { + t.Error("generated keys should be unique") + } + }, + }, + { + name: "key_contains_only_valid_chars", + test: func(t *testing.T) { + t.Helper() + + key, err := GenerateRegistrationKey() + if err != nil { + t.Errorf("GenerateRegistrationKey() error = %v", err) + } + + // Remove prefix + _, randomPart, found := strings.Cut(key, "hskey-reg-") + if !found { + t.Error("key does not contain expected prefix") + } + + // Verify base64 URL-safe characters (A-Za-z0-9_-) + for _, ch := range randomPart { + if (ch < 'A' || ch > 'Z') && + (ch < 'a' || ch > 'z') && + (ch < '0' || ch > '9') && + ch != '_' && ch != '-' { + t.Errorf("key contains invalid character: %c", ch) + } + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + tt.test(t) + }) + } +} diff --git a/integration/README.md b/integration/README.md index e5676a44..56247c52 100644 --- a/integration/README.md +++ b/integration/README.md @@ -11,10 +11,10 @@ Tests are located in files ending with `_test.go` and the framework are located ## Running integration tests locally -The easiest way to run tests locally is to use `[act](INSERT LINK)`, a local GitHub Actions runner: +The easiest way to run tests locally is to use [act](https://github.com/nektos/act), a local GitHub Actions runner: ``` -act pull_request -W .github/workflows/test-integration-v2-TestPingAllByIP.yaml +act pull_request -W .github/workflows/test-integration.yaml ``` Alternatively, the `docker run` command in each GitHub workflow file can be used. diff --git a/integration/acl_test.go b/integration/acl_test.go index 9a415ab2..c746f900 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -3,82 +3,93 @@ package integration import ( "fmt" "net/netip" + "strconv" "strings" "testing" + "time" - "github.com/juanfont/headscale/hscontrol/policy" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" + "github.com/ory/dockertest/v3" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) -var veryLargeDestination = []string{ - "0.0.0.0/5:*", - "8.0.0.0/7:*", - "11.0.0.0/8:*", - "12.0.0.0/6:*", - "16.0.0.0/4:*", - "32.0.0.0/3:*", - "64.0.0.0/2:*", - "128.0.0.0/3:*", - "160.0.0.0/5:*", - "168.0.0.0/6:*", - "172.0.0.0/12:*", - "172.32.0.0/11:*", - "172.64.0.0/10:*", - "172.128.0.0/9:*", - "173.0.0.0/8:*", - "174.0.0.0/7:*", - "176.0.0.0/4:*", - "192.0.0.0/9:*", - "192.128.0.0/11:*", - "192.160.0.0/13:*", - "192.169.0.0/16:*", - "192.170.0.0/15:*", - "192.172.0.0/14:*", - "192.176.0.0/12:*", - "192.192.0.0/10:*", - "193.0.0.0/8:*", - "194.0.0.0/7:*", - "196.0.0.0/6:*", - "200.0.0.0/5:*", - "208.0.0.0/4:*", +var veryLargeDestination = []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("0.0.0.0/5"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("8.0.0.0/7"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("11.0.0.0/8"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("12.0.0.0/6"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("16.0.0.0/4"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("32.0.0.0/3"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("64.0.0.0/2"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("128.0.0.0/3"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("160.0.0.0/5"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("168.0.0.0/6"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.0.0.0/12"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.32.0.0/11"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.64.0.0/10"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("172.128.0.0/9"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("173.0.0.0/8"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("174.0.0.0/7"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("176.0.0.0/4"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.0.0.0/9"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.128.0.0/11"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.160.0.0/13"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.169.0.0/16"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.170.0.0/15"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.172.0.0/14"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.176.0.0/12"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("192.192.0.0/10"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("193.0.0.0/8"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("194.0.0.0/7"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("196.0.0.0/6"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("200.0.0.0/5"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("208.0.0.0/4"), tailcfg.PortRangeAny), } func aclScenario( t *testing.T, - policy *policy.ACLPolicy, + policy *policyv2.Policy, clientsPerUser int, ) *Scenario { t.Helper() - scenario, err := NewScenario() - assertNoErr(t, err) - spec := map[string]int{ - "user1": clientsPerUser, - "user2": clientsPerUser, + spec := ScenarioSpec{ + NodesPerUser: clientsPerUser, + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, + scenario, err := NewScenario(spec) + require.NoError(t, err) + + err = scenario.CreateHeadscaleEnv( []tsic.Option{ // Alpine containers dont have ip6tables set up, which causes // tailscaled to stop configuring the wgengine, causing it // to not configure DNS. tsic.WithNetfilter("off"), - tsic.WithDockerEntrypoint([]string{ - "/bin/sh", - "-c", - "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", - }), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), tsic.WithDockerWorkdir("/"), }, hsic.WithACLPolicy(policy), hsic.WithTestName("acl"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), ) - assertNoErr(t, err) + require.NoError(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + require.NoError(t, err) return scenario } @@ -91,95 +102,105 @@ func aclScenario( func TestACLHostsInNetMapTable(t *testing.T) { IntegrationSkip(t) + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1", "user2"}, + } + // NOTE: All want cases currently checks the // total count of expected peers, this would // typically be the client count of the users // they can access minus one (them self). tests := map[string]struct { - users map[string]int - policy policy.ACLPolicy + users ScenarioSpec + policy policyv2.Policy want map[string]int }{ // Test that when we have no ACL, each client netmap has // the amount of peers of the total amount of clients "base-acls": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + users: spec, + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ - "user1": 3, // ns1 + ns2 - "user2": 3, // ns2 + ns1 + "user1@test.no": 3, // ns1 + ns2 + "user2@test.no": 3, // ns2 + ns1 }, }, // Test that when we have two users, which cannot see - // eachother, each node has only the number of pairs from + // each other, each node has only the number of pairs from // their own user. "two-isolated-users": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + users: spec, + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user1:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"user2"}, - Destinations: []string{"user2:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ - "user1": 1, - "user2": 1, + "user1@test.no": 1, + "user2@test.no": 1, }, }, // Test that when we have two users, with ACLs and they // are restricted to a single port, nodes are still present // in the netmap. "two-restricted-present-in-netmap": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + users: spec, + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user1:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, { - Action: "accept", - Sources: []string{"user2"}, - Destinations: []string{"user2:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user2:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, { - Action: "accept", - Sources: []string{"user2"}, - Destinations: []string{"user1:22"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRange{First: 22, Last: 22}), + }, }, }, }, want: map[string]int{ - "user1": 3, - "user2": 3, + "user1@test.no": 3, + "user2@test.no": 3, }, }, // Test that when we have two users, that are isolated, @@ -187,108 +208,125 @@ func TestACLHostsInNetMapTable(t *testing.T) { // of peers. This will still result in all the peers as we // need them present on the other side for the "return path". "two-ns-one-isolated": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + users: spec, + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user1:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"user2"}, - Destinations: []string{"user2:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user2:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ - "user1": 3, // ns1 + ns2 - "user2": 3, // ns1 + ns2 (return path) + "user1@test.no": 3, // ns1 + ns2 + "user2@test.no": 3, // ns1 + ns2 (return path) }, }, "very-large-destination-prefix-1372": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + users: spec, + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: append([]string{"user1:*"}, veryLargeDestination...), + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: append( + []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, + veryLargeDestination..., + ), }, { - Action: "accept", - Sources: []string{"user2"}, - Destinations: append([]string{"user2:*"}, veryLargeDestination...), + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: append( + []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + veryLargeDestination..., + ), }, { - Action: "accept", - Sources: []string{"user1"}, - Destinations: append([]string{"user2:*"}, veryLargeDestination...), + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: append( + []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + veryLargeDestination..., + ), }, }, }, want: map[string]int{ - "user1": 3, // ns1 + ns2 - "user2": 3, // ns1 + ns2 (return path) + "user1@test.no": 3, // ns1 + ns2 + "user2@test.no": 3, // ns1 + ns2 (return path) }, }, "ipv6-acls-1470": { - users: map[string]int{ - "user1": 2, - "user2": 2, - }, - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + users: spec, + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"0.0.0.0/0:*", "::/0:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("0.0.0.0/0"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("::/0"), tailcfg.PortRangeAny), + }, }, }, }, want: map[string]int{ - "user1": 3, // ns1 + ns2 - "user2": 3, // ns2 + ns1 + "user1@test.no": 3, // ns1 + ns2 + "user2@test.no": 3, // ns2 + ns1 }, }, } for name, testCase := range tests { t.Run(name, func(t *testing.T) { - scenario, err := NewScenario() - assertNoErr(t, err) + caseSpec := testCase.users + scenario, err := NewScenario(caseSpec) + require.NoError(t, err) - spec := testCase.users - - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{}, hsic.WithACLPolicy(&testCase.policy), ) - assertNoErr(t, err) - defer scenario.Shutdown() + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErr(t, err) + require.NoError(t, err) - err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1"]) - assertNoErrSync(t, err) + err = scenario.WaitForTailscaleSyncWithPeerCount(testCase.want["user1@test.no"], integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err) for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) - user := status.User[status.Self.UserID].LoginName + user := status.User[status.Self.UserID].LoginName - assert.Equal(t, (testCase.want[user]), len(status.Peer)) + assert.Len(c, status.Peer, (testCase.want[user])) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer visibility") } }) } @@ -303,37 +341,41 @@ func TestACLAllowUser80Dst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policy.ACLPolicy{ - ACLs: []policy.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user2:80"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRange{First: 80, Last: 80}), + }, }, }, }, 1, ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Len(t, result, 13) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying user1 can reach user2") } } @@ -341,14 +383,16 @@ func TestACLAllowUser80Dst(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Empty(t, result) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.Error(c, err) + assert.Empty(c, result) + }, 20*time.Second, 500*time.Millisecond, "Verifying user2 cannot reach user1") } } } @@ -357,27 +401,29 @@ func TestACLDenyAllPort80(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-acl-test": {"user1", "user2"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-acl-test"): []policyv2.Username{policyv2.Username("user1@"), policyv2.Username("user2@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"group:integration-acl-test"}, - Destinations: []string{"*:22"}, + Action: "accept", + Sources: []policyv2.Alias{groupp("group:integration-acl-test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRange{First: 22, Last: 22}), + }, }, }, }, 4, ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErr(t, err) + require.NoError(t, err) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErr(t, err) + require.NoError(t, err) for _, client := range allClients { for _, hostname := range allHostnames { @@ -390,9 +436,11 @@ func TestACLDenyAllPort80(t *testing.T) { url := fmt.Sprintf("http://%s/etc/hostname", hostname) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Empty(t, result) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.Error(c, err) + assert.Empty(c, result) + }, 20*time.Second, 500*time.Millisecond, "Verifying all traffic is denied") } } } @@ -404,37 +452,41 @@ func TestACLAllowUserDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policy.ACLPolicy{ - ACLs: []policy.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"user2:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, }, }, }, 2, ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Len(t, result, 13) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying user1 can reach user2") } } @@ -442,14 +494,16 @@ func TestACLAllowUserDst(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Empty(t, result) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.Error(c, err) + assert.Empty(c, result) + }, 20*time.Second, 500*time.Millisecond, "Verifying user2 cannot reach user1") } } } @@ -460,37 +514,41 @@ func TestACLAllowStarDst(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policy.ACLPolicy{ - ACLs: []policy.ACL{ + &policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"user1"}, - Destinations: []string{"*:*"}, + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, }, 2, ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Len(t, result, 13) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying user1 can reach user2") } } @@ -498,14 +556,16 @@ func TestACLAllowStarDst(t *testing.T) { for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Empty(t, result) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.Error(c, err) + assert.Empty(c, result) + }, 20*time.Second, 500*time.Millisecond, "Verifying user2 cannot reach user1") } } } @@ -517,56 +577,64 @@ func TestACLNamedHostsCanReachBySubnet(t *testing.T) { IntegrationSkip(t) scenario := aclScenario(t, - &policy.ACLPolicy{ - Hosts: policy.Hosts{ - "all": netip.MustParsePrefix("100.64.0.0/24"), + &policyv2.Policy{ + Hosts: policyv2.Hosts{ + "all": policyv2.Prefix(netip.MustParsePrefix("100.64.0.0/24")), }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ // Everyone can curl test3 { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"all:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("all"), tailcfg.PortRangeAny), + }, }, }, }, 3, ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErr(t, err) + require.NoError(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErr(t, err) + require.NoError(t, err) // Test that user1 can visit all user2 for _, client := range user1Clients { for _, peer := range user2Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Len(t, result, 13) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying user1 can reach user2") } } // Test that user2 can visit all user1 + // Test that user2 can visit all user1, note that this + // is _not_ symmetric. for _, client := range user2Clients { for _, peer := range user1Clients { fqdn, err := peer.FQDN() - assertNoErr(t, err) + require.NoError(t, err) url := fmt.Sprintf("http://%s/etc/hostname", fqdn) t.Logf("url from %s to %s", client.Hostname(), url) - result, err := client.Curl(url) - assert.Len(t, result, 13) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying user2 can reach user1") } } } @@ -614,50 +682,58 @@ func TestACLNamedHostsCanReach(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy policy.ACLPolicy + policy policyv2.Policy }{ "ipv4": { - policy: policy.ACLPolicy{ - Hosts: policy.Hosts{ - "test1": netip.MustParsePrefix("100.64.0.1/32"), - "test2": netip.MustParsePrefix("100.64.0.2/32"), - "test3": netip.MustParsePrefix("100.64.0.3/32"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("100.64.0.1/32")), + "test2": policyv2.Prefix(netip.MustParsePrefix("100.64.0.2/32")), + "test3": policyv2.Prefix(netip.MustParsePrefix("100.64.0.3/32")), }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ // Everyone can curl test3 { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"test3:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test3"), tailcfg.PortRangeAny), + }, }, // test1 can curl test2 { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, }, "ipv6": { - policy: policy.ACLPolicy{ - Hosts: policy.Hosts{ - "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), - "test3": netip.MustParsePrefix("fd7a:115c:a1e0::3/128"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::1/128")), + "test2": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::2/128")), + "test3": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::3/128")), }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ // Everyone can curl test3 { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"test3:*"}, + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test3"), tailcfg.PortRangeAny), + }, }, // test1 can curl test2 { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, @@ -670,17 +746,18 @@ func TestACLNamedHostsCanReach(t *testing.T) { &testCase.policy, 2, ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) // Since user/users dont matter here, we basically expect that some clients // will be assigned these ips and that we can pick them up for our own use. test1ip4 := netip.MustParseAddr("100.64.0.1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1, err := scenario.FindTailscaleClientByIP(test1ip6) - assertNoErr(t, err) + require.NoError(t, err) test1fqdn, err := test1.FQDN() - assertNoErr(t, err) + require.NoError(t, err) + test1ip4URL := fmt.Sprintf("http://%s/etc/hostname", test1ip4.String()) test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String()) test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) @@ -688,10 +765,11 @@ func TestACLNamedHostsCanReach(t *testing.T) { test2ip4 := netip.MustParseAddr("100.64.0.2") test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2") test2, err := scenario.FindTailscaleClientByIP(test2ip6) - assertNoErr(t, err) + require.NoError(t, err) test2fqdn, err := test2.FQDN() - assertNoErr(t, err) + require.NoError(t, err) + test2ip4URL := fmt.Sprintf("http://%s/etc/hostname", test2ip4.String()) test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String()) test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) @@ -699,154 +777,173 @@ func TestACLNamedHostsCanReach(t *testing.T) { test3ip4 := netip.MustParseAddr("100.64.0.3") test3ip6 := netip.MustParseAddr("fd7a:115c:a1e0::3") test3, err := scenario.FindTailscaleClientByIP(test3ip6) - assertNoErr(t, err) + require.NoError(t, err) test3fqdn, err := test3.FQDN() - assertNoErr(t, err) + require.NoError(t, err) + test3ip4URL := fmt.Sprintf("http://%s/etc/hostname", test3ip4.String()) test3ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test3ip6.String()) test3fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test3fqdn) // test1 can query test3 - result, err := test1.Curl(test3ip4URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip4URL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test3ip4URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip4URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv4") - result, err = test1.Curl(test3ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip6URL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test3ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via IPv6") - result, err = test1.Curl(test3fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3fqdnURL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test3fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test3 via FQDN") // test2 can query test3 - result, err = test2.Curl(test3ip4URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip4URL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test3ip4URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip4URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv4") - result, err = test2.Curl(test3ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3ip6URL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test3ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via IPv6") - result, err = test2.Curl(test3fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", - test3fqdnURL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test3fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test3 with URL %s, expected hostname of 13 chars, got %s", + test3fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test2 should reach test3 via FQDN") // test3 cannot query test1 - result, err = test3.Curl(test1ip4URL) + result, err := test3.Curl(test1ip4URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test1ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test1fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) // test3 cannot query test2 result, err = test3.Curl(test2ip4URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test2ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test3.Curl(test2fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) // test1 can query test2 - result, err = test1.Curl(test2ip4URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", - test2ip4URL, - result, - ) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ip4URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", + test2ip4URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4") - assertNoErr(t, err) - result, err = test1.Curl(test2ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", - test2ip6URL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", + test2ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6") - result, err = test1.Curl(test2fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", - test2fqdnURL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test2 with URL %s, expected hostname of 13 chars, got %s", + test2fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN") // test2 cannot query test1 result, err = test2.Curl(test1ip4URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test2.Curl(test1ip6URL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) result, err = test2.Curl(test1fqdnURL) assert.Empty(t, result) - assert.Error(t, err) + require.Error(t, err) }) } } @@ -863,71 +960,81 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { IntegrationSkip(t) tests := map[string]struct { - policy policy.ACLPolicy + policy policyv2.Policy }{ "ipv4": { - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"100.64.0.1"}, - Destinations: []string{"100.64.0.2:*"}, + Action: "accept", + Sources: []policyv2.Alias{prefixp("100.64.0.1/32")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("100.64.0.2/32"), tailcfg.PortRangeAny), + }, }, }, }, }, "ipv6": { - policy: policy.ACLPolicy{ - ACLs: []policy.ACL{ + policy: policyv2.Policy{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"fd7a:115c:a1e0::1"}, - Destinations: []string{"fd7a:115c:a1e0::2:*"}, + Action: "accept", + Sources: []policyv2.Alias{prefixp("fd7a:115c:a1e0::1/128")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("fd7a:115c:a1e0::2/128"), tailcfg.PortRangeAny), + }, }, }, }, }, "hostv4cidr": { - policy: policy.ACLPolicy{ - Hosts: policy.Hosts{ - "test1": netip.MustParsePrefix("100.64.0.1/32"), - "test2": netip.MustParsePrefix("100.64.0.2/32"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("100.64.0.1/32")), + "test2": policyv2.Prefix(netip.MustParsePrefix("100.64.0.2/32")), }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, }, "hostv6cidr": { - policy: policy.ACLPolicy{ - Hosts: policy.Hosts{ - "test1": netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), - "test2": netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + policy: policyv2.Policy{ + Hosts: policyv2.Hosts{ + "test1": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::1/128")), + "test2": policyv2.Prefix(netip.MustParsePrefix("fd7a:115c:a1e0::2/128")), }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"test1"}, - Destinations: []string{"test2:*"}, + Action: "accept", + Sources: []policyv2.Alias{hostp("test1")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(hostp("test2"), tailcfg.PortRangeAny), + }, }, }, }, }, "group": { - policy: policy.ACLPolicy{ - Groups: map[string][]string{ - "group:one": {"user1"}, - "group:two": {"user2"}, + policy: policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:one"): []policyv2.Username{policyv2.Username("user1@")}, + policyv2.Group("group:two"): []policyv2.Username{policyv2.Username("user2@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"group:one"}, - Destinations: []string{"group:two:*"}, + Action: "accept", + Sources: []policyv2.Alias{groupp("group:one")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:two"), tailcfg.PortRangeAny), + }, }, }, }, @@ -939,15 +1046,17 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { for name, testCase := range tests { t.Run(name, func(t *testing.T) { scenario := aclScenario(t, &testCase.policy, 1) + defer scenario.ShutdownAssertNoPanics(t) test1ip := netip.MustParseAddr("100.64.0.1") test1ip6 := netip.MustParseAddr("fd7a:115c:a1e0::1") test1, err := scenario.FindTailscaleClientByIP(test1ip) assert.NotNil(t, test1) - assertNoErr(t, err) + require.NoError(t, err) test1fqdn, err := test1.FQDN() - assertNoErr(t, err) + require.NoError(t, err) + test1ipURL := fmt.Sprintf("http://%s/etc/hostname", test1ip.String()) test1ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test1ip6.String()) test1fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test1fqdn) @@ -956,59 +1065,2767 @@ func TestACLDevice1CanAccessDevice2(t *testing.T) { test2ip6 := netip.MustParseAddr("fd7a:115c:a1e0::2") test2, err := scenario.FindTailscaleClientByIP(test2ip) assert.NotNil(t, test2) - assertNoErr(t, err) + require.NoError(t, err) test2fqdn, err := test2.FQDN() - assertNoErr(t, err) + require.NoError(t, err) + test2ipURL := fmt.Sprintf("http://%s/etc/hostname", test2ip.String()) test2ip6URL := fmt.Sprintf("http://[%s]/etc/hostname", test2ip6.String()) test2fqdnURL := fmt.Sprintf("http://%s/etc/hostname", test2fqdn) // test1 can query test2 - result, err := test1.Curl(test2ipURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", - test2ipURL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ipURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", + test2ipURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv4") - result, err = test1.Curl(test2ip6URL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", - test2ip6URL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2ip6URL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", + test2ip6URL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via IPv6") - result, err = test1.Curl(test2fqdnURL) - assert.Lenf( - t, - result, - 13, - "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", - test2fqdnURL, - result, - ) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test1.Curl(test2fqdnURL) + assert.NoError(c, err) + assert.Lenf( + c, + result, + 13, + "failed to connect from test1 to test with URL %s, expected hostname of 13 chars, got %s", + test2fqdnURL, + result, + ) + }, 10*time.Second, 200*time.Millisecond, "test1 should reach test2 via FQDN") - result, err = test2.Curl(test1ipURL) - assert.Empty(t, result) - assert.Error(t, err) + // test2 cannot query test1 (negative test case) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test1ipURL) + assert.Error(c, err) + assert.Empty(c, result) + }, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv4") - result, err = test2.Curl(test1ip6URL) - assert.Empty(t, result) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test1ip6URL) + assert.Error(c, err) + assert.Empty(c, result) + }, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via IPv6") - result, err = test2.Curl(test1fqdnURL) - assert.Empty(t, result) - assert.Error(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := test2.Curl(test1fqdnURL) + assert.Error(c, err) + assert.Empty(c, result) + }, 10*time.Second, 200*time.Millisecond, "test2 should NOT reach test1 via FQDN") }) } } + +func TestPolicyUpdateWhileRunningWithCLIInDatabase(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + // Alpine containers dont have ip6tables set up, which causes + // tailscaled to stop configuring the wgengine, causing it + // to not configure DNS. + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithTestName("policyreload"), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + + all := append(user1Clients, user2Clients...) + + // Initially all nodes can reach each other + for _, client := range all { + for _, peer := range all { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying user1 can reach user2") + } + } + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + p := policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + Hosts: policyv2.Hosts{}, + } + + err = headscale.SetPolicy(&p) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Get the current policy and check + // if it is the same as the one we set. + var output *policyv2.Policy + + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "policy", + "get", + "--output", + "json", + }, + &output, + ) + assert.NoError(ct, err) + + assert.Len(t, output.ACLs, 1) + + if diff := cmp.Diff(p, *output, cmpopts.IgnoreUnexported(policyv2.Policy{}), cmpopts.EquateEmpty()); diff != "" { + ct.Errorf("unexpected policy(-want +got):\n%s", diff) + } + }, 30*time.Second, 1*time.Second, "verifying that the new policy took place") + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Test that user1 can visit all user2 + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + assert.NoError(ct, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + result, err := client.Curl(url) + assert.Len(ct, result, 13) + assert.NoError(ct, err) + } + } + + // Test that user2 _cannot_ visit user1 + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + assert.NoError(ct, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + result, err := client.Curl(url) + assert.Empty(ct, result) + assert.Error(ct, err) + } + } + }, 30*time.Second, 1*time.Second, "new policy did not get propagated to nodes") +} + +func TestACLAutogroupMember(t *testing.T) { + IntegrationSkip(t) + + scenario := aclScenario(t, + &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupMember), tailcfg.PortRangeAny), + }, + }, + }, + }, + 2, + ) + defer scenario.ShutdownAssertNoPanics(t) + + allClients, err := scenario.ListTailscaleClients() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + // Test that untagged nodes can access each other + for _, client := range allClients { + var clientIsUntagged bool + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + clientIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0 + assert.True(c, clientIsUntagged, "Expected client %s to be untagged for autogroup:member test", client.Hostname()) + }, 10*time.Second, 200*time.Millisecond, "Waiting for client %s to be untagged", client.Hostname()) + + if !clientIsUntagged { + continue + } + + for _, peer := range allClients { + if client.Hostname() == peer.Hostname() { + continue + } + + var peerIsUntagged bool + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := peer.Status() + assert.NoError(c, err) + + peerIsUntagged = status.Self.Tags == nil || status.Self.Tags.Len() == 0 + assert.True(c, peerIsUntagged, "Expected peer %s to be untagged for autogroup:member test", peer.Hostname()) + }, 10*time.Second, 200*time.Millisecond, "Waiting for peer %s to be untagged", peer.Hostname()) + + if !peerIsUntagged { + continue + } + + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s to %s", client.Hostname(), url) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 20*time.Second, 500*time.Millisecond, "Verifying autogroup:member connectivity") + } + } +} + +func TestACLAutogroupTagged(t *testing.T) { + IntegrationSkip(t) + + // Create a custom scenario for testing autogroup:tagged + spec := ScenarioSpec{ + NodesPerUser: 2, // 2 nodes per user - one tagged, one untagged + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + policy := &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:test": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupTagged)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupTagged), tailcfg.PortRangeAny), + }, + }, + }, + } + + // Create only the headscale server (not the full environment with users/nodes) + headscale, err := scenario.Headscale( + hsic.WithACLPolicy(policy), + hsic.WithTestName("acl-autogroup-tagged"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + // Create users and nodes manually with specific tags + // Tags are now set via PreAuthKey (tags-as-identity model), not via --advertise-tags + for _, userStr := range spec.Users { + user, err := scenario.CreateUser(userStr) + require.NoError(t, err) + + // Create two pre-auth keys per user: one tagged, one untagged + taggedAuthKey, err := scenario.CreatePreAuthKeyWithTags(user.GetId(), true, false, []string{"tag:test"}) + require.NoError(t, err) + + untaggedAuthKey, err := scenario.CreatePreAuthKey(user.GetId(), true, false) + require.NoError(t, err) + + // Create nodes with proper naming + for i := range spec.NodesPerUser { + var ( + authKey string + version string + ) + + if i == 0 { + // First node is tagged - use tagged PreAuthKey + authKey = taggedAuthKey.GetKey() + version = "head" + + t.Logf("Creating tagged node for %s", userStr) + } else { + // Second node is untagged - use untagged PreAuthKey + authKey = untaggedAuthKey.GetKey() + version = "unstable" + + t.Logf("Creating untagged node for %s", userStr) + } + + // Get the network for this scenario + networks := scenario.Networks() + + var network *dockertest.Network + if len(networks) > 0 { + network = networks[0] + } + + // Create the tailscale node with appropriate options + opts := []tsic.Option{ + tsic.WithCACert(headscale.GetCert()), + tsic.WithHeadscaleName(headscale.GetHostname()), + tsic.WithNetwork(network), + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + } + + tsClient, err := tsic.New( + scenario.Pool(), + version, + opts..., + ) + require.NoError(t, err) + + err = tsClient.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + // Login with the appropriate auth key (tags come from the PreAuthKey) + err = tsClient.Login(headscale.GetEndpoint(), authKey) + require.NoError(t, err) + + err = tsClient.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + // Add client to user + userObj := scenario.GetOrCreateUser(userStr) + userObj.Clients[tsClient.Hostname()] = tsClient + } + } + + allClients, err := scenario.ListTailscaleClients() + require.NoError(t, err) + require.Len(t, allClients, 4) // 2 users * 2 nodes each + + // Wait for nodes to see only their allowed peers + // Tagged nodes should see each other (2 tagged nodes total) + // Untagged nodes should see no one + var ( + taggedClients []TailscaleClient + untaggedClients []TailscaleClient + ) + + // First, categorize nodes by checking their tags + + for _, client := range allClients { + hostname := client.Hostname() + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + + if status.Self.Tags != nil && status.Self.Tags.Len() > 0 { + // This is a tagged node + assert.Len(ct, status.Peers(), 1, "tagged node %s should see exactly 1 peer", hostname) + + // Add to tagged list only once we've verified it + found := false + + for _, tc := range taggedClients { + if tc.Hostname() == hostname { + found = true + break + } + } + + if !found { + taggedClients = append(taggedClients, client) + } + } else { + // This is an untagged node + assert.Empty(ct, status.Peers(), "untagged node %s should see 0 peers", hostname) + + // Add to untagged list only once we've verified it + found := false + + for _, uc := range untaggedClients { + if uc.Hostname() == hostname { + found = true + break + } + } + + if !found { + untaggedClients = append(untaggedClients, client) + } + } + }, 30*time.Second, 1*time.Second, "verifying peer visibility for node %s", hostname) + } + + // Verify we have the expected number of tagged and untagged nodes + require.Len(t, taggedClients, 2, "should have exactly 2 tagged nodes") + require.Len(t, untaggedClients, 2, "should have exactly 2 untagged nodes") + + // Explicitly verify tags on tagged nodes + for _, client := range taggedClients { + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + assert.NotNil(c, status.Self.Tags, "tagged node %s should have tags", client.Hostname()) + assert.Positive(c, status.Self.Tags.Len(), "tagged node %s should have at least one tag", client.Hostname()) + }, 10*time.Second, 200*time.Millisecond, "Waiting for tags to be applied to tagged nodes") + } + + // Verify untagged nodes have no tags + for _, client := range untaggedClients { + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + if status.Self.Tags != nil { + assert.Equal(c, 0, status.Self.Tags.Len(), "untagged node %s should have no tags", client.Hostname()) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting to verify untagged nodes have no tags") + } + + // Test that tagged nodes can communicate with each other + for _, client := range taggedClients { + for _, peer := range taggedClients { + if client.Hostname() == peer.Hostname() { + continue + } + + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + + t.Logf("Testing connection from tagged node %s to tagged node %s", client.Hostname(), peer.Hostname()) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(ct, err) + assert.Len(ct, result, 13) + }, 20*time.Second, 500*time.Millisecond, "tagged nodes should be able to communicate") + } + } + + // Test that untagged nodes cannot communicate with anyone + for _, client := range untaggedClients { + // Try to reach tagged nodes (should fail) + for _, peer := range taggedClients { + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + + t.Logf("Testing connection from untagged node %s to tagged node %s (should fail)", client.Hostname(), peer.Hostname()) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, err := client.CurlFailFast(url) + assert.Empty(ct, result) + assert.Error(ct, err) + }, 5*time.Second, 200*time.Millisecond, "untagged nodes should not be able to reach tagged nodes") + } + + // Try to reach other untagged nodes (should also fail) + for _, peer := range untaggedClients { + if client.Hostname() == peer.Hostname() { + continue + } + + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + + t.Logf("Testing connection from untagged node %s to untagged node %s (should fail)", client.Hostname(), peer.Hostname()) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, err := client.CurlFailFast(url) + assert.Empty(ct, result) + assert.Error(ct, err) + }, 5*time.Second, 200*time.Millisecond, "untagged nodes should not be able to reach other untagged nodes") + } + } + + // Test that tagged nodes cannot reach untagged nodes + for _, client := range taggedClients { + for _, peer := range untaggedClients { + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + + t.Logf("Testing connection from tagged node %s to untagged node %s (should fail)", client.Hostname(), peer.Hostname()) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, err := client.CurlFailFast(url) + assert.Empty(ct, result) + assert.Error(ct, err) + }, 5*time.Second, 200*time.Millisecond, "tagged nodes should not be able to reach untagged nodes") + } + } +} + +// Test that only devices owned by the same user can access each other and cannot access devices of other users +// Test structure: +// - user1: 2 regular nodes (tests autogroup:self for same-user access) +// - user2: 2 regular nodes (tests autogroup:self for same-user access and cross-user isolation) +// - user-router: 1 node with tag:router-node (tests that autogroup:self doesn't interfere with other rules). +func TestACLAutogroupSelf(t *testing.T) { + IntegrationSkip(t) + + // Policy with TWO separate ACL rules: + // 1. autogroup:member -> autogroup:self (same-user access) + // 2. group:home -> tag:router-node (router access) + // This tests that autogroup:self doesn't prevent other rules from working + policy := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:home"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + }, + }, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:router-node"): policyv2.Owners{ + usernameOwner("user-router@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:home")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:router-node"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:router-node")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:home"), tailcfg.PortRangeAny), + }, + }, + }, + } + + // Create custom scenario: user1 and user2 with regular nodes, plus user-router with tagged node + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("acl-autogroup-self"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + // Add router node for user-router (single shared router node) + networks := scenario.Networks() + + var network *dockertest.Network + if len(networks) > 0 { + network = networks[0] + } + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + routerUser, err := scenario.CreateUser("user-router") + require.NoError(t, err) + + // Create a tagged PreAuthKey for the router node (tags-as-identity model) + authKey, err := scenario.CreatePreAuthKeyWithTags(routerUser.GetId(), true, false, []string{"tag:router-node"}) + require.NoError(t, err) + + // Create router node (tags come from the PreAuthKey) + routerClient, err := tsic.New( + scenario.Pool(), + "unstable", + tsic.WithCACert(headscale.GetCert()), + tsic.WithHeadscaleName(headscale.GetHostname()), + tsic.WithNetwork(network), + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + ) + require.NoError(t, err) + + err = routerClient.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + err = routerClient.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + err = routerClient.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + userRouterObj := scenario.GetOrCreateUser("user-router") + userRouterObj.Clients[routerClient.Hostname()] = routerClient + + user1Clients, err := scenario.GetClients("user1") + require.NoError(t, err) + user2Clients, err := scenario.GetClients("user2") + require.NoError(t, err) + + var user1Regular, user2Regular []TailscaleClient + + for _, client := range user1Clients { + status, err := client.Status() + require.NoError(t, err) + + if status.Self != nil && (status.Self.Tags == nil || status.Self.Tags.Len() == 0) { + user1Regular = append(user1Regular, client) + } + } + + for _, client := range user2Clients { + status, err := client.Status() + require.NoError(t, err) + + if status.Self != nil && (status.Self.Tags == nil || status.Self.Tags.Len() == 0) { + user2Regular = append(user2Regular, client) + } + } + + require.NotEmpty(t, user1Regular, "user1 should have regular (untagged) devices") + require.NotEmpty(t, user2Regular, "user2 should have regular (untagged) devices") + require.NotNil(t, routerClient, "router node should exist") + + // Wait for all nodes to sync with their expected peer counts + // With our ACL policy: + // - Regular nodes (user1/user2): 1 same-user regular peer + 1 router-node = 2 peers + // - Router node: 2 user1 regular + 2 user2 regular = 4 peers + for _, client := range user1Regular { + err := client.WaitForPeers(2, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err, "user1 regular device %s should see 2 peers (1 same-user peer + 1 router)", client.Hostname()) + } + + for _, client := range user2Regular { + err := client.WaitForPeers(2, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err, "user2 regular device %s should see 2 peers (1 same-user peer + 1 router)", client.Hostname()) + } + + err = routerClient.WaitForPeers(4, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) + require.NoError(t, err, "router should see 4 peers (all group:home regular nodes)") + + // Test that user1's regular devices can access each other + for _, client := range user1Regular { + for _, peer := range user1Regular { + if client.Hostname() == peer.Hostname() { + continue + } + + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user1) to %s (user1)", client.Hostname(), fqdn) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 10*time.Second, 200*time.Millisecond, "user1 device should reach other user1 device via autogroup:self") + } + } + + // Test that user2's regular devices can access each other + for _, client := range user2Regular { + for _, peer := range user2Regular { + if client.Hostname() == peer.Hostname() { + continue + } + + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user2) to %s (user2)", client.Hostname(), fqdn) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 10*time.Second, 200*time.Millisecond, "user2 device should reach other user2 device via autogroup:self") + } + } + + // Test that user1's regular devices can access router-node + for _, client := range user1Regular { + fqdn, err := routerClient.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user1) to %s (router-node) - should SUCCEED", client.Hostname(), fqdn) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.NotEmpty(c, result, "user1 should be able to access router-node via group:home -> tag:router-node rule") + }, 10*time.Second, 200*time.Millisecond, "user1 device should reach router-node (proves autogroup:self doesn't interfere)") + } + + // Test that user2's regular devices can access router-node + for _, client := range user2Regular { + fqdn, err := routerClient.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user2) to %s (router-node) - should SUCCEED", client.Hostname(), fqdn) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.NotEmpty(c, result, "user2 should be able to access router-node via group:home -> tag:router-node rule") + }, 10*time.Second, 200*time.Millisecond, "user2 device should reach router-node (proves autogroup:self doesn't interfere)") + } + + // Test that devices from different users cannot access each other's regular devices + for _, client := range user1Regular { + for _, peer := range user2Regular { + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user1) to %s (user2 regular) - should FAIL", client.Hostname(), fqdn) + + result, err := client.Curl(url) + assert.Empty(t, result, "user1 should not be able to access user2's regular devices (autogroup:self isolation)") + assert.Error(t, err, "connection from user1 to user2 regular device should fail") + } + } + + for _, client := range user2Regular { + for _, peer := range user1Regular { + fqdn, err := peer.FQDN() + require.NoError(t, err) + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + t.Logf("url from %s (user2) to %s (user1 regular) - should FAIL", client.Hostname(), fqdn) + + result, err := client.Curl(url) + assert.Empty(t, result, "user2 should not be able to access user1's regular devices (autogroup:self isolation)") + assert.Error(t, err, "connection from user2 to user1 regular device should fail") + } + } +} + +func TestACLPolicyPropagationOverTime(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + // Install iptables to enable packet filtering for ACL tests. + // Packet filters are essential for testing autogroup:self and other ACL policies. + tsic.WithPackages("curl", "iptables", "ip6tables"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithTestName("aclpropagation"), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + + allClients := append(user1Clients, user2Clients...) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Define the four policies we'll cycle through + allowAllPolicy := &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + autogroupSelfPolicy := &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{ptr.To(policyv2.AutoGroupMember)}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(ptr.To(policyv2.AutoGroupSelf), tailcfg.PortRangeAny), + }, + }, + }, + } + + user1ToUser2Policy := &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + } + + // Run through the policy cycle 5 times + for i := range 5 { + iteration := i + 1 // range 5 gives 0-4, we want 1-5 for logging + t.Logf("=== Iteration %d/5 ===", iteration) + + // Phase 1: Allow all policy + t.Logf("Iteration %d: Setting allow-all policy", iteration) + + err = headscale.SetPolicy(allowAllPolicy) + require.NoError(t, err) + + // Wait for peer lists to sync with allow-all policy + t.Logf("Iteration %d: Phase 1 - Waiting for peer lists to sync with allow-all policy", iteration) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err, "iteration %d: Phase 1 - failed to sync after allow-all policy", iteration) + + // Test all-to-all connectivity after state is settled + t.Logf("Iteration %d: Phase 1 - Testing all-to-all connectivity", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + for _, peer := range allClients { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: %s should reach %s with allow-all policy", iteration, client.Hostname(), fqdn) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 1 - all connectivity tests with allow-all policy", iteration) + + // Phase 2: Autogroup:self policy (only same user can access) + t.Logf("Iteration %d: Phase 2 - Setting autogroup:self policy", iteration) + + err = headscale.SetPolicy(autogroupSelfPolicy) + require.NoError(t, err) + + // Wait for peer lists to sync with autogroup:self - ensures cross-user peers are removed + t.Logf("Iteration %d: Phase 2 - Waiting for peer lists to sync with autogroup:self", iteration) + + err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond) + require.NoError(t, err, "iteration %d: Phase 2 - failed to sync after autogroup:self policy", iteration) + + // Test ALL connectivity (positive and negative) in one block after state is settled + t.Logf("Iteration %d: Phase 2 - Testing all connectivity with autogroup:self", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Positive: user1 can access user1's nodes + for _, client := range user1Clients { + for _, peer := range user1Clients { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname()) + } + } + + // Positive: user2 can access user2's nodes + for _, client := range user2Clients { + for _, peer := range user2Clients { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user2 %s should reach user2's node %s", iteration, client.Hostname(), fqdn) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), fqdn) + } + } + + // Negative: user1 cannot access user2's nodes + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user1 %s should NOT reach user2's node %s with autogroup:self", iteration, client.Hostname(), fqdn) + assert.Empty(ct, result, "iteration %d: user1 %s->user2 %s should fail", iteration, client.Hostname(), fqdn) + } + } + + // Negative: user2 cannot access user1's nodes + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Empty(ct, result, "iteration %d: user2->user1 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname()) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2 - all connectivity tests with autogroup:self", iteration) + + // Phase 2b: Add a new node to user1 and validate policy propagation + t.Logf("Iteration %d: Phase 2b - Adding new node to user1 during autogroup:self policy", iteration) + + // Add a new node with the same options as the initial setup + // Get the network to use (scenario uses first network in list) + networks := scenario.Networks() + require.NotEmpty(t, networks, "scenario should have at least one network") + + newClient := scenario.MustAddAndLoginClient(t, "user1", "all", headscale, + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + tsic.WithNetwork(networks[0]), + ) + t.Logf("Iteration %d: Phase 2b - Added and logged in new node %s", iteration, newClient.Hostname()) + + // Wait for peer lists to sync after new node addition (now 3 user1 nodes, still autogroup:self) + t.Logf("Iteration %d: Phase 2b - Waiting for peer lists to sync after new node addition", iteration) + + err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond) + require.NoError(t, err, "iteration %d: Phase 2b - failed to sync after new node addition", iteration) + + // Test ALL connectivity (positive and negative) in one block after state is settled + t.Logf("Iteration %d: Phase 2b - Testing all connectivity after new node addition", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Re-fetch client list to ensure latest state + user1ClientsWithNew, err := scenario.ListTailscaleClients("user1") + assert.NoError(ct, err, "iteration %d: failed to list user1 clients", iteration) + assert.Len(ct, user1ClientsWithNew, 3, "iteration %d: user1 should have 3 nodes", iteration) + + // Positive: all user1 nodes can access each other + for _, client := range user1ClientsWithNew { + for _, peer := range user1ClientsWithNew { + if client.ContainerID() == peer.ContainerID() { + continue + } + + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user1 node %s should reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname()) + } + } + + // Negative: user1 nodes cannot access user2's nodes + for _, client := range user1ClientsWithNew { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user1 node %s should NOT reach user2 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Empty(ct, result, "iteration %d: user1->user2 connection from %s to %s should fail", iteration, client.Hostname(), peer.Hostname()) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - all connectivity tests after new node addition", iteration) + + // Delete the newly added node before Phase 3 + t.Logf("Iteration %d: Phase 2b - Deleting the newly added node from user1", iteration) + + // Get the node list and find the newest node (highest ID) + var ( + nodeList []*v1.Node + nodeToDeleteID uint64 + ) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodeList, err = headscale.ListNodes("user1") + assert.NoError(ct, err) + assert.Len(ct, nodeList, 3, "should have 3 user1 nodes before deletion") + + // Find the node with the highest ID (the newest one) + for _, node := range nodeList { + if node.GetId() > nodeToDeleteID { + nodeToDeleteID = node.GetId() + } + } + }, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - listing nodes before deletion", iteration) + + // Delete the node via headscale helper + t.Logf("Iteration %d: Phase 2b - Deleting node ID %d from headscale", iteration, nodeToDeleteID) + err = headscale.DeleteNode(nodeToDeleteID) + require.NoError(t, err, "iteration %d: failed to delete node %d", iteration, nodeToDeleteID) + + // Remove the deleted client from the scenario's user.Clients map + // This is necessary for WaitForTailscaleSyncPerUser to calculate correct peer counts + t.Logf("Iteration %d: Phase 2b - Removing deleted client from scenario", iteration) + + for clientName, client := range scenario.users["user1"].Clients { + status := client.MustStatus() + + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + if err != nil { + continue + } + + if nodeID == nodeToDeleteID { + delete(scenario.users["user1"].Clients, clientName) + t.Logf("Iteration %d: Phase 2b - Removed client %s (node ID %d) from scenario", iteration, clientName, nodeToDeleteID) + + break + } + } + + // Verify the node has been deleted + t.Logf("Iteration %d: Phase 2b - Verifying node deletion (expecting 2 user1 nodes)", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodeListAfter, err := headscale.ListNodes("user1") + assert.NoError(ct, err, "failed to list nodes after deletion") + assert.Len(ct, nodeListAfter, 2, "iteration %d: should have 2 user1 nodes after deletion, got %d", iteration, len(nodeListAfter)) + }, 10*time.Second, 500*time.Millisecond, "iteration %d: Phase 2b - node should be deleted", iteration) + + // Wait for sync after deletion to ensure peer counts are correct + // Use WaitForTailscaleSyncPerUser because autogroup:self is still active, + // so nodes only see same-user peers, not all nodes + t.Logf("Iteration %d: Phase 2b - Waiting for sync after node deletion (with autogroup:self)", iteration) + + err = scenario.WaitForTailscaleSyncPerUser(60*time.Second, 500*time.Millisecond) + require.NoError(t, err, "iteration %d: failed to sync after node deletion", iteration) + + // Refresh client lists after deletion to ensure we don't reference the deleted node + user1Clients, err = scenario.ListTailscaleClients("user1") + require.NoError(t, err, "iteration %d: failed to refresh user1 client list after deletion", iteration) + user2Clients, err = scenario.ListTailscaleClients("user2") + require.NoError(t, err, "iteration %d: failed to refresh user2 client list after deletion", iteration) + // Create NEW slice instead of appending to old allClients which still has deleted client + allClients = make([]TailscaleClient, 0, len(user1Clients)+len(user2Clients)) + allClients = append(allClients, user1Clients...) + allClients = append(allClients, user2Clients...) + + t.Logf("Iteration %d: Phase 2b completed - New node added, validated, and removed successfully", iteration) + + // Phase 3: User1 can access user2 but not reverse + t.Logf("Iteration %d: Phase 3 - Setting user1->user2 directional policy", iteration) + + err = headscale.SetPolicy(user1ToUser2Policy) + require.NoError(t, err) + + // Note: Cannot use WaitForTailscaleSync() here because directional policy means + // user2 nodes don't see user1 nodes in their peer list (asymmetric visibility). + // The EventuallyWithT block below will handle waiting for policy propagation. + + // Test ALL connectivity (positive and negative) in one block after policy settles + t.Logf("Iteration %d: Phase 3 - Testing all connectivity with directional policy", iteration) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Positive: user1 can access user2's nodes + for _, client := range user1Clients { + for _, peer := range user2Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user2 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.NoError(ct, err, "iteration %d: user1 node %s should reach user2 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Len(ct, result, 13, "iteration %d: response from %s to %s should be valid", iteration, client.Hostname(), peer.Hostname()) + } + } + + // Negative: user2 cannot access user1's nodes + for _, client := range user2Clients { + for _, peer := range user1Clients { + fqdn, err := peer.FQDN() + if !assert.NoError(ct, err, "iteration %d: failed to get FQDN for user1 peer %s", iteration, peer.Hostname()) { + continue + } + + url := fmt.Sprintf("http://%s/etc/hostname", fqdn) + result, err := client.Curl(url) + assert.Error(ct, err, "iteration %d: user2 node %s should NOT reach user1 node %s", iteration, client.Hostname(), peer.Hostname()) + assert.Empty(ct, result, "iteration %d: user2->user1 from %s to %s should fail", iteration, client.Hostname(), peer.Hostname()) + } + } + }, 90*time.Second, 500*time.Millisecond, "iteration %d: Phase 3 - all connectivity tests with directional policy", iteration) + + t.Logf("=== Iteration %d/5 completed successfully - All 3 phases passed ===", iteration) + } + + t.Log("All 5 iterations completed successfully - ACL propagation is working correctly") +} + +// TestACLTagPropagation validates that tag changes propagate immediately +// to ACLs without requiring a Headscale restart. +// This is the primary test for GitHub issue #2389. +func TestACLTagPropagation(t *testing.T) { + IntegrationSkip(t) + + tests := []struct { + name string + policy *policyv2.Policy + spec ScenarioSpec + // setup returns clients and any initial state needed + setup func(t *testing.T, scenario *Scenario, headscale ControlServer) ( + sourceClient TailscaleClient, + targetClient TailscaleClient, + targetNodeID uint64, + ) + // initialAccess: should source be able to reach target before tag change? + initialAccess bool + // tagChange: what tags to set on target node (nil = test uses custom logic) + tagChange []string + // finalAccess: should source be able to reach target after tag change? + finalAccess bool + }{ + { + name: "add-tag-grants-access", + policy: &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:shared": policyv2.Owners{usernameOwner("user1@")}, + }, + ACLs: []policyv2.ACL{ + // user1 self-access + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user1@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user1@"), tailcfg.PortRangeAny), + }, + }, + // user2 self-access + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + // user2 can access tag:shared + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:shared"), tailcfg.PortRangeAny), + }, + }, + // tag:shared can respond to user2 (return path) + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:shared")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + }, + spec: ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + }, + setup: func(t *testing.T, scenario *Scenario, headscale ControlServer) (TailscaleClient, TailscaleClient, uint64) { + t.Helper() + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + + nodes, err := headscale.ListNodes("user1") + require.NoError(t, err) + + return user2Clients[0], user1Clients[0], nodes[0].GetId() + }, + initialAccess: false, // user2 cannot access user1 (no tag) + tagChange: []string{"tag:shared"}, // add tag:shared + finalAccess: true, // user2 can now access user1 + }, + { + name: "remove-tag-revokes-access", + policy: &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:shared": policyv2.Owners{usernameOwner("user1@")}, + "tag:other": policyv2.Owners{usernameOwner("user1@")}, + }, + ACLs: []policyv2.ACL{ + // user2 self-access + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + // user2 can access tag:shared only + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:shared"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:shared")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + }, + spec: ScenarioSpec{ + NodesPerUser: 0, // manual creation for tagged node + Users: []string{"user1", "user2"}, + }, + setup: func(t *testing.T, scenario *Scenario, headscale ControlServer) (TailscaleClient, TailscaleClient, uint64) { + t.Helper() + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + // Create user1's node WITH tag:shared via PreAuthKey + taggedKey, err := scenario.CreatePreAuthKeyWithTags( + userMap["user1"].GetId(), false, false, []string{"tag:shared"}, + ) + require.NoError(t, err) + + user1Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + err = user1Node.Login(headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Create user2's node (untagged) + untaggedKey, err := scenario.CreatePreAuthKey(userMap["user2"].GetId(), false, false) + require.NoError(t, err) + + user2Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + err = user2Node.Login(headscale.GetEndpoint(), untaggedKey.GetKey()) + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + nodes, err := headscale.ListNodes("user1") + require.NoError(t, err) + + return user2Node, user1Node, nodes[0].GetId() + }, + initialAccess: true, // user2 can access user1 (has tag:shared) + tagChange: []string{"tag:other"}, // replace with tag:other + finalAccess: false, // user2 cannot access (no ACL for tag:other) + }, + { + name: "change-tag-changes-access", + policy: &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:team-a": policyv2.Owners{usernameOwner("user1@")}, + "tag:team-b": policyv2.Owners{usernameOwner("user1@")}, + }, + ACLs: []policyv2.ACL{ + // user2 self-access + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + // user2 can access tag:team-b only (NOT tag:team-a) + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:team-b"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:team-b")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + }, + spec: ScenarioSpec{ + NodesPerUser: 0, + Users: []string{"user1", "user2"}, + }, + setup: func(t *testing.T, scenario *Scenario, headscale ControlServer) (TailscaleClient, TailscaleClient, uint64) { + t.Helper() + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + // Create user1's node with tag:team-a (user2 has NO ACL for this) + taggedKey, err := scenario.CreatePreAuthKeyWithTags( + userMap["user1"].GetId(), false, false, []string{"tag:team-a"}, + ) + require.NoError(t, err) + + user1Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + err = user1Node.Login(headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Create user2's node + untaggedKey, err := scenario.CreatePreAuthKey(userMap["user2"].GetId(), false, false) + require.NoError(t, err) + + user2Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + err = user2Node.Login(headscale.GetEndpoint(), untaggedKey.GetKey()) + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + nodes, err := headscale.ListNodes("user1") + require.NoError(t, err) + + return user2Node, user1Node, nodes[0].GetId() + }, + initialAccess: false, // user2 cannot access (tag:team-a not in ACL) + tagChange: []string{"tag:team-b"}, // change to tag:team-b + finalAccess: true, // user2 can now access (tag:team-b in ACL) + }, + { + name: "multiple-tags-partial-removal", + policy: &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:web": policyv2.Owners{usernameOwner("user1@")}, + "tag:internal": policyv2.Owners{usernameOwner("user1@")}, + }, + ACLs: []policyv2.ACL{ + // user2 self-access + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + // user2 can access tag:web + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:web"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:web")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + }, + spec: ScenarioSpec{ + NodesPerUser: 0, + Users: []string{"user1", "user2"}, + }, + setup: func(t *testing.T, scenario *Scenario, headscale ControlServer) (TailscaleClient, TailscaleClient, uint64) { + t.Helper() + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + // Create user1's node with BOTH tags + taggedKey, err := scenario.CreatePreAuthKeyWithTags( + userMap["user1"].GetId(), false, false, []string{"tag:web", "tag:internal"}, + ) + require.NoError(t, err) + + user1Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + err = user1Node.Login(headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Create user2's node + untaggedKey, err := scenario.CreatePreAuthKey(userMap["user2"].GetId(), false, false) + require.NoError(t, err) + + user2Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + err = user2Node.Login(headscale.GetEndpoint(), untaggedKey.GetKey()) + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + nodes, err := headscale.ListNodes("user1") + require.NoError(t, err) + + return user2Node, user1Node, nodes[0].GetId() + }, + initialAccess: true, // user2 can access (has tag:web) + tagChange: []string{"tag:internal"}, // remove tag:web, keep tag:internal + finalAccess: false, // user2 cannot access (no ACL for tag:internal) + }, + { + name: "tag-change-updates-peer-identity", + policy: &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:server": policyv2.Owners{usernameOwner("user1@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:server"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:server")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + }, + spec: ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + }, + setup: func(t *testing.T, scenario *Scenario, headscale ControlServer) (TailscaleClient, TailscaleClient, uint64) { + t.Helper() + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + + nodes, err := headscale.ListNodes("user1") + require.NoError(t, err) + + return user2Clients[0], user1Clients[0], nodes[0].GetId() + }, + initialAccess: false, // user2 cannot access user1 (no tag yet) + tagChange: []string{"tag:server"}, // assign tag:server + finalAccess: true, // user2 can now access via tag:server + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + scenario, err := NewScenario(tt.spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(tt.policy), + hsic.WithTestName("acl-tag-"+tt.name), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Run test-specific setup + sourceClient, targetClient, targetNodeID := tt.setup(t, scenario, headscale) + + targetFQDN, err := targetClient.FQDN() + require.NoError(t, err) + + targetURL := fmt.Sprintf("http://%s/etc/hostname", targetFQDN) + + // Step 1: Verify initial access state + t.Logf("Step 1: Verifying initial access (expect success=%v)", tt.initialAccess) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := sourceClient.Curl(targetURL) + if tt.initialAccess { + assert.NoError(c, err, "Initial access should succeed") + assert.NotEmpty(c, result, "Initial access should return content") + } else { + assert.Error(c, err, "Initial access should fail") + } + }, 30*time.Second, 500*time.Millisecond, "verifying initial access state") + + // Step 1b: Verify initial NetMap visibility + t.Logf("Step 1b: Verifying initial NetMap visibility (expect visible=%v)", tt.initialAccess) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := sourceClient.Status() + assert.NoError(c, err) + + targetHostname := targetClient.Hostname() + found := false + + for _, peer := range status.Peer { + if strings.Contains(peer.HostName, targetHostname) { + found = true + break + } + } + + if tt.initialAccess { + assert.True(c, found, "Target should be visible in NetMap initially") + } else { + assert.False(c, found, "Target should NOT be visible in NetMap initially") + } + }, 30*time.Second, 500*time.Millisecond, "verifying initial NetMap visibility") + + // Step 2: Apply tag change + t.Logf("Step 2: Setting tags on node %d to %v", targetNodeID, tt.tagChange) + err = headscale.SetNodeTags(targetNodeID, tt.tagChange) + require.NoError(t, err) + + // Verify tag was applied + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // List nodes by iterating through all users since tagged nodes may "move" + var node *v1.Node + + for _, user := range tt.spec.Users { + nodes, err := headscale.ListNodes(user) + if err != nil { + continue + } + + for _, n := range nodes { + if n.GetId() == targetNodeID { + node = n + break + } + } + } + // Also check nodes without user filter + if node == nil { + // Try listing all nodes + allNodes, _ := headscale.ListNodes("") + for _, n := range allNodes { + if n.GetId() == targetNodeID { + node = n + break + } + } + } + + assert.NotNil(c, node, "Node should still exist") + + if node != nil { + assert.ElementsMatch(c, tt.tagChange, node.GetTags(), "Tags should be updated") + } + }, 10*time.Second, 500*time.Millisecond, "verifying tag change applied") + + // Step 3: Verify final access state (this is the key test for #2389) + t.Logf("Step 3: Verifying final access after tag change (expect success=%v)", tt.finalAccess) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := sourceClient.Curl(targetURL) + if tt.finalAccess { + assert.NoError(c, err, "Final access should succeed after tag change") + assert.NotEmpty(c, result, "Final access should return content") + } else { + assert.Error(c, err, "Final access should fail after tag change") + } + }, 30*time.Second, 500*time.Millisecond, "verifying access propagated after tag change") + + // Step 3b: Verify final NetMap visibility + t.Logf("Step 3b: Verifying final NetMap visibility (expect visible=%v)", tt.finalAccess) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := sourceClient.Status() + assert.NoError(c, err) + + targetHostname := targetClient.Hostname() + found := false + + for _, peer := range status.Peer { + if strings.Contains(peer.HostName, targetHostname) { + found = true + break + } + } + + if tt.finalAccess { + assert.True(c, found, "Target should be visible in NetMap after tag change") + } else { + assert.False(c, found, "Target should NOT be visible in NetMap after tag change") + } + }, 60*time.Second, 500*time.Millisecond, "verifying NetMap visibility propagated after tag change") + + t.Logf("Test %s PASSED: Tag change propagated correctly", tt.name) + }) + } +} + +// TestACLTagPropagationPortSpecific validates that tag changes correctly update +// port-specific ACLs. When a tag change restricts access to specific ports, +// the peer should remain visible but only the allowed ports should be accessible. +func TestACLTagPropagationPortSpecific(t *testing.T) { + IntegrationSkip(t) + + // Policy: tag:webserver allows port 80, tag:sshonly allows port 22 + // When we change from tag:webserver to tag:sshonly, HTTP should fail but ping should still work + policy := &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:webserver": policyv2.Owners{usernameOwner("user1@")}, + "tag:sshonly": policyv2.Owners{usernameOwner("user1@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + // user2 can access tag:webserver on port 80 only + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:webserver"), tailcfg.PortRange{First: 80, Last: 80}), + }, + }, + // user2 can access tag:sshonly on port 22 only + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:sshonly"), tailcfg.PortRange{First: 22, Last: 22}), + }, + }, + // Allow ICMP for ping tests + { + Action: "accept", + Sources: []policyv2.Alias{usernamep("user2@")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(tagp("tag:webserver"), tailcfg.PortRangeAny), + aliasWithPorts(tagp("tag:sshonly"), tailcfg.PortRangeAny), + }, + Protocol: "icmp", + }, + // Return path + { + Action: "accept", + Sources: []policyv2.Alias{tagp("tag:webserver"), tagp("tag:sshonly")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(usernamep("user2@"), tailcfg.PortRangeAny), + }, + }, + }, + } + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("acl-tag-port-specific"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + // Create user1's node WITH tag:webserver + taggedKey, err := scenario.CreatePreAuthKeyWithTags( + userMap["user1"].GetId(), false, false, []string{"tag:webserver"}, + ) + require.NoError(t, err) + + user1Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; python3 -m http.server --bind :: 80 & tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + + err = user1Node.Login(headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + // Create user2's node + untaggedKey, err := scenario.CreatePreAuthKey(userMap["user2"].GetId(), false, false) + require.NoError(t, err) + + user2Node, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", "-c", + "/bin/sleep 3 ; apk add python3 curl ; update-ca-certificates ; tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + tsic.WithNetfilter("off"), + ) + require.NoError(t, err) + + err = user2Node.Login(headscale.GetEndpoint(), untaggedKey.GetKey()) + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + nodes, err := headscale.ListNodes("user1") + require.NoError(t, err) + + targetNodeID := nodes[0].GetId() + + targetFQDN, err := user1Node.FQDN() + require.NoError(t, err) + + targetURL := fmt.Sprintf("http://%s/etc/hostname", targetFQDN) + + // Step 1: Verify initial state - HTTP on port 80 should work with tag:webserver + t.Log("Step 1: Verifying HTTP access with tag:webserver (should succeed)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := user2Node.Curl(targetURL) + assert.NoError(c, err, "HTTP should work with tag:webserver") + assert.NotEmpty(c, result) + }, 30*time.Second, 500*time.Millisecond, "initial HTTP access with tag:webserver") + + // Step 2: Change tag from webserver to sshonly + t.Logf("Step 2: Changing tag from webserver to sshonly on node %d", targetNodeID) + err = headscale.SetNodeTags(targetNodeID, []string{"tag:sshonly"}) + require.NoError(t, err) + + // Step 3: Verify peer is still visible in NetMap (partial access, not full removal) + t.Log("Step 3: Verifying peer remains visible in NetMap after tag change") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := user2Node.Status() + assert.NoError(c, err) + + targetHostname := user1Node.Hostname() + found := false + + for _, peer := range status.Peer { + if strings.Contains(peer.HostName, targetHostname) { + found = true + break + } + } + + assert.True(c, found, "Peer should still be visible with tag:sshonly (port 22 access)") + }, 60*time.Second, 500*time.Millisecond, "peer visibility after tag change") + + // Step 4: Verify HTTP on port 80 now fails (tag:sshonly only allows port 22) + t.Log("Step 4: Verifying HTTP access is now blocked (tag:sshonly only allows port 22)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + _, err := user2Node.Curl(targetURL) + assert.Error(c, err, "HTTP should fail with tag:sshonly (only port 22 allowed)") + }, 60*time.Second, 500*time.Millisecond, "HTTP blocked after tag change to sshonly") + + t.Log("Test PASSED: Port-specific ACL changes propagated correctly") +} + +// TestACLGroupWithUnknownUser tests issue #2967 where a group containing +// a reference to a non-existent user should not break connectivity for +// valid users in the same group. The expected behavior is that unknown +// users are silently ignored during group resolution. +func TestACLGroupWithUnknownUser(t *testing.T) { + IntegrationSkip(t) + + // This test verifies that when a group contains a reference to a + // non-existent user (e.g., "nonexistent@"), the valid users in + // the group should still be able to connect to each other. + // + // Issue: https://github.com/juanfont/headscale/issues/2967 + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // Create a policy with a group that includes a non-existent user + // alongside valid users. The group should still work for valid users. + policy := &policyv2.Policy{ + Groups: policyv2.Groups{ + // This group contains a reference to "nonexistent@" which does not exist + policyv2.Group("group:test"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + policyv2.Username("nonexistent@"), // This user does not exist + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:test"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("acl-unknown-user"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + require.Len(t, user1Clients, 1) + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + require.Len(t, user2Clients, 1) + + user1 := user1Clients[0] + user2 := user2Clients[0] + + // Get FQDNs for connectivity test + user1FQDN, err := user1.FQDN() + require.NoError(t, err) + user2FQDN, err := user2.FQDN() + require.NoError(t, err) + + // Test that user1 can reach user2 (valid users should be able to communicate) + // This is the key assertion for issue #2967: valid users should work + // even if the group contains references to non-existent users. + t.Log("Testing connectivity: user1 -> user2 (should succeed despite unknown user in group)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should be able to reach user2") + assert.Len(c, result, 13, "expected hostname response") + }, 30*time.Second, 500*time.Millisecond, "user1 should reach user2") + + // Test that user2 can reach user1 (bidirectional) + t.Log("Testing connectivity: user2 -> user1 (should succeed despite unknown user in group)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should be able to reach user1") + assert.Len(c, result, 13, "expected hostname response") + }, 30*time.Second, 500*time.Millisecond, "user2 should reach user1") + + t.Log("Test PASSED: Valid users can communicate despite unknown user reference in group") +} + +// TestACLGroupAfterUserDeletion tests issue #2967 scenario where a user +// is deleted but their reference remains in an ACL group. The remaining +// valid users should still be able to communicate. +func TestACLGroupAfterUserDeletion(t *testing.T) { + IntegrationSkip(t) + + // This test verifies that when a user is deleted from headscale but + // their reference remains in an ACL group, the remaining valid users + // in the group should still be able to connect to each other. + // + // Issue: https://github.com/juanfont/headscale/issues/2967 + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2", "user3"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // Create a policy with a group containing all three users + policy := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:all"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + policyv2.Username("user3@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:all")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:all"), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("acl-deleted-user"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithPolicyMode(types.PolicyModeDB), // Use DB mode so policy persists after user deletion + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + require.Len(t, user1Clients, 1) + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + require.Len(t, user2Clients, 1) + + user3Clients, err := scenario.ListTailscaleClients("user3") + require.NoError(t, err) + require.Len(t, user3Clients, 1) + + user1 := user1Clients[0] + user2 := user2Clients[0] + + // Get FQDNs for connectivity test + user1FQDN, err := user1.FQDN() + require.NoError(t, err) + user2FQDN, err := user2.FQDN() + require.NoError(t, err) + + // Step 1: Verify initial connectivity - all users can reach each other + t.Log("Step 1: Verifying initial connectivity between all users") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should be able to reach user2 initially") + assert.Len(c, result, 13, "expected hostname response") + }, 30*time.Second, 500*time.Millisecond, "initial user1 -> user2 connectivity") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should be able to reach user1 initially") + assert.Len(c, result, 13, "expected hostname response") + }, 30*time.Second, 500*time.Millisecond, "initial user2 -> user1 connectivity") + + // Step 2: Get user3's node and user, then delete them + t.Log("Step 2: Deleting user3's node and user from headscale") + + // First, get user3's node ID + nodes, err := headscale.ListNodes("user3") + require.NoError(t, err) + require.Len(t, nodes, 1, "user3 should have exactly one node") + user3NodeID := nodes[0].GetId() + + // Delete user3's node first (required before deleting the user) + err = headscale.DeleteNode(user3NodeID) + require.NoError(t, err, "failed to delete user3's node") + + // Now get user3's user ID and delete the user + user3, err := GetUserByName(headscale, "user3") + require.NoError(t, err, "user3 should exist") + + // Now delete user3 (after their nodes are deleted) + err = headscale.DeleteUser(user3.GetId()) + require.NoError(t, err) + + // Verify user3 is deleted + _, err = GetUserByName(headscale, "user3") + require.Error(t, err, "user3 should be deleted") + + // Step 3: Verify that user1 and user2 can still communicate (before triggering policy refresh) + // The policy still references "user3@" in the group, but since user3 is deleted, + // connectivity may still work due to cached/stale policy state. + t.Log("Step 3: Verifying connectivity still works immediately after user3 deletion (stale cache)") + + // Test that user1 can still reach user2 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should still be able to reach user2 after user3 deletion (stale cache)") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user2 after user3 deletion") + + // Step 4: Create a NEW user - this triggers updatePolicyManagerUsers() which + // re-evaluates the policy. According to issue #2967, this is when the bug manifests: + // the deleted user3@ in the group causes the entire group to fail resolution. + t.Log("Step 4: Creating a new user (user4) to trigger policy re-evaluation") + + _, err = headscale.CreateUser("user4") + require.NoError(t, err, "failed to create user4") + + // Verify user4 was created + _, err = GetUserByName(headscale, "user4") + require.NoError(t, err, "user4 should exist after creation") + + // Step 5: THIS IS THE CRITICAL TEST - verify connectivity STILL works after + // creating a new user. Without the fix, the group containing the deleted user3@ + // would fail to resolve, breaking connectivity for user1 and user2. + t.Log("Step 5: Verifying connectivity AFTER creating new user (this triggers the bug)") + + // Test that user1 can still reach user2 AFTER the policy refresh triggered by user creation + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should still reach user2 after policy refresh (BUG if this fails)") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user2 after policy refresh (issue #2967)") + + // Test that user2 can still reach user1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should still reach user1 after policy refresh (BUG if this fails)") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user2 -> user1 after policy refresh (issue #2967)") + + t.Log("Test PASSED: Remaining users can communicate after deleted user and policy refresh") +} + +// TestACLGroupDeletionExactReproduction reproduces issue #2967 exactly as reported: +// The reporter had ACTIVE pinging between nodes while making changes. +// The bug is that deleting a user and then creating a new user causes +// connectivity to break for remaining users in the group. +// +// Key difference from other tests: We keep multiple nodes ACTIVE and pinging +// each other throughout the test, just like the reporter's scenario. +// +// Reporter's steps (v0.28.0-beta.1): +// 1. Start pinging between nodes +// 2. Create policy with group:admin = [user1@] +// 3. Create users "deleteable" and "existinguser" +// 4. Add deleteable@ to ACL: Pinging continues +// 5. Delete deleteable: Pinging continues +// 6. Add existinguser@ to ACL: Pinging continues +// 7. Create new user "anotheruser": Pinging continues +// 8. Add anotherinvaliduser@ to ACL: Pinging stops. +func TestACLGroupDeletionExactReproduction(t *testing.T) { + IntegrationSkip(t) + + // Issue: https://github.com/juanfont/headscale/issues/2967 + + const userToDelete = "user2" + + // We need 3 users with active nodes to properly test this: + // - user1: will remain throughout (like "ritty" in the issue) + // - user2: will be deleted (like "deleteable" in the issue) + // - user3: will remain and should still be able to ping user1 after user2 deletion + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", userToDelete, "user3"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) + + // Initial policy: all three users in group, can communicate with each other + initialPolicy := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:admin"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username(userToDelete + "@"), + policyv2.Username("user3@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:admin")}, + Destinations: []policyv2.AliasWithPorts{ + // Use *:* like the reporter's ACL + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(initialPolicy), + hsic.WithTestName("acl-exact-repro"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Get all clients + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + require.Len(t, user1Clients, 1) + user1 := user1Clients[0] + + user3Clients, err := scenario.ListTailscaleClients("user3") + require.NoError(t, err) + require.Len(t, user3Clients, 1) + user3 := user3Clients[0] + + user1FQDN, err := user1.FQDN() + require.NoError(t, err) + user3FQDN, err := user3.FQDN() + require.NoError(t, err) + + // Step 1: Verify initial connectivity - user1 and user3 can ping each other + t.Log("Step 1: Verifying initial connectivity (user1 <-> user3)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user3FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should reach user3") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user3") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user3.Curl(url) + assert.NoError(c, err, "user3 should reach user1") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user3 -> user1") + + t.Log("Step 1: PASSED - initial connectivity works") + + // Step 2: Delete user2's node and user (like reporter deleting "deleteable") + // The ACL still references user2@ but user2 no longer exists + t.Log("Step 2: Deleting user2 (node + user) from database - ACL still references user2@") + + nodes, err := headscale.ListNodes(userToDelete) + require.NoError(t, err) + require.Len(t, nodes, 1) + err = headscale.DeleteNode(nodes[0].GetId()) + require.NoError(t, err) + + userToDeleteObj, err := GetUserByName(headscale, userToDelete) + require.NoError(t, err, "user to delete should exist") + + err = headscale.DeleteUser(userToDeleteObj.GetId()) + require.NoError(t, err) + + t.Log("Step 2: DONE - user2 deleted, ACL still has user2@ reference") + + // Step 3: Verify connectivity still works after user2 deletion + // This tests the immediate effect of the fix - policy should be updated + t.Log("Step 3: Verifying connectivity STILL works after user2 deletion") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user3FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should still reach user3 after user2 deletion") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user3 after user2 deletion") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user3.Curl(url) + assert.NoError(c, err, "user3 should still reach user1 after user2 deletion") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user3 -> user1 after user2 deletion") + + t.Log("Step 3: PASSED - connectivity works after user2 deletion") + + // Step 4: Create a NEW user - this triggers updatePolicyManagerUsers() + // According to the reporter, this is when the bug manifests + t.Log("Step 4: Creating new user (user4) - this triggers policy re-evaluation") + + _, err = headscale.CreateUser("user4") + require.NoError(t, err) + + // Step 5: THE CRITICAL TEST - verify connectivity STILL works + // Without the fix: DeleteUser didn't update policy, so when CreateUser + // triggers updatePolicyManagerUsers(), the stale user2@ is now unknown, + // potentially breaking the group. + t.Log("Step 5: Verifying connectivity AFTER creating new user (BUG trigger point)") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user3FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "BUG #2967: user1 should still reach user3 after user4 creation") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user3 after user4 creation (issue #2967)") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user3.Curl(url) + assert.NoError(c, err, "BUG #2967: user3 should still reach user1 after user4 creation") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user3 -> user1 after user4 creation (issue #2967)") + + // Additional verification: check filter rules are not empty + filter, err := headscale.DebugFilter() + require.NoError(t, err) + t.Logf("Filter rules: %d", len(filter)) + require.NotEmpty(t, filter, "Filter rules should not be empty") + + t.Log("Test PASSED: Connectivity maintained throughout user deletion and creation") + t.Log("Issue #2967 would cause 'pinging to stop' at Step 5") +} + +// TestACLDynamicUnknownUserAddition tests the v0.28.0-beta.1 scenario from issue #2967: +// "Pinging still stops when a non-registered user is added to a group" +// +// This test verifies that when a policy is DYNAMICALLY updated (via SetPolicy) +// to include a non-existent user in a group, connectivity for valid users +// is maintained. The v2 policy engine should gracefully handle unknown users. +// +// Steps: +// 1. Start with a valid policy (only existing users in group) +// 2. Verify connectivity works +// 3. Update policy to add unknown user to the group +// 4. Verify connectivity STILL works for valid users. +func TestACLDynamicUnknownUserAddition(t *testing.T) { + IntegrationSkip(t) + + // Issue: https://github.com/juanfont/headscale/issues/2967 + // Comment: "Pinging still stops when a non-registered user is added to a group" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) + + // Start with a VALID policy - only existing users in the group + validPolicy := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:test"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(validPolicy), + hsic.WithTestName("acl-dynamic-unknown"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + require.Len(t, user1Clients, 1) + user1 := user1Clients[0] + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + require.Len(t, user2Clients, 1) + user2 := user2Clients[0] + + user1FQDN, err := user1.FQDN() + require.NoError(t, err) + user2FQDN, err := user2.FQDN() + require.NoError(t, err) + + // Step 1: Verify initial connectivity with VALID policy + t.Log("Step 1: Verifying initial connectivity with valid policy (no unknown users)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should reach user2") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "initial user1 -> user2") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should reach user1") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "initial user2 -> user1") + + t.Log("Step 1: PASSED - connectivity works with valid policy") + + // Step 2: DYNAMICALLY update policy to add unknown user + // This mimics the v0.28.0-beta.1 scenario where a non-existent user is added + t.Log("Step 2: Updating policy to add unknown user (nonexistent@) to the group") + + policyWithUnknown := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:test"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + policyv2.Username("nonexistent@"), // Added unknown user + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = headscale.SetPolicy(policyWithUnknown) + require.NoError(t, err) + + // Wait for policy to propagate + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + // Step 3: THE CRITICAL TEST - verify connectivity STILL works + // v0.28.0-beta.1 issue: "Pinging still stops when a non-registered user is added to a group" + // With v2 policy graceful error handling, this should pass + t.Log("Step 3: Verifying connectivity AFTER adding unknown user to policy") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should STILL reach user2 after adding unknown user") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user2 after unknown user added") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should STILL reach user1 after adding unknown user") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user2 -> user1 after unknown user added") + + t.Log("Step 3: PASSED - connectivity maintained after adding unknown user") + t.Log("Test PASSED: v0.28.0-beta.1 scenario - unknown user added dynamically, valid users still work") +} + +// TestACLDynamicUnknownUserRemoval tests the scenario from issue #2967 comments: +// "Removing all invalid users from ACL restores connectivity" +// +// This test verifies that: +// 1. Start with a policy containing unknown user +// 2. Connectivity still works (v2 graceful handling) +// 3. Update policy to remove unknown user +// 4. Connectivity remains working +// +// This ensures the fix handles both: +// - Adding unknown users (tested above) +// - Removing unknown users from policy. +func TestACLDynamicUnknownUserRemoval(t *testing.T) { + IntegrationSkip(t) + + // Issue: https://github.com/juanfont/headscale/issues/2967 + // Comment: "Removing all invalid users from ACL restores connectivity" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // Start with a policy that INCLUDES an unknown user + policyWithUnknown := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:test"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + policyv2.Username("invaliduser@"), // Unknown user from the start + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithNetfilter("off"), + tsic.WithPackages("curl"), + tsic.WithWebserver(80), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(policyWithUnknown), + hsic.WithTestName("acl-unknown-removal"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + require.NoError(t, err) + + _, err = scenario.ListTailscaleClientsFQDNs() + require.NoError(t, err) + + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + user1Clients, err := scenario.ListTailscaleClients("user1") + require.NoError(t, err) + require.Len(t, user1Clients, 1) + user1 := user1Clients[0] + + user2Clients, err := scenario.ListTailscaleClients("user2") + require.NoError(t, err) + require.Len(t, user2Clients, 1) + user2 := user2Clients[0] + + user1FQDN, err := user1.FQDN() + require.NoError(t, err) + user2FQDN, err := user2.FQDN() + require.NoError(t, err) + + // Step 1: Verify initial connectivity WITH unknown user in policy + // With v2 graceful handling, this should work + t.Log("Step 1: Verifying connectivity with unknown user in policy (v2 graceful handling)") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should reach user2 even with unknown user in policy") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "initial user1 -> user2 with unknown") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should reach user1 even with unknown user in policy") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "initial user2 -> user1 with unknown") + + t.Log("Step 1: PASSED - connectivity works even with unknown user (v2 graceful handling)") + + // Step 2: Update policy to REMOVE the unknown user + t.Log("Step 2: Updating policy to remove unknown user") + + cleanPolicy := &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:test"): []policyv2.Username{ + policyv2.Username("user1@"), + policyv2.Username("user2@"), + // invaliduser@ removed + }, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:test")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + } + + err = headscale.SetPolicy(cleanPolicy) + require.NoError(t, err) + + // Wait for policy to propagate + err = scenario.WaitForTailscaleSync() + require.NoError(t, err) + + // Step 3: Verify connectivity after removing unknown user + // Issue comment: "Removing all invalid users from ACL restores connectivity" + t.Log("Step 3: Verifying connectivity AFTER removing unknown user") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user2FQDN) + result, err := user1.Curl(url) + assert.NoError(c, err, "user1 should reach user2 after removing unknown user") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user1 -> user2 after unknown removed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + url := fmt.Sprintf("http://%s/etc/hostname", user1FQDN) + result, err := user2.Curl(url) + assert.NoError(c, err, "user2 should reach user1 after removing unknown user") + assert.Len(c, result, 13, "expected hostname response") + }, 60*time.Second, 500*time.Millisecond, "user2 -> user1 after unknown removed") + + t.Log("Step 3: PASSED - connectivity maintained after removing unknown user") + t.Log("Test PASSED: Removing unknown users from policy works correctly") +} diff --git a/integration/api_auth_test.go b/integration/api_auth_test.go new file mode 100644 index 00000000..df5f2455 --- /dev/null +++ b/integration/api_auth_test.go @@ -0,0 +1,657 @@ +package integration + +import ( + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "testing" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/encoding/protojson" +) + +// TestAPIAuthenticationBypass tests that the API authentication middleware +// properly blocks unauthorized requests and does not leak sensitive data. +// This test reproduces the security issue described in: +// - https://github.com/juanfont/headscale/issues/2809 +// - https://github.com/juanfont/headscale/pull/2810 +// +// The bug: When authentication fails, the middleware writes "Unauthorized" +// but doesn't return early, allowing the handler to execute and append +// sensitive data to the response. +func TestAPIAuthenticationBypass(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + Users: []string{"user1", "user2", "user3"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("apiauthbypass")) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Create an API key using the CLI + var validAPIKey string + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + apiKeyOutput, err := headscale.Execute( + []string{ + "headscale", + "apikeys", + "create", + "--expiration", + "24h", + }, + ) + assert.NoError(ct, err) + assert.NotEmpty(ct, apiKeyOutput) + validAPIKey = strings.TrimSpace(apiKeyOutput) + }, 20*time.Second, 1*time.Second) + + // Get the API endpoint + endpoint := headscale.GetEndpoint() + apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + + // Create HTTP client + client := &http.Client{ + Timeout: 10 * time.Second, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec + }, + } + + t.Run("HTTP_NoAuthHeader", func(t *testing.T) { + // Test 1: Request without any Authorization header + // Expected: Should return 401 with ONLY "Unauthorized" text, no user data + req, err := http.NewRequest("GET", apiURL, nil) + require.NoError(t, err) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Should return 401 Unauthorized + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "Expected 401 status code for request without auth header") + + bodyStr := string(body) + + // Should contain "Unauthorized" message + assert.Contains(t, bodyStr, "Unauthorized", + "Response should contain 'Unauthorized' message") + + // Should NOT contain user data after "Unauthorized" + // This is the security bypass - if users array is present, auth was bypassed + var jsonCheck map[string]any + jsonErr := json.Unmarshal(body, &jsonCheck) + + // If we can unmarshal JSON and it contains "users", that's the bypass + if jsonErr == nil { + assert.NotContains(t, jsonCheck, "users", + "SECURITY ISSUE: Response should NOT contain 'users' data when unauthorized") + assert.NotContains(t, jsonCheck, "user", + "SECURITY ISSUE: Response should NOT contain 'user' data when unauthorized") + } + + // Additional check: response should not contain "user1", "user2", "user3" + assert.NotContains(t, bodyStr, "user1", + "SECURITY ISSUE: Response should NOT leak user 'user1' data") + assert.NotContains(t, bodyStr, "user2", + "SECURITY ISSUE: Response should NOT leak user 'user2' data") + assert.NotContains(t, bodyStr, "user3", + "SECURITY ISSUE: Response should NOT leak user 'user3' data") + + // Response should be minimal, just "Unauthorized" + // Allow some variation in response format but body should be small + assert.Less(t, len(bodyStr), 100, + "SECURITY ISSUE: Unauthorized response body should be minimal, got: %s", bodyStr) + }) + + t.Run("HTTP_InvalidAuthHeader", func(t *testing.T) { + // Test 2: Request with invalid Authorization header (missing "Bearer " prefix) + // Expected: Should return 401 with ONLY "Unauthorized" text, no user data + req, err := http.NewRequest("GET", apiURL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "InvalidToken") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "Expected 401 status code for invalid auth header format") + + bodyStr := string(body) + assert.Contains(t, bodyStr, "Unauthorized") + + // Should not leak user data + assert.NotContains(t, bodyStr, "user1", + "SECURITY ISSUE: Response should NOT leak user data") + assert.NotContains(t, bodyStr, "user2", + "SECURITY ISSUE: Response should NOT leak user data") + assert.NotContains(t, bodyStr, "user3", + "SECURITY ISSUE: Response should NOT leak user data") + + assert.Less(t, len(bodyStr), 100, + "SECURITY ISSUE: Unauthorized response should be minimal") + }) + + t.Run("HTTP_InvalidBearerToken", func(t *testing.T) { + // Test 3: Request with Bearer prefix but invalid token + // Expected: Should return 401 with ONLY "Unauthorized" text, no user data + // Note: Both malformed and properly formatted invalid tokens should return 401 + req, err := http.NewRequest("GET", apiURL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer invalid-token-12345") + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + assert.Equal(t, http.StatusUnauthorized, resp.StatusCode, + "Expected 401 status code for invalid bearer token") + + bodyStr := string(body) + assert.Contains(t, bodyStr, "Unauthorized") + + // Should not leak user data + assert.NotContains(t, bodyStr, "user1", + "SECURITY ISSUE: Response should NOT leak user data") + assert.NotContains(t, bodyStr, "user2", + "SECURITY ISSUE: Response should NOT leak user data") + assert.NotContains(t, bodyStr, "user3", + "SECURITY ISSUE: Response should NOT leak user data") + + assert.Less(t, len(bodyStr), 100, + "SECURITY ISSUE: Unauthorized response should be minimal") + }) + + t.Run("HTTP_ValidAPIKey", func(t *testing.T) { + // Test 4: Request with valid API key + // Expected: Should return 200 with user data (this is the authorized case) + req, err := http.NewRequest("GET", apiURL, nil) + require.NoError(t, err) + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", validAPIKey)) + + resp, err := client.Do(req) + require.NoError(t, err) + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + require.NoError(t, err) + + // Should succeed with valid auth + assert.Equal(t, http.StatusOK, resp.StatusCode, + "Expected 200 status code with valid API key") + + // Should be able to parse as protobuf JSON + var response v1.ListUsersResponse + err = protojson.Unmarshal(body, &response) + assert.NoError(t, err, "Response should be valid protobuf JSON with valid API key") + + // Should contain our test users + users := response.GetUsers() + assert.Len(t, users, 3, "Should have 3 users") + userNames := make([]string, len(users)) + for i, u := range users { + userNames[i] = u.GetName() + } + assert.Contains(t, userNames, "user1") + assert.Contains(t, userNames, "user2") + assert.Contains(t, userNames, "user3") + }) +} + +// TestAPIAuthenticationBypassCurl tests the same security issue using curl +// from inside a container, which is closer to how the issue was discovered. +func TestAPIAuthenticationBypassCurl(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + Users: []string{"testuser1", "testuser2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("apiauthcurl")) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Create a valid API key + apiKeyOutput, err := headscale.Execute( + []string{ + "headscale", + "apikeys", + "create", + "--expiration", + "24h", + }, + ) + require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) + + endpoint := headscale.GetEndpoint() + apiURL := fmt.Sprintf("%s/api/v1/user", endpoint) + + t.Run("Curl_NoAuth", func(t *testing.T) { + // Execute curl from inside the headscale container without auth + curlOutput, err := headscale.Execute( + []string{ + "curl", + "-s", + "-w", + "\nHTTP_CODE:%{http_code}", + apiURL, + }, + ) + require.NoError(t, err) + + // Parse the output + lines := strings.Split(curlOutput, "\n") + var httpCode string + var responseBody string + + for _, line := range lines { + if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { + httpCode = after + } else { + responseBody += line + } + } + + // Should return 401 + assert.Equal(t, "401", httpCode, + "Curl without auth should return 401") + + // Should contain Unauthorized + assert.Contains(t, responseBody, "Unauthorized", + "Response should contain 'Unauthorized'") + + // Should NOT leak user data + assert.NotContains(t, responseBody, "testuser1", + "SECURITY ISSUE: Should not leak user data") + assert.NotContains(t, responseBody, "testuser2", + "SECURITY ISSUE: Should not leak user data") + + // Response should be small (just "Unauthorized") + assert.Less(t, len(responseBody), 100, + "SECURITY ISSUE: Unauthorized response should be minimal, got: %s", responseBody) + }) + + t.Run("Curl_InvalidAuth", func(t *testing.T) { + // Execute curl with invalid auth header + curlOutput, err := headscale.Execute( + []string{ + "curl", + "-s", + "-H", + "Authorization: InvalidToken", + "-w", + "\nHTTP_CODE:%{http_code}", + apiURL, + }, + ) + require.NoError(t, err) + + lines := strings.Split(curlOutput, "\n") + var httpCode string + var responseBody string + + for _, line := range lines { + if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { + httpCode = after + } else { + responseBody += line + } + } + + assert.Equal(t, "401", httpCode) + assert.Contains(t, responseBody, "Unauthorized") + assert.NotContains(t, responseBody, "testuser1", + "SECURITY ISSUE: Should not leak user data") + assert.NotContains(t, responseBody, "testuser2", + "SECURITY ISSUE: Should not leak user data") + }) + + t.Run("Curl_ValidAuth", func(t *testing.T) { + // Execute curl with valid API key + curlOutput, err := headscale.Execute( + []string{ + "curl", + "-s", + "-H", + fmt.Sprintf("Authorization: Bearer %s", validAPIKey), + "-w", + "\nHTTP_CODE:%{http_code}", + apiURL, + }, + ) + require.NoError(t, err) + + lines := strings.Split(curlOutput, "\n") + var httpCode string + var responseBody string + + for _, line := range lines { + if after, ok := strings.CutPrefix(line, "HTTP_CODE:"); ok { + httpCode = after + } else { + responseBody += line + } + } + + // Should succeed + assert.Equal(t, "200", httpCode, + "Curl with valid API key should return 200") + + // Should contain user data + var response v1.ListUsersResponse + err = protojson.Unmarshal([]byte(responseBody), &response) + assert.NoError(t, err, "Response should be valid protobuf JSON") + users := response.GetUsers() + assert.Len(t, users, 2, "Should have 2 users") + }) +} + +// TestGRPCAuthenticationBypass tests that the gRPC authentication interceptor +// properly blocks unauthorized requests. +// This test verifies that the gRPC API does not have the same bypass issue +// as the HTTP API middleware. +func TestGRPCAuthenticationBypass(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + Users: []string{"grpcuser1", "grpcuser2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // We need TLS for remote gRPC connections + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("grpcauthtest"), + hsic.WithTLS(), + hsic.WithConfigEnv(map[string]string{ + // Enable gRPC on the standard port + "HEADSCALE_GRPC_LISTEN_ADDR": "0.0.0.0:50443", + }), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Create a valid API key + apiKeyOutput, err := headscale.Execute( + []string{ + "headscale", + "apikeys", + "create", + "--expiration", + "24h", + }, + ) + require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) + + // Get the gRPC endpoint + // For gRPC, we need to use the hostname and port 50443 + grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + + t.Run("gRPC_NoAPIKey", func(t *testing.T) { + // Test 1: Try to use CLI without API key (should fail) + // When HEADSCALE_CLI_ADDRESS is set but HEADSCALE_CLI_API_KEY is not set, + // the CLI should fail immediately + _, err := headscale.Execute( + []string{ + "sh", "-c", + fmt.Sprintf("HEADSCALE_CLI_ADDRESS=%s HEADSCALE_CLI_INSECURE=true headscale users list --output json 2>&1", grpcAddress), + }, + ) + + // Should fail - CLI exits when API key is missing + assert.Error(t, err, + "gRPC connection without API key should fail") + }) + + t.Run("gRPC_InvalidAPIKey", func(t *testing.T) { + // Test 2: Try to use CLI with invalid API key (should fail with auth error) + output, err := headscale.Execute( + []string{ + "sh", "-c", + fmt.Sprintf("HEADSCALE_CLI_ADDRESS=%s HEADSCALE_CLI_API_KEY=invalid-key-12345 HEADSCALE_CLI_INSECURE=true headscale users list --output json 2>&1", grpcAddress), + }, + ) + + // Should fail with authentication error + assert.Error(t, err, + "gRPC connection with invalid API key should fail") + + // Should contain authentication error message + outputStr := strings.ToLower(output) + assert.True(t, + strings.Contains(outputStr, "unauthenticated") || + strings.Contains(outputStr, "invalid token") || + strings.Contains(outputStr, "failed to validate token") || + strings.Contains(outputStr, "authentication"), + "Error should indicate authentication failure, got: %s", output) + + // Should NOT leak user data + assert.NotContains(t, output, "grpcuser1", + "SECURITY ISSUE: gRPC should not leak user data with invalid auth") + assert.NotContains(t, output, "grpcuser2", + "SECURITY ISSUE: gRPC should not leak user data with invalid auth") + }) + + t.Run("gRPC_ValidAPIKey", func(t *testing.T) { + // Test 3: Use CLI with valid API key (should succeed) + output, err := headscale.Execute( + []string{ + "sh", "-c", + fmt.Sprintf("HEADSCALE_CLI_ADDRESS=%s HEADSCALE_CLI_API_KEY=%s HEADSCALE_CLI_INSECURE=true headscale users list --output json", grpcAddress, validAPIKey), + }, + ) + + // Should succeed + assert.NoError(t, err, + "gRPC connection with valid API key should succeed, output: %s", output) + + // CLI outputs the users array directly, not wrapped in ListUsersResponse + // Parse as JSON array (CLI uses json.Marshal, not protojson) + var users []*v1.User + err = json.Unmarshal([]byte(output), &users) + assert.NoError(t, err, "Response should be valid JSON array") + assert.Len(t, users, 2, "Should have 2 users") + + userNames := make([]string, len(users)) + for i, u := range users { + userNames[i] = u.GetName() + } + assert.Contains(t, userNames, "grpcuser1") + assert.Contains(t, userNames, "grpcuser2") + }) +} + +// TestCLIWithConfigAuthenticationBypass tests that the headscale CLI +// with --config flag does not have authentication bypass issues when +// connecting to a remote server. +// Note: When using --config with local unix socket, no auth is needed. +// This test focuses on remote gRPC connections which require API keys. +func TestCLIWithConfigAuthenticationBypass(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + Users: []string{"cliuser1", "cliuser2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("cliconfigauth"), + hsic.WithTLS(), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_GRPC_LISTEN_ADDR": "0.0.0.0:50443", + }), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Create a valid API key + apiKeyOutput, err := headscale.Execute( + []string{ + "headscale", + "apikeys", + "create", + "--expiration", + "24h", + }, + ) + require.NoError(t, err) + validAPIKey := strings.TrimSpace(apiKeyOutput) + + grpcAddress := fmt.Sprintf("%s:50443", headscale.GetHostname()) + + // Create a config file for testing + configWithoutKey := fmt.Sprintf(` +cli: + address: %s + timeout: 5s + insecure: true +`, grpcAddress) + + configWithInvalidKey := fmt.Sprintf(` +cli: + address: %s + api_key: invalid-key-12345 + timeout: 5s + insecure: true +`, grpcAddress) + + configWithValidKey := fmt.Sprintf(` +cli: + address: %s + api_key: %s + timeout: 5s + insecure: true +`, grpcAddress, validAPIKey) + + t.Run("CLI_Config_NoAPIKey", func(t *testing.T) { + // Create config file without API key + err := headscale.WriteFile("/tmp/config_no_key.yaml", []byte(configWithoutKey)) + require.NoError(t, err) + + // Try to use CLI with config that has no API key + _, err = headscale.Execute( + []string{ + "headscale", + "--config", "/tmp/config_no_key.yaml", + "users", "list", + "--output", "json", + }, + ) + + // Should fail + assert.Error(t, err, + "CLI with config missing API key should fail") + }) + + t.Run("CLI_Config_InvalidAPIKey", func(t *testing.T) { + // Create config file with invalid API key + err := headscale.WriteFile("/tmp/config_invalid_key.yaml", []byte(configWithInvalidKey)) + require.NoError(t, err) + + // Try to use CLI with invalid API key + output, err := headscale.Execute( + []string{ + "sh", "-c", + "headscale --config /tmp/config_invalid_key.yaml users list --output json 2>&1", + }, + ) + + // Should fail + assert.Error(t, err, + "CLI with invalid API key should fail") + + // Should indicate authentication failure + outputStr := strings.ToLower(output) + assert.True(t, + strings.Contains(outputStr, "unauthenticated") || + strings.Contains(outputStr, "invalid token") || + strings.Contains(outputStr, "failed to validate token") || + strings.Contains(outputStr, "authentication"), + "Error should indicate authentication failure, got: %s", output) + + // Should NOT leak user data + assert.NotContains(t, output, "cliuser1", + "SECURITY ISSUE: CLI should not leak user data with invalid auth") + assert.NotContains(t, output, "cliuser2", + "SECURITY ISSUE: CLI should not leak user data with invalid auth") + }) + + t.Run("CLI_Config_ValidAPIKey", func(t *testing.T) { + // Create config file with valid API key + err := headscale.WriteFile("/tmp/config_valid_key.yaml", []byte(configWithValidKey)) + require.NoError(t, err) + + // Use CLI with valid API key + output, err := headscale.Execute( + []string{ + "headscale", + "--config", "/tmp/config_valid_key.yaml", + "users", "list", + "--output", "json", + }, + ) + + // Should succeed + assert.NoError(t, err, + "CLI with valid API key should succeed") + + // CLI outputs the users array directly, not wrapped in ListUsersResponse + // Parse as JSON array (CLI uses json.Marshal, not protojson) + var users []*v1.User + err = json.Unmarshal([]byte(output), &users) + assert.NoError(t, err, "Response should be valid JSON array") + assert.Len(t, users, 2, "Should have 2 users") + + userNames := make([]string, len(users)) + for i, u := range users { + userNames[i] = u.GetName() + } + assert.Contains(t, userNames, "cliuser1") + assert.Contains(t, userNames, "cliuser2") + }) +} diff --git a/integration/auth_key_test.go b/integration/auth_key_test.go new file mode 100644 index 00000000..ba6a195b --- /dev/null +++ b/integration/auth_key_test.go @@ -0,0 +1,739 @@ +package integration + +import ( + "fmt" + "net/netip" + "slices" + "strconv" + "testing" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +func TestAuthKeyLogoutAndReloginSameUser(t *testing.T) { + IntegrationSkip(t) + + for _, https := range []bool{true, false} { + t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + opts := []hsic.Option{ + hsic.WithTestName("pingallbyip"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + } + if https { + opts = append(opts, []hsic.Option{ + hsic.WithTLS(), + }...) + } + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + expectedNodes := collectExpectedNodeIDs(t, allClients) + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected", 120*time.Second) + + // Validate that all nodes have NetInfo and DERP servers before logout + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP before logout", 3*time.Minute) + + // assertClientsState(t, allClients) + + clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + clientIPs[client] = ips + } + + var listNodes []*v1.Node + var nodeCountBeforeLogout int + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, len(allClients)) + + for _, node := range listNodes { + assertLastSeenSetWithCollect(c, node) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout") + + nodeCountBeforeLogout = len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + for _, client := range allClients { + err := client.Logout() + if err != nil { + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + err = scenario.WaitForTailscaleLogout() + requireNoErrLogout(t, err) + + // After taking down all nodes, verify all systems show nodes offline + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should have logged out", 120*time.Second) + + t.Logf("all clients logged out") + + t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after logout") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should match before logout count - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node persistence after logout (nodes should remain in database)") + + for _, node := range listNodes { + assertLastSeenSet(t, node) + } + + // if the server is not running with HTTPS, we have to wait a bit before + // reconnection as the newest Tailscale client has a measure that will only + // reconnect over HTTPS if they saw a noise connection previously. + // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 + // https://github.com/juanfont/headscale/issues/2164 + if !https { + //nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS + time.Sleep(5 * time.Minute) + } + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + for _, userName := range spec.Users { + key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + t.Logf("Validating node persistence after relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after relogin") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after relogin - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 60*time.Second, 2*time.Second, "validating node count stability after same-user auth key relogin") + + for _, node := range listNodes { + assertLastSeenSet(t, node) + } + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 120*time.Second) + + // Wait for Tailscale sync before validating NetInfo to ensure proper state propagation + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // Validate that all nodes have NetInfo and DERP servers after reconnection + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after reconnection", 3*time.Minute) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + + // lets check if the IPs are the same + if len(ips) != len(clientIPs[client]) { + t.Fatalf("IPs changed for client %s", client.Hostname()) + } + + for _, ip := range ips { + if !slices.Contains(clientIPs[client], ip) { + t.Fatalf( + "IPs changed for client %s. Used to be %v now %v", + client.Hostname(), + clientIPs[client], + ips, + ) + } + } + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, nodeCountBeforeLogout) + + for _, node := range listNodes { + assertLastSeenSetWithCollect(c, node) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for node list after relogin") + }) + } +} + +// This test will first log in two sets of nodes to two sets of users, then +// it will log out all nodes and log them in as user1 using a pre-auth key. +// This should create new nodes for user1 while preserving the original nodes for user2. +// Pre-auth key re-authentication with a different user creates new nodes, not transfers. +func TestAuthKeyLogoutAndReloginNewUser(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, + hsic.WithTestName("keyrelognewuser"), + hsic.WithTLS(), + hsic.WithDERPAsIP(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + + var listNodes []*v1.Node + var nodeCountBeforeLogout int + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, len(allClients)) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout") + + nodeCountBeforeLogout = len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + for _, client := range allClients { + err := client.Logout() + if err != nil { + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + err = scenario.WaitForTailscaleLogout() + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) + + t.Logf("all clients logged out") + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + // Create a new authkey for user1, to be used for all clients + key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user1: %s", err) + } + + // Log in all clients as user1, iterating over the spec only returns the + // clients, not the usernames. + for _, userName := range spec.Users { + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + user1Nodes, err = headscale.ListNodes("user1") + assert.NoError(ct, err, "Failed to list nodes for user1 after relogin") + assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after relogin, got %d nodes", len(allClients), len(user1Nodes)) + }, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after auth key relogin") + + // Collect expected node IDs for user1 after relogin + expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes)) + for _, node := range user1Nodes { + expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId())) + } + + // Validate connection state after relogin as user1 + requireAllClientsOnline(t, headscale, expectedUser1Nodes, true, "all user1 nodes should be connected after relogin", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedUser1Nodes, "all user1 nodes should have NetInfo and DERP after relogin", 3*time.Minute) + + // Validate that user2 still has their original nodes after user1's re-authentication + // When nodes re-authenticate with a different user's pre-auth key, NEW nodes are created + // for the new user. The original nodes remain with the original user. + var user2Nodes []*v1.Node + t.Logf("Validating user2 node persistence after user1 relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + user2Nodes, err = headscale.ListNodes("user2") + assert.NoError(ct, err, "Failed to list nodes for user2 after user1 relogin") + assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d clients after user1 relogin, got %d nodes", len(allClients)/2, len(user2Nodes)) + }, 30*time.Second, 2*time.Second, "validating user2 nodes persist after user1 relogin (should not be affected)") + + t.Logf("Validating client login states after user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) + assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) + }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after auth key user switch", client.Hostname())) + } +} + +func TestAuthKeyLogoutAndReloginSameUserExpiredKey(t *testing.T) { + IntegrationSkip(t) + + for _, https := range []bool{true, false} { + t.Run(fmt.Sprintf("with-https-%t", https), func(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + opts := []hsic.Option{ + hsic.WithTestName("pingallbyip"), + hsic.WithDERPAsIP(), + } + if https { + opts = append(opts, []hsic.Option{ + hsic.WithTLS(), + }...) + } + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, opts...) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) + + clientIPs := make(map[TailscaleClient][]netip.Addr) + for _, client := range allClients { + ips, err := client.IPs() + if err != nil { + t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) + } + clientIPs[client] = ips + } + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) + + var listNodes []*v1.Node + var nodeCountBeforeLogout int + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, len(allClients)) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list before logout") + + nodeCountBeforeLogout = len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + for _, client := range allClients { + err := client.Logout() + if err != nil { + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) + } + } + + err = scenario.WaitForTailscaleLogout() + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) + + t.Logf("all clients logged out") + + // if the server is not running with HTTPS, we have to wait a bit before + // reconnection as the newest Tailscale client has a measure that will only + // reconnect over HTTPS if they saw a noise connection previously. + // https://github.com/tailscale/tailscale/commit/1eaad7d3deb0815e8932e913ca1a862afa34db38 + // https://github.com/juanfont/headscale/issues/2164 + if !https { + //nolint:forbidigo // Intentional delay: Tailscale client requires 5 min wait before reconnecting over non-HTTPS + time.Sleep(5 * time.Minute) + } + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + for _, userName := range spec.Users { + key, err := scenario.CreatePreAuthKey(userMap[userName].GetId(), true, false) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + // Expire the key so it can't be used + _, err = headscale.Execute( + []string{ + "headscale", + "preauthkeys", + "expire", + "--id", + strconv.FormatUint(key.GetId(), 10), + }) + require.NoError(t, err) + require.NoError(t, err) + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + assert.ErrorContains(t, err, "authkey expired") + } + }) + } +} + +// TestAuthKeyDeleteKey tests Issue #2830: node with deleted auth key should still reconnect. +// Scenario from user report: "create node, delete the auth key, restart to validate it can connect" +// Steps: +// 1. Create node with auth key +// 2. DELETE the auth key from database (completely remove it) +// 3. Restart node - should successfully reconnect using MachineKey identity. +func TestAuthKeyDeleteKey(t *testing.T) { + IntegrationSkip(t) + + // Create scenario with NO nodes - we'll create the node manually so we can capture the auth key + scenario, err := NewScenario(ScenarioSpec{ + NodesPerUser: 0, // No nodes created automatically + Users: []string{"user1"}, + }) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("delkey"), hsic.WithTLS(), hsic.WithDERPAsIP()) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Get the user + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap["user1"].GetId() + + // Create a pre-auth key - we keep the full key string before it gets redacted + authKey, err := scenario.CreatePreAuthKey(userID, false, false) + require.NoError(t, err) + + authKeyString := authKey.GetKey() + authKeyID := authKey.GetId() + t.Logf("Created pre-auth key ID %d: %s", authKeyID, authKeyString) + + // Create a tailscale client and log it in with the auth key + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), authKeyString) + require.NoError(t, err) + + // Wait for the node to be registered + var user1Nodes []*v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + + user1Nodes, err = headscale.ListNodes("user1") + assert.NoError(c, err) + assert.Len(c, user1Nodes, 1) + }, 30*time.Second, 500*time.Millisecond, "waiting for node to be registered") + + nodeID := user1Nodes[0].GetId() + nodeName := user1Nodes[0].GetName() + t.Logf("Node %d (%s) created successfully with auth_key_id=%d", nodeID, nodeName, authKeyID) + + // Verify node is online + requireAllClientsOnline(t, headscale, []types.NodeID{types.NodeID(nodeID)}, true, "node should be online initially", 120*time.Second) + + // DELETE the pre-auth key using the API + t.Logf("Deleting pre-auth key ID %d using API", authKeyID) + + err = headscale.DeleteAuthKey(authKeyID) + require.NoError(t, err) + t.Logf("Successfully deleted auth key") + + // Simulate node restart (down + up) + t.Logf("Restarting node after deleting its auth key") + + err = client.Down() + require.NoError(t, err) + + // Wait for client to fully stop before bringing it back up + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + assert.Equal(c, "Stopped", status.BackendState) + }, 10*time.Second, 200*time.Millisecond, "client should be stopped") + + err = client.Up() + require.NoError(t, err) + + // Verify node comes back online + // This will FAIL without the fix because auth key validation will reject deleted key + // With the fix, MachineKey identity allows reconnection even with deleted key + requireAllClientsOnline(t, headscale, []types.NodeID{types.NodeID(nodeID)}, true, "node should reconnect after restart despite deleted key", 120*time.Second) + + t.Logf("✓ Node successfully reconnected after its auth key was deleted") +} + +// TestAuthKeyLogoutAndReloginRoutesPreserved tests that routes remain serving +// after a node logs out and re-authenticates with the same user. +// +// This test validates the fix for issue #2896: +// https://github.com/juanfont/headscale/issues/2896 +// +// Bug: When a node with already-approved routes restarts/re-authenticates, +// the routes show as "Approved" and "Available" but NOT "Serving" (Primary). +// A headscale restart would fix it, indicating a state management issue. +// +// The test scenario: +// 1. Node registers with auth key and advertises routes +// 2. Routes are auto-approved and verified as serving +// 3. Node logs out +// 4. Node re-authenticates with same auth key +// 5. Routes should STILL be serving (this is where the bug manifests). +func TestAuthKeyLogoutAndReloginRoutesPreserved(t *testing.T) { + IntegrationSkip(t) + + user := "routeuser" + advertiseRoute := "10.55.0.0/24" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{user}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{ + tsic.WithAcceptRoutes(), + // Advertise route on initial login + tsic.WithExtraLoginArgs([]string{"--advertise-routes=" + advertiseRoute}), + }, + hsic.WithTestName("routelogout"), + hsic.WithTLS(), + hsic.WithACLPolicy( + &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{policyv2.Wildcard}, + Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}}, + }, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + netip.MustParsePrefix(advertiseRoute): {ptr.To(policyv2.Username(user + "@test.no"))}, + }, + }, + }, + ), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + require.Len(t, allClients, 1) + + client := allClients[0] + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Step 1: Verify initial route is advertised, approved, and SERVING + t.Logf("Step 1: Verifying initial route is advertised, approved, and SERVING at %s", time.Now().Format(TimestampFormat)) + + var initialNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + initialNode = nodes[0] + // Check: 1 announced, 1 approved, 1 serving (subnet route) + assert.Lenf(c, initialNode.GetAvailableRoutes(), 1, + "Node should have 1 available route, got %v", initialNode.GetAvailableRoutes()) + assert.Lenf(c, initialNode.GetApprovedRoutes(), 1, + "Node should have 1 approved route, got %v", initialNode.GetApprovedRoutes()) + assert.Lenf(c, initialNode.GetSubnetRoutes(), 1, + "Node should have 1 serving (subnet) route, got %v - THIS IS THE BUG if empty", initialNode.GetSubnetRoutes()) + assert.Contains(c, initialNode.GetSubnetRoutes(), advertiseRoute, + "Subnet routes should contain %s", advertiseRoute) + } + }, 30*time.Second, 500*time.Millisecond, "initial route should be serving") + + require.NotNil(t, initialNode, "Initial node should be found") + initialNodeID := initialNode.GetId() + t.Logf("Initial node ID: %d, Available: %v, Approved: %v, Serving: %v", + initialNodeID, initialNode.GetAvailableRoutes(), initialNode.GetApprovedRoutes(), initialNode.GetSubnetRoutes()) + + // Step 2: Logout + t.Logf("Step 2: Logging out at %s", time.Now().Format(TimestampFormat)) + + err = client.Logout() + require.NoError(t, err) + + // Wait for logout to complete + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout") + }, 30*time.Second, 1*time.Second, "waiting for logout to complete") + + t.Logf("Logout completed, node should still exist in database") + + // Verify node still exists (routes should still be in DB) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Node should persist in database after logout") + }, 10*time.Second, 500*time.Millisecond, "node should persist after logout") + + // Step 3: Re-authenticate with the SAME user (using auth key) + t.Logf("Step 3: Re-authenticating with same user at %s", time.Now().Format(TimestampFormat)) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + key, err := scenario.CreatePreAuthKey(userMap[user].GetId(), true, false) + require.NoError(t, err) + + // Re-login - the container already has extraLoginArgs with --advertise-routes + // from the initial setup, so routes will be advertised on re-login + err = scenario.RunTailscaleUp(user, headscale.GetEndpoint(), key.GetKey()) + require.NoError(t, err) + + // Wait for client to be running + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after relogin") + }, 30*time.Second, 1*time.Second, "waiting for relogin to complete") + + t.Logf("Re-authentication completed at %s", time.Now().Format(TimestampFormat)) + + // Step 4: THE CRITICAL TEST - Verify routes are STILL SERVING after re-authentication + t.Logf("Step 4: Verifying routes are STILL SERVING after re-authentication at %s", time.Now().Format(TimestampFormat)) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should still have exactly 1 node after relogin") + + if len(nodes) == 1 { + node := nodes[0] + t.Logf("After relogin - Available: %v, Approved: %v, Serving: %v", + node.GetAvailableRoutes(), node.GetApprovedRoutes(), node.GetSubnetRoutes()) + + // This is where issue #2896 manifests: + // - Available shows the route (from Hostinfo.RoutableIPs) + // - Approved shows the route (from ApprovedRoutes) + // - BUT Serving (SubnetRoutes/PrimaryRoutes) is EMPTY! + assert.Lenf(c, node.GetAvailableRoutes(), 1, + "Node should have 1 available route after relogin, got %v", node.GetAvailableRoutes()) + assert.Lenf(c, node.GetApprovedRoutes(), 1, + "Node should have 1 approved route after relogin, got %v", node.GetApprovedRoutes()) + assert.Lenf(c, node.GetSubnetRoutes(), 1, + "BUG #2896: Node should have 1 SERVING route after relogin, got %v", node.GetSubnetRoutes()) + assert.Contains(c, node.GetSubnetRoutes(), advertiseRoute, + "BUG #2896: Subnet routes should contain %s after relogin", advertiseRoute) + + // Also verify node ID was preserved (same node, not new registration) + assert.Equal(c, initialNodeID, node.GetId(), + "Node ID should be preserved after same-user relogin") + } + }, 30*time.Second, 500*time.Millisecond, + "BUG #2896: routes should remain SERVING after logout/relogin with same user") + + t.Logf("Test completed - verifying issue #2896 fix") +} diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 7a0ed9c7..359dd456 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -1,87 +1,446 @@ package integration import ( - "context" - "crypto/tls" - "errors" - "fmt" - "io" - "log" - "net" - "net/http" + "maps" "net/netip" + "net/url" + "sort" "strconv" "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" - "github.com/ory/dockertest/v3" - "github.com/ory/dockertest/v3/docker" + "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/ipn/ipnstate" + "tailscale.com/tailcfg" ) -const ( - dockerContextPath = "../." - hsicOIDCMockHashLength = 6 - defaultAccessTTL = 10 * time.Minute -) - -var errStatusCodeNotOK = errors.New("status code not OK") - -type AuthOIDCScenario struct { - *Scenario - - mockOIDC *dockertest.Resource -} - func TestOIDCAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() - baseScenario, err := NewScenario() - assertNoErr(t, err) - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, - } - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), + // Logins to MockOIDC is served by a queue with a strict order, + // if we use more than one node per user, the order of the logins + // will not be deterministic and the test will fail. + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }, } - oidcConfig, err := scenario.runMockOIDC(defaultAccessTTL) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + + defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), "CREDENTIALS_DIRECTORY_TEST": "/tmp", "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", - "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), } - err = scenario.CreateHeadscaleEnv( - spec, + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, hsic.WithTestName("oidcauthping"), hsic.WithConfigEnv(oidcMap), - hsic.WithHostnameAsServerURL(), - hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(oidcConfig.ClientSecret)), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + listUsers, err := headscale.ListUsers() + require.NoError(t, err) + + want := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@test.no", + }, + { + Id: 2, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + { + Id: 3, + Name: "user2", + Email: "user2@test.no", + }, + { + Id: 4, + Name: "user2", + Email: "", // Unverified + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } +} + +// TestOIDCExpireNodesBasedOnTokenExpiry validates that nodes correctly transition to NeedsLogin +// state when their OIDC tokens expire. This test uses a short token TTL to validate the +// expiration behavior without waiting for production-length timeouts. +// +// The test verifies: +// - Nodes can successfully authenticate via OIDC and establish connectivity +// - When OIDC tokens expire, nodes transition to NeedsLogin state +// - The expiration is based on individual token issue times, not a global timer +// +// Known timing considerations: +// - Nodes may expire at different times due to sequential login processing +// - The test must account for login time spread between first and last node. +func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { + IntegrationSkip(t) + + shortAccessTTL := 5 * time.Minute + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", false), + }, + OIDCAccessTTL: shortAccessTTL, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "HEADSCALE_OIDC_CLIENT_SECRET": scenario.mockOIDC.ClientSecret(), + "HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcexpirenodes"), + hsic.WithConfigEnv(oidcMap), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + // Record when sync completes to better estimate token expiry timing + syncCompleteTime := time.Now() + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + loginDuration := time.Since(syncCompleteTime) + t.Logf("Login and sync completed in %v", loginDuration) + + // assertClientsState(t, allClients) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps)) + + // Wait for OIDC token expiry and verify all nodes transition to NeedsLogin. + // We add extra time to account for: + // - Sequential login processing causing different token issue times + // - Network and processing delays + // - Safety margin for test reliability + loginTimeSpread := 1 * time.Minute // Account for sequential login delays + safetyBuffer := 30 * time.Second // Additional safety margin + totalWaitTime := shortAccessTTL + loginTimeSpread + safetyBuffer + + t.Logf("Waiting %v for OIDC tokens to expire (TTL: %v, spread: %v, buffer: %v)", + totalWaitTime, shortAccessTTL, loginTimeSpread, safetyBuffer) + + // EventuallyWithT retries the test function until it passes or times out. + // IMPORTANT: Use 'ct' (CollectT) for all assertions inside the function, not 't'. + // Using 't' would cause immediate test failure without retries, defeating the purpose + // of EventuallyWithT which is designed to handle timing-dependent conditions. + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check each client's status individually to provide better diagnostics + expiredCount := 0 + for _, client := range allClients { + status, err := client.Status() + if assert.NoError(ct, err, "failed to get status for client %s", client.Hostname()) { + if status.BackendState == "NeedsLogin" { + expiredCount++ + } + } + } + + // Log progress for debugging + if expiredCount < len(allClients) { + t.Logf("Token expiry progress: %d/%d clients in NeedsLogin state", expiredCount, len(allClients)) + } + + // All clients must be in NeedsLogin state + assert.Equal(ct, len(allClients), expiredCount, + "expected all %d clients to be in NeedsLogin state, but only %d are", + len(allClients), expiredCount) + + // Only check detailed logout state if all clients are expired + if expiredCount == len(allClients) { + assertTailscaleNodesLogout(ct, allClients) + } + }, totalWaitTime, 5*time.Second) +} + +func TestOIDC024UserCreation(t *testing.T) { + IntegrationSkip(t) + + tests := []struct { + name string + config map[string]string + emailVerified bool + cliUsers []string + oidcUsers []string + want func(iss string) []*v1.User + }{ + { + name: "no-migration-verified-email", + emailVerified: true, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []*v1.User { + return []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@test.no", + }, + { + Id: 2, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: 3, + Name: "user2", + Email: "user2@test.no", + }, + { + Id: 4, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "no-migration-not-verified-email", + emailVerified: false, + cliUsers: []string{"user1", "user2"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []*v1.User { + return []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@test.no", + }, + { + Id: 2, + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: 3, + Name: "user2", + Email: "user2@test.no", + }, + { + Id: 4, + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + { + name: "migration-no-strip-domains-not-verified-email", + emailVerified: false, + cliUsers: []string{"user1.headscale.net", "user2.headscale.net"}, + oidcUsers: []string{"user1", "user2"}, + want: func(iss string) []*v1.User { + return []*v1.User{ + { + Id: 1, + Name: "user1.headscale.net", + Email: "user1.headscale.net@test.no", + }, + { + Id: 2, + Name: "user1", + Provider: "oidc", + ProviderId: iss + "/user1", + }, + { + Id: 3, + Name: "user2.headscale.net", + Email: "user2.headscale.net@test.no", + }, + { + Id: 4, + Name: "user2", + Provider: "oidc", + ProviderId: iss + "/user2", + }, + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: 1, + } + spec.Users = append(spec.Users, tt.cliUsers...) + + for _, user := range tt.oidcUsers { + spec.OIDCUsers = append(spec.OIDCUsers, oidcMockUser(user, tt.emailVerified)) + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + maps.Copy(oidcMap, tt.config) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcmigration"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + ) + requireNoErrHeadscaleEnv(t, err) + + // Ensure that the nodes have logged in, this is what + // triggers user creation via OIDC. + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + want := tt.want(scenario.mockOIDC.Issuer()) + + listUsers, err := headscale.ListUsers() + require.NoError(t, err) + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(want, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Errorf("unexpected users: %s", diff) + } + }) + } +} + +func TestOIDCAuthenticationWithPKCE(t *testing.T) { + IntegrationSkip(t) + + // Single user with one node for testing PKCE flow + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_PKCE_ENABLED": "1", // Enable PKCE + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthpkce"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + ) + requireNoErrHeadscaleEnv(t, err) + + // Get all clients and verify they can connect + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -91,301 +450,1468 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } -// This test is really flaky. -func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { +// TestOIDCReloginSameNodeNewUser tests the scenario where: +// 1. A Tailscale client logs in with user1 (creates node1 for user1) +// 2. The same client logs out and logs in with user2 (creates node2 for user2) +// 3. The same client logs out and logs in with user1 again (reuses node1, node2 remains) +// This validates that OIDC relogin properly handles node reuse and cleanup. +func TestOIDCReloginSameNodeNewUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() - shortAccessTTL := 5 * time.Minute - - baseScenario, err := NewScenario() - assertNoErr(t, err) - - baseScenario.pool.MaxWait = 5 * time.Minute - - scenario := AuthOIDCScenario{ - Scenario: baseScenario, - } - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": 3, - } - - oidcConfig, err := scenario.runMockOIDC(shortAccessTTL) - assertNoErrf(t, "failed to run mock OIDC server: %s", err) + // Create no nodes and no users + scenario, err := NewScenario(ScenarioSpec{ + // First login creates the first OIDC user + // Second login logs in the same node, which creates a new node + // Third login logs in the same node back into the original user + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + oidcMockUser("user2", true), + oidcMockUser("user1", true), + }, + }) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) oidcMap := map[string]string{ - "HEADSCALE_OIDC_ISSUER": oidcConfig.Issuer, - "HEADSCALE_OIDC_CLIENT_ID": oidcConfig.ClientID, - "HEADSCALE_OIDC_CLIENT_SECRET": oidcConfig.ClientSecret, - "HEADSCALE_OIDC_STRIP_EMAIL_DOMAIN": fmt.Sprintf("%t", oidcConfig.StripEmaildomain), - "HEADSCALE_OIDC_USE_EXPIRY_FROM_TOKEN": "1", + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } - err = scenario.CreateHeadscaleEnv( - spec, - hsic.WithTestName("oidcexpirenodes"), + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), hsic.WithConfigEnv(oidcMap), - hsic.WithHostnameAsServerURL(), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) - allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + headscale, err := scenario.Headscale() + require.NoError(t, err) - allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) + require.NoError(t, err) - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) - allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { - return x.String() + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Validating initial user creation at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users during initial validation") + assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("User validation failed after first login - unexpected users: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") + + t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var listNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during initial validation") + assert.Len(ct, listNodes, 1, "Expected exactly 1 node after first login, got %d", len(listNodes)) + }, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login") + + // Collect expected node IDs for validation after user1 initial login + expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status := ts.MustStatus() + assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) + assert.NoError(ct, err, "Failed to parse node ID from status") + }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + + // Validate initial connection state for user1 + validateInitialConnection(t, headscale, expectedNodes) + + // Log out user1 and log in user2, this should create a new node + // for user2, the node should have the same machine key and + // a new node key. + err = ts.Logout() + require.NoError(t, err) + + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + require.NoError(t, err) + + // Wait for logout to complete and then do second logout + t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before user2 login") + + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Validating user2 creation at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users after user2 login") + assert.Len(ct, listUsers, 2, "Expected exactly 2 users after user2 login, got %d users", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + { + Id: 2, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("User validation failed after user2 login - expected both user1 and user2: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating both user1 and user2 exist after second OIDC login") + + var listNodesAfterNewUserLogin []*v1.Node + // First, wait for the new node to be created + t.Logf("Waiting for user2 node creation at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterNewUserLogin, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after user2 login") + // We might temporarily have more than 2 nodes during cleanup, so check for at least 2 + assert.GreaterOrEqual(ct, len(listNodesAfterNewUserLogin), 2, "Should have at least 2 nodes after user2 login, got %d (may include temporary nodes during cleanup)", len(listNodesAfterNewUserLogin)) + }, 30*time.Second, 1*time.Second, "waiting for user2 node creation (allowing temporary extra nodes during cleanup)") + + // Then wait for cleanup to stabilize at exactly 2 nodes + t.Logf("Waiting for node cleanup stabilization at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterNewUserLogin, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during cleanup validation") + assert.Len(ct, listNodesAfterNewUserLogin, 2, "Should have exactly 2 nodes after cleanup (1 for user1, 1 for user2), got %d nodes", len(listNodesAfterNewUserLogin)) + + // Validate that both nodes have the same machine key but different node keys + if len(listNodesAfterNewUserLogin) >= 2 { + // Machine key is the same as the "machine" has not changed, + // but Node key is not as it is a new node + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Machine key should be preserved from original node") + assert.Equal(ct, listNodesAfterNewUserLogin[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "Both nodes should share the same machine key") + assert.NotEqual(ct, listNodesAfterNewUserLogin[0].GetNodeKey(), listNodesAfterNewUserLogin[1].GetNodeKey(), "Node keys should be different between user1 and user2 nodes") + } + }, 90*time.Second, 2*time.Second, "waiting for node count stabilization at exactly 2 nodes after user2 login") + + // Security validation: Only user2's node should be active after user switch + var activeUser2NodeID types.NodeID + for _, node := range listNodesAfterNewUserLogin { + if node.GetUser().GetId() == 2 { // user2 + activeUser2NodeID = types.NodeID(node.GetId()) + t.Logf("Active user2 node: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break + } + } + + // Validate only user2's node is online (security requirement) + t.Logf("Validating only user2 node is online at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user2 node is online + if node, exists := nodeStore[activeUser2NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User2 node should have online status") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User2 node should be online after login") + } + } else { + assert.Fail(c, "User2 node not found in nodestore") + } + }, 60*time.Second, 2*time.Second, "validating only user2 node is online after user switch") + + // Before logging out user2, validate we have exactly 2 nodes and both are stable + t.Logf("Pre-logout validation: checking node stability at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + currentNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes before user2 logout") + assert.Len(ct, currentNodes, 2, "Should have exactly 2 stable nodes before user2 logout, got %d", len(currentNodes)) + + // Validate node stability - ensure no phantom nodes + for i, node := range currentNodes { + assert.NotNil(ct, node.GetUser(), "Node %d should have a valid user before logout", i) + assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should have a valid machine key before logout", i) + t.Logf("Pre-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...") + } + }, 60*time.Second, 2*time.Second, "validating stable node count and integrity before user2 logout") + + // Log out user2, and log into user1, no new node should be created, + // the node should now "become" node1 again + err = ts.Logout() + require.NoError(t, err) + + t.Logf("Logged out take one") + t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") + + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + require.NoError(t, err) + + t.Logf("Logged out take two") + t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") + + // Wait for logout to complete and then do second logout + t.Logf("Waiting for user2 logout completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the first logout completed + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during user2 logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after user2 logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user2 logout to complete before user1 relogin") + + // Before logging back in, ensure we still have exactly 2 nodes + // Note: We skip validateLogoutComplete here since it expects all nodes to be offline, + // but in OIDC scenario we maintain both nodes in DB with only active user online + + // Additional validation that nodes are properly maintained during logout + t.Logf("Post-logout validation: checking node persistence at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + currentNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after user2 logout") + assert.Len(ct, currentNodes, 2, "Should still have exactly 2 nodes after user2 logout (nodes should persist), got %d", len(currentNodes)) + + // Ensure both nodes are still valid (not cleaned up incorrectly) + for i, node := range currentNodes { + assert.NotNil(ct, node.GetUser(), "Node %d should still have a valid user after user2 logout", i) + assert.NotEmpty(ct, node.GetMachineKey(), "Node %d should still have a valid machine key after user2 logout", i) + t.Logf("Post-logout node %d: User=%s, MachineKey=%s", i, node.GetUser().GetName(), node.GetMachineKey()[:16]+"...") + } + }, 60*time.Second, 2*time.Second, "validating node persistence and integrity after user2 logout") + + // We do not actually "change" the user here, it is done by logging in again + // as the OIDC mock server is kind of like a stack, and the next user is + // prepared and ready to go. + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during user1 relogin validation") + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (final login)") + + t.Logf("Logged back in") + t.Log("timestamp: " + time.Now().Format(TimestampFormat) + "\n") + + t.Logf("Final validation: checking user persistence at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users during final validation") + assert.Len(ct, listUsers, 2, "Should still have exactly 2 users after user1 relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + { + Id: 2, + Name: "user2", + Email: "user2@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user2", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("Final user validation failed - both users should persist after relogin cycle: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating user persistence after complete relogin cycle (user1->user2->user1)") + + var listNodesAfterLoggingBackIn []*v1.Node + // Wait for login to complete and nodes to stabilize + t.Logf("Final node validation: checking node stability after user1 relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodesAfterLoggingBackIn, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during final validation") + + // Allow for temporary instability during login process + if len(listNodesAfterLoggingBackIn) < 2 { + ct.Errorf("Not enough nodes yet during final validation, got %d, want at least 2", len(listNodesAfterLoggingBackIn)) + return + } + + // Final check should have exactly 2 nodes + assert.Len(ct, listNodesAfterLoggingBackIn, 2, "Should have exactly 2 nodes after complete relogin cycle, got %d", len(listNodesAfterLoggingBackIn)) + + // Validate that the machine we had when we logged in the first time, has the same + // machine key, but a different ID than the newly logged in version of the same + // machine. + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[0].GetMachineKey(), "Original user1 machine key should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetNodeKey(), listNodesAfterNewUserLogin[0].GetNodeKey(), "Original user1 node key should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[0].GetId(), "Original user1 node ID should match user1 node after user switch") + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterNewUserLogin[1].GetMachineKey(), "User1 and user2 nodes should share the same machine key") + assert.NotEqual(ct, listNodes[0].GetId(), listNodesAfterNewUserLogin[1].GetId(), "User1 and user2 nodes should have different node IDs") + assert.NotEqual(ct, listNodes[0].GetUser().GetId(), listNodesAfterNewUserLogin[1].GetUser().GetId(), "User1 and user2 nodes should belong to different users") + + // Even tho we are logging in again with the same user, the previous key has been expired + // and a new one has been generated. The node entry in the database should be the same + // as the user + machinekey still matches. + assert.Equal(ct, listNodes[0].GetMachineKey(), listNodesAfterLoggingBackIn[0].GetMachineKey(), "Machine key should remain consistent after user1 relogin") + assert.NotEqual(ct, listNodes[0].GetNodeKey(), listNodesAfterLoggingBackIn[0].GetNodeKey(), "Node key should be regenerated after user1 relogin") + assert.Equal(ct, listNodes[0].GetId(), listNodesAfterLoggingBackIn[0].GetId(), "Node ID should be preserved for user1 after relogin") + + // The "logged back in" machine should have the same machinekey but a different nodekey + // than the version logged in with a different user. + assert.Equal(ct, listNodesAfterLoggingBackIn[0].GetMachineKey(), listNodesAfterLoggingBackIn[1].GetMachineKey(), "Both final nodes should share the same machine key") + assert.NotEqual(ct, listNodesAfterLoggingBackIn[0].GetNodeKey(), listNodesAfterLoggingBackIn[1].GetNodeKey(), "Final nodes should have different node keys for different users") + + t.Logf("Final validation complete - node counts and key relationships verified at %s", time.Now().Format(TimestampFormat)) + }, 60*time.Second, 2*time.Second, "validating final node state after complete user1->user2->user1 relogin cycle with detailed key validation") + + // Security validation: Only user1's node should be active after relogin + var activeUser1NodeID types.NodeID + for _, node := range listNodesAfterLoggingBackIn { + if node.GetUser().GetId() == 1 { // user1 + activeUser1NodeID = types.NodeID(node.GetId()) + t.Logf("Active user1 node after relogin: %d (User: %s)", node.GetId(), node.GetUser().GetName()) + break + } + } + + // Validate only user1's node is online (security requirement) + t.Logf("Validating only user1 node is online after relogin at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user1 node is online + if node, exists := nodeStore[activeUser1NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User1 node should have online status after relogin") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User1 node should be online after relogin") + } + } else { + assert.Fail(c, "User1 node not found in nodestore after relogin") + } + }, 60*time.Second, 2*time.Second, "validating only user1 node is online after final relogin") +} + +// TestOIDCFollowUpUrl validates the follow-up login flow +// Prerequisites: +// - short TTL for the registration cache via HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION +// Scenario: +// - client starts a login process and gets initial AuthURL +// - time.sleep(HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION + 30 secs) waits for the cache to expire +// - client checks its status to verify that AuthUrl has changed (by followup URL) +// - client uses the new AuthURL to log in. It should complete successfully. +func TestOIDCFollowUpUrl(t *testing.T) { + IntegrationSkip(t) + + // Create no nodes and no users + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + // smaller cache expiration time to quickly expire AuthURL + "HEADSCALE_TUNING_REGISTER_CACHE_CLEANUP": "10s", + "HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION": "1m30s", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + listUsers, err := headscale.ListUsers() + require.NoError(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // wait for the registration cache to expire + // a little bit more than HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION (1m30s) + //nolint:forbidigo // Intentional delay: must wait for real-time cache expiration (HEADSCALE_TUNING_REGISTER_CACHE_EXPIRATION=1m30s) + time.Sleep(2 * time.Minute) + + var newUrl *url.URL + assert.EventuallyWithT(t, func(c *assert.CollectT) { + st, err := ts.Status() + assert.NoError(c, err) + assert.Equal(c, "NeedsLogin", st.BackendState) + + // get new AuthURL from daemon + newUrl, err = url.Parse(st.AuthURL) + assert.NoError(c, err) + + assert.NotEqual(c, u.String(), st.AuthURL, "AuthURL should change") + }, 10*time.Second, 200*time.Millisecond, "Waiting for registration cache to expire and status to reflect NeedsLogin") + + _, err = doLoginURL(ts.Hostname(), newUrl) + require.NoError(t, err) + + listUsers, err = headscale.ListUsers() + require.NoError(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + listNodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, 1) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login") +} + +// TestOIDCMultipleOpenedLoginUrls tests the scenario: +// - client (mostly Windows) opens multiple browser tabs with different login URLs +// - client performs auth on the first opened browser tab +// +// This test makes sure that cookies are still valid for the first browser tab. +func TestOIDCMultipleOpenedLoginUrls(t *testing.T) { + IntegrationSkip(t) + + scenario, err := NewScenario( + ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, + }, + ) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcauthrelog"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + listUsers, err := headscale.ListUsers() + require.NoError(t, err) + assert.Empty(t, listUsers) + + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + u1, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + u2, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // make sure login URLs are different + require.NotEqual(t, u1.String(), u2.String()) + + loginClient, err := newLoginHTTPClient(ts.Hostname()) + require.NoError(t, err) + + // open the first login URL "in browser" + _, redirect1, err := doLoginURLWithClient(ts.Hostname(), u1, loginClient, false) + require.NoError(t, err) + // open the second login URL "in browser" + _, redirect2, err := doLoginURLWithClient(ts.Hostname(), u2, loginClient, false) + require.NoError(t, err) + + // two valid redirects with different state/nonce params + require.NotEqual(t, redirect1.String(), redirect2.String()) + + // complete auth with the first opened "browser tab" + _, redirect1, err = doLoginURLWithClient(ts.Hostname(), redirect1, loginClient, true) + require.NoError(t, err) + + listUsers, err = headscale.ListUsers() + require.NoError(t, err) + assert.Len(t, listUsers, 1) + + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice( + listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }, + ) + + if diff := cmp.Diff( + wantUsers, + listUsers, + cmpopts.IgnoreUnexported(v1.User{}), + cmpopts.IgnoreFields(v1.User{}, "CreatedAt"), + ); diff != "" { + t.Fatalf("unexpected users: %s", diff) + } + + assert.EventuallyWithT( + t, func(c *assert.CollectT) { + listNodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, listNodes, 1) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected node list after OIDC login", + ) +} + +// TestOIDCReloginSameNodeSameUser tests the scenario where a single Tailscale client +// authenticates using OIDC (OpenID Connect), logs out, and then logs back in as the same user. +// +// OIDC is an authentication layer built on top of OAuth 2.0 that allows users to authenticate +// using external identity providers (like Google, Microsoft, etc.) rather than managing +// credentials directly in headscale. +// +// This test validates the "same user relogin" behavior in headscale's OIDC authentication flow: +// - A single client authenticates via OIDC as user1 +// - The client logs out, ending the session +// - The same client logs back in via OIDC as the same user (user1) +// - The test verifies that the user account persists correctly +// - The test verifies that the machine key is preserved (since it's the same physical device) +// - The test verifies that the node ID is preserved (since it's the same user on the same device) +// - The test verifies that the node key is regenerated (since it's a new session) +// - The test verifies that the client comes back online properly +// +// This scenario is important for normal user workflows where someone might need to restart +// their Tailscale client, reboot their computer, or temporarily disconnect and reconnect. +// It ensures that headscale properly handles session management while preserving device +// identity and user associations. +// +// The test uses a single node scenario (unlike multi-node tests) to focus specifically on +// the authentication and session management aspects rather than network topology changes. +// The "same node" in the name refers to the same physical device/client, while "same user" +// refers to authenticating with the same OIDC identity. +func TestOIDCReloginSameNodeSameUser(t *testing.T) { + IntegrationSkip(t) + + // Create scenario with same user for both login attempts + scenario, err := NewScenario(ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), // Initial login + oidcMockUser("user1", true), // Relogin with same user + }, + }) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + } + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcsameuser"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) + require.NoError(t, err) + + // Initial login as user1 + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Validating initial user1 creation at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users during initial validation") + assert.Len(ct, listUsers, 1, "Expected exactly 1 user after first login, got %d", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("User validation failed after first login - unexpected users: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating user1 creation after initial OIDC login") + + t.Logf("Validating initial node creation at %s", time.Now().Format(TimestampFormat)) + var initialNodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + initialNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during initial validation") + assert.Len(ct, initialNodes, 1, "Expected exactly 1 node after first login, got %d", len(initialNodes)) + }, 30*time.Second, 1*time.Second, "validating initial node creation for user1 after OIDC login") + + // Collect expected node IDs for validation after user1 initial login + expectedNodes := make([]types.NodeID, 0, 1) + var nodeID uint64 + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status := ts.MustStatus() + assert.NotEmpty(ct, status.Self.ID, "Node ID should be populated in status") + var err error + nodeID, err = strconv.ParseUint(string(status.Self.ID), 10, 64) + assert.NoError(ct, err, "Failed to parse node ID from status") + }, 30*time.Second, 1*time.Second, "waiting for node ID to be populated in status after initial login") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + + // Validate initial connection state for user1 + validateInitialConnection(t, headscale, expectedNodes) + + // Store initial node keys for comparison + initialMachineKey := initialNodes[0].GetMachineKey() + initialNodeKey := initialNodes[0].GetNodeKey() + initialNodeID := initialNodes[0].GetId() + + // Logout user1 + err = ts.Logout() + require.NoError(t, err) + + // TODO(kradalby): Not sure why we need to logout twice, but it fails and + // logs in immediately after the first logout and I cannot reproduce it + // manually. + err = ts.Logout() + require.NoError(t, err) + + // Wait for logout to complete + t.Logf("Waiting for user1 logout completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Check that the logout completed + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during logout validation") + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 logout to complete before same-user relogin") + + // Validate node persistence during logout (node should remain in DB) + t.Logf("Validating node persistence during logout at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listNodes, err := headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during logout validation") + assert.Len(ct, listNodes, 1, "Should still have exactly 1 node during logout (node should persist in DB), got %d", len(listNodes)) + }, 30*time.Second, 1*time.Second, "validating node persistence in database during same-user logout") + + // Login again as the same user (user1) + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Waiting for user1 relogin completion at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err, "Failed to get client status during relogin validation") + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after user1 relogin, got %s", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for user1 relogin to complete (same user)") + + t.Logf("Final validation: checking user persistence after same-user relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + listUsers, err := headscale.ListUsers() + assert.NoError(ct, err, "Failed to list users during final validation") + assert.Len(ct, listUsers, 1, "Should still have exactly 1 user after same-user relogin, got %d", len(listUsers)) + wantUsers := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@headscale.net", + Provider: "oidc", + ProviderId: scenario.mockOIDC.Issuer() + "/user1", + }, + } + + sort.Slice(listUsers, func(i, j int) bool { + return listUsers[i].GetId() < listUsers[j].GetId() + }) + + if diff := cmp.Diff(wantUsers, listUsers, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + ct.Errorf("Final user validation failed - user1 should persist after same-user relogin: %s", diff) + } + }, 30*time.Second, 1*time.Second, "validating user1 persistence after same-user OIDC relogin cycle") + + var finalNodes []*v1.Node + t.Logf("Final node validation: checking node stability after same-user relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + finalNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes during final validation") + assert.Len(ct, finalNodes, 1, "Should have exactly 1 node after same-user relogin, got %d", len(finalNodes)) + + // Validate node key behavior for same user relogin + finalNode := finalNodes[0] + + // Machine key should be preserved (same physical machine) + assert.Equal(ct, initialMachineKey, finalNode.GetMachineKey(), "Machine key should be preserved for same user same node relogin") + + // Node ID should be preserved (same user, same machine) + assert.Equal(ct, initialNodeID, finalNode.GetId(), "Node ID should be preserved for same user same node relogin") + + // Node key should be regenerated (new session after logout) + assert.NotEqual(ct, initialNodeKey, finalNode.GetNodeKey(), "Node key should be regenerated after logout/relogin even for same user") + + t.Logf("Final validation complete - same user relogin key relationships verified at %s", time.Now().Format(TimestampFormat)) + }, 60*time.Second, 2*time.Second, "validating final node state after same-user OIDC relogin cycle with key preservation validation") + + // Security validation: user1's node should be active after relogin + activeUser1NodeID := types.NodeID(finalNodes[0].GetId()) + t.Logf("Validating user1 node is online after same-user relogin at %s", time.Now().Format(TimestampFormat)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + + // Check user1 node is online + if node, exists := nodeStore[activeUser1NodeID]; exists { + assert.NotNil(c, node.IsOnline, "User1 node should have online status after same-user relogin") + if node.IsOnline != nil { + assert.True(c, *node.IsOnline, "User1 node should be online after same-user relogin") + } + } else { + assert.Fail(c, "User1 node not found in nodestore after same-user relogin") + } + }, 60*time.Second, 2*time.Second, "validating user1 node is online after same-user OIDC relogin") +} + +// TestOIDCExpiryAfterRestart validates that node expiry is preserved +// when a tailscaled client restarts and reconnects to headscale. +// +// This test reproduces the bug reported in https://github.com/juanfont/headscale/issues/2862 +// where OIDC expiry was reset to 0001-01-01 00:00:00 after tailscaled restart. +// +// Test flow: +// 1. Node logs in with OIDC (gets 72h expiry) +// 2. Verify expiry is set correctly in headscale +// 3. Restart tailscaled container (simulates daemon restart) +// 4. Wait for reconnection +// 5. Verify expiry is still set correctly (not zero). +func TestOIDCExpiryAfterRestart(t *testing.T) { + IntegrationSkip(t) + + scenario, err := NewScenario(ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), + }, }) - success := pingAllHelper(t, allClients, allAddrs) - t.Logf("%d successful pings out of %d (before expiry)", success, len(allClients)*len(allIps)) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) - // This is not great, but this sadly is a time dependent test, so the - // safe thing to do is wait out the whole TTL time before checking if - // the clients have logged out. The Wait function cant do it itself - // as it has an upper bound of 1 min. - time.Sleep(shortAccessTTL) + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", + "HEADSCALE_OIDC_EXPIRY": "72h", + } - assertTailscaleNodesLogout(t, allClients) + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("oidcexpiry"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Create and login tailscale client + ts, err := scenario.CreateTailscaleNode("unstable", tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) + require.NoError(t, err) + + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + t.Logf("Validating initial login and expiry at %s", time.Now().Format(TimestampFormat)) + + // Verify initial expiry is set + var initialExpiry time.Time + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 1) + + node := nodes[0] + assert.NotNil(ct, node.GetExpiry(), "Expiry should be set after OIDC login") + + if node.GetExpiry() != nil { + expiryTime := node.GetExpiry().AsTime() + assert.False(ct, expiryTime.IsZero(), "Expiry should not be zero time") + + initialExpiry = expiryTime + t.Logf("Initial expiry set to: %v (expires in %v)", expiryTime, time.Until(expiryTime)) + } + }, 30*time.Second, 1*time.Second, "validating initial expiry after OIDC login") + + // Now restart the tailscaled container + t.Logf("Restarting tailscaled container at %s", time.Now().Format(TimestampFormat)) + + err = ts.Restart() + require.NoError(t, err, "Failed to restart tailscaled container") + + t.Logf("Tailscaled restarted, waiting for reconnection at %s", time.Now().Format(TimestampFormat)) + + // Wait for the node to come back online + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + if !assert.NoError(ct, err) { + return + } + + if !assert.NotNil(ct, status) { + return + } + + assert.Equal(ct, "Running", status.BackendState) + }, 60*time.Second, 2*time.Second, "waiting for tailscale to reconnect after restart") + + // THE CRITICAL TEST: Verify expiry is still set correctly after restart + t.Logf("Validating expiry preservation after restart at %s", time.Now().Format(TimestampFormat)) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 1, "Should still have exactly 1 node after restart") + + node := nodes[0] + assert.NotNil(ct, node.GetExpiry(), "Expiry should NOT be nil after restart") + + if node.GetExpiry() != nil { + expiryTime := node.GetExpiry().AsTime() + + // This is the bug check - expiry should NOT be zero time + assert.False(ct, expiryTime.IsZero(), + "BUG: Expiry was reset to zero time after tailscaled restart! This is issue #2862") + + // Expiry should be exactly the same as before restart + assert.Equal(ct, initialExpiry, expiryTime, + "Expiry should be exactly the same after restart, got %v, expected %v", + expiryTime, initialExpiry) + + t.Logf("SUCCESS: Expiry preserved after restart: %v (expires in %v)", + expiryTime, time.Until(expiryTime)) + } + }, 30*time.Second, 1*time.Second, "validating expiry preservation after restart") } -func (s *AuthOIDCScenario) CreateHeadscaleEnv( - users map[string]int, - opts ...hsic.Option, -) error { - headscale, err := s.Headscale(opts...) - if err != nil { - return err - } +// TestOIDCACLPolicyOnJoin validates that ACL policies are correctly applied +// to newly joined OIDC nodes without requiring a client restart. +// +// This test validates the fix for issue #2888: +// https://github.com/juanfont/headscale/issues/2888 +// +// Bug: Nodes joining via OIDC authentication did not get the appropriate ACL +// policy applied until they restarted their client. This was a regression +// introduced in v0.27.0. +// +// The test scenario: +// 1. Creates a CLI user (gateway) with a node advertising a route +// 2. Sets up ACL policy allowing all nodes to access advertised routes +// 3. OIDC user authenticates and joins with a new node +// 4. Verifies that the OIDC user's node IMMEDIATELY sees the advertised route +// +// Expected behavior: +// - Without fix: OIDC node cannot see the route (PrimaryRoutes is nil/empty) +// - With fix: OIDC node immediately sees the route in PrimaryRoutes +// +// Root cause: The buggy code called a.h.Change(c) immediately after user +// creation but BEFORE node registration completed, creating a race condition +// where policy change notifications were sent asynchronously before the node +// was fully registered. +func TestOIDCACLPolicyOnJoin(t *testing.T) { + IntegrationSkip(t) - err = headscale.WaitForRunning() - if err != nil { - return err - } + gatewayUser := "gateway" + oidcUser := "oidcuser" - for userName, clientCount := range users { - log.Printf("creating user %s with %d clients", userName, clientCount) - err = s.CreateUser(userName) - if err != nil { - return err - } - - err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) - if err != nil { - return err - } - - err = s.runTailscaleUp(userName, headscale.GetEndpoint()) - if err != nil { - return err - } - } - - return nil -} - -func (s *AuthOIDCScenario) runMockOIDC(accessTTL time.Duration) (*types.OIDCConfig, error) { - port, err := dockertestutil.RandomFreeHostPort() - if err != nil { - log.Fatalf("could not find an open port: %s", err) - } - portNotation := fmt.Sprintf("%d/tcp", port) - - hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) - - hostname := fmt.Sprintf("hs-oidcmock-%s", hash) - - mockOidcOptions := &dockertest.RunOptions{ - Name: hostname, - Cmd: []string{"headscale", "mockoidc"}, - ExposedPorts: []string{portNotation}, - PortBindings: map[docker.Port][]docker.PortBinding{ - docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, - }, - Networks: []*dockertest.Network{s.Scenario.network}, - Env: []string{ - fmt.Sprintf("MOCKOIDC_ADDR=%s", hostname), - fmt.Sprintf("MOCKOIDC_PORT=%d", port), - "MOCKOIDC_CLIENT_ID=superclient", - "MOCKOIDC_CLIENT_SECRET=supersecret", - fmt.Sprintf("MOCKOIDC_ACCESS_TTL=%s", accessTTL.String()), + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{gatewayUser}, + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser(oidcUser, true), }, } - headscaleBuildOptions := &dockertest.BuildOptions{ - Dockerfile: "Dockerfile.debug", - ContextDir: dockerContextPath, + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } - err = s.pool.RemoveContainerByName(hostname) - if err != nil { - return nil, err - } - - if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( - headscaleBuildOptions, - mockOidcOptions, - dockertestutil.DockerRestartPolicy); err == nil { - s.mockOIDC = pmockoidc - } else { - return nil, err - } - - log.Println("Waiting for headscale mock oidc to be ready for tests") - hostEndpoint := fmt.Sprintf("%s:%d", s.mockOIDC.GetIPInNetwork(s.network), port) - - if err := s.pool.Retry(func() error { - oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) - httpClient := &http.Client{} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf("headscale mock OIDC tests is not ready: %s\n", err) - - return err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return errStatusCodeNotOK - } - - return nil - }); err != nil { - return nil, err - } - - log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) - - return &types.OIDCConfig{ - Issuer: fmt.Sprintf( - "http://%s/oidc", - net.JoinHostPort(s.mockOIDC.GetIPInNetwork(s.network), strconv.Itoa(port)), + // Create headscale environment with ACL policy that allows OIDC user + // to access routes advertised by gateway user + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithAcceptRoutes(), + }, + hsic.WithTestName("oidcaclpolicy"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithACLPolicy( + &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{prefixp("100.64.0.0/10")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("100.64.0.0/10"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("10.33.0.0/24"), tailcfg.PortRangeAny), + aliasWithPorts(prefixp("10.44.0.0/24"), tailcfg.PortRangeAny), + }, + }, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + netip.MustParsePrefix("10.33.0.0/24"): {usernameApprover("gateway@test.no"), usernameApprover("oidcuser@headscale.net"), usernameApprover("jane.doe@example.com")}, + netip.MustParsePrefix("10.44.0.0/24"): {usernameApprover("gateway@test.no"), usernameApprover("oidcuser@headscale.net"), usernameApprover("jane.doe@example.com")}, + }, + }, + }, ), - ClientID: "superclient", - ClientSecret: "supersecret", - StripEmaildomain: true, - OnlyStartIfOIDCIsAvailable: true, - }, nil -} + ) + requireNoErrHeadscaleEnv(t, err) -func (s *AuthOIDCScenario) runTailscaleUp( - userStr, loginServer string, -) error { - headscale, err := s.Headscale() - if err != nil { - return err + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Get the gateway client (CLI user) - only one client at first + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + require.Len(t, allClients, 1, "Should have exactly 1 client (gateway) before OIDC login") + + gatewayClient := allClients[0] + + // Wait for initial sync (gateway logs in) + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // Gateway advertises route 10.33.0.0/24 + advertiseRoute := "10.33.0.0/24" + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + advertiseRoute, } + _, _, err = gatewayClient.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) - log.Printf("running tailscale up for user %s", userStr) - if user, ok := s.users[userStr]; ok { - for _, client := range user.Clients { - c := client - user.joinWaitGroup.Go(func() error { - loginURL, err := c.LoginWithURL(loginServer) - if err != nil { - log.Printf("%s failed to run tailscale up: %s", c.Hostname(), err) - } + // Wait for route advertisement to propagate + var gatewayNodeID uint64 - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) - loginURL.Scheme = "http" + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 1) - insecureTransport := &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint - } + gatewayNode := nodes[0] + gatewayNodeID = gatewayNode.GetId() + assert.Len(ct, gatewayNode.GetAvailableRoutes(), 1) + assert.Contains(ct, gatewayNode.GetAvailableRoutes(), advertiseRoute) + }, 10*time.Second, 500*time.Millisecond, "route advertisement should propagate to headscale") - log.Printf("%s login url: %s\n", c.Hostname(), loginURL.String()) + // Approve the advertised route + _, err = headscale.ApproveRoutes( + gatewayNodeID, + []netip.Prefix{netip.MustParsePrefix(advertiseRoute)}, + ) + require.NoError(t, err) - if err := s.pool.Retry(func() error { - log.Printf("%s logging in with url", c.Hostname()) - httpClient := &http.Client{Transport: insecureTransport} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) - if err != nil { - log.Printf( - "%s failed to login using url %s: %s", - c.Hostname(), - loginURL, - err, - ) + // Wait for route approval to propagate + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 1) - return err - } + gatewayNode := nodes[0] + assert.Len(ct, gatewayNode.GetApprovedRoutes(), 1) + assert.Contains(ct, gatewayNode.GetApprovedRoutes(), advertiseRoute) + }, 10*time.Second, 500*time.Millisecond, "route approval should propagate to headscale") - if resp.StatusCode != http.StatusOK { - log.Printf("%s response code of oidc login request was %s", c.Hostname(), resp.Status) + // NOW create the OIDC user by having them join + // This is where issue #2888 manifests - the new OIDC node should immediately + // see the gateway's advertised route + t.Logf("OIDC user joining at %s", time.Now().Format(TimestampFormat)) - return errStatusCodeNotOK - } + // Create OIDC user's tailscale node + oidcAdvertiseRoute := "10.44.0.0/24" + oidcClient, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithAcceptRoutes(), + tsic.WithExtraLoginArgs([]string{"--advertise-routes=" + oidcAdvertiseRoute}), + ) + require.NoError(t, err) - defer resp.Body.Close() + // OIDC login happens automatically via LoginWithURL + loginURL, err := oidcClient.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) - _, err = io.ReadAll(resp.Body) - if err != nil { - log.Printf("%s failed to read response body: %s", c.Hostname(), err) + _, err = doLoginURL(oidcClient.Hostname(), loginURL) + require.NoError(t, err) - return err - } + t.Logf("OIDC user logged in successfully at %s", time.Now().Format(TimestampFormat)) - return nil - }); err != nil { - return err - } + // THE CRITICAL TEST: Verify that the OIDC user's node can IMMEDIATELY + // see the gateway's advertised route WITHOUT needing a client restart. + // + // This is where the bug manifests: + // - Without fix: PrimaryRoutes will be nil/empty + // - With fix: PrimaryRoutes immediately contains the advertised route + t.Logf("Verifying OIDC user can immediately see advertised routes at %s", time.Now().Format(TimestampFormat)) - log.Printf("Finished request for %s to join tailnet", c.Hostname()) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := oidcClient.Status() + assert.NoError(ct, err) - return nil - }) + // Find the gateway peer in the OIDC user's peer list + var gatewayPeer *ipnstate.PeerStatus - log.Printf("client %s is ready", client.Hostname()) - } - - if err := user.joinWaitGroup.Wait(); err != nil { - return err - } - - for _, client := range user.Clients { - err := client.WaitForRunning() - if err != nil { - return fmt.Errorf( - "%s tailscale node has not reached running: %w", - client.Hostname(), - err, - ) + for _, peerKey := range status.Peers() { + peer := status.Peer[peerKey] + // Gateway is the peer that's not the OIDC user + if peer.UserID != status.Self.UserID { + gatewayPeer = peer + break } } - return nil - } + assert.NotNil(ct, gatewayPeer, "OIDC user should see gateway as peer") - return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) + if gatewayPeer != nil { + // This is the critical assertion - PrimaryRoutes should NOT be nil + assert.NotNil(ct, gatewayPeer.PrimaryRoutes, + "BUG #2888: Gateway peer PrimaryRoutes is nil - ACL policy not applied to new OIDC node!") + + if gatewayPeer.PrimaryRoutes != nil { + routes := gatewayPeer.PrimaryRoutes.AsSlice() + assert.Contains(ct, routes, netip.MustParsePrefix(advertiseRoute), + "OIDC user should immediately see gateway's advertised route %s in PrimaryRoutes", advertiseRoute) + t.Logf("SUCCESS: OIDC user can see advertised route %s in gateway's PrimaryRoutes", advertiseRoute) + } + + // Also verify AllowedIPs includes the route + if gatewayPeer.AllowedIPs != nil && gatewayPeer.AllowedIPs.Len() > 0 { + allowedIPs := gatewayPeer.AllowedIPs.AsSlice() + t.Logf("Gateway peer AllowedIPs: %v", allowedIPs) + } + } + }, 15*time.Second, 500*time.Millisecond, + "OIDC user should immediately see gateway's advertised route without client restart (issue #2888)") + + // Verify that the Gateway node sees the OIDC node's advertised route (AutoApproveRoutes check) + t.Logf("Verifying Gateway user can immediately see OIDC advertised routes at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := gatewayClient.Status() + assert.NoError(ct, err) + + // Find the OIDC peer in the Gateway user's peer list + var oidcPeer *ipnstate.PeerStatus + + for _, peerKey := range status.Peers() { + peer := status.Peer[peerKey] + if peer.UserID != status.Self.UserID { + oidcPeer = peer + break + } + } + + assert.NotNil(ct, oidcPeer, "Gateway user should see OIDC user as peer") + + if oidcPeer != nil { + assert.NotNil(ct, oidcPeer.PrimaryRoutes, + "BUG: OIDC peer PrimaryRoutes is nil - AutoApproveRoutes failed or overwritten!") + + if oidcPeer.PrimaryRoutes != nil { + routes := oidcPeer.PrimaryRoutes.AsSlice() + assert.Contains(ct, routes, netip.MustParsePrefix(oidcAdvertiseRoute), + "Gateway user should immediately see OIDC's advertised route %s in PrimaryRoutes", oidcAdvertiseRoute) + } + } + }, 15*time.Second, 500*time.Millisecond, + "Gateway user should immediately see OIDC's advertised route (AutoApproveRoutes check)") + + // Additional validation: Verify nodes in headscale match expectations + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 2, "Should have 2 nodes (gateway + oidcuser)") + + // Verify OIDC user was created correctly + users, err := headscale.ListUsers() + assert.NoError(ct, err) + // Note: mockoidc may create additional default users (like jane.doe) + // so we check for at least 2 users, not exactly 2 + assert.GreaterOrEqual(ct, len(users), 2, "Should have at least 2 users (gateway CLI user + oidcuser)") + + // Find gateway CLI user + var gatewayUser *v1.User + + for _, user := range users { + if user.GetName() == "gateway" && user.GetProvider() == "" { + gatewayUser = user + break + } + } + + assert.NotNil(ct, gatewayUser, "Should have gateway CLI user") + + if gatewayUser != nil { + assert.Equal(ct, "gateway", gatewayUser.GetName()) + } + + // Find OIDC user + var oidcUserFound *v1.User + + for _, user := range users { + if user.GetName() == "oidcuser" && user.GetProvider() == "oidc" { + oidcUserFound = user + break + } + } + + assert.NotNil(ct, oidcUserFound, "Should have OIDC user") + + if oidcUserFound != nil { + assert.Equal(ct, "oidcuser", oidcUserFound.GetName()) + assert.Equal(ct, "oidcuser@headscale.net", oidcUserFound.GetEmail()) + } + }, 10*time.Second, 500*time.Millisecond, "headscale should have correct users and nodes") + + t.Logf("Test completed successfully - issue #2888 fix validated") } -func (s *AuthOIDCScenario) Shutdown() { - err := s.pool.Purge(s.mockOIDC) - if err != nil { - log.Printf("failed to remove mock oidc container") +// TestOIDCReloginSameUserRoutesPreserved tests the scenario where: +// - A node logs in via OIDC and advertises routes +// - Routes are auto-approved and verified as SERVING +// - The node logs out +// - The node logs back in as the same user +// - Routes should STILL be SERVING (not just approved/available) +// +// This test validates the fix for issue #2896: +// https://github.com/juanfont/headscale/issues/2896 +// +// Bug: When a node with already-approved routes restarts/re-authenticates, +// the routes show as "Approved" and "Available" but NOT "Serving" (Primary). +// A headscale restart would fix it, indicating a state management issue. +func TestOIDCReloginSameUserRoutesPreserved(t *testing.T) { + IntegrationSkip(t) + + advertiseRoute := "10.55.0.0/24" + + // Create scenario with same user for both login attempts + scenario, err := NewScenario(ScenarioSpec{ + OIDCUsers: []mockoidc.MockUser{ + oidcMockUser("user1", true), // Initial login + oidcMockUser("user1", true), // Relogin with same user + }, + }) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + oidcMap := map[string]string{ + "HEADSCALE_OIDC_ISSUER": scenario.mockOIDC.Issuer(), + "HEADSCALE_OIDC_CLIENT_ID": scenario.mockOIDC.ClientID(), + "CREDENTIALS_DIRECTORY_TEST": "/tmp", + "HEADSCALE_OIDC_CLIENT_SECRET_PATH": "${CREDENTIALS_DIRECTORY_TEST}/hs_client_oidc_secret", } - s.Scenario.Shutdown() -} - -func assertTailscaleNodesLogout(t *testing.T, clients []TailscaleClient) { - t.Helper() - - for _, client := range clients { - status, err := client.Status() - assertNoErr(t, err) - - assert.Equal(t, "NeedsLogin", status.BackendState) - } + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithAcceptRoutes(), + }, + hsic.WithTestName("oidcrouterelogin"), + hsic.WithConfigEnv(oidcMap), + hsic.WithTLS(), + hsic.WithFileInContainer("/tmp/hs_client_oidc_secret", []byte(scenario.mockOIDC.ClientSecret())), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + hsic.WithACLPolicy( + &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{policyv2.Wildcard}, + Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}}, + }, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + netip.MustParsePrefix(advertiseRoute): {usernameApprover("user1@headscale.net")}, + }, + }, + }, + ), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Create client with route advertisement + ts, err := scenario.CreateTailscaleNode( + "unstable", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithAcceptRoutes(), + tsic.WithExtraLoginArgs([]string{"--advertise-routes=" + advertiseRoute}), + ) + require.NoError(t, err) + + // Initial login as user1 + u, err := ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + // Wait for client to be running + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "Running", status.BackendState) + }, 30*time.Second, 1*time.Second, "waiting for initial login to complete") + + // Step 1: Verify initial route is advertised, approved, and SERVING + t.Logf("Step 1: Verifying initial route is advertised, approved, and SERVING at %s", time.Now().Format(TimestampFormat)) + + var initialNode *v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + initialNode = nodes[0] + // Check: 1 announced, 1 approved, 1 serving (subnet route) + assert.Lenf(c, initialNode.GetAvailableRoutes(), 1, + "Node should have 1 available route, got %v", initialNode.GetAvailableRoutes()) + assert.Lenf(c, initialNode.GetApprovedRoutes(), 1, + "Node should have 1 approved route, got %v", initialNode.GetApprovedRoutes()) + assert.Lenf(c, initialNode.GetSubnetRoutes(), 1, + "Node should have 1 serving (subnet) route, got %v - THIS IS THE BUG if empty", initialNode.GetSubnetRoutes()) + assert.Contains(c, initialNode.GetSubnetRoutes(), advertiseRoute, + "Subnet routes should contain %s", advertiseRoute) + } + }, 30*time.Second, 500*time.Millisecond, "initial route should be serving") + + require.NotNil(t, initialNode, "Initial node should be found") + initialNodeID := initialNode.GetId() + t.Logf("Initial node ID: %d, Available: %v, Approved: %v, Serving: %v", + initialNodeID, initialNode.GetAvailableRoutes(), initialNode.GetApprovedRoutes(), initialNode.GetSubnetRoutes()) + + // Step 2: Logout + t.Logf("Step 2: Logging out at %s", time.Now().Format(TimestampFormat)) + + err = ts.Logout() + require.NoError(t, err) + + // Wait for logout to complete + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "NeedsLogin", status.BackendState, "Expected NeedsLogin state after logout") + }, 30*time.Second, 1*time.Second, "waiting for logout to complete") + + t.Logf("Logout completed, node should still exist in database") + + // Verify node still exists (routes should still be in DB) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Node should persist in database after logout") + }, 10*time.Second, 500*time.Millisecond, "node should persist after logout") + + // Step 3: Re-authenticate via OIDC as the same user + t.Logf("Step 3: Re-authenticating with same user via OIDC at %s", time.Now().Format(TimestampFormat)) + + u, err = ts.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + _, err = doLoginURL(ts.Hostname(), u) + require.NoError(t, err) + + // Wait for client to be running + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := ts.Status() + assert.NoError(ct, err) + assert.Equal(ct, "Running", status.BackendState, "Expected Running state after relogin") + }, 30*time.Second, 1*time.Second, "waiting for relogin to complete") + + t.Logf("Re-authentication completed at %s", time.Now().Format(TimestampFormat)) + + // Step 4: THE CRITICAL TEST - Verify routes are STILL SERVING after re-authentication + t.Logf("Step 4: Verifying routes are STILL SERVING after re-authentication at %s", time.Now().Format(TimestampFormat)) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should still have exactly 1 node after relogin") + + if len(nodes) == 1 { + node := nodes[0] + t.Logf("After relogin - Available: %v, Approved: %v, Serving: %v", + node.GetAvailableRoutes(), node.GetApprovedRoutes(), node.GetSubnetRoutes()) + + // This is where issue #2896 manifests: + // - Available shows the route (from Hostinfo.RoutableIPs) + // - Approved shows the route (from ApprovedRoutes) + // - BUT Serving (SubnetRoutes/PrimaryRoutes) is EMPTY! + assert.Lenf(c, node.GetAvailableRoutes(), 1, + "Node should have 1 available route after relogin, got %v", node.GetAvailableRoutes()) + assert.Lenf(c, node.GetApprovedRoutes(), 1, + "Node should have 1 approved route after relogin, got %v", node.GetApprovedRoutes()) + assert.Lenf(c, node.GetSubnetRoutes(), 1, + "BUG #2896: Node should have 1 SERVING route after relogin, got %v", node.GetSubnetRoutes()) + assert.Contains(c, node.GetSubnetRoutes(), advertiseRoute, + "BUG #2896: Subnet routes should contain %s after relogin", advertiseRoute) + + // Also verify node ID was preserved (same node, not new registration) + assert.Equal(c, initialNodeID, node.GetId(), + "Node ID should be preserved after same-user relogin") + } + }, 30*time.Second, 500*time.Millisecond, + "BUG #2896: routes should remain SERVING after OIDC logout/relogin with same user") + + t.Logf("Test completed - verifying issue #2896 fix for OIDC") } diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 90ce571b..5dd546f3 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -1,57 +1,54 @@ package integration import ( - "context" - "errors" "fmt" - "io" - "log" - "net/http" "net/netip" - "net/url" - "strings" + "slices" "testing" + "time" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var errParseAuthPage = errors.New("failed to parse auth page") - -type AuthWebFlowScenario struct { - *Scenario -} - func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() - baseScenario, err := NewScenario() + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) if err != nil { t.Fatalf("failed to create scenario: %s", err) } + defer scenario.ShutdownAssertNoPanics(t) - scenario := AuthWebFlowScenario{ - Scenario: baseScenario, - } - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), - } - - err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("webauthping")) - assertNoErrHeadscaleEnv(t, err) + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("webauthping"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -61,34 +58,36 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } -func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { +func TestAuthWebFlowLogoutAndReloginSameUser(t *testing.T) { IntegrationSkip(t) - t.Parallel() - baseScenario, err := NewScenario() - assertNoErr(t, err) - - scenario := AuthWebFlowScenario{ - Scenario: baseScenario, - } - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, hsic.WithTestName("weblogout")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("weblogout"), + hsic.WithDERPAsIP(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -97,6 +96,26 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + validateInitialConnection(t, headscale, expectedNodes) + + var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after web authentication") + assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node count matches client count after web authentication") + nodeCountBeforeLogout := len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + clientIPs := make(map[TailscaleClient][]netip.Addr) for _, client := range allClients { ips, err := client.IPs() @@ -114,27 +133,36 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + validateLogoutComplete(t, headscale, expectedNodes) t.Logf("all clients logged out") - headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) - - for userName := range spec { - err = scenario.runTailscaleUp(userName, headscale.GetEndpoint()) + for _, userName := range spec.Users { + err = scenario.RunTailscaleUpWithURL(userName, headscale.GetEndpoint()) if err != nil { - t.Fatalf("failed to run tailscale up: %s", err) + t.Fatalf("failed to run tailscale up (%q): %s", headscale.GetEndpoint(), err) } } t.Logf("all clients logged in again") - allClients, err = scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + t.Logf("Validating node persistence after logout at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after web flow logout") + assert.Len(ct, listNodes, nodeCountBeforeLogout, "Node count should remain unchanged after logout - expected %d nodes, got %d", nodeCountBeforeLogout, len(listNodes)) + }, 60*time.Second, 2*time.Second, "validating node persistence in database after web flow logout") + t.Logf("node count first login: %d, after relogin: %d", nodeCountBeforeLogout, len(listNodes)) + + // Validate connection state after relogin + validateReloginComplete(t, headscale, expectedNodes) allIps, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs = lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -155,14 +183,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { } for _, ip := range ips { - found := false - for _, oldIP := range clientIPs[client] { - if ip == oldIP { - found = true - - break - } - } + found := slices.Contains(clientIPs[client], ip) if !found { t.Fatalf( @@ -178,139 +199,165 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { t.Logf("all clients IPs are the same") } -func (s *AuthWebFlowScenario) CreateHeadscaleEnv( - users map[string]int, - opts ...hsic.Option, -) error { - headscale, err := s.Headscale(opts...) - if err != nil { - return err +// TestAuthWebFlowLogoutAndReloginNewUser tests the scenario where multiple Tailscale clients +// initially authenticate using the web-based authentication flow (where users visit a URL +// in their browser to authenticate), then all clients log out and log back in as a different user. +// +// This test validates the "user switching" behavior in headscale's web authentication flow: +// - Multiple clients authenticate via web flow, each to their respective users (user1, user2) +// - All clients log out simultaneously +// - All clients log back in via web flow, but this time they all authenticate as user1 +// - The test verifies that user1 ends up with all the client nodes +// - The test verifies that user2's original nodes still exist in the database but are offline +// - The test verifies network connectivity works after the user switch +// +// This scenario is important for organizations that need to reassign devices between users +// or when consolidating multiple user accounts. It ensures that headscale properly handles +// the security implications of user switching while maintaining node persistence in the database. +// +// The test uses headscale's web authentication flow, which is the most user-friendly method +// where authentication happens through a web browser rather than pre-shared keys or OIDC. +func TestAuthWebFlowLogoutAndReloginNewUser(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, } - err = headscale.WaitForRunning() - if err != nil { - return err - } + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) - for userName, clientCount := range users { - log.Printf("creating user %s with %d clients", userName, clientCount) - err = s.CreateUser(userName) - if err != nil { - return err - } + err = scenario.CreateHeadscaleEnvWithLoginURL( + nil, + hsic.WithTestName("webflowrelnewuser"), + hsic.WithDERPAsIP(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) - err = s.CreateTailscaleNodesInUser(userName, "all", clientCount) - if err != nil { - return err - } + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) - err = s.runTailscaleUp(userName, headscale.GetEndpoint()) + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Collect expected node IDs for validation + expectedNodes := collectExpectedNodeIDs(t, allClients) + + // Validate initial connection state + validateInitialConnection(t, headscale, expectedNodes) + + var listNodes []*v1.Node + t.Logf("Validating initial node count after web auth at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err, "Failed to list nodes after initial web authentication") + assert.Len(ct, listNodes, len(allClients), "Expected %d nodes after web auth, got %d", len(allClients), len(listNodes)) + }, 30*time.Second, 2*time.Second, "validating node count matches client count after initial web authentication") + nodeCountBeforeLogout := len(listNodes) + t.Logf("node count before logout: %d", nodeCountBeforeLogout) + + // Log out all clients + for _, client := range allClients { + err := client.Logout() if err != nil { - return err + t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) } } - return nil -} - -func (s *AuthWebFlowScenario) runTailscaleUp( - userStr, loginServer string, -) error { - log.Printf("running tailscale up for user %s", userStr) - if user, ok := s.users[userStr]; ok { - for _, client := range user.Clients { - c := client - user.joinWaitGroup.Go(func() error { - loginURL, err := c.LoginWithURL(loginServer) - if err != nil { - log.Printf("failed to run tailscale up (%s): %s", c.Hostname(), err) - - return err - } - - err = s.runHeadscaleRegister(userStr, loginURL) - if err != nil { - log.Printf("failed to register client (%s): %s", c.Hostname(), err) - - return err - } - - return nil - }) - - err := client.WaitForRunning() - if err != nil { - log.Printf("error waiting for client %s to be ready: %s", client.Hostname(), err) - } - } - - if err := user.joinWaitGroup.Wait(); err != nil { - return err - } - - for _, client := range user.Clients { - err := client.WaitForRunning() - if err != nil { - return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err) - } - } - - return nil - } - - return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) -} - -func (s *AuthWebFlowScenario) runHeadscaleRegister(userStr string, loginURL *url.URL) error { - headscale, err := s.Headscale() - if err != nil { - return err - } - - log.Printf("loginURL: %s", loginURL) - loginURL.Host = fmt.Sprintf("%s:8080", headscale.GetIP()) - loginURL.Scheme = "http" - - httpClient := &http.Client{} - ctx := context.Background() - req, _ := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) - resp, err := httpClient.Do(req) - if err != nil { - return err - } - - body, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - - defer resp.Body.Close() - - // see api.go HTML template - codeSep := strings.Split(string(body), "</code>") - if len(codeSep) != 2 { - return errParseAuthPage - } - - keySep := strings.Split(codeSep[0], "key ") - if len(keySep) != 2 { - return errParseAuthPage - } - key := keySep[1] - log.Printf("registering node %s", key) - - if headscale, err := s.Headscale(); err == nil { - _, err = headscale.Execute( - []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, - ) - if err != nil { - log.Printf("failed to register node: %s", err) - - return err - } - - return nil - } - - return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) + err = scenario.WaitForTailscaleLogout() + requireNoErrLogout(t, err) + + // Validate that all nodes are offline after logout + validateLogoutComplete(t, headscale, expectedNodes) + + t.Logf("all clients logged out") + + // Log all clients back in as user1 using web flow + // We manually iterate over all clients and authenticate each one as user1 + // This tests the cross-user re-authentication behavior where ALL clients + // (including those originally from user2) are registered to user1 + for _, client := range allClients { + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + if err != nil { + t.Fatalf("failed to get login URL for client %s: %s", client.Hostname(), err) + } + + body, err := doLoginURL(client.Hostname(), loginURL) + if err != nil { + t.Fatalf("failed to complete login for client %s: %s", client.Hostname(), err) + } + + // Register all clients as user1 (this is where cross-user registration happens) + // This simulates: headscale nodes register --user user1 --key <key> + scenario.runHeadscaleRegister("user1", body) + } + + // Wait for all clients to reach running state + for _, client := range allClients { + err := client.WaitForRunning(integrationutil.PeerSyncTimeout()) + if err != nil { + t.Fatalf("%s tailscale node has not reached running: %s", client.Hostname(), err) + } + } + + t.Logf("all clients logged back in as user1") + + var user1Nodes []*v1.Node + t.Logf("Validating user1 node count after relogin at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + user1Nodes, err = headscale.ListNodes("user1") + assert.NoError(ct, err, "Failed to list nodes for user1 after web flow relogin") + assert.Len(ct, user1Nodes, len(allClients), "User1 should have all %d clients after web flow relogin, got %d nodes", len(allClients), len(user1Nodes)) + }, 60*time.Second, 2*time.Second, "validating user1 has all client nodes after web flow user switch relogin") + + // Collect expected node IDs for user1 after relogin + expectedUser1Nodes := make([]types.NodeID, 0, len(user1Nodes)) + for _, node := range user1Nodes { + expectedUser1Nodes = append(expectedUser1Nodes, types.NodeID(node.GetId())) + } + + // Validate connection state after relogin as user1 + validateReloginComplete(t, headscale, expectedUser1Nodes) + + // Validate that user2's old nodes still exist in database (but are expired/offline) + // When CLI registration creates new nodes for user1, user2's old nodes remain + var user2Nodes []*v1.Node + t.Logf("Validating user2 old nodes remain in database after CLI registration to user1 at %s", time.Now().Format(TimestampFormat)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + user2Nodes, err = headscale.ListNodes("user2") + assert.NoError(ct, err, "Failed to list nodes for user2 after CLI registration to user1") + assert.Len(ct, user2Nodes, len(allClients)/2, "User2 should still have %d old nodes (likely expired) after CLI registration to user1, got %d nodes", len(allClients)/2, len(user2Nodes)) + }, 30*time.Second, 2*time.Second, "validating user2 old nodes remain in database after CLI registration to user1") + + t.Logf("Validating client login states after web flow user switch at %s", time.Now().Format(TimestampFormat)) + for _, client := range allClients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) + assert.Equal(ct, "user1@test.no", status.User[status.Self.UserID].LoginName, "Client %s should be logged in as user1 after web flow user switch, got %s", client.Hostname(), status.User[status.Self.UserID].LoginName) + }, 30*time.Second, 2*time.Second, fmt.Sprintf("validating %s is logged in as user1 after web flow user switch", client.Hostname())) + } + + // Test connectivity after user switch + allIps, err = scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d after web flow user switch", success, len(allClients)*len(allIps)) } diff --git a/integration/cli_test.go b/integration/cli_test.go index 0ff0ffca..65d82444 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -1,16 +1,25 @@ package integration import ( + "cmp" "encoding/json" "fmt" - "sort" + "slices" + "strconv" + "strings" "testing" "time" + tcmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" ) func executeAndUnmarshal[T any](headscale ControlServer, command []string, result T) error { @@ -21,134 +30,291 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul err = json.Unmarshal([]byte(str), result) if err != nil { - return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str) + return fmt.Errorf("failed to unmarshal: %w\n command err: %s", err, str) } return nil } +// Interface ensuring that we can sort structs from gRPC that +// have an ID field. +type GRPCSortable interface { + GetId() uint64 +} + +func sortWithID[T GRPCSortable](a, b T) int { + return cmp.Compare(a.GetId(), b.GetId()) +} + func TestUserCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": 0, - "user2": 0, + spec := ScenarioSpec{ + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) - var listUsers []v1.User - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "users", - "list", - "--output", - "json", - }, - &listUsers, + var ( + listUsers []*v1.User + result []string ) - assertNoErr(t, err) - result := []string{listUsers[0].GetName(), listUsers[1].GetName()} - sort.Strings(result) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err := executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listUsers, + ) + assert.NoError(ct, err) - assert.Equal( - t, - []string{"user1", "user2"}, - result, - ) + slices.SortFunc(listUsers, sortWithID) + result = []string{listUsers[0].GetName(), listUsers[1].GetName()} + + assert.Equal( + ct, + []string{"user1", "user2"}, + result, + "Should have user1 and user2 in users list", + ) + }, 20*time.Second, 1*time.Second) _, err = headscale.Execute( []string{ "headscale", "users", "rename", - "--output", - "json", - "user2", - "newname", + "--output=json", + fmt.Sprintf("--identifier=%d", listUsers[1].GetId()), + "--new-name=newname", }, ) - assertNoErr(t, err) + require.NoError(t, err) - var listAfterRenameUsers []v1.User - err = executeAndUnmarshal(headscale, + var listAfterRenameUsers []*v1.User + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err := executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listAfterRenameUsers, + ) + assert.NoError(ct, err) + + slices.SortFunc(listAfterRenameUsers, sortWithID) + result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()} + + assert.Equal( + ct, + []string{"user1", "newname"}, + result, + "Should have user1 and newname after rename operation", + ) + }, 20*time.Second, 1*time.Second) + + var listByUsername []*v1.User + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + "--name=user1", + }, + &listByUsername, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user list by username") + + slices.SortFunc(listByUsername, sortWithID) + + want := []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@test.no", + }, + } + + if diff := tcmp.Diff(want, listByUsername, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Errorf("unexpected users (-want +got):\n%s", diff) + } + + var listByID []*v1.User + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + "--identifier=1", + }, + &listByID, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user list by ID") + + slices.SortFunc(listByID, sortWithID) + + want = []*v1.User{ + { + Id: 1, + Name: "user1", + Email: "user1@test.no", + }, + } + + if diff := tcmp.Diff(want, listByID, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + t.Errorf("unexpected users (-want +got):\n%s", diff) + } + + deleteResult, err := headscale.Execute( []string{ "headscale", "users", - "list", - "--output", - "json", + "destroy", + "--force", + // Delete "user1" + "--identifier=1", }, - &listAfterRenameUsers, ) - assertNoErr(t, err) + assert.NoError(t, err) + assert.Contains(t, deleteResult, "User destroyed") - result = []string{listAfterRenameUsers[0].GetName(), listAfterRenameUsers[1].GetName()} - sort.Strings(result) + var listAfterIDDelete []*v1.User - assert.Equal( - t, - []string{"newname", "user1"}, - result, + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err := executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listAfterIDDelete, + ) + assert.NoError(ct, err) + + slices.SortFunc(listAfterIDDelete, sortWithID) + + want := []*v1.User{ + { + Id: 2, + Name: "newname", + Email: "user2@test.no", + }, + } + + if diff := tcmp.Diff(want, listAfterIDDelete, cmpopts.IgnoreUnexported(v1.User{}), cmpopts.IgnoreFields(v1.User{}, "CreatedAt")); diff != "" { + assert.Fail(ct, "unexpected users", "diff (-want +got):\n%s", diff) + } + }, 20*time.Second, 1*time.Second) + + deleteResult, err = headscale.Execute( + []string{ + "headscale", + "users", + "destroy", + "--force", + "--name=newname", + }, ) + assert.NoError(t, err) + assert.Contains(t, deleteResult, "User destroyed") + + var listAfterNameDelete []v1.User + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "users", + "list", + "--output", + "json", + }, + &listAfterNameDelete, + ) + assert.NoError(c, err) + assert.Empty(c, listAfterNameDelete) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user list after name delete") } func TestPreAuthKeyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "preauthkeyspace" count := 3 - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - user: 0, + spec := ScenarioSpec{ + Users: []string{user}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipak")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) keys := make([]*v1.PreAuthKey, count) - assertNoErr(t, err) - for index := 0; index < count; index++ { + require.NoError(t, err) + + for index := range count { var preAuthKey v1.PreAuthKey - err := executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "create", - "--reusable", - "--expiration", - "24h", - "--output", - "json", - "--tags", - "tag:test1,tag:test2", - }, - &preAuthKey, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err := executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--reusable", + "--expiration", + "24h", + "--output", + "json", + "--tags", + "tag:test1,tag:test2", + }, + &preAuthKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth key creation") keys[index] = &preAuthKey } @@ -156,37 +322,39 @@ func TestPreAuthKeyCommand(t *testing.T) { assert.Len(t, keys, 3) var listedPreAuthKeys []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "list", - "--output", - "json", - }, - &listedPreAuthKeys, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list") // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 4) assert.Equal( t, - []string{keys[0].GetId(), keys[1].GetId(), keys[2].GetId()}, - []string{ + []uint64{keys[0].GetId(), keys[1].GetId(), keys[2].GetId()}, + []uint64{ listedPreAuthKeys[1].GetId(), listedPreAuthKeys[2].GetId(), listedPreAuthKeys[3].GetId(), }, ) - assert.NotEmpty(t, listedPreAuthKeys[1].GetKey()) - assert.NotEmpty(t, listedPreAuthKeys[2].GetKey()) - assert.NotEmpty(t, listedPreAuthKeys[3].GetKey()) + // New keys show prefix after listing, so check the created keys instead + assert.NotEmpty(t, keys[0].GetKey()) + assert.NotEmpty(t, keys[1].GetKey()) + assert.NotEmpty(t, keys[2].GetKey()) assert.True(t, listedPreAuthKeys[1].GetExpiration().AsTime().After(time.Now())) assert.True(t, listedPreAuthKeys[2].GetExpiration().AsTime().After(time.Now())) @@ -210,7 +378,11 @@ func TestPreAuthKeyCommand(t *testing.T) { continue } - assert.Equal(t, listedPreAuthKeys[index].GetAclTags(), []string{"tag:test1", "tag:test2"}) + assert.Equal( + t, + []string{"tag:test1", "tag:test2"}, + listedPreAuthKeys[index].GetAclTags(), + ) } // Test key expiry @@ -218,29 +390,29 @@ func TestPreAuthKeyCommand(t *testing.T) { []string{ "headscale", "preauthkeys", - "--user", - user, "expire", - listedPreAuthKeys[1].GetKey(), + "--id", + strconv.FormatUint(keys[0].GetId(), 10), }, ) - assertNoErr(t, err) + require.NoError(t, err) var listedPreAuthKeysAfterExpire []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "list", - "--output", - "json", - }, - &listedPreAuthKeysAfterExpire, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "list", + "--output", + "json", + }, + &listedPreAuthKeysAfterExpire, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list after expire") assert.True(t, listedPreAuthKeysAfterExpire[1].GetExpiration().AsTime().Before(time.Now())) assert.True(t, listedPreAuthKeysAfterExpire[2].GetExpiration().AsTime().After(time.Now())) @@ -249,56 +421,59 @@ func TestPreAuthKeyCommand(t *testing.T) { func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "pre-auth-key-without-exp-user" - - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - user: 0, + spec := ScenarioSpec{ + Users: []string{user}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipaknaexp")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipaknaexp")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var preAuthKey v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "create", - "--reusable", - "--output", - "json", - }, - &preAuthKey, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--reusable", + "--output", + "json", + }, + &preAuthKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth key creation without expiry") var listedPreAuthKeys []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "list", - "--output", - "json", - }, - &listedPreAuthKeys, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list") // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 2) @@ -312,105 +487,337 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { IntegrationSkip(t) - t.Parallel() user := "pre-auth-key-reus-ephm-user" - - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - user: 0, + spec := ScenarioSpec{ + Users: []string{user}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipakresueeph")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clipakresueeph")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) var preAuthReusableKey v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "create", - "--reusable=true", - "--output", - "json", - }, - &preAuthReusableKey, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--reusable=true", + "--output", + "json", + }, + &preAuthReusableKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for reusable preauth key creation") var preAuthEphemeralKey v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "create", - "--ephemeral=true", - "--output", - "json", - }, - &preAuthEphemeralKey, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + "1", + "create", + "--ephemeral=true", + "--output", + "json", + }, + &preAuthEphemeralKey, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for ephemeral preauth key creation") assert.True(t, preAuthEphemeralKey.GetEphemeral()) assert.False(t, preAuthEphemeralKey.GetReusable()) var listedPreAuthKeys []v1.PreAuthKey - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "preauthkeys", - "--user", - user, - "list", - "--output", - "json", - }, - &listedPreAuthKeys, - ) - assertNoErr(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "list", + "--output", + "json", + }, + &listedPreAuthKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for preauth keys list after reusable/ephemeral creation") // There is one key created by "scenario.CreateHeadscaleEnv" assert.Len(t, listedPreAuthKeys, 3) } +func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) { + IntegrationSkip(t) + + user1 := "user1" + user2 := "user2" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{user1}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("clipak"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + u2, err := headscale.CreateUser(user2) + require.NoError(t, err) + + var user2Key v1.PreAuthKey + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + strconv.FormatUint(u2.GetId(), 10), + "create", + "--reusable", + "--expiration", + "24h", + "--output", + "json", + "--tags", + "tag:test1,tag:test2", + }, + &user2Key, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user2 preauth key creation") + + var listNodes []*v1.Node + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, listNodes, 1, "Should have exactly 1 node for user1") + assert.Equal(ct, user1, listNodes[0].GetUser().GetName(), "Node should belong to user1") + }, 15*time.Second, 1*time.Second) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + require.Len(t, allClients, 1) + + client := allClients[0] + + // Log out from user1 + err = client.Logout() + require.NoError(t, err) + + err = scenario.WaitForTailscaleLogout() + require.NoError(t, err) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState, + "Expected node to be logged out, backend state: %s", status.BackendState) + }, 30*time.Second, 2*time.Second) + + err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + assert.Equal(ct, "Running", status.BackendState, "Expected node to be logged in, backend state: %s", status.BackendState) + // With tags-as-identity model, tagged nodes show as TaggedDevices user (2147455555) + // The PreAuthKey was created with tags, so the node is tagged + assert.Equal(ct, "userid:2147455555", status.Self.UserID.String(), "Expected node to be logged in as tagged-devices user") + }, 30*time.Second, 2*time.Second) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, listNodes, 2, "Should have 2 nodes after re-login") + assert.Equal(ct, user1, listNodes[0].GetUser().GetName(), "First node should belong to user1") + // Second node is tagged (created with tagged PreAuthKey), so it shows as "tagged-devices" + assert.Equal(ct, "tagged-devices", listNodes[1].GetUser().GetName(), "Second node should be tagged-devices") + }, 20*time.Second, 1*time.Second) +} + +func TestTaggedNodesCLIOutput(t *testing.T) { + IntegrationSkip(t) + + user1 := "user1" + user2 := "user2" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{user1}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("tagcli"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + u2, err := headscale.CreateUser(user2) + require.NoError(t, err) + + var user2Key v1.PreAuthKey + + // Create a tagged PreAuthKey for user2 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "preauthkeys", + "--user", + strconv.FormatUint(u2.GetId(), 10), + "create", + "--reusable", + "--expiration", + "24h", + "--output", + "json", + "--tags", + "tag:test1,tag:test2", + }, + &user2Key, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for user2 tagged preauth key creation") + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + require.Len(t, allClients, 1) + + client := allClients[0] + + // Log out from user1 + err = client.Logout() + require.NoError(t, err) + + err = scenario.WaitForTailscaleLogout() + require.NoError(t, err) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + assert.NotContains(ct, []string{"Starting", "Running"}, status.BackendState, + "Expected node to be logged out, backend state: %s", status.BackendState) + }, 30*time.Second, 2*time.Second) + + // Log in with the tagged PreAuthKey (from user2, with tags) + err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) + require.NoError(t, err) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + assert.Equal(ct, "Running", status.BackendState, "Expected node to be logged in, backend state: %s", status.BackendState) + // With tags-as-identity model, tagged nodes show as TaggedDevices user (2147455555) + assert.Equal(ct, "userid:2147455555", status.Self.UserID.String(), "Expected node to be logged in as tagged-devices user") + }, 30*time.Second, 2*time.Second) + + // Wait for the second node to appear + var listNodes []*v1.Node + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + + listNodes, err = headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, listNodes, 2, "Should have 2 nodes after re-login with tagged key") + assert.Equal(ct, user1, listNodes[0].GetUser().GetName(), "First node should belong to user1") + assert.Equal(ct, "tagged-devices", listNodes[1].GetUser().GetName(), "Second node should be tagged-devices") + }, 20*time.Second, 1*time.Second) + + // Test: tailscale status output should show "tagged-devices" not "userid:2147455555" + // This is the fix for issue #2970 - the Tailscale client should display user-friendly names + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + stdout, stderr, err := client.Execute([]string{"tailscale", "status"}) + assert.NoError(ct, err, "tailscale status command should succeed, stderr: %s", stderr) + + t.Logf("Tailscale status output:\n%s", stdout) + + // The output should contain "tagged-devices" for tagged nodes + assert.Contains(ct, stdout, "tagged-devices", "Tailscale status should show 'tagged-devices' for tagged nodes") + + // The output should NOT show the raw numeric userid to the user + assert.NotContains(ct, stdout, "userid:2147455555", "Tailscale status should not show numeric userid for tagged nodes") + }, 20*time.Second, 1*time.Second) +} + func TestApiKeyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() count := 5 - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": 0, - "user2": 0, + spec := ScenarioSpec{ + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) keys := make([]string, count) - for idx := 0; idx < count; idx++ { + for idx := range count { apiResult, err := headscale.Execute( []string{ "headscale", @@ -422,7 +829,7 @@ func TestApiKeyCommand(t *testing.T) { "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) assert.NotEmpty(t, apiResult) keys[idx] = apiResult @@ -431,17 +838,20 @@ func TestApiKeyCommand(t *testing.T) { assert.Len(t, keys, 5) var listedAPIKeys []v1.ApiKey - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "apikeys", - "list", - "--output", - "json", - }, - &listedAPIKeys, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "apikeys", + "list", + "--output", + "json", + }, + &listedAPIKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list") assert.Len(t, listedAPIKeys, 5) @@ -487,7 +897,7 @@ func TestApiKeyCommand(t *testing.T) { expiredPrefixes := make(map[string]bool) // Expire three keys - for idx := 0; idx < 3; idx++ { + for idx := range 3 { _, err := headscale.Execute( []string{ "headscale", @@ -497,23 +907,26 @@ func TestApiKeyCommand(t *testing.T) { listedAPIKeys[idx].GetPrefix(), }, ) - assert.Nil(t, err) + assert.NoError(t, err) expiredPrefixes[listedAPIKeys[idx].GetPrefix()] = true } var listedAfterExpireAPIKeys []v1.ApiKey - err = executeAndUnmarshal(headscale, - []string{ - "headscale", - "apikeys", - "list", - "--output", - "json", - }, - &listedAfterExpireAPIKeys, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "apikeys", + "list", + "--output", + "json", + }, + &listedAfterExpireAPIKeys, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list after expire") for index := range listedAfterExpireAPIKeys { if _, ok := expiredPrefixes[listedAfterExpireAPIKeys[index].GetPrefix()]; ok { @@ -530,172 +943,137 @@ func TestApiKeyCommand(t *testing.T) { ) } } -} -func TestNodeTagCommand(t *testing.T) { - IntegrationSkip(t) - t.Parallel() + _, err = headscale.Execute( + []string{ + "headscale", + "apikeys", + "delete", + "--prefix", + listedAPIKeys[0].GetPrefix(), + }) + assert.NoError(t, err) - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() + var listedAPIKeysAfterDelete []v1.ApiKey - spec := map[string]int{ - "user1": 0, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) - - headscale, err := scenario.Headscale() - assertNoErr(t, err) - - machineKeys := []string{ - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - } - nodes := make([]*v1.Node, len(machineKeys)) - assert.Nil(t, err) - - for index, machineKey := range machineKeys { - _, err := headscale.Execute( + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, []string{ "headscale", - "debug", - "create-node", - "--name", - fmt.Sprintf("node-%d", index+1), - "--user", - "user1", - "--key", - machineKey, + "apikeys", + "list", "--output", "json", }, + &listedAPIKeysAfterDelete, ) - assert.Nil(t, err) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list after delete") - var node v1.Node - err = executeAndUnmarshal( - headscale, + assert.Len(t, listedAPIKeysAfterDelete, 4) + + // Test expire by ID (using key at index 0) + _, err = headscale.Execute( + []string{ + "headscale", + "apikeys", + "expire", + "--id", + strconv.FormatUint(listedAPIKeysAfterDelete[0].GetId(), 10), + }) + require.NoError(t, err) + + var listedAPIKeysAfterExpireByID []v1.ApiKey + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, []string{ "headscale", - "nodes", - "--user", - "user1", - "register", - "--key", - machineKey, + "apikeys", + "list", "--output", "json", }, - &node, + &listedAPIKeysAfterExpireByID, ) - assert.Nil(t, err) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list after expire by ID") - nodes[index] = &node - } - assert.Len(t, nodes, len(machineKeys)) - - var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "tag", - "-i", "1", - "-t", "tag:test", - "--output", "json", - }, - &node, - ) - assert.Nil(t, err) - - assert.Equal(t, []string{"tag:test"}, node.GetForcedTags()) - - // try to set a wrong tag and retrieve the error - type errOutput struct { - Error string `json:"error"` - } - var errorOutput errOutput - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "tag", - "-i", "2", - "-t", "wrong-tag", - "--output", "json", - }, - &errorOutput, - ) - assert.Nil(t, err) - assert.Contains(t, errorOutput.Error, "tag must start with the string 'tag:'") - - // Test list all nodes after added seconds - resultMachines := make([]*v1.Node, len(machineKeys)) - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", "json", - }, - &resultMachines, - ) - assert.Nil(t, err) - found := false - for _, node := range resultMachines { - if node.GetForcedTags() != nil { - for _, tag := range node.GetForcedTags() { - if tag == "tag:test" { - found = true - } - } + // Verify the key was expired + for idx := range listedAPIKeysAfterExpireByID { + if listedAPIKeysAfterExpireByID[idx].GetId() == listedAPIKeysAfterDelete[0].GetId() { + assert.True(t, listedAPIKeysAfterExpireByID[idx].GetExpiration().AsTime().Before(time.Now()), + "Key expired by ID should have expiration in the past") } } - assert.Equal( - t, - true, - found, - "should find a node with the tag 'tag:test' in the list of nodes", - ) + + // Test delete by ID (using key at index 1) + deletedKeyID := listedAPIKeysAfterExpireByID[1].GetId() + _, err = headscale.Execute( + []string{ + "headscale", + "apikeys", + "delete", + "--id", + strconv.FormatUint(deletedKeyID, 10), + }) + require.NoError(t, err) + + var listedAPIKeysAfterDeleteByID []v1.ApiKey + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal(headscale, + []string{ + "headscale", + "apikeys", + "list", + "--output", + "json", + }, + &listedAPIKeysAfterDeleteByID, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for API keys list after delete by ID") + + assert.Len(t, listedAPIKeysAfterDeleteByID, 3) + + // Verify the specific key was deleted + for idx := range listedAPIKeysAfterDeleteByID { + assert.NotEqual(t, deletedKeyID, listedAPIKeysAfterDeleteByID[idx].GetId(), + "Deleted key should not be present in the list") + } } func TestNodeCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "node-user": 0, - "other-user": 0, + spec := ScenarioSpec{ + Users: []string{"node-user", "other-user"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) - // Pregenerated machine keys - machineKeys := []string{ - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) - assert.Nil(t, err) + nodes := make([]*v1.Node, len(regIDs)) - for index, machineKey := range machineKeys { + assert.NoError(t, err) + + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -706,52 +1084,59 @@ func TestNodeCommand(t *testing.T) { "--user", "node-user", "--key", - machineKey, + regID, "--output", "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "node-user", - "register", - "--key", - machineKey, - "--output", - "json", - }, - &node, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "node-user", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node registration") nodes[index] = &node } - assert.Len(t, nodes, len(machineKeys)) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + assert.Len(ct, nodes, len(regIDs), "Should have correct number of nodes after CLI operations") + }, 15*time.Second, 1*time.Second) // Test list all nodes after added seconds var listAll []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAll, - ) - assert.Nil(t, err) - assert.Len(t, listAll, 5) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err := executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAll, + ) + assert.NoError(ct, err) + assert.Len(ct, listAll, len(regIDs), "Should list all nodes after CLI operations") + }, 20*time.Second, 1*time.Second) assert.Equal(t, uint64(1), listAll[0].GetId()) assert.Equal(t, uint64(2), listAll[1].GetId()) @@ -765,68 +1150,77 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, "node-4", listAll[3].GetName()) assert.Equal(t, "node-5", listAll[4].GetName()) - otherUserMachineKeys := []string{ - "mkey:b5b444774186d4217adcec407563a1223929465ee2c68a4da13af0d0185b4f8e", - "mkey:dc721977ac7415aafa87f7d4574cbe07c6b171834a6d37375782bdc1fb6b3584", + otherUserRegIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - otherUserMachines := make([]*v1.Node, len(otherUserMachineKeys)) - assert.Nil(t, err) + otherUserMachines := make([]*v1.Node, len(otherUserRegIDs)) - for index, machineKey := range otherUserMachineKeys { + assert.NoError(t, err) + + for index, regID := range otherUserRegIDs { _, err := headscale.Execute( []string{ "headscale", "debug", "create-node", "--name", - fmt.Sprintf("otherUser-node-%d", index+1), + fmt.Sprintf("otheruser-node-%d", index+1), "--user", "other-user", "--key", - machineKey, + regID, "--output", "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "other-user", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for other-user node registration") + + otherUserMachines[index] = &node + } + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + assert.Len(ct, otherUserMachines, len(otherUserRegIDs), "Should have correct number of otherUser machines after CLI operations") + }, 15*time.Second, 1*time.Second) + + // Test list all nodes after added otherUser + var listAllWithotherUser []v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { err = executeAndUnmarshal( headscale, []string{ "headscale", "nodes", - "--user", - "other-user", - "register", - "--key", - machineKey, + "list", "--output", "json", }, - &node, + &listAllWithotherUser, ) - assert.Nil(t, err) - - otherUserMachines[index] = &node - } - - assert.Len(t, otherUserMachines, len(otherUserMachineKeys)) - - // Test list all nodes after added otherUser - var listAllWithotherUser []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllWithotherUser, - ) - assert.Nil(t, err) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after adding other-user nodes") // All nodes, nodes + otherUser assert.Len(t, listAllWithotherUser, 7) @@ -834,25 +1228,28 @@ func TestNodeCommand(t *testing.T) { assert.Equal(t, uint64(6), listAllWithotherUser[5].GetId()) assert.Equal(t, uint64(7), listAllWithotherUser[6].GetId()) - assert.Equal(t, "otherUser-node-1", listAllWithotherUser[5].GetName()) - assert.Equal(t, "otherUser-node-2", listAllWithotherUser[6].GetName()) + assert.Equal(t, "otheruser-node-1", listAllWithotherUser[5].GetName()) + assert.Equal(t, "otheruser-node-2", listAllWithotherUser[6].GetName()) // Test list all nodes after added otherUser var listOnlyotherUserMachineUser []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--user", - "other-user", - "--output", - "json", - }, - &listOnlyotherUserMachineUser, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--user", + "other-user", + "--output", + "json", + }, + &listOnlyotherUserMachineUser, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list filtered by other-user") assert.Len(t, listOnlyotherUserMachineUser, 2) @@ -861,12 +1258,12 @@ func TestNodeCommand(t *testing.T) { assert.Equal( t, - "otherUser-node-1", + "otheruser-node-1", listOnlyotherUserMachineUser[0].GetName(), ) assert.Equal( t, - "otherUser-node-2", + "otheruser-node-2", listOnlyotherUserMachineUser[1].GetName(), ) @@ -884,57 +1281,58 @@ func TestNodeCommand(t *testing.T) { "--force", }, ) - assert.Nil(t, err) + assert.NoError(t, err) // Test: list main user after node is deleted var listOnlyMachineUserAfterDelete []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--user", - "node-user", - "--output", - "json", - }, - &listOnlyMachineUserAfterDelete, - ) - assert.Nil(t, err) - assert.Len(t, listOnlyMachineUserAfterDelete, 4) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err := executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--user", + "node-user", + "--output", + "json", + }, + &listOnlyMachineUserAfterDelete, + ) + assert.NoError(ct, err) + assert.Len(ct, listOnlyMachineUserAfterDelete, 4, "Should have 4 nodes for node-user after deletion") + }, 20*time.Second, 1*time.Second) } func TestNodeExpireCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "node-expire-user": 0, + spec := ScenarioSpec{ + Users: []string{"node-expire-user"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) - // Pregenerated machine keys - machineKeys := []string{ - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) + nodes := make([]*v1.Node, len(regIDs)) - for index, machineKey := range machineKeys { + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -945,49 +1343,55 @@ func TestNodeExpireCommand(t *testing.T) { "--user", "node-expire-user", "--key", - machineKey, + regID, "--output", "json", }, ) - assert.Nil(t, err) + assert.NoError(t, err) var node v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "node-expire-user", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node-expire-user node registration") + + nodes[index] = &node + } + + assert.Len(t, nodes, len(regIDs)) + + var listAll []v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { err = executeAndUnmarshal( headscale, []string{ "headscale", "nodes", - "--user", - "node-expire-user", - "register", - "--key", - machineKey, + "list", "--output", "json", }, - &node, + &listAll, ) - assert.Nil(t, err) - - nodes[index] = &node - } - - assert.Len(t, nodes, len(machineKeys)) - - var listAll []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAll, - ) - assert.Nil(t, err) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list in expire test") assert.Len(t, listAll, 5) @@ -997,32 +1401,35 @@ func TestNodeExpireCommand(t *testing.T) { assert.True(t, listAll[3].GetExpiry().AsTime().IsZero()) assert.True(t, listAll[4].GetExpiry().AsTime().IsZero()) - for idx := 0; idx < 3; idx++ { + for idx := range 3 { _, err := headscale.Execute( []string{ "headscale", "nodes", "expire", "--identifier", - fmt.Sprintf("%d", listAll[idx].GetId()), + strconv.FormatUint(listAll[idx].GetId(), 10), }, ) - assert.Nil(t, err) + assert.NoError(t, err) } var listAllAfterExpiry []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllAfterExpiry, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllAfterExpiry, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after expiry") assert.Len(t, listAllAfterExpiry, 5) @@ -1035,34 +1442,34 @@ func TestNodeExpireCommand(t *testing.T) { func TestNodeRenameCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "node-rename-command": 0, + spec := ScenarioSpec{ + Users: []string{"node-rename-command"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("clins")) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) - // Pregenerated machine keys - machineKeys := []string{ - "mkey:cf7b0fd05da556fdc3bab365787b506fd82d64a70745db70e00e86c1b1c03084", - "mkey:8bc13285cee598acf76b1824a6f4490f7f2e3751b201e28aeb3b07fe81d5b4a1", - "mkey:f08305b4ee4250b95a70f3b7504d048d75d899993c624a26d422c67af0422507", - "mkey:6abd00bb5fdda622db51387088c68e97e71ce58e7056aa54f592b6a8219d524c", - "mkey:9b2ffa7e08cc421a3d2cca9012280f6a236fd0de0b4ce005b30a98ad930306fe", + regIDs := []string{ + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), + types.MustRegistrationID().String(), } - nodes := make([]*v1.Node, len(machineKeys)) - assert.Nil(t, err) + nodes := make([]*v1.Node, len(regIDs)) - for index, machineKey := range machineKeys { + assert.NoError(t, err) + + for index, regID := range regIDs { _, err := headscale.Execute( []string{ "headscale", @@ -1073,49 +1480,55 @@ func TestNodeRenameCommand(t *testing.T) { "--user", "node-rename-command", "--key", - machineKey, + regID, "--output", "json", }, ) - assertNoErr(t, err) + require.NoError(t, err) var node v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "--user", + "node-rename-command", + "register", + "--key", + regID, + "--output", + "json", + }, + &node, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for node-rename-command node registration") + + nodes[index] = &node + } + + assert.Len(t, nodes, len(regIDs)) + + var listAll []v1.Node + + assert.EventuallyWithT(t, func(c *assert.CollectT) { err = executeAndUnmarshal( headscale, []string{ "headscale", "nodes", - "--user", - "node-rename-command", - "register", - "--key", - machineKey, + "list", "--output", "json", }, - &node, + &listAll, ) - assertNoErr(t, err) - - nodes[index] = &node - } - - assert.Len(t, nodes, len(machineKeys)) - - var listAll []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAll, - ) - assert.Nil(t, err) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list in rename test") assert.Len(t, listAll, 5) @@ -1125,33 +1538,38 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAll[3].GetGivenName(), "node-4") assert.Contains(t, listAll[4].GetGivenName(), "node-5") - for idx := 0; idx < 3; idx++ { - _, err := headscale.Execute( + for idx := range 3 { + res, err := headscale.Execute( []string{ "headscale", "nodes", "rename", "--identifier", - fmt.Sprintf("%d", listAll[idx].GetId()), + strconv.FormatUint(listAll[idx].GetId(), 10), fmt.Sprintf("newnode-%d", idx+1), }, ) - assert.Nil(t, err) + assert.NoError(t, err) + + assert.Contains(t, res, "Node renamed") } var listAllAfterRename []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllAfterRename, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllAfterRename, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after rename") assert.Len(t, listAllAfterRename, 5) @@ -1162,32 +1580,34 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAllAfterRename[4].GetGivenName(), "node-5") // Test failure for too long names - result, err := headscale.Execute( + _, err = headscale.Execute( []string{ "headscale", "nodes", "rename", "--identifier", - fmt.Sprintf("%d", listAll[4].GetId()), - "testmaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaachine12345678901234567890", + strconv.FormatUint(listAll[4].GetId(), 10), + strings.Repeat("t", 64), }, ) - assert.Nil(t, err) - assert.Contains(t, result, "not be over 63 chars") + assert.ErrorContains(t, err, "must not exceed 63 characters") var listAllAfterRenameAttempt []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &listAllAfterRenameAttempt, - ) - assert.Nil(t, err) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &listAllAfterRenameAttempt, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after failed rename attempt") assert.Len(t, listAllAfterRenameAttempt, 5) @@ -1198,165 +1618,163 @@ func TestNodeRenameCommand(t *testing.T) { assert.Contains(t, listAllAfterRenameAttempt[4].GetGivenName(), "node-5") } -func TestNodeMoveCommand(t *testing.T) { +func TestPolicyCommand(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "old-user": 0, - "new-user": 0, + spec := ScenarioSpec{ + Users: []string{"user1"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clins")) - assertNoErr(t, err) + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("clins"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_POLICY_MODE": "database", + }), + ) + require.NoError(t, err) headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) - // Randomly generated node key - machineKey := "mkey:688411b767663479632d44140f08a9fde87383adc7cdeb518f62ce28a17ef0aa" + p := policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:exists"): policyv2.Owners{usernameOwner("user1@")}, + }, + } + pBytes, _ := json.Marshal(p) + + policyFilePath := "/etc/headscale/policy.json" + + err = headscale.WriteFile(policyFilePath, pBytes) + require.NoError(t, err) + + // No policy is present at this time. + // Add a new policy from a file. _, err = headscale.Execute( []string{ "headscale", - "debug", - "create-node", - "--name", - "nomad-node", - "--user", - "old-user", - "--key", - machineKey, - "--output", - "json", + "policy", + "set", + "-f", + policyFilePath, }, ) - assert.Nil(t, err) - var node v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "--user", - "old-user", - "register", - "--key", - machineKey, - "--output", - "json", - }, - &node, - ) - assert.Nil(t, err) + require.NoError(t, err) - assert.Equal(t, uint64(1), node.GetId()) - assert.Equal(t, "nomad-node", node.GetName()) - assert.Equal(t, node.GetUser().GetName(), "old-user") + // Get the current policy and check + // if it is the same as the one we set. + var output *policyv2.Policy - nodeID := fmt.Sprintf("%d", node.GetId()) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "policy", + "get", + "--output", + "json", + }, + &output, + ) + assert.NoError(c, err) + }, 10*time.Second, 200*time.Millisecond, "Waiting for policy get command") - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "move", - "--identifier", - nodeID, - "--user", - "new-user", - "--output", - "json", - }, - &node, - ) - assert.Nil(t, err) - - assert.Equal(t, node.GetUser().GetName(), "new-user") - - var allNodes []v1.Node - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "list", - "--output", - "json", - }, - &allNodes, - ) - assert.Nil(t, err) - - assert.Len(t, allNodes, 1) - - assert.Equal(t, allNodes[0].GetId(), node.GetId()) - assert.Equal(t, allNodes[0].GetUser(), node.GetUser()) - assert.Equal(t, allNodes[0].GetUser().GetName(), "new-user") - - moveToNonExistingNSResult, err := headscale.Execute( - []string{ - "headscale", - "nodes", - "move", - "--identifier", - nodeID, - "--user", - "non-existing-user", - "--output", - "json", - }, - ) - assert.Nil(t, err) - - assert.Contains( - t, - moveToNonExistingNSResult, - "user not found", - ) - assert.Equal(t, node.GetUser().GetName(), "new-user") - - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "move", - "--identifier", - nodeID, - "--user", - "old-user", - "--output", - "json", - }, - &node, - ) - assert.Nil(t, err) - - assert.Equal(t, node.GetUser().GetName(), "old-user") - - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "nodes", - "move", - "--identifier", - nodeID, - "--user", - "old-user", - "--output", - "json", - }, - &node, - ) - assert.Nil(t, err) - - assert.Equal(t, node.GetUser().GetName(), "old-user") + assert.Len(t, output.TagOwners, 1) + assert.Len(t, output.ACLs, 1) +} + +func TestPolicyBrokenConfigCommand(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("clins"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_POLICY_MODE": "database", + }), + ) + require.NoError(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + p := policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + // This is an unknown action, so it will return an error + // and the config will not be applied. + Action: "unknown-action", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:exists"): policyv2.Owners{usernameOwner("user1@")}, + }, + } + + pBytes, _ := json.Marshal(p) + + policyFilePath := "/etc/headscale/policy.json" + + err = headscale.WriteFile(policyFilePath, pBytes) + require.NoError(t, err) + + // No policy is present at this time. + // Add a new policy from a file. + _, err = headscale.Execute( + []string{ + "headscale", + "policy", + "set", + "-f", + policyFilePath, + }, + ) + assert.ErrorContains(t, err, `invalid action "unknown-action"`) + + // The new policy was invalid, the old one should still be in place, which + // is none. + _, err = headscale.Execute( + []string{ + "headscale", + "policy", + "get", + "--output", + "json", + }, + ) + assert.ErrorContains(t, err, "acl policy not found") } diff --git a/integration/control.go b/integration/control.go index f5557495..58a061e3 100644 --- a/integration/control.go +++ b/integration/control.go @@ -1,23 +1,49 @@ package integration import ( + "net/netip" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/routes" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/integration/hsic" "github.com/ory/dockertest/v3" + "tailscale.com/tailcfg" ) type ControlServer interface { - Shutdown() error - SaveLog(string) error + Shutdown() (string, string, error) + SaveLog(string) (string, string, error) SaveProfile(string) error Execute(command []string) (string, error) + WriteFile(path string, content []byte) error ConnectToNetwork(network *dockertest.Network) error GetHealthEndpoint() string GetEndpoint() string WaitForRunning() error - CreateUser(user string) error - CreateAuthKey(user string, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) - ListNodesInUser(user string) ([]*v1.Node, error) + CreateUser(user string) (*v1.User, error) + CreateAuthKey(user uint64, reusable bool, ephemeral bool) (*v1.PreAuthKey, error) + CreateAuthKeyWithTags(user uint64, reusable bool, ephemeral bool, tags []string) (*v1.PreAuthKey, error) + CreateAuthKeyWithOptions(opts hsic.AuthKeyOptions) (*v1.PreAuthKey, error) + DeleteAuthKey(id uint64) error + ListNodes(users ...string) ([]*v1.Node, error) + DeleteNode(nodeID uint64) error + NodesByUser() (map[string][]*v1.Node, error) + NodesByName() (map[string]*v1.Node, error) + ListUsers() ([]*v1.User, error) + MapUsers() (map[string]*v1.User, error) + DeleteUser(userID uint64) error + ApproveRoutes(uint64, []netip.Prefix) (*v1.Node, error) + SetNodeTags(nodeID uint64, tags []string) error GetCert() []byte GetHostname() string - GetIP() string + GetIPInNetwork(network *dockertest.Network) string + SetPolicy(*policyv2.Policy) error + GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) + PrimaryRoutes() (*routes.DebugRoutes, error) + DebugBatcher() (*hscontrol.DebugBatcherInfo, error) + DebugNodeStore() (map[types.NodeID]types.Node, error) + DebugFilter() ([]tailcfg.FilterRule, error) } diff --git a/integration/derp_verify_endpoint_test.go b/integration/derp_verify_endpoint_test.go new file mode 100644 index 00000000..60260bb1 --- /dev/null +++ b/integration/derp_verify_endpoint_test.go @@ -0,0 +1,121 @@ +package integration + +import ( + "fmt" + "net" + "strconv" + "testing" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/dsic" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" + "tailscale.com/derp" + "tailscale.com/derp/derphttp" + "tailscale.com/net/netmon" + "tailscale.com/tailcfg" + "tailscale.com/types/key" +) + +func TestDERPVerifyEndpoint(t *testing.T) { + IntegrationSkip(t) + + // Generate random hostname for the headscale instance + hash, err := util.GenerateRandomStringDNSSafe(6) + require.NoError(t, err) + testName := "derpverify" + hostname := fmt.Sprintf("hs-%s-%s", testName, hash) + + headscalePort := 8080 + + // Create cert for headscale + certHeadscale, keyHeadscale, err := integrationutil.CreateCertificate(hostname) + require.NoError(t, err) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + derper, err := scenario.CreateDERPServer("head", + dsic.WithCACert(certHeadscale), + dsic.WithVerifyClientURL(fmt.Sprintf("https://%s/verify", net.JoinHostPort(hostname, strconv.Itoa(headscalePort)))), + ) + require.NoError(t, err) + + derpRegion := tailcfg.DERPRegion{ + RegionCode: "test-derpverify", + RegionName: "TestDerpVerify", + Nodes: []*tailcfg.DERPNode{ + { + Name: "TestDerpVerify", + RegionID: 900, + HostName: derper.GetHostname(), + STUNPort: derper.GetSTUNPort(), + STUNOnly: false, + DERPPort: derper.GetDERPPort(), + InsecureForTests: true, + }, + }, + } + derpMap := tailcfg.DERPMap{ + Regions: map[int]*tailcfg.DERPRegion{ + 900: &derpRegion, + }, + } + + err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithCACert(derper.GetCert())}, + hsic.WithHostname(hostname), + hsic.WithPort(headscalePort), + hsic.WithCustomTLS(certHeadscale, keyHeadscale), + hsic.WithDERPConfig(derpMap)) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + fakeKey := key.NewNode() + DERPVerify(t, fakeKey, derpRegion, false) + + for _, client := range allClients { + nodeKey, err := client.GetNodePrivateKey() + require.NoError(t, err) + DERPVerify(t, *nodeKey, derpRegion, true) + } +} + +func DERPVerify( + t *testing.T, + nodeKey key.NodePrivate, + region tailcfg.DERPRegion, + expectSuccess bool, +) { + t.Helper() + + c := derphttp.NewRegionClient(nodeKey, t.Logf, netmon.NewStatic(), func() *tailcfg.DERPRegion { + return ®ion + }) + defer c.Close() + + var result error + if err := c.Connect(t.Context()); err != nil { + result = fmt.Errorf("client Connect: %w", err) + } + if m, err := c.Recv(); err != nil { + result = fmt.Errorf("client first Recv: %w", err) + } else if v, ok := m.(derp.ServerInfoMessage); !ok { + result = fmt.Errorf("client first Recv was unexpected type %T", v) + } + + if expectSuccess && result != nil { + t.Fatalf("DERP verify failed unexpectedly for client %s. Expected success but got error: %v", nodeKey.Public(), result) + } else if !expectSuccess && result == nil { + t.Fatalf("DERP verify succeeded unexpectedly for client %s. Expected failure but it succeeded.", nodeKey.Public()) + } +} diff --git a/integration/dns_test.go b/integration/dns_test.go new file mode 100644 index 00000000..e937a421 --- /dev/null +++ b/integration/dns_test.go @@ -0,0 +1,226 @@ +package integration + +import ( + "encoding/json" + "fmt" + "strings" + "testing" + "time" + + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" +) + +func TestResolveMagicDNS(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("magicdns")) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) + + // Poor mans cache + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + _, err = scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + for _, client := range allClients { + for _, peer := range allClients { + // It is safe to ignore this error as we handled it when caching it + peerFQDN, _ := peer.FQDN() + + assert.Equal(t, peer.Hostname()+".headscale.net.", peerFQDN) + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + command := []string{ + "tailscale", + "ip", peerFQDN, + } + result, _, err := client.Execute(command) + assert.NoError(ct, err, "Failed to execute resolve/ip command %s from %s", peerFQDN, client.Hostname()) + + ips, err := peer.IPs() + assert.NoError(ct, err, "Failed to get IPs for %s", peer.Hostname()) + + for _, ip := range ips { + assert.Contains(ct, result, ip.String(), "IP %s should be found in DNS resolution result from %s to %s", ip.String(), client.Hostname(), peer.Hostname()) + } + }, 30*time.Second, 2*time.Second) + } + } +} + +func TestResolveMagicDNSExtraRecordsPath(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + const erPath = "/tmp/extra_records.json" + + extraRecords := []tailcfg.DNSRecord{ + { + Name: "test.myvpn.example.com", + Type: "A", + Value: "6.6.6.6", + }, + } + b, _ := json.Marshal(extraRecords) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithPackages("python3", "curl", "bind-tools"), + }, + hsic.WithTestName("extrarecords"), + hsic.WithConfigEnv(map[string]string{ + // Disable global nameservers to make the test run offline. + "HEADSCALE_DNS_NAMESERVERS_GLOBAL": "", + "HEADSCALE_DNS_EXTRA_RECORDS_PATH": erPath, + }), + hsic.WithFileInContainer(erPath, b), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) + + // Poor mans cache + _, err = scenario.ListTailscaleClientsFQDNs() + requireNoErrListFQDN(t, err) + + _, err = scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") + } + + hs, err := scenario.Headscale() + require.NoError(t, err) + + // Write the file directly into place from the docker API. + b0, _ := json.Marshal([]tailcfg.DNSRecord{ + { + Name: "docker.myvpn.example.com", + Type: "A", + Value: "2.2.2.2", + }, + }) + + err = hs.WriteFile(erPath, b0) + require.NoError(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "2.2.2.2") + } + + // Write a new file and move it to the path to ensure the reload + // works when a file is moved atomically into place. + extraRecords = append(extraRecords, tailcfg.DNSRecord{ + Name: "otherrecord.myvpn.example.com", + Type: "A", + Value: "7.7.7.7", + }) + b2, _ := json.Marshal(extraRecords) + + err = hs.WriteFile(erPath+"2", b2) + require.NoError(t, err) + _, err = hs.Execute([]string{"mv", erPath + "2", erPath}) + require.NoError(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "test.myvpn.example.com"}, "6.6.6.6") + assertCommandOutputContains(t, client, []string{"dig", "otherrecord.myvpn.example.com"}, "7.7.7.7") + } + + // Write a new file and copy it to the path to ensure the reload + // works when a file is copied into place. + b3, _ := json.Marshal([]tailcfg.DNSRecord{ + { + Name: "copy.myvpn.example.com", + Type: "A", + Value: "8.8.8.8", + }, + }) + + err = hs.WriteFile(erPath+"3", b3) + require.NoError(t, err) + _, err = hs.Execute([]string{"cp", erPath + "3", erPath}) + require.NoError(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8") + } + + // Write in place to ensure pipe like behaviour works + b4, _ := json.Marshal([]tailcfg.DNSRecord{ + { + Name: "docker.myvpn.example.com", + Type: "A", + Value: "9.9.9.9", + }, + }) + command := []string{"echo", fmt.Sprintf("'%s'", string(b4)), ">", erPath} + _, err = hs.Execute([]string{"bash", "-c", strings.Join(command, " ")}) + require.NoError(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "docker.myvpn.example.com"}, "9.9.9.9") + } + + // Delete the file and create a new one to ensure it is picked up again. + _, err = hs.Execute([]string{"rm", erPath}) + require.NoError(t, err) + + // The same paths should still be available as it is not cleared on delete. + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + result, _, err := client.Execute([]string{"dig", "docker.myvpn.example.com"}) + assert.NoError(ct, err) + assert.Contains(ct, result, "9.9.9.9") + } + }, 10*time.Second, 1*time.Second) + + // Write a new file, the backoff mechanism should make the filewatcher pick it up + // again. + err = hs.WriteFile(erPath, b3) + require.NoError(t, err) + + for _, client := range allClients { + assertCommandOutputContains(t, client, []string{"dig", "copy.myvpn.example.com"}, "8.8.8.8") + } +} diff --git a/integration/dockertestutil/build.go b/integration/dockertestutil/build.go new file mode 100644 index 00000000..dd082d22 --- /dev/null +++ b/integration/dockertestutil/build.go @@ -0,0 +1,25 @@ +package dockertestutil + +import ( + "context" + "os/exec" + "time" +) + +// RunDockerBuildForDiagnostics runs docker build manually to get detailed error output. +// This is used when a docker build fails to provide more detailed diagnostic information +// than what dockertest typically provides. +// +// Returns the build output regardless of success/failure, and an error if the build failed. +func RunDockerBuildForDiagnostics(contextDir, dockerfile string) (string, error) { + // Use a context with timeout to prevent hanging builds + const buildTimeout = 10 * time.Minute + + ctx, cancel := context.WithTimeout(context.Background(), buildTimeout) + defer cancel() + + cmd := exec.CommandContext(ctx, "docker", "build", "--progress=plain", "--no-cache", "-f", dockerfile, contextDir) + output, err := cmd.CombinedOutput() + + return string(output), err +} diff --git a/integration/dockertestutil/config.go b/integration/dockertestutil/config.go index 4dc3ee33..c0c57a3e 100644 --- a/integration/dockertestutil/config.go +++ b/integration/dockertestutil/config.go @@ -1,40 +1,72 @@ package dockertestutil import ( + "fmt" "os" + "strings" + "time" - "github.com/ory/dockertest/v3/docker" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/ory/dockertest/v3" ) +const ( + // TimestampFormatRunID is used for generating unique run identifiers + // Format: "20060102-150405" provides compact date-time for file/directory names. + TimestampFormatRunID = "20060102-150405" +) + +// GetIntegrationRunID returns the run ID for the current integration test session. +// This is set by the hi tool and passed through environment variables. +func GetIntegrationRunID() string { + return os.Getenv("HEADSCALE_INTEGRATION_RUN_ID") +} + +// DockerAddIntegrationLabels adds integration test labels to Docker RunOptions. +// This allows the hi tool to identify containers belonging to specific test runs. +// This function should be called before passing RunOptions to dockertest functions. +func DockerAddIntegrationLabels(opts *dockertest.RunOptions, testType string) { + runID := GetIntegrationRunID() + if runID == "" { + panic("HEADSCALE_INTEGRATION_RUN_ID environment variable is required") + } + + if opts.Labels == nil { + opts.Labels = make(map[string]string) + } + opts.Labels["hi.run-id"] = runID + opts.Labels["hi.test-type"] = testType +} + +// GenerateRunID creates a unique run identifier with timestamp and random hash. +// Format: YYYYMMDD-HHMMSS-HASH (e.g., 20250619-143052-a1b2c3). +func GenerateRunID() string { + now := time.Now() + timestamp := now.Format(TimestampFormatRunID) + + // Add a short random hash to ensure uniqueness + randomHash := util.MustGenerateRandomStringDNSSafe(6) + + return fmt.Sprintf("%s-%s", timestamp, randomHash) +} + +// ExtractRunIDFromContainerName extracts the run ID from container name. +// Expects format: "prefix-YYYYMMDD-HHMMSS-HASH". +func ExtractRunIDFromContainerName(containerName string) string { + parts := strings.Split(containerName, "-") + if len(parts) >= 3 { + // Return the last three parts as the run ID (YYYYMMDD-HHMMSS-HASH) + return strings.Join(parts[len(parts)-3:], "-") + } + + panic("unexpected container name format: " + containerName) +} + +// IsRunningInContainer checks if the current process is running inside a Docker container. +// This is used by tests to determine if they should run integration tests. func IsRunningInContainer() bool { - if _, err := os.Stat("/.dockerenv"); err != nil { - return false - } - - return true -} - -func DockerRestartPolicy(config *docker.HostConfig) { - // set AutoRemove to true so that stopped container goes away by itself on error *immediately*. - // when set to false, containers remain until the end of the integration test. - config.AutoRemove = false - config.RestartPolicy = docker.RestartPolicy{ - Name: "no", - } -} - -func DockerAllowLocalIPv6(config *docker.HostConfig) { - if config.Sysctls == nil { - config.Sysctls = make(map[string]string, 1) - } - config.Sysctls["net.ipv6.conf.all.disable_ipv6"] = "0" -} - -func DockerAllowNetworkAdministration(config *docker.HostConfig) { - config.CapAdd = append(config.CapAdd, "NET_ADMIN") - config.Mounts = append(config.Mounts, docker.HostMount{ - Type: "bind", - Source: "/dev/net/tun", - Target: "/dev/net/tun", - }) + // Check for the common indicator that we're in a container + // This could be improved with more robust detection if needed + _, err := os.Stat("/.dockerenv") + return err == nil } diff --git a/integration/dockertestutil/execute.go b/integration/dockertestutil/execute.go index 5a8e92b3..b09e0d40 100644 --- a/integration/dockertestutil/execute.go +++ b/integration/dockertestutil/execute.go @@ -4,12 +4,13 @@ import ( "bytes" "errors" "fmt" + "sync" "time" "github.com/ory/dockertest/v3" ) -const dockerExecuteTimeout = time.Second * 30 +const dockerExecuteTimeout = time.Second * 10 var ( ErrDockertestCommandFailed = errors.New("dockertest command failed") @@ -25,19 +26,40 @@ type ExecuteCommandOption func(*ExecuteCommandConfig) error func ExecuteCommandTimeout(timeout time.Duration) ExecuteCommandOption { return ExecuteCommandOption(func(conf *ExecuteCommandConfig) error { conf.timeout = timeout - return nil }) } +// buffer is a goroutine safe bytes.buffer. +type buffer struct { + store bytes.Buffer + mutex sync.Mutex +} + +// Write appends the contents of p to the buffer, growing the buffer as needed. It returns +// the number of bytes written. +func (b *buffer) Write(p []byte) (n int, err error) { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.store.Write(p) +} + +// String returns the contents of the unread portion of the buffer +// as a string. +func (b *buffer) String() string { + b.mutex.Lock() + defer b.mutex.Unlock() + return b.store.String() +} + func ExecuteCommand( resource *dockertest.Resource, cmd []string, env []string, options ...ExecuteCommandOption, ) (string, string, error) { - var stdout bytes.Buffer - var stderr bytes.Buffer + stdout := buffer{} + stderr := buffer{} execConfig := ExecuteCommandConfig{ timeout: dockerExecuteTimeout, @@ -62,11 +84,12 @@ func ExecuteCommand( exitCode, err := resource.Exec( cmd, dockertest.ExecOptions{ - Env: append(env, "HEADSCALE_LOG_LEVEL=disabled"), + Env: append(env, "HEADSCALE_LOG_LEVEL=info"), StdOut: &stdout, StdErr: &stderr, }, ) + resultChan <- result{exitCode, err} }() @@ -74,7 +97,7 @@ func ExecuteCommand( select { case res := <-resultChan: if res.err != nil { - return stdout.String(), stderr.String(), res.err + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), res.err) } if res.exitCode != 0 { @@ -83,12 +106,11 @@ func ExecuteCommand( // log.Println("stdout: ", stdout.String()) // log.Println("stderr: ", stderr.String()) - return stdout.String(), stderr.String(), ErrDockertestCommandFailed + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandFailed) } return stdout.String(), stderr.String(), nil case <-time.After(execConfig.timeout): - - return stdout.String(), stderr.String(), ErrDockertestCommandTimeout + return stdout.String(), stderr.String(), fmt.Errorf("command failed, stderr: %s: %w", stderr.String(), ErrDockertestCommandTimeout) } } diff --git a/integration/dockertestutil/logs.go b/integration/dockertestutil/logs.go index 98ba970a..7d104e43 100644 --- a/integration/dockertestutil/logs.go +++ b/integration/dockertestutil/logs.go @@ -3,6 +3,7 @@ package dockertestutil import ( "bytes" "context" + "io" "log" "os" "path" @@ -13,25 +14,18 @@ import ( const filePerm = 0o644 -func SaveLog( +func WriteLog( pool *dockertest.Pool, resource *dockertest.Resource, - basePath string, + stdout io.Writer, + stderr io.Writer, ) error { - err := os.MkdirAll(basePath, os.ModePerm) - if err != nil { - return err - } - - var stdout bytes.Buffer - var stderr bytes.Buffer - - err = pool.Client.Logs( + return pool.Client.Logs( docker.LogsOptions{ Context: context.TODO(), Container: resource.Container.ID, - OutputStream: &stdout, - ErrorStream: &stderr, + OutputStream: stdout, + ErrorStream: stderr, Tail: "all", RawTerminal: false, Stdout: true, @@ -40,29 +34,45 @@ func SaveLog( Timestamps: false, }, ) +} + +func SaveLog( + pool *dockertest.Pool, + resource *dockertest.Resource, + basePath string, +) (string, string, error) { + err := os.MkdirAll(basePath, os.ModePerm) if err != nil { - return err + return "", "", err + } + + var stdout, stderr bytes.Buffer + err = WriteLog(pool, resource, &stdout, &stderr) + if err != nil { + return "", "", err } log.Printf("Saving logs for %s to %s\n", resource.Container.Name, basePath) + stdoutPath := path.Join(basePath, resource.Container.Name+".stdout.log") err = os.WriteFile( - path.Join(basePath, resource.Container.Name+".stdout.log"), + stdoutPath, stdout.Bytes(), filePerm, ) if err != nil { - return err + return "", "", err } + stderrPath := path.Join(basePath, resource.Container.Name+".stderr.log") err = os.WriteFile( - path.Join(basePath, resource.Container.Name+".stderr.log"), + stderrPath, stderr.Bytes(), filePerm, ) if err != nil { - return err + return "", "", err } - return nil + return stdoutPath, stderrPath, nil } diff --git a/integration/dockertestutil/network.go b/integration/dockertestutil/network.go index 89fdc8ec..42483247 100644 --- a/integration/dockertestutil/network.go +++ b/integration/dockertestutil/network.go @@ -2,8 +2,11 @@ package dockertestutil import ( "errors" + "fmt" + "log" "net" + "github.com/juanfont/headscale/hscontrol/util" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" ) @@ -12,7 +15,10 @@ var ErrContainerNotFound = errors.New("container not found") func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Network, error) { networks, err := pool.NetworksByName(name) - if err != nil || len(networks) == 0 { + if err != nil { + return nil, fmt.Errorf("looking up network names: %w", err) + } + if len(networks) == 0 { if _, err := pool.CreateNetwork(name); err == nil { // Create does not give us an updated version of the resource, so we need to // get it again. @@ -22,6 +28,8 @@ func GetFirstOrCreateNetwork(pool *dockertest.Pool, name string) (*dockertest.Ne } return &networks[0], nil + } else { + return nil, fmt.Errorf("creating network: %w", err) } } @@ -50,7 +58,7 @@ func AddContainerToNetwork( return err } - // TODO(kradalby): This doesnt work reliably, but calling the exact same functions + // TODO(kradalby): This doesn't work reliably, but calling the exact same functions // seem to work fine... // if container, ok := pool.ContainerByName("/" + testContainer); ok { // err := container.ConnectToNetwork(network) @@ -78,3 +86,89 @@ func RandomFreeHostPort() (int, error) { //nolint:forcetypeassert return listener.Addr().(*net.TCPAddr).Port, nil } + +// CleanUnreferencedNetworks removes networks that are not referenced by any containers. +func CleanUnreferencedNetworks(pool *dockertest.Pool) error { + filter := "name=hs-" + networks, err := pool.NetworksByName(filter) + if err != nil { + return fmt.Errorf("getting networks by filter %q: %w", filter, err) + } + + for _, network := range networks { + if len(network.Network.Containers) == 0 { + err := pool.RemoveNetwork(&network) + if err != nil { + log.Printf("removing network %s: %s", network.Network.Name, err) + } + } + } + + return nil +} + +// CleanImagesInCI removes images if running in CI. +// It only removes dangling (untagged) images to avoid forcing rebuilds. +// Tagged images (golang:*, tailscale/tailscale:*, etc.) are automatically preserved. +func CleanImagesInCI(pool *dockertest.Pool) error { + if !util.IsCI() { + log.Println("Skipping image cleanup outside of CI") + return nil + } + + images, err := pool.Client.ListImages(docker.ListImagesOptions{}) + if err != nil { + return fmt.Errorf("getting images: %w", err) + } + + removedCount := 0 + for _, image := range images { + // Only remove dangling (untagged) images to avoid forcing rebuilds + // Dangling images have no RepoTags or only have "<none>:<none>" + if len(image.RepoTags) == 0 || (len(image.RepoTags) == 1 && image.RepoTags[0] == "<none>:<none>") { + log.Printf("Removing dangling image: %s", image.ID[:12]) + + err := pool.Client.RemoveImage(image.ID) + if err != nil { + log.Printf("Warning: failed to remove image %s: %v", image.ID[:12], err) + } else { + removedCount++ + } + } + } + + if removedCount > 0 { + log.Printf("Removed %d dangling images in CI", removedCount) + } else { + log.Println("No dangling images to remove in CI") + } + + return nil +} + +// DockerRestartPolicy sets the restart policy for containers. +func DockerRestartPolicy(config *docker.HostConfig) { + config.RestartPolicy = docker.RestartPolicy{ + Name: "unless-stopped", + } +} + +// DockerAllowLocalIPv6 allows IPv6 traffic within the container. +func DockerAllowLocalIPv6(config *docker.HostConfig) { + config.NetworkMode = "default" + config.Sysctls = map[string]string{ + "net.ipv6.conf.all.disable_ipv6": "0", + } +} + +// DockerAllowNetworkAdministration gives the container network administration capabilities. +func DockerAllowNetworkAdministration(config *docker.HostConfig) { + config.CapAdd = append(config.CapAdd, "NET_ADMIN") + config.Privileged = true +} + +// DockerMemoryLimit sets memory limit and disables OOM kill for containers. +func DockerMemoryLimit(config *docker.HostConfig) { + config.Memory = 2 * 1024 * 1024 * 1024 // 2GB in bytes + config.OOMKillDisable = true +} diff --git a/integration/dsic/dsic.go b/integration/dsic/dsic.go new file mode 100644 index 00000000..d8a77575 --- /dev/null +++ b/integration/dsic/dsic.go @@ -0,0 +1,366 @@ +package dsic + +import ( + "crypto/tls" + "errors" + "fmt" + "log" + "net" + "net/http" + "strconv" + "strings" + "time" + + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/dockertestutil" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/ory/dockertest/v3" + "github.com/ory/dockertest/v3/docker" +) + +const ( + dsicHashLength = 6 + dockerContextPath = "../." + caCertRoot = "/usr/local/share/ca-certificates" + DERPerCertRoot = "/usr/local/share/derper-certs" + dockerExecuteTimeout = 60 * time.Second +) + +var errDERPerStatusCodeNotOk = errors.New("DERPer status code not OK") + +// DERPServerInContainer represents DERP Server in Container (DSIC). +type DERPServerInContainer struct { + version string + hostname string + + pool *dockertest.Pool + container *dockertest.Resource + networks []*dockertest.Network + + stunPort int + derpPort int + caCerts [][]byte + tlsCert []byte + tlsKey []byte + withExtraHosts []string + withVerifyClientURL string + workdir string +} + +// Option represent optional settings that can be given to a +// DERPer instance. +type Option = func(c *DERPServerInContainer) + +// WithCACert adds it to the trusted surtificate of the Tailscale container. +func WithCACert(cert []byte) Option { + return func(dsic *DERPServerInContainer) { + dsic.caCerts = append(dsic.caCerts, cert) + } +} + +// WithOrCreateNetwork sets the Docker container network to use with +// the DERPer instance, if the parameter is nil, a new network, +// isolating the DERPer, will be created. If a network is +// passed, the DERPer instance will join the given network. +func WithOrCreateNetwork(network *dockertest.Network) Option { + return func(dsic *DERPServerInContainer) { + if network != nil { + dsic.networks = append(dsic.networks, network) + + return + } + + network, err := dockertestutil.GetFirstOrCreateNetwork( + dsic.pool, + dsic.hostname+"-network", + ) + if err != nil { + log.Fatalf("failed to create network: %s", err) + } + + dsic.networks = append(dsic.networks, network) + } +} + +// WithDockerWorkdir allows the docker working directory to be set. +func WithDockerWorkdir(dir string) Option { + return func(tsic *DERPServerInContainer) { + tsic.workdir = dir + } +} + +// WithVerifyClientURL sets the URL to verify the client. +func WithVerifyClientURL(url string) Option { + return func(tsic *DERPServerInContainer) { + tsic.withVerifyClientURL = url + } +} + +// WithExtraHosts adds extra hosts to the container. +func WithExtraHosts(hosts []string) Option { + return func(tsic *DERPServerInContainer) { + tsic.withExtraHosts = hosts + } +} + +// buildEntrypoint builds the container entrypoint command based on configuration. +// It constructs proper wait conditions instead of fixed sleeps: +// 1. Wait for network to be ready +// 2. Wait for TLS cert to be written (always written after container start) +// 3. Wait for CA certs if configured +// 4. Update CA certificates +// 5. Run derper with provided arguments. +func (dsic *DERPServerInContainer) buildEntrypoint(derperArgs string) []string { + var commands []string + + // Wait for network to be ready + commands = append(commands, "while ! ip route show default >/dev/null 2>&1; do sleep 0.1; done") + + // Wait for TLS cert to be written (always written after container start) + commands = append(commands, + fmt.Sprintf("while [ ! -f %s/%s.crt ]; do sleep 0.1; done", DERPerCertRoot, dsic.hostname)) + + // If CA certs are configured, wait for them to be written + if len(dsic.caCerts) > 0 { + commands = append(commands, + fmt.Sprintf("while [ ! -f %s/user-0.crt ]; do sleep 0.1; done", caCertRoot)) + } + + // Update CA certificates + commands = append(commands, "update-ca-certificates") + + // Run derper + commands = append(commands, "derper "+derperArgs) + + return []string{"/bin/sh", "-c", strings.Join(commands, " ; ")} +} + +// New returns a new TailscaleInContainer instance. +func New( + pool *dockertest.Pool, + version string, + networks []*dockertest.Network, + opts ...Option, +) (*DERPServerInContainer, error) { + hash, err := util.GenerateRandomStringDNSSafe(dsicHashLength) + if err != nil { + return nil, err + } + + // Include run ID in hostname for easier identification of which test run owns this container + runID := dockertestutil.GetIntegrationRunID() + + var hostname string + + if runID != "" { + // Use last 6 chars of run ID (the random hash part) for brevity + runIDShort := runID[len(runID)-6:] + hostname = fmt.Sprintf("derp-%s-%s-%s", runIDShort, strings.ReplaceAll(version, ".", "-"), hash) + } else { + hostname = fmt.Sprintf("derp-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) + } + tlsCert, tlsKey, err := integrationutil.CreateCertificate(hostname) + if err != nil { + return nil, fmt.Errorf("failed to create certificates for headscale test: %w", err) + } + dsic := &DERPServerInContainer{ + version: version, + hostname: hostname, + pool: pool, + networks: networks, + tlsCert: tlsCert, + tlsKey: tlsKey, + stunPort: 3478, //nolint + derpPort: 443, //nolint + } + + for _, opt := range opts { + opt(dsic) + } + + var cmdArgs strings.Builder + fmt.Fprintf(&cmdArgs, "--hostname=%s", hostname) + fmt.Fprintf(&cmdArgs, " --certmode=manual") + fmt.Fprintf(&cmdArgs, " --certdir=%s", DERPerCertRoot) + fmt.Fprintf(&cmdArgs, " --a=:%d", dsic.derpPort) + fmt.Fprintf(&cmdArgs, " --stun=true") + fmt.Fprintf(&cmdArgs, " --stun-port=%d", dsic.stunPort) + if dsic.withVerifyClientURL != "" { + fmt.Fprintf(&cmdArgs, " --verify-client-url=%s", dsic.withVerifyClientURL) + } + + runOptions := &dockertest.RunOptions{ + Name: hostname, + Networks: dsic.networks, + ExtraHosts: dsic.withExtraHosts, + Entrypoint: dsic.buildEntrypoint(cmdArgs.String()), + ExposedPorts: []string{ + "80/tcp", + fmt.Sprintf("%d/tcp", dsic.derpPort), + fmt.Sprintf("%d/udp", dsic.stunPort), + }, + } + + if dsic.workdir != "" { + runOptions.WorkingDir = dsic.workdir + } + + // dockertest isn't very good at handling containers that has already + // been created, this is an attempt to make sure this container isn't + // present. + err = pool.RemoveContainerByName(hostname) + if err != nil { + return nil, err + } + + var container *dockertest.Resource + buildOptions := &dockertest.BuildOptions{ + Dockerfile: "Dockerfile.derper", + ContextDir: dockerContextPath, + BuildArgs: []docker.BuildArg{}, + } + switch version { + case "head": + buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ + Name: "VERSION_BRANCH", + Value: "main", + }) + default: + buildOptions.BuildArgs = append(buildOptions.BuildArgs, docker.BuildArg{ + Name: "VERSION_BRANCH", + Value: "v" + version, + }) + } + // Add integration test labels if running under hi tool + dockertestutil.DockerAddIntegrationLabels(runOptions, "derp") + + container, err = pool.BuildAndRunWithBuildOptions( + buildOptions, + runOptions, + dockertestutil.DockerRestartPolicy, + dockertestutil.DockerAllowLocalIPv6, + dockertestutil.DockerAllowNetworkAdministration, + ) + if err != nil { + return nil, fmt.Errorf( + "%s could not start tailscale DERPer container (version: %s): %w", + hostname, + version, + err, + ) + } + log.Printf("Created %s container\n", hostname) + + dsic.container = container + + for i, cert := range dsic.caCerts { + err = dsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert) + if err != nil { + return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) + } + } + if len(dsic.tlsCert) != 0 { + err = dsic.WriteFile(fmt.Sprintf("%s/%s.crt", DERPerCertRoot, dsic.hostname), dsic.tlsCert) + if err != nil { + return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) + } + } + if len(dsic.tlsKey) != 0 { + err = dsic.WriteFile(fmt.Sprintf("%s/%s.key", DERPerCertRoot, dsic.hostname), dsic.tlsKey) + if err != nil { + return nil, fmt.Errorf("failed to write TLS key to container: %w", err) + } + } + + return dsic, nil +} + +// Shutdown stops and cleans up the DERPer container. +func (t *DERPServerInContainer) Shutdown() error { + err := t.SaveLog("/tmp/control") + if err != nil { + log.Printf( + "Failed to save log from %s: %s", + t.hostname, + fmt.Errorf("failed to save log: %w", err), + ) + } + + return t.pool.Purge(t.container) +} + +// GetCert returns the TLS certificate of the DERPer instance. +func (t *DERPServerInContainer) GetCert() []byte { + return t.tlsCert +} + +// Hostname returns the hostname of the DERPer instance. +func (t *DERPServerInContainer) Hostname() string { + return t.hostname +} + +// Version returns the running DERPer version of the instance. +func (t *DERPServerInContainer) Version() string { + return t.version +} + +// ID returns the Docker container ID of the DERPServerInContainer +// instance. +func (t *DERPServerInContainer) ID() string { + return t.container.Container.ID +} + +func (t *DERPServerInContainer) GetHostname() string { + return t.hostname +} + +// GetSTUNPort returns the STUN port of the DERPer instance. +func (t *DERPServerInContainer) GetSTUNPort() int { + return t.stunPort +} + +// GetDERPPort returns the DERP port of the DERPer instance. +func (t *DERPServerInContainer) GetDERPPort() int { + return t.derpPort +} + +// WaitForRunning blocks until the DERPer instance is ready to be used. +func (t *DERPServerInContainer) WaitForRunning() error { + url := "https://" + net.JoinHostPort(t.GetHostname(), strconv.Itoa(t.GetDERPPort())) + "/" + log.Printf("waiting for DERPer to be ready at %s", url) + + insecureTransport := http.DefaultTransport.(*http.Transport).Clone() //nolint + insecureTransport.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} //nolint + client := &http.Client{Transport: insecureTransport} + + return t.pool.Retry(func() error { + resp, err := client.Get(url) //nolint + if err != nil { + return fmt.Errorf("headscale is not ready: %w", err) + } + + if resp.StatusCode != http.StatusOK { + return errDERPerStatusCodeNotOk + } + + return nil + }) +} + +// ConnectToNetwork connects the DERPer instance to a network. +func (t *DERPServerInContainer) ConnectToNetwork(network *dockertest.Network) error { + return t.container.ConnectToNetwork(network) +} + +// WriteFile save file inside the container. +func (t *DERPServerInContainer) WriteFile(path string, data []byte) error { + return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) +} + +// SaveLog saves the current stdout log of the container to a path +// on the host system. +func (t *DERPServerInContainer) SaveLog(path string) error { + _, _, err := dockertestutil.SaveLog(t.pool, t.container, path) + + return err +} diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 3a407496..89154f63 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -1,231 +1,209 @@ package integration import ( - "fmt" - "log" - "net/url" "testing" + "time" - "github.com/juanfont/headscale/hscontrol/util" - "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" - "github.com/ory/dockertest/v3" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/key" ) -type EmbeddedDERPServerScenario struct { - *Scenario - - tsicNetworks map[string]*dockertest.Network +type ClientsSpec struct { + Plain int + WebsocketDERP int } func TestDERPServerScenario(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2", "user3"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + "usernet3": {"user3"}, + }, + } + + derpServerScenario(t, spec, false, func(scenario *Scenario) { + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + t.Logf("checking %d clients for websocket connections", len(allClients)) + + for _, client := range allClients { + if didClientUseWebsocketForDERP(t, client) { + t.Logf( + "client %q used websocket a connection, but was not expected to", + client.Hostname(), + ) + t.Fail() + } + } + + hsServer, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + derpRegion := tailcfg.DERPRegion{ + RegionCode: "test-derpverify", + RegionName: "TestDerpVerify", + Nodes: []*tailcfg.DERPNode{ + { + Name: "TestDerpVerify", + RegionID: 900, + HostName: hsServer.GetHostname(), + STUNPort: 3478, + STUNOnly: false, + DERPPort: 443, + InsecureForTests: true, + }, + }, + } + + fakeKey := key.NewNode() + DERPVerify(t, fakeKey, derpRegion, false) + }) +} + +func TestDERPServerWebsocketScenario(t *testing.T) { + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2", "user3"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + "usernet3": {"user3"}, + }, + } + + derpServerScenario(t, spec, true, func(scenario *Scenario) { + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + t.Logf("checking %d clients for websocket connections", len(allClients)) + + for _, client := range allClients { + if !didClientUseWebsocketForDERP(t, client) { + t.Logf( + "client %q does not seem to have used a websocket connection, even though it was expected to do so", + client.Hostname(), + ) + t.Fail() + } + } + }) +} + +// This function implements the common parts of a DERP scenario, +// we *want* it to show up in stacktraces, +// so marking it as a test helper would be counterproductive. +// +//nolint:thelper +func derpServerScenario( + t *testing.T, + spec ScenarioSpec, + websocket bool, + furtherAssertions ...func(*Scenario), +) { IntegrationSkip(t) - // t.Parallel() - baseScenario, err := NewScenario() - assertNoErr(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) - scenario := EmbeddedDERPServerScenario{ - Scenario: baseScenario, - tsicNetworks: map[string]*dockertest.Network{}, - } - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), - } - - headscaleConfig := map[string]string{} - headscaleConfig["HEADSCALE_DERP_URLS"] = "" - headscaleConfig["HEADSCALE_DERP_SERVER_ENABLED"] = "true" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_ID"] = "999" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale" - headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" - headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" - headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" - // Envknob for enabling DERP debug logs - headscaleConfig["DERP_DEBUG_LOGS"] = "true" - headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true" + defer scenario.ShutdownAssertNoPanics(t) err = scenario.CreateHeadscaleEnv( - spec, - hsic.WithConfigEnv(headscaleConfig), + []tsic.Option{ + tsic.WithWebsocketDERP(websocket), + }, hsic.WithTestName("derpserver"), hsic.WithExtraPorts([]string{"3478/udp"}), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithPort(443), hsic.WithTLS(), - hsic.WithHostnameAsServerURL(), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "true", + "HEADSCALE_DERP_UPDATE_FREQUENCY": "10s", + "HEADSCALE_LISTEN_ADDR": "0.0.0.0:443", + "HEADSCALE_DERP_SERVER_VERIFY_CLIENTS": "true", + }), ) - assertNoErrHeadscaleEnv(t, err) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) + + for _, client := range allClients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) + + for _, health := range status.Health { + assert.NotContains(ct, health, "could not connect to any relay server", + "Client %s should be connected to DERP relay", client.Hostname()) + assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", + "Client %s should be connected to Headscale Embedded DERP", client.Hostname()) + } + }, 30*time.Second, 2*time.Second) + } success := pingDerpAllHelper(t, allClients, allHostnames) + if len(allHostnames)*len(allClients) > success { + t.FailNow() - t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) -} - -func (s *EmbeddedDERPServerScenario) CreateHeadscaleEnv( - users map[string]int, - opts ...hsic.Option, -) error { - hsServer, err := s.Headscale(opts...) - if err != nil { - return err + return } - headscaleEndpoint := hsServer.GetEndpoint() - headscaleURL, err := url.Parse(headscaleEndpoint) - if err != nil { - return err - } + for _, client := range allClients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) - headscaleURL.Host = fmt.Sprintf("%s:%s", hsServer.GetHostname(), headscaleURL.Port()) - - err = hsServer.WaitForRunning() - if err != nil { - return err - } - - hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) - if err != nil { - return err - } - - for userName, clientCount := range users { - err = s.CreateUser(userName) - if err != nil { - return err - } - - err = s.CreateTailscaleIsolatedNodesInUser( - hash, - userName, - "all", - clientCount, - ) - if err != nil { - return err - } - - key, err := s.CreatePreAuthKey(userName, true, false) - if err != nil { - return err - } - - err = s.RunTailscaleUp(userName, headscaleURL.String(), key.GetKey()) - if err != nil { - return err - } - } - - return nil -} - -func (s *EmbeddedDERPServerScenario) CreateTailscaleIsolatedNodesInUser( - hash string, - userStr string, - requestedVersion string, - count int, - opts ...tsic.Option, -) error { - hsServer, err := s.Headscale() - if err != nil { - return err - } - - if user, ok := s.users[userStr]; ok { - for clientN := 0; clientN < count; clientN++ { - networkName := fmt.Sprintf("tsnet-%s-%s-%d", - hash, - userStr, - clientN, - ) - network, err := dockertestutil.GetFirstOrCreateNetwork( - s.pool, - networkName, - ) - if err != nil { - return fmt.Errorf("failed to create or get %s network: %w", networkName, err) + for _, health := range status.Health { + assert.NotContains(ct, health, "could not connect to any relay server", + "Client %s should be connected to DERP relay after first run", client.Hostname()) + assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", + "Client %s should be connected to Headscale Embedded DERP after first run", client.Hostname()) } - - s.tsicNetworks[networkName] = network - - err = hsServer.ConnectToNetwork(network) - if err != nil { - return fmt.Errorf("failed to connect headscale to %s network: %w", networkName, err) - } - - version := requestedVersion - if requestedVersion == "all" { - version = MustTestVersions[clientN%len(MustTestVersions)] - } - - cert := hsServer.GetCert() - - opts = append(opts, - tsic.WithHeadscaleTLS(cert), - ) - - user.createWaitGroup.Go(func() error { - tsClient, err := tsic.New( - s.pool, - version, - network, - opts..., - ) - if err != nil { - return fmt.Errorf( - "failed to create tailscale (%s) node: %w", - tsClient.Hostname(), - err, - ) - } - - err = tsClient.WaitForNeedsLogin() - if err != nil { - return fmt.Errorf( - "failed to wait for tailscaled (%s) to need login: %w", - tsClient.Hostname(), - err, - ) - } - - s.mu.Lock() - user.Clients[tsClient.Hostname()] = tsClient - s.mu.Unlock() - - return nil - }) - } - - if err := user.createWaitGroup.Wait(); err != nil { - return err - } - - return nil + }, 30*time.Second, 2*time.Second) } - return fmt.Errorf("failed to add tailscale nodes: %w", errNoUserAvailable) -} + t.Logf("Run 1: %d successful pings out of %d", success, len(allClients)*len(allHostnames)) -func (s *EmbeddedDERPServerScenario) Shutdown() { - for _, network := range s.tsicNetworks { - err := s.pool.RemoveNetwork(network) - if err != nil { - log.Printf("failed to remove DERP network %s", network.Network.Name) - } + // Let the DERP updater run a couple of times to ensure it does not + // break the DERPMap. The updater runs on a 10s interval by default. + //nolint:forbidigo // Intentional delay: must wait for DERP updater to run multiple times (interval-based) + time.Sleep(30 * time.Second) + + success = pingDerpAllHelper(t, allClients, allHostnames) + if len(allHostnames)*len(allClients) > success { + t.Fail() } - s.Scenario.Shutdown() + for _, client := range allClients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err, "Failed to get status for client %s", client.Hostname()) + + for _, health := range status.Health { + assert.NotContains(ct, health, "could not connect to any relay server", + "Client %s should be connected to DERP relay after second run", client.Hostname()) + assert.NotContains(ct, health, "could not connect to the 'Headscale Embedded DERP' relay server.", + "Client %s should be connected to Headscale Embedded DERP after second run", client.Hostname()) + } + }, 30*time.Second, 2*time.Second) + } + + t.Logf("Run2: %d successful pings out of %d", success, len(allClients)*len(allHostnames)) + + for _, check := range furtherAssertions { + check(scenario) + } } diff --git a/integration/general_test.go b/integration/general_test.go index 2e0f7fe6..f44a0f03 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -1,122 +1,120 @@ package integration import ( + "context" "encoding/json" "fmt" "net/netip" + "strconv" "strings" "testing" "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" "tailscale.com/client/tailscale/apitype" "tailscale.com/types/key" ) func TestPingAllByIP(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + MaxWait: dockertestMaxWait(), } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("pingallbyip"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithIPAllocationStrategy(types.IPAllocationStrategyRandom), + ) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + hs, err := scenario.Headscale() + require.NoError(t, err) + + // Extract node IDs for validation + expectedNodes := make([]types.NodeID, 0, len(allClients)) + for _, client := range allClients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err, "failed to parse node ID") + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + requireAllClientsOnline(t, hs, expectedNodes, true, "all clients should be online across all systems", 30*time.Second) + + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() }) + // Get headscale instance for batcher debug check + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Test our DebugBatcher functionality + t.Logf("Testing DebugBatcher functionality...") + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to the batcher", 30*time.Second) + success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) } -func TestAuthKeyLogoutAndRelogin(t *testing.T) { +func TestPingAllByIPPublicDERP(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyip")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("pingallbyippubderp"), + ) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - clientIPs := make(map[TailscaleClient][]netip.Addr) - for _, client := range allClients { - ips, err := client.IPs() - if err != nil { - t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) - } - clientIPs[client] = ips - } - - for _, client := range allClients { - err := client.Logout() - if err != nil { - t.Fatalf("failed to logout client %s: %s", client.Hostname(), err) - } - } - - err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) - - t.Logf("all clients logged out") - - headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) - - for userName := range spec { - key, err := scenario.CreatePreAuthKey(userName, true, false) - if err != nil { - t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) - } - - err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) - if err != nil { - t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) - } - } - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - allClients, err = scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -124,68 +122,47 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { success := pingAllHelper(t, allClients, allAddrs) t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) - - for _, client := range allClients { - ips, err := client.IPs() - if err != nil { - t.Fatalf("failed to get IPs for client %s: %s", client.Hostname(), err) - } - - // lets check if the IPs are the same - if len(ips) != len(clientIPs[client]) { - t.Fatalf("IPs changed for client %s", client.Hostname()) - } - - for _, ip := range ips { - found := false - for _, oldIP := range clientIPs[client] { - if ip == oldIP { - found = true - - break - } - } - - if !found { - t.Fatalf( - "IPs changed for client %s. Used to be %v now %v", - client.Hostname(), - clientIPs[client], - ips, - ) - } - } - } } func TestEphemeral(t *testing.T) { + testEphemeralWithOptions(t, hsic.WithTestName("ephemeral")) +} + +func TestEphemeralInAlternateTimezone(t *testing.T) { + testEphemeralWithOptions( + t, + hsic.WithTestName("ephemeral-tz"), + hsic.WithTimezone("America/Los_Angeles"), + ) +} + +func testEphemeralWithOptions(t *testing.T, opts ...hsic.Option) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), - "user2": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, } - headscale, err := scenario.Headscale(hsic.WithTestName("ephemeral")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) - for userName, clientCount := range spec { - err = scenario.CreateUser(userName) + headscale, err := scenario.Headscale(opts...) + requireNoErrHeadscaleEnv(t, err) + + for _, userName := range spec.Users { + user, err := scenario.CreateUser(userName) if err != nil { t.Fatalf("failed to create user %s: %s", userName, err) } - err = scenario.CreateTailscaleNodesInUser(userName, "all", clientCount, []tsic.Option{}...) + err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) if err != nil { t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) } - key, err := scenario.CreatePreAuthKey(userName, true, true) + key, err := scenario.CreatePreAuthKey(user.GetId(), true, true) if err != nil { t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) } @@ -197,13 +174,13 @@ func TestEphemeral(t *testing.T) { } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -220,12 +197,126 @@ func TestEphemeral(t *testing.T) { } err = scenario.WaitForTailscaleLogout() - assertNoErrLogout(t, err) + requireNoErrLogout(t, err) t.Logf("all clients logged out") - for userName := range spec { - nodes, err := headscale.ListNodesInUser(userName) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 0, "All ephemeral nodes should be cleaned up after logout") + }, 30*time.Second, 2*time.Second) +} + +// TestEphemeral2006DeletedTooQuickly verifies that ephemeral nodes are not +// deleted by accident if they are still online and active. +func TestEphemeral2006DeletedTooQuickly(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + headscale, err := scenario.Headscale( + hsic.WithTestName("ephemeral2006"), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "1m6s", + }), + ) + requireNoErrHeadscaleEnv(t, err) + + for _, userName := range spec.Users { + user, err := scenario.CreateUser(userName) + if err != nil { + t.Fatalf("failed to create user %s: %s", userName, err) + } + + err = scenario.CreateTailscaleNodesInUser(userName, "all", spec.NodesPerUser, tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) + if err != nil { + t.Fatalf("failed to create tailscale nodes in user %s: %s", userName, err) + } + + key, err := scenario.CreatePreAuthKey(user.GetId(), true, true) + if err != nil { + t.Fatalf("failed to create pre-auth key for user %s: %s", userName, err) + } + + err = scenario.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + t.Fatalf("failed to run tailscale up for user %s: %s", userName, err) + } + } + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + // All ephemeral nodes should be online and reachable. + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Take down all clients, this should start an expiry timer for each. + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // Wait a bit and bring up the clients again before the expiry + // time of the ephemeral nodes. + // Nodes should be able to reconnect and work fine. + for _, client := range allClients { + err := client.Up() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // Wait for clients to sync and be able to ping each other after reconnection + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = scenario.WaitForTailscaleSync() + assert.NoError(ct, err) + + success = pingAllHelper(t, allClients, allAddrs) + assert.Greater(ct, success, 0, "Ephemeral nodes should be able to reconnect and ping") + }, 60*time.Second, 2*time.Second) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Take down all clients, this should start an expiry timer for each. + for _, client := range allClients { + err := client.Down() + if err != nil { + t.Fatalf("failed to take down client %s: %s", client.Hostname(), err) + } + } + + // This time wait for all of the nodes to expire and check that they are no longer + // registered. + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, userName := range spec.Users { + nodes, err := headscale.ListNodes(userName) + assert.NoError(ct, err) + assert.Len(ct, nodes, 0, "Ephemeral nodes should be expired and removed for user %s", userName) + } + }, 4*time.Minute, 10*time.Second) + + for _, userName := range spec.Users { + nodes, err := headscale.ListNodes(userName) if err != nil { log.Error(). Err(err). @@ -243,28 +334,29 @@ func TestEphemeral(t *testing.T) { func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user3": len(MustTestVersions), - "user4": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("pingallbyname")) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) allHostnames, err := scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) success := pingAllHelper(t, allClients, allHostnames) @@ -275,44 +367,124 @@ func TestPingAllByHostname(t *testing.T) { // This might mean we approach setup slightly wrong, but for now, ignore // the linter // nolint:tparallel +// TestTaildrop tests the Taildrop file sharing functionality across multiple scenarios: +// 1. Same-user transfers: Nodes owned by the same user can send files to each other +// 2. Cross-user transfers: Nodes owned by different users cannot send files to each other +// 3. Tagged device transfers: Tagged devices cannot send nor receive files +// +// Each user gets len(MustTestVersions) nodes to ensure compatibility across all supported versions. func TestTaildrop(t *testing.T) { IntegrationSkip(t) - t.Parallel() - retry := func(times int, sleepInverval time.Duration, doWork func() error) error { - var err error - for attempts := 0; attempts < times; attempts++ { - err = doWork() - if err == nil { - return nil - } - time.Sleep(sleepInverval) - } - - return err + spec := ScenarioSpec{ + NodesPerUser: 0, // We'll create nodes manually to control tags + Users: []string{"user1", "user2"}, } - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) - spec := map[string]int{ - "taildrop": len(MustTestVersions), + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, + hsic.WithTestName("taildrop"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + networks := scenario.Networks() + require.NotEmpty(t, networks, "scenario should have at least one network") + network := networks[0] + + // Create untagged nodes for user1 using all test versions + user1Key, err := scenario.CreatePreAuthKey(userMap["user1"].GetId(), true, false) + require.NoError(t, err) + + var user1Clients []TailscaleClient + for i, version := range MustTestVersions { + t.Logf("Creating user1 client %d with version %s", i, version) + client, err := scenario.CreateTailscaleNode( + version, + tsic.WithNetwork(network), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), user1Key.GetKey()) + require.NoError(t, err) + + err = client.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + user1Clients = append(user1Clients, client) + scenario.GetOrCreateUser("user1").Clients[client.Hostname()] = client } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("taildrop")) - assertNoErrHeadscaleEnv(t, err) + // Create untagged nodes for user2 using all test versions + user2Key, err := scenario.CreatePreAuthKey(userMap["user2"].GetId(), true, false) + require.NoError(t, err) + + var user2Clients []TailscaleClient + for i, version := range MustTestVersions { + t.Logf("Creating user2 client %d with version %s", i, version) + client, err := scenario.CreateTailscaleNode( + version, + tsic.WithNetwork(network), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), user2Key.GetKey()) + require.NoError(t, err) + + err = client.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + user2Clients = append(user2Clients, client) + scenario.GetOrCreateUser("user2").Clients[client.Hostname()] = client + } + + // Create a tagged device (tags-as-identity: tags come from PreAuthKey) + // Use "head" version to test latest behavior + taggedKey, err := scenario.CreatePreAuthKeyWithTags(userMap["user1"].GetId(), true, false, []string{"tag:server"}) + require.NoError(t, err) + + taggedClient, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(network), + ) + require.NoError(t, err) + + err = taggedClient.Login(headscale.GetEndpoint(), taggedKey.GetKey()) + require.NoError(t, err) + + err = taggedClient.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err) + + // Add tagged client to user1 for tracking (though it's tagged, not user-owned) + scenario.GetOrCreateUser("user1").Clients[taggedClient.Hostname()] = taggedClient allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) + + // Expected: len(MustTestVersions) for user1 + len(MustTestVersions) for user2 + 1 tagged + expectedClientCount := len(MustTestVersions)*2 + 1 + require.Len(t, allClients, expectedClientCount, + "should have %d clients: %d user1 + %d user2 + 1 tagged", + expectedClientCount, len(MustTestVersions), len(MustTestVersions)) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) - // This will essentially fetch and cache all the FQDNs + // Cache FQDNs _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) + // Install curl on all clients for _, client := range allClients { if !strings.Contains(client.Hostname(), "head") { command := []string{"apk", "add", "curl"} @@ -320,210 +492,474 @@ func TestTaildrop(t *testing.T) { if err != nil { t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) } - } - curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} - err = retry(10, 1*time.Second, func() error { - result, _, err := client.Execute(curlCommand) - if err != nil { - return err - } - var fts []apitype.FileTarget - err = json.Unmarshal([]byte(result), &fts) - if err != nil { - return err - } + } - if len(fts) != len(allClients)-1 { - ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname()) - for _, ft := range fts { - ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) - } - return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr) - } - - return err - }) + // Helper to get FileTargets for a client. + getFileTargets := func(client TailscaleClient) ([]apitype.FileTarget, error) { + curlCommand := []string{ + "curl", + "--unix-socket", + "/var/run/tailscale/tailscaled.sock", + "http://local-tailscaled.sock/localapi/v0/file-targets", + } + result, _, err := client.Execute(curlCommand) if err != nil { - t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err) + return nil, err } + + var fts []apitype.FileTarget + if err := json.Unmarshal([]byte(result), &fts); err != nil { + return nil, fmt.Errorf("failed to parse file-targets response: %w (response: %s)", err, result) + } + + return fts, nil } - for _, client := range allClients { - command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())} - - if _, _, err := client.Execute(command); err != nil { - t.Fatalf("failed to create taildrop file on %s, err: %s", client.Hostname(), err) - } - - for _, peer := range allClients { - if client.Hostname() == peer.Hostname() { - continue + // Helper to check if a client is in the FileTargets list + isInFileTargets := func(fts []apitype.FileTarget, targetHostname string) bool { + for _, ft := range fts { + if strings.Contains(ft.Node.Name, targetHostname) { + return true } + } + return false + } - // It is safe to ignore this error as we handled it when caching it - peerFQDN, _ := peer.FQDN() + // Test 1: Verify user1 nodes can see each other in FileTargets but not user2 nodes or tagged node + t.Run("FileTargets-user1", func(t *testing.T) { + for _, client := range user1Clients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + fts, err := getFileTargets(client) + assert.NoError(ct, err) - t.Run(fmt.Sprintf("%s-%s", client.Hostname(), peer.Hostname()), func(t *testing.T) { - command := []string{ - "tailscale", "file", "cp", - fmt.Sprintf("/tmp/file_from_%s", client.Hostname()), - fmt.Sprintf("%s:", peerFQDN), + // Should see the other user1 clients + for _, peer := range user1Clients { + if peer.Hostname() == client.Hostname() { + continue + } + assert.True(ct, isInFileTargets(fts, peer.Hostname()), + "user1 client %s should see user1 peer %s in FileTargets", client.Hostname(), peer.Hostname()) } - err := retry(10, 1*time.Second, func() error { - t.Logf( - "Sending file from %s to %s\n", - client.Hostname(), - peer.Hostname(), - ) - _, _, err := client.Execute(command) + // Should NOT see user2 clients + for _, peer := range user2Clients { + assert.False(ct, isInFileTargets(fts, peer.Hostname()), + "user1 client %s should NOT see user2 peer %s in FileTargets", client.Hostname(), peer.Hostname()) + } - return err + // Should NOT see tagged client + assert.False(ct, isInFileTargets(fts, taggedClient.Hostname()), + "user1 client %s should NOT see tagged client %s in FileTargets", client.Hostname(), taggedClient.Hostname()) + }, 10*time.Second, 1*time.Second) + } + }) + + // Test 2: Verify user2 nodes can see each other in FileTargets but not user1 nodes or tagged node + t.Run("FileTargets-user2", func(t *testing.T) { + for _, client := range user2Clients { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + fts, err := getFileTargets(client) + assert.NoError(ct, err) + + // Should see the other user2 clients + for _, peer := range user2Clients { + if peer.Hostname() == client.Hostname() { + continue + } + assert.True(ct, isInFileTargets(fts, peer.Hostname()), + "user2 client %s should see user2 peer %s in FileTargets", client.Hostname(), peer.Hostname()) + } + + // Should NOT see user1 clients + for _, peer := range user1Clients { + assert.False(ct, isInFileTargets(fts, peer.Hostname()), + "user2 client %s should NOT see user1 peer %s in FileTargets", client.Hostname(), peer.Hostname()) + } + + // Should NOT see tagged client + assert.False(ct, isInFileTargets(fts, taggedClient.Hostname()), + "user2 client %s should NOT see tagged client %s in FileTargets", client.Hostname(), taggedClient.Hostname()) + }, 10*time.Second, 1*time.Second) + } + }) + + // Test 3: Verify tagged device has no FileTargets (empty list) + t.Run("FileTargets-tagged", func(t *testing.T) { + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + fts, err := getFileTargets(taggedClient) + assert.NoError(ct, err) + assert.Empty(ct, fts, "tagged client %s should have no FileTargets", taggedClient.Hostname()) + }, 10*time.Second, 1*time.Second) + }) + + // Test 4: Same-user file transfer works (user1 -> user1) for all version combinations + t.Run("SameUserTransfer", func(t *testing.T) { + for _, sender := range user1Clients { + // Create file on sender + filename := fmt.Sprintf("file_from_%s", sender.Hostname()) + command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)} + _, _, err := sender.Execute(command) + require.NoError(t, err, "failed to create taildrop file on %s", sender.Hostname()) + + for _, receiver := range user1Clients { + if sender.Hostname() == receiver.Hostname() { + continue + } + + receiverFQDN, _ := receiver.FQDN() + + t.Run(fmt.Sprintf("%s->%s", sender.Hostname(), receiver.Hostname()), func(t *testing.T) { + sendCommand := []string{ + "tailscale", "file", "cp", + fmt.Sprintf("/tmp/%s", filename), + fmt.Sprintf("%s:", receiverFQDN), + } + + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + t.Logf("Sending file from %s to %s", sender.Hostname(), receiver.Hostname()) + _, _, err := sender.Execute(sendCommand) + assert.NoError(ct, err) + }, 10*time.Second, 1*time.Second) }) - if err != nil { - t.Fatalf( - "failed to send taildrop file on %s with command %q, err: %s", - client.Hostname(), - strings.Join(command, " "), - err, - ) - } - }) - } - } - - for _, client := range allClients { - command := []string{ - "tailscale", "file", - "get", - "/tmp/", - } - if _, _, err := client.Execute(command); err != nil { - t.Fatalf("failed to get taildrop file on %s, err: %s", client.Hostname(), err) - } - - for _, peer := range allClients { - if client.Hostname() == peer.Hostname() { - continue } - - t.Run(fmt.Sprintf("%s-%s", client.Hostname(), peer.Hostname()), func(t *testing.T) { - command := []string{ - "ls", - fmt.Sprintf("/tmp/file_from_%s", peer.Hostname()), - } - log.Printf( - "Checking file in %s from %s\n", - client.Hostname(), - peer.Hostname(), - ) - - result, _, err := client.Execute(command) - assertNoErrf(t, "failed to execute command to ls taildrop: %s", err) - - log.Printf("Result for %s: %s\n", peer.Hostname(), result) - if fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()) != result { - t.Fatalf( - "taildrop result is not correct %s, wanted %s", - result, - fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()), - ) - } - }) } - } + + // Receive files on all user1 clients + for _, client := range user1Clients { + getCommand := []string{"tailscale", "file", "get", "/tmp/"} + _, _, err := client.Execute(getCommand) + require.NoError(t, err, "failed to get taildrop file on %s", client.Hostname()) + + // Verify files from all other user1 clients exist + for _, peer := range user1Clients { + if client.Hostname() == peer.Hostname() { + continue + } + + t.Run(fmt.Sprintf("verify-%s-received-from-%s", client.Hostname(), peer.Hostname()), func(t *testing.T) { + lsCommand := []string{"ls", fmt.Sprintf("/tmp/file_from_%s", peer.Hostname())} + result, _, err := client.Execute(lsCommand) + require.NoErrorf(t, err, "failed to ls taildrop file from %s", peer.Hostname()) + assert.Equal(t, fmt.Sprintf("/tmp/file_from_%s\n", peer.Hostname()), result) + }) + } + } + }) + + // Test 5: Cross-user file transfer fails (user1 -> user2) + t.Run("CrossUserTransferBlocked", func(t *testing.T) { + sender := user1Clients[0] + receiver := user2Clients[0] + + // Create file on sender + filename := fmt.Sprintf("cross_user_file_from_%s", sender.Hostname()) + command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)} + _, _, err := sender.Execute(command) + require.NoError(t, err, "failed to create taildrop file on %s", sender.Hostname()) + + // Attempt to send file - this should fail + receiverFQDN, _ := receiver.FQDN() + sendCommand := []string{ + "tailscale", "file", "cp", + fmt.Sprintf("/tmp/%s", filename), + fmt.Sprintf("%s:", receiverFQDN), + } + + t.Logf("Attempting cross-user file send from %s to %s (should fail)", sender.Hostname(), receiver.Hostname()) + _, stderr, err := sender.Execute(sendCommand) + + // The file transfer should fail because user2 is not in user1's FileTargets + // Either the command errors, or it silently fails (check stderr for error message) + if err != nil { + t.Logf("Cross-user transfer correctly failed with error: %v", err) + } else if strings.Contains(stderr, "not a valid peer") || strings.Contains(stderr, "unknown target") { + t.Logf("Cross-user transfer correctly rejected: %s", stderr) + } else { + // Even if command succeeded, verify the file was NOT received + getCommand := []string{"tailscale", "file", "get", "/tmp/"} + receiver.Execute(getCommand) + + lsCommand := []string{"ls", fmt.Sprintf("/tmp/%s", filename)} + _, _, lsErr := receiver.Execute(lsCommand) + assert.Error(t, lsErr, "Cross-user file should NOT have been received") + } + }) + + // Test 6: Tagged device cannot send files + t.Run("TaggedCannotSend", func(t *testing.T) { + // Create file on tagged client + filename := fmt.Sprintf("file_from_tagged_%s", taggedClient.Hostname()) + command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)} + _, _, err := taggedClient.Execute(command) + require.NoError(t, err, "failed to create taildrop file on tagged client") + + // Attempt to send to user1 client - should fail because tagged client has no FileTargets + receiver := user1Clients[0] + receiverFQDN, _ := receiver.FQDN() + sendCommand := []string{ + "tailscale", "file", "cp", + fmt.Sprintf("/tmp/%s", filename), + fmt.Sprintf("%s:", receiverFQDN), + } + + t.Logf("Attempting tagged->user file send from %s to %s (should fail)", taggedClient.Hostname(), receiver.Hostname()) + _, stderr, err := taggedClient.Execute(sendCommand) + + if err != nil { + t.Logf("Tagged client send correctly failed with error: %v", err) + } else if strings.Contains(stderr, "not a valid peer") || strings.Contains(stderr, "unknown target") || strings.Contains(stderr, "no matches for") { + t.Logf("Tagged client send correctly rejected: %s", stderr) + } else { + // Verify file was NOT received + getCommand := []string{"tailscale", "file", "get", "/tmp/"} + receiver.Execute(getCommand) + + lsCommand := []string{"ls", fmt.Sprintf("/tmp/%s", filename)} + _, _, lsErr := receiver.Execute(lsCommand) + assert.Error(t, lsErr, "Tagged client's file should NOT have been received") + } + }) + + // Test 7: Tagged device cannot receive files (user1 tries to send to tagged) + t.Run("TaggedCannotReceive", func(t *testing.T) { + sender := user1Clients[0] + + // Create file on sender + filename := fmt.Sprintf("file_to_tagged_from_%s", sender.Hostname()) + command := []string{"touch", fmt.Sprintf("/tmp/%s", filename)} + _, _, err := sender.Execute(command) + require.NoError(t, err, "failed to create taildrop file on %s", sender.Hostname()) + + // Attempt to send to tagged client - should fail because tagged is not in user1's FileTargets + taggedFQDN, _ := taggedClient.FQDN() + sendCommand := []string{ + "tailscale", "file", "cp", + fmt.Sprintf("/tmp/%s", filename), + fmt.Sprintf("%s:", taggedFQDN), + } + + t.Logf("Attempting user->tagged file send from %s to %s (should fail)", sender.Hostname(), taggedClient.Hostname()) + _, stderr, err := sender.Execute(sendCommand) + + if err != nil { + t.Logf("Send to tagged client correctly failed with error: %v", err) + } else if strings.Contains(stderr, "not a valid peer") || strings.Contains(stderr, "unknown target") || strings.Contains(stderr, "no matches for") { + t.Logf("Send to tagged client correctly rejected: %s", stderr) + } else { + // Verify file was NOT received by tagged client + getCommand := []string{"tailscale", "file", "get", "/tmp/"} + taggedClient.Execute(getCommand) + + lsCommand := []string{"ls", fmt.Sprintf("/tmp/%s", filename)} + _, _, lsErr := taggedClient.Execute(lsCommand) + assert.Error(t, lsErr, "File to tagged client should NOT have been received") + } + }) } -func TestResolveMagicDNS(t *testing.T) { +func TestUpdateHostnameFromClient(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "magicdns1": len(MustTestVersions), - "magicdns2": len(MustTestVersions), + hostnames := map[string]string{ + "1": "user1-host", + "2": "user2-host", + "3": "user3-host", } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("magicdns")) - assertNoErrHeadscaleEnv(t, err) + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario") + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("updatehostname")) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) - // Poor mans cache - _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) - - _, err = scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + // update hostnames using the up command for _, client := range allClients { - for _, peer := range allClients { - // It is safe to ignore this error as we handled it when caching it - peerFQDN, _ := peer.FQDN() + status := client.MustStatus() - command := []string{ - "tailscale", - "ip", peerFQDN, - } - result, _, err := client.Execute(command) - if err != nil { - t.Fatalf( - "failed to execute resolve/ip command %s from %s: %s", - peerFQDN, - client.Hostname(), - err, - ) - } + command := []string{ + "tailscale", + "set", + "--hostname=" + hostnames[string(status.Self.ID)], + } + _, _, err = client.Execute(command) + require.NoErrorf(t, err, "failed to set hostname") + } - ips, err := peer.IPs() - if err != nil { - t.Fatalf( - "failed to get ips for %s: %s", - peer.Hostname(), - err, - ) - } + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) - for _, ip := range ips { - if !strings.Contains(result, ip.String()) { - t.Fatalf("ip %s is not found in \n%s\n", ip.String(), result) + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second + var nodes []*v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err := executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) + assert.NoError(ct, err) + assert.Len(ct, nodes, 3, "Should have 3 nodes after hostname updates") + + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + assert.Equal(ct, hostname, node.GetName(), "Node name should match hostname") + + // GivenName is normalized (lowercase, invalid chars stripped) + normalised, err := util.NormaliseHostname(hostname) + assert.NoError(ct, err) + assert.Equal(ct, normalised, node.GetGivenName(), "Given name should match FQDN rules") + } + }, 20*time.Second, 1*time.Second) + + // Rename givenName in nodes + for _, node := range nodes { + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + _, err = headscale.Execute( + []string{ + "headscale", + "node", + "rename", + givenName, + "--identifier", + strconv.FormatUint(node.GetId(), 10), + }) + require.NoError(t, err) + } + + // Verify that the server-side rename is reflected in DNSName while HostName remains unchanged + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // Build a map of expected DNSNames by node ID + expectedDNSNames := make(map[string]string) + for _, node := range nodes { + nodeID := strconv.FormatUint(node.GetId(), 10) + expectedDNSNames[nodeID] = fmt.Sprintf("%d-givenname.headscale.net.", node.GetId()) + } + + // Verify from each client's perspective + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + + // Check self node + selfID := string(status.Self.ID) + expectedDNS := expectedDNSNames[selfID] + assert.Equal(ct, expectedDNS, status.Self.DNSName, + "Self DNSName should be renamed for client %s (ID: %s)", client.Hostname(), selfID) + + // HostName should remain as the original client-reported hostname + originalHostname := hostnames[selfID] + assert.Equal(ct, originalHostname, status.Self.HostName, + "Self HostName should remain unchanged for client %s (ID: %s)", client.Hostname(), selfID) + + // Check peers + for _, peer := range status.Peer { + peerID := string(peer.ID) + if expectedDNS, ok := expectedDNSNames[peerID]; ok { + assert.Equal(ct, expectedDNS, peer.DNSName, + "Peer DNSName should be renamed for peer ID %s as seen by client %s", peerID, client.Hostname()) + + // HostName should remain as the original client-reported hostname + originalHostname := hostnames[peerID] + assert.Equal(ct, originalHostname, peer.HostName, + "Peer HostName should remain unchanged for peer ID %s as seen by client %s", peerID, client.Hostname()) } } } + }, 60*time.Second, 2*time.Second) + + for _, client := range allClients { + status := client.MustStatus() + + command := []string{ + "tailscale", + "set", + "--hostname=" + hostnames[string(status.Self.ID)] + "NEW", + } + _, _, err = client.Execute(command) + require.NoErrorf(t, err, "failed to set hostname") } + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // Wait for nodestore batch processing to complete + // NodeStore batching timeout is 500ms, so we wait up to 1 second + assert.Eventually(t, func() bool { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "node", + "list", + "--output", + "json", + }, + &nodes, + ) + + if err != nil || len(nodes) != 3 { + return false + } + + for _, node := range nodes { + hostname := hostnames[strconv.FormatUint(node.GetId(), 10)] + givenName := fmt.Sprintf("%d-givenname", node.GetId()) + // Hostnames are lowercased before being stored, so "NEW" becomes "new" + if node.GetName() != hostname+"new" || node.GetGivenName() != givenName { + return false + } + } + return true + }, time.Second, 50*time.Millisecond, "hostname updates should be reflected in node list with new suffix") } func TestExpireNode(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("expirenode")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenode")) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClientIPs(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -533,87 +969,228 @@ func TestExpireNode(t *testing.T) { t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps)) for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) - // Assert that we have the original count - self - assert.Len(t, status.Peers(), len(MustTestVersions)-1) + // Assert that we have the original count - self + assert.Len(ct, status.Peers(), spec.NodesPerUser-1, "Client %s should see correct number of peers", client.Hostname()) + }, 30*time.Second, 1*time.Second) } headscale, err := scenario.Headscale() - assertNoErr(t, err) + require.NoError(t, err) // TODO(kradalby): This is Headscale specific and would not play nicely // with other implementations of the ControlServer interface result, err := headscale.Execute([]string{ - "headscale", "nodes", "expire", "--identifier", "0", "--output", "json", + "headscale", "nodes", "expire", "--identifier", "1", "--output", "json", }) - assertNoErr(t, err) + require.NoError(t, err) var node v1.Node err = json.Unmarshal([]byte(result), &node) - assertNoErr(t, err) + require.NoError(t, err) var expiredNodeKey key.NodePublic err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey())) - assertNoErr(t, err) + require.NoError(t, err) t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) - time.Sleep(30 * time.Second) + // Verify that the expired node has been marked in all peers list. + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + for _, client := range allClients { + status, err := client.Status() + assert.NoError(ct, err) + + if client.Hostname() != node.GetName() { + // Check if the expired node appears as expired in this client's peer list + for key, peer := range status.Peer { + if key == expiredNodeKey { + assert.True(ct, peer.Expired, "Node should be marked as expired for client %s", client.Hostname()) + break + } + } + } + } + }, 3*time.Minute, 10*time.Second) now := time.Now() // Verify that the expired node has been marked in all peers list. for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) - - if client.Hostname() != node.GetName() { - t.Logf("available peers of %s: %v", client.Hostname(), status.Peers()) - - // In addition to marking nodes expired, we filter them out during the map response - // this check ensures that the node is either not present, or that it is expired - // if it is in the map response. - if peerStatus, ok := status.Peer[expiredNodeKey]; ok { - assertNotNil(t, peerStatus.Expired) - assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %s should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) - assert.Truef(t, peerStatus.Expired, "node %s should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) - } - - // TODO(kradalby): We do not propogate expiry correctly, nodes should be aware - // of their status, and this should be sent directly to the node when its - // expired. This needs a notifier that goes directly to the node (currently we only do peers) - // so fix this in a follow up PR. - // } else { - // assert.True(t, status.Self.Expired) + if client.Hostname() == node.GetName() { + continue } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + // Ensures that the node is present, and that it is expired. + peerStatus, ok := status.Peer[expiredNodeKey] + assert.True(c, ok, "expired node key should be present in peer list") + + if ok { + assert.NotNil(c, peerStatus.Expired) + assert.NotNil(c, peerStatus.KeyExpiry) + + if peerStatus.KeyExpiry != nil { + assert.Truef( + c, + peerStatus.KeyExpiry.Before(now), + "node %q should have a key expire before %s, was %s", + peerStatus.HostName, + now.String(), + peerStatus.KeyExpiry, + ) + } + + assert.Truef( + c, + peerStatus.Expired, + "node %q should be expired, expired is %v", + peerStatus.HostName, + peerStatus.Expired, + ) + + _, stderr, _ := client.Execute([]string{"tailscale", "ping", node.GetName()}) + if !strings.Contains(stderr, "node key has expired") { + c.Errorf( + "expected to be unable to ping expired host %q from %q", + node.GetName(), + client.Hostname(), + ) + } + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for expired node status to propagate") } } -func TestNodeOnlineLastSeenStatus(t *testing.T) { +// TestSetNodeExpiryInFuture tests setting arbitrary expiration date +// New expiration date should be stored in the db and propagated to all peers +func TestSetNodeExpiryInFuture(t *testing.T) { IntegrationSkip(t) - t.Parallel() - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - spec := map[string]int{ - "user1": len(MustTestVersions), + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("expirenodefuture")) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - allIps, err := scenario.ListTailscaleClientsIPs() - assertNoErrListClientIPs(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + targetExpiry := time.Now().Add(2 * time.Hour).Round(time.Second).UTC() + + result, err := headscale.Execute( + []string{ + "headscale", "nodes", "expire", + "--identifier", "1", + "--output", "json", + "--expiry", targetExpiry.Format(time.RFC3339), + }, + ) + require.NoError(t, err) + + var node v1.Node + err = json.Unmarshal([]byte(result), &node) + require.NoError(t, err) + + require.True(t, node.GetExpiry().AsTime().After(time.Now())) + require.WithinDuration(t, targetExpiry, node.GetExpiry().AsTime(), 2*time.Second) + + var nodeKey key.NodePublic + err = nodeKey.UnmarshalText([]byte(node.GetNodeKey())) + require.NoError(t, err) + + for _, client := range allClients { + if client.Hostname() == node.GetName() { + continue + } + + assert.EventuallyWithT( + t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + + peerStatus, ok := status.Peer[nodeKey] + assert.True(ct, ok, "node key should be present in peer list") + + if !ok { + return + } + + assert.NotNil(ct, peerStatus.KeyExpiry) + assert.NotNil(ct, peerStatus.Expired) + + if peerStatus.KeyExpiry != nil { + assert.WithinDuration( + ct, + targetExpiry, + *peerStatus.KeyExpiry, + 5*time.Second, + "node %q should have key expiry near the requested future time", + peerStatus.HostName, + ) + + assert.Truef( + ct, + peerStatus.KeyExpiry.After(time.Now()), + "node %q should have a key expiry timestamp in the future", + peerStatus.HostName, + ) + } + + assert.Falsef( + ct, + peerStatus.Expired, + "node %q should not be marked as expired", + peerStatus.HostName, + ) + }, 3*time.Minute, 5*time.Second, "Waiting for future expiry to propagate", + ) + } +} + +func TestNodeOnlineStatus(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("online")) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { return x.String() @@ -623,17 +1200,17 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps)) for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) - // Assert that we have the original count - self - assert.Len(t, status.Peers(), len(MustTestVersions)-1) + // Assert that we have the original count - self + assert.Len(c, status.Peers(), len(MustTestVersions)-1) + }, 10*time.Second, 200*time.Millisecond, "Waiting for expected peer count") } headscale, err := scenario.Headscale() - assertNoErr(t, err) - - keepAliveInterval := 60 * time.Second + require.NoError(t, err) // Duration is chosen arbitrarily, 10m is reported in #1561 testDuration := 12 * time.Minute @@ -649,81 +1226,284 @@ func TestNodeOnlineLastSeenStatus(t *testing.T) { return } - result, err := headscale.Execute([]string{ - "headscale", "nodes", "list", "--output", "json", - }) - assertNoErr(t, err) - var nodes []*v1.Node - err = json.Unmarshal([]byte(result), &nodes) - assertNoErr(t, err) + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, err := headscale.Execute([]string{ + "headscale", "nodes", "list", "--output", "json", + }) + assert.NoError(ct, err) - now := time.Now() + err = json.Unmarshal([]byte(result), &nodes) + assert.NoError(ct, err) - // Threshold with some leeway - lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second)) - - // Verify that headscale reports the nodes as online - for _, node := range nodes { - // All nodes should be online - assert.Truef( - t, - node.GetOnline(), - "expected %s to have online status in Headscale, marked as offline %s after start", - node.GetName(), - time.Since(start), - ) - - lastSeen := node.GetLastSeen().AsTime() - // All nodes should have been last seen between now and the keepAliveInterval - assert.Truef( - t, - lastSeen.After(lastSeenThreshold), - "lastSeen (%v) was not %s after the threshold (%v)", - lastSeen, - keepAliveInterval, - lastSeenThreshold, - ) - } + // Verify that headscale reports the nodes as online + for _, node := range nodes { + // All nodes should be online + assert.Truef( + ct, + node.GetOnline(), + "expected %s to have online status in Headscale, marked as offline %s after start", + node.GetName(), + time.Since(start), + ) + } + }, 15*time.Second, 1*time.Second) // Verify that all nodes report all nodes to be online for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) - - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - // .Online is only available from CapVer 16, which - // is not present in 1.18 which is the lowest we - // test. - if strings.Contains(client.Hostname(), "1-18") { - continue + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + status, err := client.Status() + assert.NoError(ct, err) + if status == nil { + assert.Fail(ct, "status is nil") + return } - // All peers of this nodess are reporting to be - // connected to the control server - assert.Truef( - t, - peerStatus.Online, - "expected node %s to be marked as online in %s peer list, marked as offline %s after start", - peerStatus.HostName, - client.Hostname(), - time.Since(start), - ) + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - // from docs: last seen to tailcontrol; only present if offline - // assert.Nilf( - // t, - // peerStatus.LastSeen, - // "expected node %s to not have LastSeen set, got %s", - // peerStatus.HostName, - // peerStatus.LastSeen, - // ) - } + // .Online is only available from CapVer 16, which + // is not present in 1.18 which is the lowest we + // test. + if strings.Contains(client.Hostname(), "1-18") { + continue + } + + // All peers of this nodes are reporting to be + // connected to the control server + assert.Truef( + ct, + peerStatus.Online, + "expected node %s to be marked as online in %s peer list, marked as offline %s after start", + peerStatus.HostName, + client.Hostname(), + time.Since(start), + ) + } + }, 15*time.Second, 1*time.Second) } // Check maximum once per second time.Sleep(time.Second) } } + +// TestPingAllByIPManyUpDown is a variant of the PingAll +// test which will take the tailscale node up and down +// five times ensuring they are able to restablish connectivity. +func TestPingAllByIPManyUpDown(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: len(MustTestVersions), + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("pingallbyipmany"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithDERPAsIP(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // assertClientsState(t, allClients) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + // Get headscale instance for batcher debug checks + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Initial check: all nodes should be connected to batcher + // Extract node IDs for validation + expectedNodes := make([]types.NodeID, 0, len(allClients)) + for _, client := range allClients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected to batcher", 30*time.Second) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + for run := range 3 { + t.Logf("Starting DownUpPing run %d at %s", run+1, time.Now().Format(TimestampFormat)) + + // Create fresh errgroup with timeout for each run + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + wg, _ := errgroup.WithContext(ctx) + + for _, client := range allClients { + c := client + wg.Go(func() error { + t.Logf("taking down %q", c.Hostname()) + return c.Down() + }) + } + + if err := wg.Wait(); err != nil { + t.Fatalf("failed to take down all nodes: %s", err) + } + t.Logf("All nodes taken down at %s", time.Now().Format(TimestampFormat)) + + // After taking down all nodes, verify all systems show nodes offline + requireAllClientsOnline(t, headscale, expectedNodes, false, fmt.Sprintf("Run %d: all nodes should be offline after Down()", run+1), 120*time.Second) + + for _, client := range allClients { + c := client + wg.Go(func() error { + t.Logf("bringing up %q", c.Hostname()) + return c.Up() + }) + } + + if err := wg.Wait(); err != nil { + t.Fatalf("failed to bring up all nodes: %s", err) + } + t.Logf("All nodes brought up at %s", time.Now().Format(TimestampFormat)) + + // After bringing up all nodes, verify batcher shows all reconnected + requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all nodes should be reconnected after Up()", run+1), 120*time.Second) + + // Wait for sync and successful pings after nodes come back up + err = scenario.WaitForTailscaleSync() + assert.NoError(t, err) + + t.Logf("All nodes synced up %s", time.Now().Format(TimestampFormat)) + + requireAllClientsOnline(t, headscale, expectedNodes, true, fmt.Sprintf("Run %d: all systems should show nodes online after reconnection", run+1), 60*time.Second) + + success := pingAllHelper(t, allClients, allAddrs) + assert.Equalf(t, len(allClients)*len(allIps), success, "%d successful pings out of %d", success, len(allClients)*len(allIps)) + + // Clean up context for this run + cancel() + } +} + +func Test2118DeletingOnlineNodePanics(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + } + + scenario, err := NewScenario(spec) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithTestName("deletenocrash"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + requireNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("%d successful pings out of %d", success, len(allClients)*len(allIps)) + + headscale, err := scenario.Headscale() + require.NoError(t, err) + + // Test list all nodes after added otherUser + var nodeList []v1.Node + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeList, + ) + require.NoError(t, err) + assert.Len(t, nodeList, 2) + assert.True(t, nodeList[0].GetOnline()) + assert.True(t, nodeList[1].GetOnline()) + + // Delete the first node, which is online + _, err = headscale.Execute( + []string{ + "headscale", + "nodes", + "delete", + "--identifier", + // Delete the last added machine + fmt.Sprintf("%d", nodeList[0].GetId()), + "--output", + "json", + "--force", + }, + ) + require.NoError(t, err) + + // Ensure that the node has been deleted, this did not occur due to a panic. + var nodeListAfter []v1.Node + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeListAfter, + ) + assert.NoError(ct, err) + assert.Len(ct, nodeListAfter, 1, "Node should be deleted from list") + }, 10*time.Second, 1*time.Second) + + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "nodes", + "list", + "--output", + "json", + }, + &nodeListAfter, + ) + require.NoError(t, err) + assert.Len(t, nodeListAfter, 1) + assert.True(t, nodeListAfter[0].GetOnline()) + assert.Equal(t, nodeList[1].GetId(), nodeListAfter[0].GetId()) +} diff --git a/integration/helpers.go b/integration/helpers.go new file mode 100644 index 00000000..7d40c8e6 --- /dev/null +++ b/integration/helpers.go @@ -0,0 +1,1042 @@ +package integration + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/netip" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/cenkalti/backoff/v5" + "github.com/google/go-cmp/cmp" + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/integrationutil" + "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/exp/maps" + "golang.org/x/exp/slices" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +const ( + // derpPingTimeout defines the timeout for individual DERP ping operations + // Used in DERP connectivity tests to verify relay server communication. + derpPingTimeout = 2 * time.Second + + // derpPingCount defines the number of ping attempts for DERP connectivity tests + // Higher count provides better reliability assessment of DERP connectivity. + derpPingCount = 10 + + // TimestampFormat is the standard timestamp format used across all integration tests + // Format: "2006-01-02T15-04-05.999999999" provides high precision timestamps + // suitable for debugging and log correlation in integration tests. + TimestampFormat = "2006-01-02T15-04-05.999999999" + + // TimestampFormatRunID is used for generating unique run identifiers + // Format: "20060102-150405" provides compact date-time for file/directory names. + TimestampFormatRunID = "20060102-150405" +) + +// NodeSystemStatus represents the status of a node across different systems +type NodeSystemStatus struct { + Batcher bool + BatcherConnCount int + MapResponses bool + NodeStore bool +} + +// requireNoErrHeadscaleEnv validates that headscale environment creation succeeded. +// Provides specific error context for headscale environment setup failures. +func requireNoErrHeadscaleEnv(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to create headscale environment") +} + +// requireNoErrGetHeadscale validates that headscale server retrieval succeeded. +// Provides specific error context for headscale server access failures. +func requireNoErrGetHeadscale(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to get headscale") +} + +// requireNoErrListClients validates that client listing operations succeeded. +// Provides specific error context for client enumeration failures. +func requireNoErrListClients(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to list clients") +} + +// requireNoErrListClientIPs validates that client IP retrieval succeeded. +// Provides specific error context for client IP address enumeration failures. +func requireNoErrListClientIPs(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to get client IPs") +} + +// requireNoErrSync validates that client synchronization operations succeeded. +// Provides specific error context for client sync failures across the network. +func requireNoErrSync(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to have all clients sync up") +} + +// requireNoErrListFQDN validates that FQDN listing operations succeeded. +// Provides specific error context for DNS name enumeration failures. +func requireNoErrListFQDN(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to list FQDNs") +} + +// requireNoErrLogout validates that tailscale node logout operations succeeded. +// Provides specific error context for client logout failures. +func requireNoErrLogout(t *testing.T, err error) { + t.Helper() + require.NoError(t, err, "failed to log out tailscale nodes") +} + +// collectExpectedNodeIDs extracts node IDs from a list of TailscaleClients for validation purposes +func collectExpectedNodeIDs(t *testing.T, clients []TailscaleClient) []types.NodeID { + t.Helper() + + expectedNodes := make([]types.NodeID, 0, len(clients)) + for _, client := range clients { + status := client.MustStatus() + nodeID, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + require.NoError(t, err) + expectedNodes = append(expectedNodes, types.NodeID(nodeID)) + } + return expectedNodes +} + +// validateInitialConnection performs comprehensive validation after initial client login. +// Validates that all nodes are online and have proper NetInfo/DERP configuration, +// essential for ensuring successful initial connection state in relogin tests. +func validateInitialConnection(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after initial login", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after initial login", 3*time.Minute) +} + +// validateLogoutComplete performs comprehensive validation after client logout. +// Ensures all nodes are properly offline across all headscale systems, +// critical for validating clean logout state in relogin tests. +func validateLogoutComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, false, "all nodes should be offline after logout", 120*time.Second) +} + +// validateReloginComplete performs comprehensive validation after client relogin. +// Validates that all nodes are back online with proper NetInfo/DERP configuration, +// ensuring successful relogin state restoration in integration tests. +func validateReloginComplete(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID) { + t.Helper() + + requireAllClientsOnline(t, headscale, expectedNodes, true, "all clients should be connected after relogin", 120*time.Second) + requireAllClientsNetInfoAndDERP(t, headscale, expectedNodes, "all clients should have NetInfo and DERP after relogin", 3*time.Minute) +} + +// requireAllClientsOnline validates that all nodes are online/offline across all headscale systems +// requireAllClientsOnline verifies all expected nodes are in the specified online state across all systems +func requireAllClientsOnline(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + t.Logf("requireAllSystemsOnline: Starting %s validation for %d nodes at %s - %s", stateStr, len(expectedNodes), startTime.Format(TimestampFormat), message) + + if expectedOnline { + // For online validation, use the existing logic with full timeout + requireAllClientsOnlineWithSingleTimeout(t, headscale, expectedNodes, expectedOnline, message, timeout) + } else { + // For offline validation, use staged approach with component-specific timeouts + requireAllClientsOfflineStaged(t, headscale, expectedNodes, message, timeout) + } + + endTime := time.Now() + t.Logf("requireAllSystemsOnline: Completed %s validation for %d nodes at %s - Duration: %s - %s", stateStr, len(expectedNodes), endTime.Format(TimestampFormat), endTime.Sub(startTime), message) +} + +// requireAllClientsOnlineWithSingleTimeout is the original validation logic for online state +func requireAllClientsOnlineWithSingleTimeout(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, expectedOnline bool, message string, timeout time.Duration) { + t.Helper() + + var prevReport string + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get batcher state + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + // Get map responses + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate that all expected nodes are present in nodeStore + for _, nodeID := range expectedNodes { + _, exists := nodeStore[nodeID] + assert.True(c, exists, "Expected node %d not found in nodeStore", nodeID) + } + + // Check that we have map responses for expected nodes + mapResponseCount := len(mapResponses) + expectedCount := len(expectedNodes) + assert.GreaterOrEqual(c, mapResponseCount, expectedCount, "MapResponses insufficient - expected at least %d responses, got %d", expectedCount, mapResponseCount) + + // Build status map for each node + nodeStatus := make(map[types.NodeID]NodeSystemStatus) + + // Initialize all expected nodes + for _, nodeID := range expectedNodes { + nodeStatus[nodeID] = NodeSystemStatus{} + } + + // Check batcher state for expected nodes + for _, nodeID := range expectedNodes { + nodeIDStr := fmt.Sprintf("%d", nodeID) + if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists { + if status, exists := nodeStatus[nodeID]; exists { + status.Batcher = nodeInfo.Connected + status.BatcherConnCount = nodeInfo.ActiveConnections + nodeStatus[nodeID] = status + } + } else { + // Node not found in batcher, mark as disconnected + if status, exists := nodeStatus[nodeID]; exists { + status.Batcher = false + status.BatcherConnCount = 0 + nodeStatus[nodeID] = status + } + } + } + + // Check map responses using buildExpectedOnlineMap + onlineFromMaps := make(map[types.NodeID]bool) + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + + // For single node scenarios, we can't validate peer visibility since there are no peers + if len(expectedNodes) == 1 { + // For single node, just check that we have map responses for the node + for nodeID := range nodeStatus { + if _, exists := onlineMap[nodeID]; exists { + onlineFromMaps[nodeID] = true + } else { + onlineFromMaps[nodeID] = false + } + } + } else { + // Multi-node scenario: check peer visibility + for nodeID := range nodeStatus { + // Initialize as offline - will be set to true only if visible in all relevant peer maps + onlineFromMaps[nodeID] = false + + // Count how many peer maps should show this node + expectedPeerMaps := 0 + foundOnlinePeerMaps := 0 + + for id, peerMap := range onlineMap { + if id == nodeID { + continue // Skip self-references + } + expectedPeerMaps++ + + if online, exists := peerMap[nodeID]; exists && online { + foundOnlinePeerMaps++ + } + } + + // Node is considered online if it appears online in all peer maps + // (or if there are no peer maps to check) + if expectedPeerMaps == 0 || foundOnlinePeerMaps == expectedPeerMaps { + onlineFromMaps[nodeID] = true + } + } + } + assert.Lenf(c, onlineFromMaps, expectedCount, "MapResponses missing nodes in status check") + + // Update status with map response data + for nodeID, online := range onlineFromMaps { + if status, exists := nodeStatus[nodeID]; exists { + status.MapResponses = online + nodeStatus[nodeID] = status + } + } + + // Check nodestore state for expected nodes + for _, nodeID := range expectedNodes { + if node, exists := nodeStore[nodeID]; exists { + if status, exists := nodeStatus[nodeID]; exists { + // Check if node is online in nodestore + status.NodeStore = node.IsOnline != nil && *node.IsOnline + nodeStatus[nodeID] = status + } + } + } + + // Verify all systems show nodes in expected state and report failures + allMatch := true + var failureReport strings.Builder + + ids := types.NodeIDs(maps.Keys(nodeStatus)) + slices.Sort(ids) + for _, nodeID := range ids { + status := nodeStatus[nodeID] + systemsMatch := (status.Batcher == expectedOnline) && + (status.MapResponses == expectedOnline) && + (status.NodeStore == expectedOnline) + + if !systemsMatch { + allMatch = false + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + failureReport.WriteString(fmt.Sprintf("node:%d is not fully %s (timestamp: %s):\n", nodeID, stateStr, time.Now().Format(TimestampFormat))) + failureReport.WriteString(fmt.Sprintf(" - batcher: %t (expected: %t)\n", status.Batcher, expectedOnline)) + failureReport.WriteString(fmt.Sprintf(" - conn count: %d\n", status.BatcherConnCount)) + failureReport.WriteString(fmt.Sprintf(" - mapresponses: %t (expected: %t, down with at least one peer)\n", status.MapResponses, expectedOnline)) + failureReport.WriteString(fmt.Sprintf(" - nodestore: %t (expected: %t)\n", status.NodeStore, expectedOnline)) + } + } + + if !allMatch { + if diff := cmp.Diff(prevReport, failureReport.String()); diff != "" { + t.Logf("Node state validation report changed at %s:", time.Now().Format(TimestampFormat)) + t.Logf("Previous report:\n%s", prevReport) + t.Logf("Current report:\n%s", failureReport.String()) + t.Logf("Report diff:\n%s", diff) + prevReport = failureReport.String() + } + + failureReport.WriteString(fmt.Sprintf("validation_timestamp: %s\n", time.Now().Format(TimestampFormat))) + // Note: timeout_remaining not available in this context + + assert.Fail(c, failureReport.String()) + } + + stateStr := "offline" + if expectedOnline { + stateStr = "online" + } + assert.True(c, allMatch, fmt.Sprintf("Not all %d nodes are %s across all systems (batcher, mapresponses, nodestore)", len(expectedNodes), stateStr)) + }, timeout, 2*time.Second, message) +} + +// requireAllClientsOfflineStaged validates offline state with staged timeouts for different components +func requireAllClientsOfflineStaged(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, totalTimeout time.Duration) { + t.Helper() + + // Stage 1: Verify batcher disconnection (should be immediate) + t.Logf("Stage 1: Verifying batcher disconnection for %d nodes", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + debugInfo, err := headscale.DebugBatcher() + assert.NoError(c, err, "Failed to get batcher debug info") + if err != nil { + return + } + + allBatcherOffline := true + for _, nodeID := range expectedNodes { + nodeIDStr := fmt.Sprintf("%d", nodeID) + if nodeInfo, exists := debugInfo.ConnectedNodes[nodeIDStr]; exists && nodeInfo.Connected { + allBatcherOffline = false + assert.False(c, nodeInfo.Connected, "Node %d should not be connected in batcher", nodeID) + } + } + assert.True(c, allBatcherOffline, "All nodes should be disconnected from batcher") + }, 15*time.Second, 1*time.Second, "batcher disconnection validation") + + // Stage 2: Verify nodestore offline status (up to 15 seconds due to disconnect detection delay) + t.Logf("Stage 2: Verifying nodestore offline status for %d nodes (allowing for 10s disconnect detection delay)", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + allNodeStoreOffline := true + for _, nodeID := range expectedNodes { + if node, exists := nodeStore[nodeID]; exists { + isOnline := node.IsOnline != nil && *node.IsOnline + if isOnline { + allNodeStoreOffline = false + assert.False(c, isOnline, "Node %d should be offline in nodestore", nodeID) + } + } + } + assert.True(c, allNodeStoreOffline, "All nodes should be offline in nodestore") + }, 20*time.Second, 1*time.Second, "nodestore offline validation") + + // Stage 3: Verify map response propagation (longest delay due to peer update timing) + t.Logf("Stage 3: Verifying map response propagation for %d nodes (allowing for peer map update delays)", len(expectedNodes)) + require.EventuallyWithT(t, func(c *assert.CollectT) { + mapResponses, err := headscale.GetAllMapReponses() + assert.NoError(c, err, "Failed to get map responses") + if err != nil { + return + } + + onlineMap := integrationutil.BuildExpectedOnlineMap(mapResponses) + allMapResponsesOffline := true + + if len(expectedNodes) == 1 { + // Single node: check if it appears in map responses + for nodeID := range onlineMap { + if slices.Contains(expectedNodes, nodeID) { + allMapResponsesOffline = false + assert.False(c, true, "Node %d should not appear in map responses", nodeID) + } + } + } else { + // Multi-node: check peer visibility + for _, nodeID := range expectedNodes { + for id, peerMap := range onlineMap { + if id == nodeID { + continue // Skip self-references + } + if online, exists := peerMap[nodeID]; exists && online { + allMapResponsesOffline = false + assert.False(c, online, "Node %d should not be visible in node %d's map response", nodeID, id) + } + } + } + } + assert.True(c, allMapResponsesOffline, "All nodes should be absent from peer map responses") + }, 60*time.Second, 2*time.Second, "map response propagation validation") + + t.Logf("All stages completed: nodes are fully offline across all systems") +} + +// requireAllClientsNetInfoAndDERP validates that all nodes have NetInfo in the database +// and a valid DERP server based on the NetInfo. This function follows the pattern of +// requireAllClientsOnline by using hsic.DebugNodeStore to get the database state. +func requireAllClientsNetInfoAndDERP(t *testing.T, headscale ControlServer, expectedNodes []types.NodeID, message string, timeout time.Duration) { + t.Helper() + + startTime := time.Now() + t.Logf("requireAllClientsNetInfoAndDERP: Starting NetInfo/DERP validation for %d nodes at %s - %s", len(expectedNodes), startTime.Format(TimestampFormat), message) + + require.EventuallyWithT(t, func(c *assert.CollectT) { + // Get nodestore state + nodeStore, err := headscale.DebugNodeStore() + assert.NoError(c, err, "Failed to get nodestore debug info") + if err != nil { + return + } + + // Validate that all expected nodes are present in nodeStore + for _, nodeID := range expectedNodes { + _, exists := nodeStore[nodeID] + assert.True(c, exists, "Expected node %d not found in nodeStore during NetInfo validation", nodeID) + } + + // Check each expected node + for _, nodeID := range expectedNodes { + node, exists := nodeStore[nodeID] + assert.True(c, exists, "Node %d not found in nodestore during NetInfo validation", nodeID) + if !exists { + continue + } + + // Validate that the node has Hostinfo + assert.NotNil(c, node.Hostinfo, "Node %d (%s) should have Hostinfo for NetInfo validation", nodeID, node.Hostname) + if node.Hostinfo == nil { + t.Logf("Node %d (%s) missing Hostinfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) + continue + } + + // Validate that the node has NetInfo + assert.NotNil(c, node.Hostinfo.NetInfo, "Node %d (%s) should have NetInfo in Hostinfo for DERP connectivity", nodeID, node.Hostname) + if node.Hostinfo.NetInfo == nil { + t.Logf("Node %d (%s) missing NetInfo at %s", nodeID, node.Hostname, time.Now().Format(TimestampFormat)) + continue + } + + // Validate that the node has a valid DERP server (PreferredDERP should be > 0) + preferredDERP := node.Hostinfo.NetInfo.PreferredDERP + assert.Greater(c, preferredDERP, 0, "Node %d (%s) should have a valid DERP server (PreferredDERP > 0) for relay connectivity, got %d", nodeID, node.Hostname, preferredDERP) + + t.Logf("Node %d (%s) has valid NetInfo with DERP server %d at %s", nodeID, node.Hostname, preferredDERP, time.Now().Format(TimestampFormat)) + } + }, timeout, 5*time.Second, message) + + endTime := time.Now() + duration := endTime.Sub(startTime) + t.Logf("requireAllClientsNetInfoAndDERP: Completed NetInfo/DERP validation for %d nodes at %s - Duration: %v - %s", len(expectedNodes), endTime.Format(TimestampFormat), duration, message) +} + +// assertLastSeenSet validates that a node has a non-nil LastSeen timestamp. +// Critical for ensuring node activity tracking is functioning properly. +func assertLastSeenSet(t *testing.T, node *v1.Node) { + assert.NotNil(t, node) + assert.NotNil(t, node.GetLastSeen()) +} + +func assertLastSeenSetWithCollect(c *assert.CollectT, node *v1.Node) { + assert.NotNil(c, node) + assert.NotNil(c, node.GetLastSeen()) +} + +// assertTailscaleNodesLogout verifies that all provided Tailscale clients +// are in the logged-out state (NeedsLogin). +func assertTailscaleNodesLogout(t assert.TestingT, clients []TailscaleClient) { + if h, ok := t.(interface{ Helper() }); ok { + h.Helper() + } + + for _, client := range clients { + status, err := client.Status() + assert.NoError(t, err, "failed to get status for client %s", client.Hostname()) + assert.Equal(t, "NeedsLogin", status.BackendState, + "client %s should be logged out", client.Hostname()) + } +} + +// pingAllHelper performs ping tests between all clients and addresses, returning success count. +// This is used to validate network connectivity in integration tests. +// Returns the total number of successful ping operations. +func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, addr := range addrs { + err := client.Ping(addr, opts...) + if err != nil { + t.Errorf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + +// pingDerpAllHelper performs DERP-based ping tests between all clients and addresses. +// This specifically tests connectivity through DERP relay servers, which is important +// for validating NAT traversal and relay functionality. Returns success count. +func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { + t.Helper() + success := 0 + + for _, client := range clients { + for _, addr := range addrs { + if isSelfClient(client, addr) { + continue + } + + err := client.Ping( + addr, + tsic.WithPingTimeout(derpPingTimeout), + tsic.WithPingCount(derpPingCount), + tsic.WithPingUntilDirect(false), + ) + if err != nil { + t.Logf("failed to ping %s from %s: %s", addr, client.Hostname(), err) + } else { + success++ + } + } + } + + return success +} + +// isSelfClient determines if the given address belongs to the client itself. +// Used to avoid self-ping operations in connectivity tests by checking +// hostname and IP address matches. +func isSelfClient(client TailscaleClient, addr string) bool { + if addr == client.Hostname() { + return true + } + + ips, err := client.IPs() + if err != nil { + return false + } + + for _, ip := range ips { + if ip.String() == addr { + return true + } + } + + return false +} + +// assertClientsState validates the status and netmap of a list of clients for general connectivity. +// Runs parallel validation of status, netcheck, and netmap for all clients to ensure +// they have proper network configuration for all-to-all connectivity tests. +func assertClientsState(t *testing.T, clients []TailscaleClient) { + t.Helper() + + var wg sync.WaitGroup + + for _, client := range clients { + wg.Add(1) + c := client // Avoid loop pointer + go func() { + defer wg.Done() + assertValidStatus(t, c) + assertValidNetcheck(t, c) + assertValidNetmap(t, c) + }() + } + + t.Logf("waiting for client state checks to finish") + wg.Wait() +} + +// assertValidNetmap validates that a client's netmap has all required fields for proper operation. +// Checks self node and all peers for essential networking data including hostinfo, addresses, +// endpoints, and DERP configuration. Skips validation for Tailscale versions below 1.56. +// This test is not suitable for ACL/partial connection tests. +func assertValidNetmap(t *testing.T, client TailscaleClient) { + t.Helper() + + if !util.TailscaleVersionNewerOrEqual("1.56", client.Version()) { + t.Logf("%q has version %q, skipping netmap check...", client.Hostname(), client.Version()) + + return + } + + t.Logf("Checking netmap of %q", client.Hostname()) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + netmap, err := client.Netmap() + assert.NoError(c, err, "getting netmap for %q", client.Hostname()) + + assert.Truef(c, netmap.SelfNode.Hostinfo().Valid(), "%q does not have Hostinfo", client.Hostname()) + if hi := netmap.SelfNode.Hostinfo(); hi.Valid() { + assert.LessOrEqual(c, 1, netmap.SelfNode.Hostinfo().Services().Len(), "%q does not have enough services, got: %v", client.Hostname(), netmap.SelfNode.Hostinfo().Services()) + } + + assert.NotEmptyf(c, netmap.SelfNode.AllowedIPs(), "%q does not have any allowed IPs", client.Hostname()) + assert.NotEmptyf(c, netmap.SelfNode.Addresses(), "%q does not have any addresses", client.Hostname()) + + assert.Truef(c, netmap.SelfNode.Online().Get(), "%q is not online", client.Hostname()) + + assert.Falsef(c, netmap.SelfNode.Key().IsZero(), "%q does not have a valid NodeKey", client.Hostname()) + assert.Falsef(c, netmap.SelfNode.Machine().IsZero(), "%q does not have a valid MachineKey", client.Hostname()) + assert.Falsef(c, netmap.SelfNode.DiscoKey().IsZero(), "%q does not have a valid DiscoKey", client.Hostname()) + + for _, peer := range netmap.Peers { + assert.NotEqualf(c, "127.3.3.40:0", peer.LegacyDERPString(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.LegacyDERPString()) + assert.NotEqualf(c, 0, peer.HomeDERP(), "peer (%s) has no home DERP in %q's netmap, got: %d", peer.ComputedName(), client.Hostname(), peer.HomeDERP()) + + assert.Truef(c, peer.Hostinfo().Valid(), "peer (%s) of %q does not have Hostinfo", peer.ComputedName(), client.Hostname()) + if hi := peer.Hostinfo(); hi.Valid() { + assert.LessOrEqualf(c, 3, peer.Hostinfo().Services().Len(), "peer (%s) of %q does not have enough services, got: %v", peer.ComputedName(), client.Hostname(), peer.Hostinfo().Services()) + + // Netinfo is not always set + // assert.Truef(c, hi.NetInfo().Valid(), "peer (%s) of %q does not have NetInfo", peer.ComputedName(), client.Hostname()) + if ni := hi.NetInfo(); ni.Valid() { + assert.NotEqualf(c, 0, ni.PreferredDERP(), "peer (%s) has no home DERP in %q's netmap, got: %s", peer.ComputedName(), client.Hostname(), peer.Hostinfo().NetInfo().PreferredDERP()) + } + } + + assert.NotEmptyf(c, peer.Endpoints(), "peer (%s) of %q does not have any endpoints", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(c, peer.AllowedIPs(), "peer (%s) of %q does not have any allowed IPs", peer.ComputedName(), client.Hostname()) + assert.NotEmptyf(c, peer.Addresses(), "peer (%s) of %q does not have any addresses", peer.ComputedName(), client.Hostname()) + + assert.Truef(c, peer.Online().Get(), "peer (%s) of %q is not online", peer.ComputedName(), client.Hostname()) + + assert.Falsef(c, peer.Key().IsZero(), "peer (%s) of %q does not have a valid NodeKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(c, peer.Machine().IsZero(), "peer (%s) of %q does not have a valid MachineKey", peer.ComputedName(), client.Hostname()) + assert.Falsef(c, peer.DiscoKey().IsZero(), "peer (%s) of %q does not have a valid DiscoKey", peer.ComputedName(), client.Hostname()) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for valid netmap for %q", client.Hostname()) +} + +// assertValidStatus validates that a client's status has all required fields for proper operation. +// Checks self and peer status for essential data including hostinfo, tailscale IPs, endpoints, +// and network map presence. This test is not suitable for ACL/partial connection tests. +func assertValidStatus(t *testing.T, client TailscaleClient) { + t.Helper() + status, err := client.Status(true) + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEmptyf(t, status.Self.HostName, "%q does not have HostName set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.OS, "%q does not have OS set, likely missing Hostinfo", client.Hostname()) + assert.NotEmptyf(t, status.Self.Relay, "%q does not have a relay, likely missing Hostinfo/Netinfo", client.Hostname()) + + assert.NotEmptyf(t, status.Self.TailscaleIPs, "%q does not have Tailscale IPs", client.Hostname()) + + // This seem to not appear until version 1.56 + if status.Self.AllowedIPs != nil { + assert.NotEmptyf(t, status.Self.AllowedIPs, "%q does not have any allowed IPs", client.Hostname()) + } + + assert.NotEmptyf(t, status.Self.Addrs, "%q does not have any endpoints", client.Hostname()) + + assert.Truef(t, status.Self.Online, "%q is not online", client.Hostname()) + + assert.Truef(t, status.Self.InNetworkMap, "%q is not in network map", client.Hostname()) + + // This isn't really relevant for Self as it won't be in its own socket/wireguard. + // assert.Truef(t, status.Self.InMagicSock, "%q is not tracked by magicsock", client.Hostname()) + // assert.Truef(t, status.Self.InEngine, "%q is not in wireguard engine", client.Hostname()) + + for _, peer := range status.Peer { + assert.NotEmptyf(t, peer.HostName, "peer (%s) of %q does not have HostName set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.OS, "peer (%s) of %q does not have OS set, likely missing Hostinfo", peer.DNSName, client.Hostname()) + assert.NotEmptyf(t, peer.Relay, "peer (%s) of %q does not have a relay, likely missing Hostinfo/Netinfo", peer.DNSName, client.Hostname()) + + assert.NotEmptyf(t, peer.TailscaleIPs, "peer (%s) of %q does not have Tailscale IPs", peer.DNSName, client.Hostname()) + + // This seem to not appear until version 1.56 + if peer.AllowedIPs != nil { + assert.NotEmptyf(t, peer.AllowedIPs, "peer (%s) of %q does not have any allowed IPs", peer.DNSName, client.Hostname()) + } + + // Addrs does not seem to appear in the status from peers. + // assert.NotEmptyf(t, peer.Addrs, "peer (%s) of %q does not have any endpoints", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.Online, "peer (%s) of %q is not online", peer.DNSName, client.Hostname()) + + assert.Truef(t, peer.InNetworkMap, "peer (%s) of %q is not in network map", peer.DNSName, client.Hostname()) + assert.Truef(t, peer.InMagicSock, "peer (%s) of %q is not tracked by magicsock", peer.DNSName, client.Hostname()) + + // TODO(kradalby): InEngine is only true when a proper tunnel is set up, + // there might be some interesting stuff to test here in the future. + // assert.Truef(t, peer.InEngine, "peer (%s) of %q is not in wireguard engine", peer.DNSName, client.Hostname()) + } +} + +// assertValidNetcheck validates that a client has a proper DERP relay configured. +// Ensures the client has discovered and selected a DERP server for relay functionality, +// which is essential for NAT traversal and connectivity in restricted networks. +func assertValidNetcheck(t *testing.T, client TailscaleClient) { + t.Helper() + report, err := client.Netcheck() + if err != nil { + t.Fatalf("getting status for %q: %s", client.Hostname(), err) + } + + assert.NotEqualf(t, 0, report.PreferredDERP, "%q does not have a DERP relay", client.Hostname()) +} + +// assertCommandOutputContains executes a command with exponential backoff retry until the output +// contains the expected string or timeout is reached (10 seconds). +// This implements eventual consistency patterns and should be used instead of time.Sleep +// before executing commands that depend on network state propagation. +// +// Timeout: 10 seconds with exponential backoff +// Use cases: DNS resolution, route propagation, policy updates. +func assertCommandOutputContains(t *testing.T, c TailscaleClient, command []string, contains string) { + t.Helper() + + _, err := backoff.Retry(t.Context(), func() (struct{}, error) { + stdout, stderr, err := c.Execute(command) + if err != nil { + return struct{}{}, fmt.Errorf("executing command, stdout: %q stderr: %q, err: %w", stdout, stderr, err) + } + + if !strings.Contains(stdout, contains) { + return struct{}{}, fmt.Errorf("executing command, expected string %q not found in %q", contains, stdout) + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + + assert.NoError(t, err) +} + +// dockertestMaxWait returns the maximum wait time for Docker-based test operations. +// Uses longer timeouts in CI environments to account for slower resource allocation +// and higher system load during automated testing. +func dockertestMaxWait() time.Duration { + wait := 300 * time.Second //nolint + + if util.IsCI() { + wait = 600 * time.Second //nolint + } + + return wait +} + +// didClientUseWebsocketForDERP analyzes client logs to determine if WebSocket was used for DERP. +// Searches for WebSocket connection indicators in client logs to validate +// DERP relay communication method for debugging connectivity issues. +func didClientUseWebsocketForDERP(t *testing.T, client TailscaleClient) bool { + t.Helper() + + buf := &bytes.Buffer{} + err := client.WriteLogs(buf, buf) + if err != nil { + t.Fatalf("failed to fetch client logs: %s: %s", client.Hostname(), err) + } + + count, err := countMatchingLines(buf, func(line string) bool { + return strings.Contains(line, "websocket: connected to ") + }) + if err != nil { + t.Fatalf("failed to process client logs: %s: %s", client.Hostname(), err) + } + + return count > 0 +} + +// countMatchingLines counts lines in a reader that match the given predicate function. +// Uses optimized buffering for log analysis and provides flexible line-by-line +// filtering for log parsing and pattern matching in integration tests. +func countMatchingLines(in io.Reader, predicate func(string) bool) (int, error) { + count := 0 + scanner := bufio.NewScanner(in) + { + const logBufferInitialSize = 1024 << 10 // preallocate 1 MiB + buff := make([]byte, logBufferInitialSize) + scanner.Buffer(buff, len(buff)) + scanner.Split(bufio.ScanLines) + } + + for scanner.Scan() { + if predicate(scanner.Text()) { + count += 1 + } + } + + return count, scanner.Err() +} + +// wildcard returns a wildcard alias (*) for use in policy v2 configurations. +// Provides a convenient helper for creating permissive policy rules. +func wildcard() policyv2.Alias { + return policyv2.Wildcard +} + +// usernamep returns a pointer to a Username as an Alias for policy v2 configurations. +// Used in ACL rules to reference specific users in network access policies. +func usernamep(name string) policyv2.Alias { + return ptr.To(policyv2.Username(name)) +} + +// hostp returns a pointer to a Host as an Alias for policy v2 configurations. +// Used in ACL rules to reference specific hosts in network access policies. +func hostp(name string) policyv2.Alias { + return ptr.To(policyv2.Host(name)) +} + +// groupp returns a pointer to a Group as an Alias for policy v2 configurations. +// Used in ACL rules to reference user groups in network access policies. +func groupp(name string) policyv2.Alias { + return ptr.To(policyv2.Group(name)) +} + +// tagp returns a pointer to a Tag as an Alias for policy v2 configurations. +// Used in ACL rules to reference node tags in network access policies. +func tagp(name string) policyv2.Alias { + return ptr.To(policyv2.Tag(name)) +} + +// prefixp returns a pointer to a Prefix from a CIDR string for policy v2 configurations. +// Converts CIDR notation to policy prefix format for network range specifications. +func prefixp(cidr string) policyv2.Alias { + prefix := netip.MustParsePrefix(cidr) + return ptr.To(policyv2.Prefix(prefix)) +} + +// aliasWithPorts creates an AliasWithPorts structure from an alias and port ranges. +// Combines network targets with specific port restrictions for fine-grained +// access control in policy v2 configurations. +func aliasWithPorts(alias policyv2.Alias, ports ...tailcfg.PortRange) policyv2.AliasWithPorts { + return policyv2.AliasWithPorts{ + Alias: alias, + Ports: ports, + } +} + +// usernameOwner returns a Username as an Owner for use in TagOwners policies. +// Specifies which users can assign and manage specific tags in ACL configurations. +func usernameOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Username(name)) +} + +// groupOwner returns a Group as an Owner for use in TagOwners policies. +// Specifies which groups can assign and manage specific tags in ACL configurations. +func groupOwner(name string) policyv2.Owner { + return ptr.To(policyv2.Group(name)) +} + +// usernameApprover returns a Username as an AutoApprover for subnet route policies. +// Specifies which users can automatically approve subnet route advertisements. +func usernameApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Username(name)) +} + +// groupApprover returns a Group as an AutoApprover for subnet route policies. +// Specifies which groups can automatically approve subnet route advertisements. +func groupApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Group(name)) +} + +// tagApprover returns a Tag as an AutoApprover for subnet route policies. +// Specifies which tagged nodes can automatically approve subnet route advertisements. +func tagApprover(name string) policyv2.AutoApprover { + return ptr.To(policyv2.Tag(name)) +} + +// oidcMockUser creates a MockUser for OIDC authentication testing. +// Generates consistent test user data with configurable email verification status +// for validating OIDC integration flows in headscale authentication tests. +func oidcMockUser(username string, emailVerified bool) mockoidc.MockUser { + return mockoidc.MockUser{ + Subject: username, + PreferredUsername: username, + Email: username + "@headscale.net", + EmailVerified: emailVerified, + } +} + +// GetUserByName retrieves a user by name from the headscale server. +// This is a common pattern used when creating preauth keys or managing users. +func GetUserByName(headscale ControlServer, username string) (*v1.User, error) { + users, err := headscale.ListUsers() + if err != nil { + return nil, fmt.Errorf("failed to list users: %w", err) + } + + for _, u := range users { + if u.GetName() == username { + return u, nil + } + } + + return nil, fmt.Errorf("user %s not found", username) +} + +// FindNewClient finds a client that is in the new list but not in the original list. +// This is useful when dynamically adding nodes during tests and needing to identify +// which client was just added. +func FindNewClient(original, updated []TailscaleClient) (TailscaleClient, error) { + for _, client := range updated { + isOriginal := false + for _, origClient := range original { + if client.Hostname() == origClient.Hostname() { + isOriginal = true + break + } + } + if !isOriginal { + return client, nil + } + } + return nil, fmt.Errorf("no new client found") +} + +// AddAndLoginClient adds a new tailscale client to a user and logs it in. +// This combines the common pattern of: +// 1. Creating a new node +// 2. Finding the new node in the client list +// 3. Getting the user to create a preauth key +// 4. Logging in the new node +func (s *Scenario) AddAndLoginClient( + t *testing.T, + username string, + version string, + headscale ControlServer, + tsOpts ...tsic.Option, +) (TailscaleClient, error) { + t.Helper() + + // Get the original client list + originalClients, err := s.ListTailscaleClients(username) + if err != nil { + return nil, fmt.Errorf("failed to list original clients: %w", err) + } + + // Create the new node + err = s.CreateTailscaleNodesInUser(username, version, 1, tsOpts...) + if err != nil { + return nil, fmt.Errorf("failed to create tailscale node: %w", err) + } + + // Wait for the new node to appear in the client list + var newClient TailscaleClient + + _, err = backoff.Retry(t.Context(), func() (struct{}, error) { + updatedClients, err := s.ListTailscaleClients(username) + if err != nil { + return struct{}{}, fmt.Errorf("failed to list updated clients: %w", err) + } + + if len(updatedClients) != len(originalClients)+1 { + return struct{}{}, fmt.Errorf("expected %d clients, got %d", len(originalClients)+1, len(updatedClients)) + } + + newClient, err = FindNewClient(originalClients, updatedClients) + if err != nil { + return struct{}{}, fmt.Errorf("failed to find new client: %w", err) + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewConstantBackOff(500*time.Millisecond)), backoff.WithMaxElapsedTime(10*time.Second)) + if err != nil { + return nil, fmt.Errorf("timeout waiting for new client: %w", err) + } + + // Get the user and create preauth key + user, err := GetUserByName(headscale, username) + if err != nil { + return nil, fmt.Errorf("failed to get user: %w", err) + } + + authKey, err := s.CreatePreAuthKey(user.GetId(), true, false) + if err != nil { + return nil, fmt.Errorf("failed to create preauth key: %w", err) + } + + // Login the new client + err = newClient.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + return nil, fmt.Errorf("failed to login new client: %w", err) + } + + return newClient, nil +} + +// MustAddAndLoginClient is like AddAndLoginClient but fails the test on error. +func (s *Scenario) MustAddAndLoginClient( + t *testing.T, + username string, + version string, + headscale ControlServer, + tsOpts ...tsic.Option, +) TailscaleClient { + t.Helper() + + client, err := s.AddAndLoginClient(t, username, version, headscale, tsOpts...) + require.NoError(t, err) + return client +} diff --git a/integration/hsic/config.go b/integration/hsic/config.go index 00c1770c..8ceca90f 100644 --- a/integration/hsic/config.go +++ b/integration/hsic/config.go @@ -1,102 +1,6 @@ package hsic -// const ( -// defaultEphemeralNodeInactivityTimeout = time.Second * 30 -// defaultNodeUpdateCheckInterval = time.Second * 10 -// ) - -// TODO(kradalby): This approach doesnt work because we cannot -// serialise our config object to YAML or JSON. -// func DefaultConfig() headscale.Config { -// derpMap, _ := url.Parse("https://controlplane.tailscale.com/derpmap/default") -// -// config := headscale.Config{ -// Log: headscale.LogConfig{ -// Level: zerolog.TraceLevel, -// }, -// ACL: headscale.GetACLConfig(), -// DBtype: "sqlite3", -// EphemeralNodeInactivityTimeout: defaultEphemeralNodeInactivityTimeout, -// NodeUpdateCheckInterval: defaultNodeUpdateCheckInterval, -// IPPrefixes: []netip.Prefix{ -// netip.MustParsePrefix("fd7a:115c:a1e0::/48"), -// netip.MustParsePrefix("100.64.0.0/10"), -// }, -// DNSConfig: &tailcfg.DNSConfig{ -// Proxied: true, -// Nameservers: []netip.Addr{ -// netip.MustParseAddr("127.0.0.11"), -// netip.MustParseAddr("1.1.1.1"), -// }, -// Resolvers: []*dnstype.Resolver{ -// { -// Addr: "127.0.0.11", -// }, -// { -// Addr: "1.1.1.1", -// }, -// }, -// }, -// BaseDomain: "headscale.net", -// -// DBpath: "/tmp/integration_test_db.sqlite3", -// -// PrivateKeyPath: "/tmp/integration_private.key", -// NoisePrivateKeyPath: "/tmp/noise_integration_private.key", -// Addr: "0.0.0.0:8080", -// MetricsAddr: "127.0.0.1:9090", -// ServerURL: "http://headscale:8080", -// -// DERP: headscale.DERPConfig{ -// URLs: []url.URL{ -// *derpMap, -// }, -// AutoUpdate: false, -// UpdateFrequency: 1 * time.Minute, -// }, -// } -// -// return config -// } - -// TODO: Reuse the actual configuration object above. -// Deprecated: use env function instead as it is easier to -// override. -func DefaultConfigYAML() string { - yaml := ` -log: - level: trace -acl_policy_path: "" -db_type: sqlite3 -db_path: /tmp/integration_test_db.sqlite3 -ephemeral_node_inactivity_timeout: 30m -node_update_check_interval: 10s -ip_prefixes: - - fd7a:115c:a1e0::/48 - - 100.64.0.0/10 -dns_config: - base_domain: headscale.net - magic_dns: true - domains: [] - nameservers: - - 127.0.0.11 - - 1.1.1.1 -private_key_path: /tmp/private.key -noise: - private_key_path: /tmp/noise_private.key -listen_addr: 0.0.0.0:8080 -metrics_listen_addr: 127.0.0.1:9090 -server_url: http://headscale:8080 - -derp: - urls: - - https://controlplane.tailscale.com/derpmap/default - auto_update_enabled: false - update_frequency: 1m -` - - return yaml -} +import "github.com/juanfont/headscale/hscontrol/types" func MinimumConfigYAML() string { return ` @@ -109,23 +13,28 @@ noise: func DefaultConfigEnv() map[string]string { return map[string]string{ "HEADSCALE_LOG_LEVEL": "trace", - "HEADSCALE_ACL_POLICY_PATH": "", - "HEADSCALE_DB_TYPE": "sqlite3", - "HEADSCALE_DB_PATH": "/tmp/integration_test_db.sqlite3", + "HEADSCALE_POLICY_PATH": "", + "HEADSCALE_DATABASE_TYPE": "sqlite", + "HEADSCALE_DATABASE_SQLITE_PATH": "/tmp/integration_test_db.sqlite3", + "HEADSCALE_DATABASE_DEBUG": "0", + "HEADSCALE_DATABASE_GORM_SLOW_THRESHOLD": "1", "HEADSCALE_EPHEMERAL_NODE_INACTIVITY_TIMEOUT": "30m", - "HEADSCALE_NODE_UPDATE_CHECK_INTERVAL": "10s", - "HEADSCALE_IP_PREFIXES": "fd7a:115c:a1e0::/48 100.64.0.0/10", - "HEADSCALE_DNS_CONFIG_BASE_DOMAIN": "headscale.net", - "HEADSCALE_DNS_CONFIG_MAGIC_DNS": "true", - "HEADSCALE_DNS_CONFIG_DOMAINS": "", - "HEADSCALE_DNS_CONFIG_NAMESERVERS": "127.0.0.11 1.1.1.1", + "HEADSCALE_PREFIXES_V4": "100.64.0.0/10", + "HEADSCALE_PREFIXES_V6": "fd7a:115c:a1e0::/48", + "HEADSCALE_DNS_BASE_DOMAIN": "headscale.net", + "HEADSCALE_DNS_MAGIC_DNS": "true", + "HEADSCALE_DNS_OVERRIDE_LOCAL_DNS": "false", + "HEADSCALE_DNS_NAMESERVERS_GLOBAL": "127.0.0.11 1.1.1.1", "HEADSCALE_PRIVATE_KEY_PATH": "/tmp/private.key", "HEADSCALE_NOISE_PRIVATE_KEY_PATH": "/tmp/noise_private.key", - "HEADSCALE_LISTEN_ADDR": "0.0.0.0:8080", - "HEADSCALE_METRICS_LISTEN_ADDR": "127.0.0.1:9090", - "HEADSCALE_SERVER_URL": "http://headscale:8080", + "HEADSCALE_METRICS_LISTEN_ADDR": "0.0.0.0:9090", "HEADSCALE_DERP_URLS": "https://controlplane.tailscale.com/derpmap/default", "HEADSCALE_DERP_AUTO_UPDATE_ENABLED": "false", "HEADSCALE_DERP_UPDATE_FREQUENCY": "1m", + "HEADSCALE_DEBUG_PORT": "40000", + + // a bunch of tests (ACL/Policy) rely on predictable IP alloc, + // so ensure the sequential alloc is used by default. + "HEADSCALE_PREFIXES_ALLOCATION": string(types.IPAllocationStrategySequential), } } diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index 5019895a..42bb8e93 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -1,46 +1,59 @@ package hsic import ( + "archive/tar" "bytes" - "crypto/rand" - "crypto/rsa" + "cmp" "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" "encoding/json" - "encoding/pem" "errors" "fmt" + "io" "log" - "math/big" - "net" + "maps" "net/http" - "net/url" + "net/netip" "os" "path" + "path/filepath" + "sort" + "strconv" "strings" "time" "github.com/davecgh/go-spew/spew" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" - "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/hscontrol" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/routes" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "gopkg.in/yaml.v3" + "tailscale.com/tailcfg" + "tailscale.com/util/mak" ) const ( - hsicHashLength = 6 - dockerContextPath = "../." - aclPolicyPath = "/etc/headscale/acl.hujson" - tlsCertPath = "/etc/headscale/tls.cert" - tlsKeyPath = "/etc/headscale/tls.key" - headscaleDefaultPort = 8080 + hsicHashLength = 6 + dockerContextPath = "../." + caCertRoot = "/usr/local/share/ca-certificates" + aclPolicyPath = "/etc/headscale/acl.hujson" + tlsCertPath = "/etc/headscale/tls.cert" + tlsKeyPath = "/etc/headscale/tls.key" + headscaleDefaultPort = 8080 + IntegrationTestDockerFileName = "Dockerfile.integration" ) -var errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok") +var ( + errHeadscaleStatusCodeNotOk = errors.New("headscale status code not ok") + errInvalidHeadscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_HEADSCALE_IMAGE format, expected repository:tag") + errHeadscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_HEADSCALE_IMAGE must be set in CI") + errInvalidPostgresImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_POSTGRES_IMAGE format, expected repository:tag") +) type fileInContainer struct { path string @@ -54,17 +67,23 @@ type HeadscaleInContainer struct { pool *dockertest.Pool container *dockertest.Resource - network *dockertest.Network + networks []*dockertest.Network + + pgContainer *dockertest.Resource // optional config port int extraPorts []string + hostMetricsPort string // Dynamically assigned host port for metrics/pprof access + caCerts [][]byte hostPortBindings map[string][]string - aclPolicy *policy.ACLPolicy + aclPolicy *policyv2.Policy env map[string]string tlsCert []byte tlsKey []byte filesInContainer []fileInContainer + postgres bool + policyMode types.PolicyMode } // Option represent optional settings that can be given to a @@ -73,27 +92,42 @@ type Option = func(c *HeadscaleInContainer) // WithACLPolicy adds a hscontrol.ACLPolicy policy to the // HeadscaleInContainer instance. -func WithACLPolicy(acl *policy.ACLPolicy) Option { +func WithACLPolicy(acl *policyv2.Policy) Option { return func(hsic *HeadscaleInContainer) { + if acl == nil { + return + } + // TODO(kradalby): Move somewhere appropriate - hsic.env["HEADSCALE_ACL_POLICY_PATH"] = aclPolicyPath + hsic.env["HEADSCALE_POLICY_PATH"] = aclPolicyPath hsic.aclPolicy = acl } } +// WithCACert adds it to the trusted surtificate of the container. +func WithCACert(cert []byte) Option { + return func(hsic *HeadscaleInContainer) { + hsic.caCerts = append(hsic.caCerts, cert) + } +} + // WithTLS creates certificates and enables HTTPS. func WithTLS() Option { return func(hsic *HeadscaleInContainer) { - cert, key, err := createCertificate(hsic.hostname) + cert, key, err := integrationutil.CreateCertificate(hsic.hostname) if err != nil { log.Fatalf("failed to create certificates for headscale test: %s", err) } - // TODO(kradalby): Move somewhere appropriate - hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath - hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath + hsic.tlsCert = cert + hsic.tlsKey = key + } +} +// WithCustomTLS uses the given certificates for the Headscale instance. +func WithCustomTLS(cert, key []byte) Option { + return func(hsic *HeadscaleInContainer) { hsic.tlsCert = cert hsic.tlsKey = key } @@ -103,9 +137,7 @@ func WithTLS() Option { // can be used to override Headscale configuration. func WithConfigEnv(configEnv map[string]string) Option { return func(hsic *HeadscaleInContainer) { - for key, value := range configEnv { - hsic.env[key] = value - } + maps.Copy(hsic.env, configEnv) } } @@ -140,14 +172,10 @@ func WithTestName(testName string) Option { } } -// WithHostnameAsServerURL sets the Headscale ServerURL based on -// the Hostname. -func WithHostnameAsServerURL() Option { +// WithHostname sets the hostname of the Headscale instance. +func WithHostname(hostname string) Option { return func(hsic *HeadscaleInContainer) { - hsic.env["HEADSCALE_SERVER_URL"] = fmt.Sprintf("http://%s", - net.JoinHostPort(hsic.GetHostname(), - fmt.Sprintf("%d", hsic.port)), - ) + hsic.hostname = hostname } } @@ -162,10 +190,140 @@ func WithFileInContainer(path string, contents []byte) Option { } } +// WithPostgres spins up a Postgres container and +// sets it as the main database. +func WithPostgres() Option { + return func(hsic *HeadscaleInContainer) { + hsic.postgres = true + } +} + +// WithPolicy sets the policy mode for headscale. +func WithPolicyMode(mode types.PolicyMode) Option { + return func(hsic *HeadscaleInContainer) { + hsic.policyMode = mode + hsic.env["HEADSCALE_POLICY_MODE"] = string(mode) + } +} + +// WithIPAllocationStrategy sets the tests IP Allocation strategy. +func WithIPAllocationStrategy(strategy types.IPAllocationStrategy) Option { + return func(hsic *HeadscaleInContainer) { + hsic.env["HEADSCALE_PREFIXES_ALLOCATION"] = string(strategy) + } +} + +// WithEmbeddedDERPServerOnly configures Headscale to start +// and only use the embedded DERP server. +// It requires WithTLS and WithHostnameAsServerURL to be +// set. +func WithEmbeddedDERPServerOnly() Option { + return func(hsic *HeadscaleInContainer) { + hsic.env["HEADSCALE_DERP_URLS"] = "" + hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "true" + hsic.env["HEADSCALE_DERP_SERVER_REGION_ID"] = "999" + hsic.env["HEADSCALE_DERP_SERVER_REGION_CODE"] = "headscale" + hsic.env["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" + hsic.env["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" + hsic.env["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" + + // Envknob for enabling DERP debug logs + hsic.env["DERP_DEBUG_LOGS"] = "true" + hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true" + } +} + +// WithDERPConfig configures Headscale use a custom +// DERP server only. +func WithDERPConfig(derpMap tailcfg.DERPMap) Option { + return func(hsic *HeadscaleInContainer) { + contents, err := yaml.Marshal(derpMap) + if err != nil { + log.Fatalf("failed to marshal DERP map: %s", err) + + return + } + + hsic.env["HEADSCALE_DERP_PATHS"] = "/etc/headscale/derp.yml" + hsic.filesInContainer = append(hsic.filesInContainer, + fileInContainer{ + path: "/etc/headscale/derp.yml", + contents: contents, + }) + + // Disable global DERP server and embedded DERP server + hsic.env["HEADSCALE_DERP_URLS"] = "" + hsic.env["HEADSCALE_DERP_SERVER_ENABLED"] = "false" + + // Envknob for enabling DERP debug logs + hsic.env["DERP_DEBUG_LOGS"] = "true" + hsic.env["DERP_PROBER_DEBUG_LOGS"] = "true" + } +} + +// WithTuning allows changing the tuning settings easily. +func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option { + return func(hsic *HeadscaleInContainer) { + hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String() + hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa( + mapSessionChanSize, + ) + } +} + +func WithTimezone(timezone string) Option { + return func(hsic *HeadscaleInContainer) { + hsic.env["TZ"] = timezone + } +} + +// WithDERPAsIP enables using IP address instead of hostname for DERP server. +// This is useful for integration tests where DNS resolution may be unreliable. +func WithDERPAsIP() Option { + return func(hsic *HeadscaleInContainer) { + hsic.env["HEADSCALE_DEBUG_DERP_USE_IP"] = "1" + } +} + +// buildEntrypoint builds the container entrypoint command based on configuration. +// It constructs proper wait conditions instead of fixed sleeps: +// 1. Wait for network to be ready +// 2. Wait for config.yaml (always written after container start) +// 3. Wait for CA certs if configured +// 4. Update CA certificates +// 5. Run headscale serve +// 6. Sleep at end to keep container alive for log collection on shutdown. +func (hsic *HeadscaleInContainer) buildEntrypoint() []string { + var commands []string + + // Wait for network to be ready + commands = append(commands, "while ! ip route show default >/dev/null 2>&1; do sleep 0.1; done") + + // Wait for config.yaml to be written (always written after container start) + commands = append(commands, "while [ ! -f /etc/headscale/config.yaml ]; do sleep 0.1; done") + + // If CA certs are configured, wait for them to be written + if len(hsic.caCerts) > 0 { + commands = append(commands, + fmt.Sprintf("while [ ! -f %s/user-0.crt ]; do sleep 0.1; done", caCertRoot)) + } + + // Update CA certificates + commands = append(commands, "update-ca-certificates") + + // Run headscale serve + commands = append(commands, "/usr/local/bin/headscale serve") + + // Keep container alive after headscale exits for log collection + commands = append(commands, "/bin/sleep 30") + + return []string{"/bin/bash", "-c", strings.Join(commands, " ; ")} +} + // New returns a new HeadscaleInContainer instance. func New( pool *dockertest.Pool, - network *dockertest.Network, + networks []*dockertest.Network, opts ...Option, ) (*HeadscaleInContainer, error) { hash, err := util.GenerateRandomStringDNSSafe(hsicHashLength) @@ -173,17 +331,29 @@ func New( return nil, err } - hostname := fmt.Sprintf("hs-%s", hash) + // Include run ID in hostname for easier identification of which test run owns this container + runID := dockertestutil.GetIntegrationRunID() + + var hostname string + + if runID != "" { + // Use last 6 chars of run ID (the random hash part) for brevity + runIDShort := runID[len(runID)-6:] + hostname = fmt.Sprintf("hs-%s-%s", runIDShort, hash) + } else { + hostname = "hs-" + hash + } hsic := &HeadscaleInContainer{ hostname: hostname, port: headscaleDefaultPort, - pool: pool, - network: network, + pool: pool, + networks: networks, env: DefaultConfigEnv(), filesInContainer: []fileInContainer{}, + policyMode: types.PolicyModeFile, } for _, opt := range opts { @@ -194,26 +364,75 @@ func New( portProto := fmt.Sprintf("%d/tcp", hsic.port) - serverURL, err := url.Parse(hsic.env["HEADSCALE_SERVER_URL"]) - if err != nil { - return nil, err - } - - if len(hsic.tlsCert) != 0 && len(hsic.tlsKey) != 0 { - serverURL.Scheme = "https" - hsic.env["HEADSCALE_SERVER_URL"] = serverURL.String() - } - headscaleBuildOptions := &dockertest.BuildOptions{ - Dockerfile: "Dockerfile.debug", + Dockerfile: IntegrationTestDockerFileName, ContextDir: dockerContextPath, } - env := []string{ - "HEADSCALE_PROFILING_ENABLED=1", - "HEADSCALE_PROFILING_PATH=/tmp/profile", - "HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH=/tmp/mapresponses", + if hsic.postgres { + hsic.env["HEADSCALE_DATABASE_TYPE"] = "postgres" + hsic.env["HEADSCALE_DATABASE_POSTGRES_HOST"] = "postgres-" + hash + hsic.env["HEADSCALE_DATABASE_POSTGRES_USER"] = "headscale" + hsic.env["HEADSCALE_DATABASE_POSTGRES_PASS"] = "headscale" + hsic.env["HEADSCALE_DATABASE_POSTGRES_NAME"] = "headscale" + delete(hsic.env, "HEADSCALE_DATABASE_SQLITE_PATH") + + // Determine postgres image - use prebuilt if available, otherwise pull from registry + pgRepo := "postgres" + pgTag := "latest" + + if prebuiltImage := os.Getenv("HEADSCALE_INTEGRATION_POSTGRES_IMAGE"); prebuiltImage != "" { + repo, tag, found := strings.Cut(prebuiltImage, ":") + if !found { + return nil, errInvalidPostgresImageFormat + } + + pgRepo = repo + pgTag = tag + } + + pgRunOptions := &dockertest.RunOptions{ + Name: "postgres-" + hash, + Repository: pgRepo, + Tag: pgTag, + Networks: networks, + Env: []string{ + "POSTGRES_USER=headscale", + "POSTGRES_PASSWORD=headscale", + "POSTGRES_DB=headscale", + }, + } + + // Add integration test labels if running under hi tool + dockertestutil.DockerAddIntegrationLabels(pgRunOptions, "postgres") + + pg, err := pool.RunWithOptions(pgRunOptions) + if err != nil { + return nil, fmt.Errorf("starting postgres container: %w", err) + } + + hsic.pgContainer = pg } + + env := []string{ + "HEADSCALE_DEBUG_PROFILING_ENABLED=1", + "HEADSCALE_DEBUG_PROFILING_PATH=/tmp/profile", + "HEADSCALE_DEBUG_DUMP_MAPRESPONSE_PATH=/tmp/mapresponses", + "HEADSCALE_DEBUG_DEADLOCK=1", + "HEADSCALE_DEBUG_DEADLOCK_TIMEOUT=5s", + "HEADSCALE_DEBUG_HIGH_CARDINALITY_METRICS=1", + "HEADSCALE_DEBUG_DUMP_CONFIG=1", + } + if hsic.hasTLS() { + hsic.env["HEADSCALE_TLS_CERT_PATH"] = tlsCertPath + hsic.env["HEADSCALE_TLS_KEY_PATH"] = tlsKeyPath + } + + // Server URL and Listen Addr should not be overridable outside of + // the configuration passed to docker. + hsic.env["HEADSCALE_SERVER_URL"] = hsic.GetEndpoint() + hsic.env["HEADSCALE_LISTEN_ADDR"] = fmt.Sprintf("0.0.0.0:%d", hsic.port) + for key, value := range hsic.env { env = append(env, fmt.Sprintf("%s=%s", key, value)) } @@ -222,17 +441,25 @@ func New( runOptions := &dockertest.RunOptions{ Name: hsic.hostname, - ExposedPorts: append([]string{portProto}, hsic.extraPorts...), - Networks: []*dockertest.Network{network}, + ExposedPorts: append([]string{portProto, "9090/tcp"}, hsic.extraPorts...), + Networks: networks, // Cmd: []string{"headscale", "serve"}, // TODO(kradalby): Get rid of this hack, we currently need to give us some // to inject the headscale configuration further down. - Entrypoint: []string{"/bin/bash", "-c", "/bin/sleep 3 ; headscale serve ; /bin/sleep 30"}, + Entrypoint: hsic.buildEntrypoint(), Env: env, } - if len(hsic.hostPortBindings) > 0 { + // Bind metrics port to dynamic host port (kernel assigns free port) + if runOptions.PortBindings == nil { runOptions.PortBindings = map[docker.Port][]docker.PortBinding{} + } + + runOptions.PortBindings["9090/tcp"] = []docker.PortBinding{ + {HostPort: "0"}, // Let kernel assign a free port + } + + if len(hsic.hostPortBindings) > 0 { for port, hostPorts := range hsic.hostPortBindings { runOptions.PortBindings[docker.Port(port)] = []docker.PortBinding{} for _, hostPort := range hostPorts { @@ -243,42 +470,115 @@ func New( } } - // dockertest isnt very good at handling containers that has already - // been created, this is an attempt to make sure this container isnt + // dockertest isn't very good at handling containers that has already + // been created, this is an attempt to make sure this container isn't // present. err = pool.RemoveContainerByName(hsic.hostname) if err != nil { return nil, err } - container, err := pool.BuildAndRunWithBuildOptions( - headscaleBuildOptions, - runOptions, - dockertestutil.DockerRestartPolicy, - dockertestutil.DockerAllowLocalIPv6, - dockertestutil.DockerAllowNetworkAdministration, - ) - if err != nil { - return nil, fmt.Errorf("could not start headscale container: %w", err) + // Add integration test labels if running under hi tool + dockertestutil.DockerAddIntegrationLabels(runOptions, "headscale") + + var container *dockertest.Resource + + // Check if a pre-built image is available via environment variable + prebuiltImage := os.Getenv("HEADSCALE_INTEGRATION_HEADSCALE_IMAGE") + + if prebuiltImage != "" { + log.Printf("Using pre-built headscale image: %s", prebuiltImage) + // Parse image into repository and tag + repo, tag, ok := strings.Cut(prebuiltImage, ":") + if !ok { + return nil, errInvalidHeadscaleImageFormat + } + + runOptions.Repository = repo + runOptions.Tag = tag + + container, err = pool.RunWithOptions( + runOptions, + dockertestutil.DockerRestartPolicy, + dockertestutil.DockerAllowLocalIPv6, + dockertestutil.DockerAllowNetworkAdministration, + ) + if err != nil { + return nil, fmt.Errorf("could not run pre-built headscale container %q: %w", prebuiltImage, err) + } + } else if util.IsCI() { + return nil, errHeadscaleImageRequiredInCI + } else { + container, err = pool.BuildAndRunWithBuildOptions( + headscaleBuildOptions, + runOptions, + dockertestutil.DockerRestartPolicy, + dockertestutil.DockerAllowLocalIPv6, + dockertestutil.DockerAllowNetworkAdministration, + ) + if err != nil { + // Try to get more detailed build output + log.Printf("Docker build/run failed, attempting to get detailed output...") + + buildOutput, buildErr := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, IntegrationTestDockerFileName) + + // Show the last 100 lines of build output to avoid overwhelming the logs + lines := strings.Split(buildOutput, "\n") + + const maxLines = 100 + + startLine := 0 + if len(lines) > maxLines { + startLine = len(lines) - maxLines + } + + relevantOutput := strings.Join(lines[startLine:], "\n") + + if buildErr != nil { + // The diagnostic build also failed - this is the real error + return nil, fmt.Errorf("could not start headscale container: %w\n\nDocker build failed. Last %d lines of output:\n%s", err, maxLines, relevantOutput) + } + + if buildOutput != "" { + // Build succeeded on retry but container creation still failed + return nil, fmt.Errorf("could not start headscale container: %w\n\nDocker build succeeded on retry, but container creation failed. Last %d lines of build output:\n%s", err, maxLines, relevantOutput) + } + + // No output at all - diagnostic build command may have failed + return nil, fmt.Errorf("could not start headscale container: %w\n\nUnable to get diagnostic build output (command may have failed silently)", err) + } } log.Printf("Created %s container\n", hsic.hostname) hsic.container = container + // Get the dynamically assigned host port for metrics/pprof + hsic.hostMetricsPort = container.GetHostPort("9090/tcp") + + log.Printf( + "Headscale %s metrics available at http://localhost:%s/metrics (debug at http://localhost:%s/debug/)\n", + hsic.hostname, + hsic.hostMetricsPort, + hsic.hostMetricsPort, + ) + + // Write the CA certificates to the container + for i, cert := range hsic.caCerts { + err = hsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert) + if err != nil { + return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) + } + } + err = hsic.WriteFile("/etc/headscale/config.yaml", []byte(MinimumConfigYAML())) if err != nil { return nil, fmt.Errorf("failed to write headscale config to container: %w", err) } if hsic.aclPolicy != nil { - data, err := json.Marshal(hsic.aclPolicy) + err = hsic.writePolicy(hsic.aclPolicy) if err != nil { - return nil, fmt.Errorf("failed to marshal ACL Policy to JSON: %w", err) - } - - err = hsic.WriteFile(aclPolicyPath, data) - if err != nil { - return nil, fmt.Errorf("failed to write ACL policy to container: %w", err) + return nil, fmt.Errorf("writing policy: %w", err) } } @@ -300,6 +600,15 @@ func New( } } + // Load the database from policy file on repeat until it succeeds, + // this is done as the container sleeps before starting headscale. + if hsic.aclPolicy != nil && hsic.policyMode == types.PolicyModeDB { + err := pool.Retry(hsic.reloadDatabasePolicy) + if err != nil { + return nil, fmt.Errorf("loading database policy on startup: %w", err) + } + } + return hsic, nil } @@ -312,8 +621,8 @@ func (t *HeadscaleInContainer) hasTLS() bool { } // Shutdown stops and cleans up the Headscale container. -func (t *HeadscaleInContainer) Shutdown() error { - err := t.SaveLog("/tmp/control") +func (t *HeadscaleInContainer) Shutdown() (string, string, error) { + stdoutPath, stderrPath, err := t.SaveLog("/tmp/control") if err != nil { log.Printf( "Failed to save log from control: %s", @@ -321,6 +630,14 @@ func (t *HeadscaleInContainer) Shutdown() error { ) } + err = t.SaveMetrics(fmt.Sprintf("/tmp/control/%s_metrics.txt", t.hostname)) + if err != nil { + log.Printf( + "Failed to metrics from control: %s", + err, + ) + } + // Send a interrupt signal to the "headscale" process inside the container // allowing it to shut down gracefully and flush the profile to disk. // The container will live for a bit longer due to the sleep at the end. @@ -348,39 +665,156 @@ func (t *HeadscaleInContainer) Shutdown() error { ) } - err = t.SaveDatabase("/tmp/control") - if err != nil { - log.Printf( - "Failed to save database from control: %s", - fmt.Errorf("failed to save database from control: %w", err), - ) + // We dont have a database to save if we use postgres + if !t.postgres { + err = t.SaveDatabase("/tmp/control") + if err != nil { + log.Printf( + "Failed to save database from control: %s", + fmt.Errorf("failed to save database from control: %w", err), + ) + } } - return t.pool.Purge(t.container) + // Cleanup postgres container if enabled. + if t.postgres { + t.pool.Purge(t.pgContainer) + } + + return stdoutPath, stderrPath, t.pool.Purge(t.container) +} + +// WriteLogs writes the current stdout/stderr log of the container to +// the given io.Writers. +func (t *HeadscaleInContainer) WriteLogs(stdout, stderr io.Writer) error { + return dockertestutil.WriteLog(t.pool, t.container, stdout, stderr) } // SaveLog saves the current stdout log of the container to a path // on the host system. -func (t *HeadscaleInContainer) SaveLog(path string) error { +func (t *HeadscaleInContainer) SaveLog(path string) (string, string, error) { return dockertestutil.SaveLog(t.pool, t.container, path) } +func (t *HeadscaleInContainer) SaveMetrics(savePath string) error { + resp, err := http.Get(fmt.Sprintf("http://%s:9090/metrics", t.hostname)) + if err != nil { + return fmt.Errorf("getting metrics: %w", err) + } + defer resp.Body.Close() + out, err := os.Create(savePath) + if err != nil { + return fmt.Errorf("creating file for metrics: %w", err) + } + defer out.Close() + _, err = io.Copy(out, resp.Body) + if err != nil { + return fmt.Errorf("copy response to file: %w", err) + } + + return nil +} + +// extractTarToDirectory extracts a tar archive to a directory. +func extractTarToDirectory(tarData []byte, targetDir string) error { + if err := os.MkdirAll(targetDir, 0o755); err != nil { + return fmt.Errorf("failed to create directory %s: %w", targetDir, err) + } + + tarReader := tar.NewReader(bytes.NewReader(tarData)) + + // Find the top-level directory to strip + var topLevelDir string + firstPass := tar.NewReader(bytes.NewReader(tarData)) + for { + header, err := firstPass.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + if header.Typeflag == tar.TypeDir && topLevelDir == "" { + topLevelDir = strings.TrimSuffix(header.Name, "/") + break + } + } + + tarReader = tar.NewReader(bytes.NewReader(tarData)) + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + // Clean the path to prevent directory traversal + cleanName := filepath.Clean(header.Name) + if strings.Contains(cleanName, "..") { + continue // Skip potentially dangerous paths + } + + // Strip the top-level directory + if topLevelDir != "" && strings.HasPrefix(cleanName, topLevelDir+"/") { + cleanName = strings.TrimPrefix(cleanName, topLevelDir+"/") + } else if cleanName == topLevelDir { + // Skip the top-level directory itself + continue + } + + // Skip empty paths after stripping + if cleanName == "" { + continue + } + + targetPath := filepath.Join(targetDir, cleanName) + + switch header.Typeflag { + case tar.TypeDir: + // Create directory + if err := os.MkdirAll(targetPath, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to create directory %s: %w", targetPath, err) + } + case tar.TypeReg: + // Ensure parent directories exist + if err := os.MkdirAll(filepath.Dir(targetPath), 0o755); err != nil { + return fmt.Errorf("failed to create parent directories for %s: %w", targetPath, err) + } + + // Create file + outFile, err := os.Create(targetPath) + if err != nil { + return fmt.Errorf("failed to create file %s: %w", targetPath, err) + } + + if _, err := io.Copy(outFile, tarReader); err != nil { + outFile.Close() + return fmt.Errorf("failed to copy file contents: %w", err) + } + outFile.Close() + + // Set file permissions + if err := os.Chmod(targetPath, os.FileMode(header.Mode)); err != nil { + return fmt.Errorf("failed to set file permissions: %w", err) + } + } + } + + return nil +} + func (t *HeadscaleInContainer) SaveProfile(savePath string) error { tarFile, err := t.FetchPath("/tmp/profile") if err != nil { return err } - err = os.WriteFile( - path.Join(savePath, t.hostname+".pprof.tar"), - tarFile, - os.ModePerm, - ) - if err != nil { - return err - } + targetDir := path.Join(savePath, "pprof") - return nil + return extractTarToDirectory(tarFile, targetDir) } func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { @@ -389,34 +823,101 @@ func (t *HeadscaleInContainer) SaveMapResponses(savePath string) error { return err } - err = os.WriteFile( - path.Join(savePath, t.hostname+".maps.tar"), - tarFile, - os.ModePerm, - ) - if err != nil { - return err - } + targetDir := path.Join(savePath, "mapresponses") - return nil + return extractTarToDirectory(tarFile, targetDir) } func (t *HeadscaleInContainer) SaveDatabase(savePath string) error { + // If using PostgreSQL, skip database file extraction + if t.postgres { + return nil + } + + // Also check for any .sqlite files + sqliteFiles, err := t.Execute([]string{"find", "/tmp", "-name", "*.sqlite*", "-type", "f"}) + if err != nil { + log.Printf("Warning: could not find sqlite files: %v", err) + } else { + log.Printf("SQLite files found in %s:\n%s", t.hostname, sqliteFiles) + } + + // Check if the database file exists and has a schema + dbPath := "/tmp/integration_test_db.sqlite3" + fileInfo, err := t.Execute([]string{"ls", "-la", dbPath}) + if err != nil { + return fmt.Errorf("database file does not exist at %s: %w", dbPath, err) + } + log.Printf("Database file info: %s", fileInfo) + + // Check if the database has any tables (schema) + schemaCheck, err := t.Execute([]string{"sqlite3", dbPath, ".schema"}) + if err != nil { + return fmt.Errorf("failed to check database schema (sqlite3 command failed): %w", err) + } + + if strings.TrimSpace(schemaCheck) == "" { + return errors.New("database file exists but has no schema (empty database)") + } + tarFile, err := t.FetchPath("/tmp/integration_test_db.sqlite3") if err != nil { - return err + return fmt.Errorf("failed to fetch database file: %w", err) } - err = os.WriteFile( - path.Join(savePath, t.hostname+".db.tar"), - tarFile, - os.ModePerm, - ) - if err != nil { - return err + // For database, extract the first regular file (should be the SQLite file) + tarReader := tar.NewReader(bytes.NewReader(tarFile)) + for { + header, err := tarReader.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar header: %w", err) + } + + log.Printf( + "Found file in tar: %s (type: %d, size: %d)", + header.Name, + header.Typeflag, + header.Size, + ) + + // Extract the first regular file we find + if header.Typeflag == tar.TypeReg { + dbPath := path.Join(savePath, t.hostname+".db") + outFile, err := os.Create(dbPath) + if err != nil { + return fmt.Errorf("failed to create database file: %w", err) + } + + written, err := io.Copy(outFile, tarReader) + outFile.Close() + if err != nil { + return fmt.Errorf("failed to copy database file: %w", err) + } + + log.Printf( + "Extracted database file: %s (%d bytes written, header claimed %d bytes)", + dbPath, + written, + header.Size, + ) + + // Check if we actually wrote something + if written == 0 { + return fmt.Errorf( + "database file is empty (size: %d, header size: %d)", + written, + header.Size, + ) + } + + return nil + } } - return nil + return errors.New("no regular file found in database tar archive") } // Execute runs a command inside the Headscale container and returns the @@ -430,45 +931,64 @@ func (t *HeadscaleInContainer) Execute( []string{}, ) if err != nil { + log.Printf("command: %v", command) log.Printf("command stderr: %s\n", stderr) if stdout != "" { log.Printf("command stdout: %s\n", stdout) } - return "", err + return stdout, fmt.Errorf("executing command in docker: %w, stderr: %s", err, stderr) } return stdout, nil } -// GetIP returns the docker container IP as a string. -func (t *HeadscaleInContainer) GetIP() string { - return t.container.GetIPInNetwork(t.network) -} - // GetPort returns the docker container port as a string. func (t *HeadscaleInContainer) GetPort() string { - return fmt.Sprintf("%d", t.port) + return strconv.Itoa(t.port) +} + +// GetHostMetricsPort returns the dynamically assigned host port for metrics/pprof access. +// This port can be used by operators to access metrics at http://localhost:{port}/metrics +// and debug endpoints at http://localhost:{port}/debug/ while tests are running. +func (t *HeadscaleInContainer) GetHostMetricsPort() string { + return t.hostMetricsPort } // GetHealthEndpoint returns a health endpoint for the HeadscaleInContainer // instance. func (t *HeadscaleInContainer) GetHealthEndpoint() string { - return fmt.Sprintf("%s/health", t.GetEndpoint()) + return t.GetEndpoint() + "/health" } // GetEndpoint returns the Headscale endpoint for the HeadscaleInContainer. func (t *HeadscaleInContainer) GetEndpoint() string { - hostEndpoint := fmt.Sprintf("%s:%d", - t.GetIP(), - t.port) + return t.getEndpoint(false) +} - if t.hasTLS() { - return fmt.Sprintf("https://%s", hostEndpoint) +// GetIPEndpoint returns the Headscale endpoint using IP address instead of hostname. +func (t *HeadscaleInContainer) GetIPEndpoint() string { + return t.getEndpoint(true) +} + +// getEndpoint returns the Headscale endpoint, optionally using IP address instead of hostname. +func (t *HeadscaleInContainer) getEndpoint(useIP bool) string { + var host string + if useIP && len(t.networks) > 0 { + // Use IP address from the first network + host = t.GetIPInNetwork(t.networks[0]) + } else { + host = t.GetHostname() } - return fmt.Sprintf("http://%s", hostEndpoint) + hostEndpoint := fmt.Sprintf("%s:%d", host, t.port) + + if t.hasTLS() { + return "https://" + hostEndpoint + } + + return "http://" + hostEndpoint } // GetCert returns the public certificate of the HeadscaleInContainer. @@ -481,6 +1001,11 @@ func (t *HeadscaleInContainer) GetHostname() string { return t.hostname } +// GetIPInNetwork returns the IP address of the HeadscaleInContainer in the given network. +func (t *HeadscaleInContainer) GetIPInNetwork(network *dockertest.Network) string { + return t.container.GetIPInNetwork(network) +} + // WaitForRunning blocks until the Headscale instance is ready to // serve clients. func (t *HeadscaleInContainer) WaitForRunning() error { @@ -513,48 +1038,81 @@ func (t *HeadscaleInContainer) WaitForRunning() error { // CreateUser adds a new user to the Headscale instance. func (t *HeadscaleInContainer) CreateUser( user string, -) error { - command := []string{"headscale", "users", "create", user} +) (*v1.User, error) { + command := []string{ + "headscale", + "users", + "create", + user, + fmt.Sprintf("--email=%s@test.no", user), + "--output", + "json", + } - _, _, err := dockertestutil.ExecuteCommand( + result, _, err := dockertestutil.ExecuteCommand( t.container, command, []string{}, ) if err != nil { - return err + return nil, err } - return nil + var u v1.User + err = json.Unmarshal([]byte(result), &u) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal user: %w", err) + } + + return &u, nil } -// CreateAuthKey creates a new "authorisation key" for a User that can be used -// to authorise a TailscaleClient with the Headscale instance. -func (t *HeadscaleInContainer) CreateAuthKey( - user string, - reusable bool, - ephemeral bool, -) (*v1.PreAuthKey, error) { +// AuthKeyOptions defines options for creating an auth key. +type AuthKeyOptions struct { + // User is the user ID that owns the auth key. If nil and Tags are specified, + // the auth key is owned by the tags only (tags-as-identity model). + User *uint64 + // Reusable indicates if the key can be used multiple times + Reusable bool + // Ephemeral indicates if nodes registered with this key should be ephemeral + Ephemeral bool + // Tags are the tags to assign to the auth key + Tags []string +} + +// CreateAuthKeyWithOptions creates a new "authorisation key" with the specified options. +// This supports both user-owned and tags-only auth keys. +func (t *HeadscaleInContainer) CreateAuthKeyWithOptions(opts AuthKeyOptions) (*v1.PreAuthKey, error) { command := []string{ "headscale", - "--user", - user, + } + + // Only add --user flag if User is specified + if opts.User != nil { + command = append(command, "--user", strconv.FormatUint(*opts.User, 10)) + } + + command = append(command, "preauthkeys", "create", "--expiration", "24h", "--output", "json", - } + ) - if reusable { + if opts.Reusable { command = append(command, "--reusable") } - if ephemeral { + if opts.Ephemeral { command = append(command, "--ephemeral") } + if len(opts.Tags) > 0 { + command = append(command, "--tags", strings.Join(opts.Tags, ",")) + } + result, _, err := dockertestutil.ExecuteCommand( t.container, command, @@ -565,6 +1123,7 @@ func (t *HeadscaleInContainer) CreateAuthKey( } var preAuthKey v1.PreAuthKey + err = json.Unmarshal([]byte(result), &preAuthKey) if err != nil { return nil, fmt.Errorf("failed to unmarshal auth key: %w", err) @@ -573,12 +1132,172 @@ func (t *HeadscaleInContainer) CreateAuthKey( return &preAuthKey, nil } -// ListNodesInUser list the TailscaleClients (Node, Headscale internal representation) -// associated with a user. -func (t *HeadscaleInContainer) ListNodesInUser( - user string, +// CreateAuthKey creates a new "authorisation key" for a User that can be used +// to authorise a TailscaleClient with the Headscale instance. +func (t *HeadscaleInContainer) CreateAuthKey( + user uint64, + reusable bool, + ephemeral bool, +) (*v1.PreAuthKey, error) { + return t.CreateAuthKeyWithOptions(AuthKeyOptions{ + User: &user, + Reusable: reusable, + Ephemeral: ephemeral, + }) +} + +// CreateAuthKeyWithTags creates a new "authorisation key" for a User with the specified tags. +// This is used to create tagged PreAuthKeys for testing the tags-as-identity model. +func (t *HeadscaleInContainer) CreateAuthKeyWithTags( + user uint64, + reusable bool, + ephemeral bool, + tags []string, +) (*v1.PreAuthKey, error) { + return t.CreateAuthKeyWithOptions(AuthKeyOptions{ + User: &user, + Reusable: reusable, + Ephemeral: ephemeral, + Tags: tags, + }) +} + +// DeleteAuthKey deletes an "authorisation key" by ID. +func (t *HeadscaleInContainer) DeleteAuthKey( + id uint64, +) error { + command := []string{ + "headscale", + "preauthkeys", + "delete", + "--id", + strconv.FormatUint(id, 10), + "--output", + "json", + } + + _, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute delete auth key command: %w", err) + } + + return nil +} + +// ListNodes lists the currently registered Nodes in headscale. +// Optionally a list of usernames can be passed to get users for +// specific users. +func (t *HeadscaleInContainer) ListNodes( + users ...string, ) ([]*v1.Node, error) { - command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"} + var ret []*v1.Node + execUnmarshal := func(command []string) error { + result, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute list node command: %w", err) + } + + var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) + if err != nil { + return fmt.Errorf("failed to unmarshal nodes: %w", err) + } + + ret = append(ret, nodes...) + + return nil + } + + if len(users) == 0 { + err := execUnmarshal([]string{"headscale", "nodes", "list", "--output", "json"}) + if err != nil { + return nil, err + } + } else { + for _, user := range users { + command := []string{"headscale", "--user", user, "nodes", "list", "--output", "json"} + + err := execUnmarshal(command) + if err != nil { + return nil, err + } + } + } + + sort.Slice(ret, func(i, j int) bool { + return cmp.Compare(ret[i].GetId(), ret[j].GetId()) == -1 + }) + + return ret, nil +} + +func (t *HeadscaleInContainer) DeleteNode(nodeID uint64) error { + command := []string{ + "headscale", + "nodes", + "delete", + "--identifier", + fmt.Sprintf("%d", nodeID), + "--output", + "json", + "--force", + } + + _, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute delete node command: %w", err) + } + + return nil +} + +func (t *HeadscaleInContainer) NodesByUser() (map[string][]*v1.Node, error) { + nodes, err := t.ListNodes() + if err != nil { + return nil, err + } + + var userMap map[string][]*v1.Node + for _, node := range nodes { + if _, ok := userMap[node.GetUser().GetName()]; !ok { + mak.Set(&userMap, node.GetUser().GetName(), []*v1.Node{node}) + } else { + userMap[node.GetUser().GetName()] = append(userMap[node.GetUser().GetName()], node) + } + } + + return userMap, nil +} + +func (t *HeadscaleInContainer) NodesByName() (map[string]*v1.Node, error) { + nodes, err := t.ListNodes() + if err != nil { + return nil, err + } + + var nameMap map[string]*v1.Node + for _, node := range nodes { + mak.Set(&nameMap, node.GetName(), node) + } + + return nameMap, nil +} + +// ListUsers returns a list of users from Headscale. +func (t *HeadscaleInContainer) ListUsers() ([]*v1.User, error) { + command := []string{"headscale", "users", "list", "--output", "json"} result, _, err := dockertestutil.ExecuteCommand( t.container, @@ -589,13 +1308,228 @@ func (t *HeadscaleInContainer) ListNodesInUser( return nil, fmt.Errorf("failed to execute list node command: %w", err) } - var nodes []*v1.Node - err = json.Unmarshal([]byte(result), &nodes) + var users []*v1.User + err = json.Unmarshal([]byte(result), &users) if err != nil { return nil, fmt.Errorf("failed to unmarshal nodes: %w", err) } - return nodes, nil + return users, nil +} + +// MapUsers returns a map of users from Headscale. It is keyed by the +// user name. +func (t *HeadscaleInContainer) MapUsers() (map[string]*v1.User, error) { + users, err := t.ListUsers() + if err != nil { + return nil, err + } + + var userMap map[string]*v1.User + for _, user := range users { + mak.Set(&userMap, user.GetName(), user) + } + + return userMap, nil +} + +// DeleteUser deletes a user from the Headscale instance. +func (t *HeadscaleInContainer) DeleteUser(userID uint64) error { + command := []string{ + "headscale", + "users", + "delete", + "--identifier", + strconv.FormatUint(userID, 10), + "--force", + "--output", + "json", + } + + _, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute delete user command: %w", err) + } + + return nil +} + +func (h *HeadscaleInContainer) SetPolicy(pol *policyv2.Policy) error { + err := h.writePolicy(pol) + if err != nil { + return fmt.Errorf("writing policy file: %w", err) + } + + switch h.policyMode { + case types.PolicyModeDB: + err := h.reloadDatabasePolicy() + if err != nil { + return fmt.Errorf("reloading database policy: %w", err) + } + case types.PolicyModeFile: + err := h.Reload() + if err != nil { + return fmt.Errorf("reloading policy file: %w", err) + } + default: + panic("policy mode is not valid: " + h.policyMode) + } + + return nil +} + +func (h *HeadscaleInContainer) reloadDatabasePolicy() error { + _, err := h.Execute( + []string{ + "headscale", + "policy", + "set", + "-f", + aclPolicyPath, + }, + ) + if err != nil { + return fmt.Errorf("setting policy with db command: %w", err) + } + + return nil +} + +func (h *HeadscaleInContainer) writePolicy(pol *policyv2.Policy) error { + pBytes, err := json.Marshal(pol) + if err != nil { + return fmt.Errorf("marshalling pol: %w", err) + } + + err = h.WriteFile(aclPolicyPath, pBytes) + if err != nil { + return fmt.Errorf("writing policy to headscale container: %w", err) + } + + return nil +} + +func (h *HeadscaleInContainer) PID() (int, error) { + // Use pidof to find the headscale process, which is more reliable than grep + // as it only looks for the actual binary name, not processes that contain + // "headscale" in their command line (like the dlv debugger). + output, err := h.Execute([]string{"pidof", "headscale"}) + if err != nil { + // pidof returns exit code 1 when no process is found + return 0, os.ErrNotExist + } + + // pidof returns space-separated PIDs on a single line + pidStrs := strings.Fields(strings.TrimSpace(output)) + if len(pidStrs) == 0 { + return 0, os.ErrNotExist + } + + pids := make([]int, 0, len(pidStrs)) + for _, pidStr := range pidStrs { + pidInt, err := strconv.Atoi(pidStr) + if err != nil { + return 0, fmt.Errorf("parsing PID %q: %w", pidStr, err) + } + // We dont care about the root pid for the container + if pidInt == 1 { + continue + } + pids = append(pids, pidInt) + } + + switch len(pids) { + case 0: + return 0, os.ErrNotExist + case 1: + return pids[0], nil + default: + // If we still have multiple PIDs, return the first one as a fallback + // This can happen in edge cases during startup/shutdown + return pids[0], nil + } +} + +// Reload sends a SIGHUP to the headscale process to reload internals, +// for example Policy from file. +func (h *HeadscaleInContainer) Reload() error { + pid, err := h.PID() + if err != nil { + return fmt.Errorf("getting headscale PID: %w", err) + } + + _, err = h.Execute([]string{"kill", "-HUP", strconv.Itoa(pid)}) + if err != nil { + return fmt.Errorf("reloading headscale with HUP: %w", err) + } + + return nil +} + +// ApproveRoutes approves routes for a node. +func (t *HeadscaleInContainer) ApproveRoutes(id uint64, routes []netip.Prefix) (*v1.Node, error) { + command := []string{ + "headscale", "nodes", "approve-routes", + "--output", "json", + "--identifier", strconv.FormatUint(id, 10), + "--routes=" + strings.Join(util.PrefixesToString(routes), ","), + } + + result, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return nil, fmt.Errorf( + "failed to execute approve routes command (node %d, routes %v): %w", + id, + routes, + err, + ) + } + + var node *v1.Node + err = json.Unmarshal([]byte(result), &node) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal node response: %q, error: %w", result, err) + } + + return node, nil +} + +// SetNodeTags sets tags on a node via the headscale CLI. +// This simulates what the Tailscale admin console UI does - it calls the headscale +// SetTags API which is exposed via the CLI command: headscale nodes tag -i <id> -t <tags>. +func (t *HeadscaleInContainer) SetNodeTags(nodeID uint64, tags []string) error { + command := []string{ + "headscale", "nodes", "tag", + "--identifier", strconv.FormatUint(nodeID, 10), + "--output", "json", + } + + // Add tags - the CLI expects -t flag for each tag or comma-separated + if len(tags) > 0 { + command = append(command, "--tags", strings.Join(tags, ",")) + } else { + // Empty tags to clear all tags + command = append(command, "--tags", "") + } + + _, _, err := dockertestutil.ExecuteCommand( + t.container, + command, + []string{}, + ) + if err != nil { + return fmt.Errorf("failed to execute set tags command (node %d, tags %v): %w", nodeID, tags, err) + } + + return nil } // WriteFile save file inside the Headscale container. @@ -623,85 +1557,116 @@ func (t *HeadscaleInContainer) SendInterrupt() error { return nil } -// nolint -func createCertificate(hostname string) ([]byte, []byte, error) { - // From: - // https://shaneutt.com/blog/golang-ca-and-signed-cert-go/ - - ca := &x509.Certificate{ - SerialNumber: big.NewInt(2019), - Subject: pkix.Name{ - Organization: []string{"Headscale testing INC"}, - Country: []string{"NL"}, - Locality: []string{"Leiden"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(60 * time.Minute), - IsCA: true, - ExtKeyUsage: []x509.ExtKeyUsage{ - x509.ExtKeyUsageClientAuth, - x509.ExtKeyUsageServerAuth, - }, - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, - BasicConstraintsValid: true, +func (t *HeadscaleInContainer) GetAllMapReponses() (map[types.NodeID][]tailcfg.MapResponse, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/mapresponses", } - caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + result, err := t.Execute(command) if err != nil { - return nil, nil, err + return nil, fmt.Errorf("fetching mapresponses from debug endpoint: %w", err) } - cert := &x509.Certificate{ - SerialNumber: big.NewInt(1658), - Subject: pkix.Name{ - CommonName: hostname, - Organization: []string{"Headscale testing INC"}, - Country: []string{"NL"}, - Locality: []string{"Leiden"}, - }, - NotBefore: time.Now(), - NotAfter: time.Now().Add(60 * time.Minute), - SubjectKeyId: []byte{1, 2, 3, 4, 6}, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, - KeyUsage: x509.KeyUsageDigitalSignature, - DNSNames: []string{hostname}, + var res map[types.NodeID][]tailcfg.MapResponse + if err := json.Unmarshal([]byte(result), &res); err != nil { + return nil, fmt.Errorf("decoding routes response: %w", err) } - certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) - if err != nil { - return nil, nil, err - } - - certBytes, err := x509.CreateCertificate( - rand.Reader, - cert, - ca, - &certPrivKey.PublicKey, - caPrivKey, - ) - if err != nil { - return nil, nil, err - } - - certPEM := new(bytes.Buffer) - - err = pem.Encode(certPEM, &pem.Block{ - Type: "CERTIFICATE", - Bytes: certBytes, - }) - if err != nil { - return nil, nil, err - } - - certPrivKeyPEM := new(bytes.Buffer) - - err = pem.Encode(certPrivKeyPEM, &pem.Block{ - Type: "RSA PRIVATE KEY", - Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), - }) - if err != nil { - return nil, nil, err - } - - return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil + return res, nil +} + +// PrimaryRoutes fetches the primary routes from the debug endpoint. +func (t *HeadscaleInContainer) PrimaryRoutes() (*routes.DebugRoutes, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/routes", + } + + result, err := t.Execute(command) + if err != nil { + return nil, fmt.Errorf("fetching routes from debug endpoint: %w", err) + } + + var debugRoutes routes.DebugRoutes + if err := json.Unmarshal([]byte(result), &debugRoutes); err != nil { + return nil, fmt.Errorf("decoding routes response: %w", err) + } + + return &debugRoutes, nil +} + +// DebugBatcher fetches the batcher debug information from the debug endpoint. +func (t *HeadscaleInContainer) DebugBatcher() (*hscontrol.DebugBatcherInfo, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/batcher", + } + + result, err := t.Execute(command) + if err != nil { + return nil, fmt.Errorf("fetching batcher debug info: %w", err) + } + + var debugInfo hscontrol.DebugBatcherInfo + if err := json.Unmarshal([]byte(result), &debugInfo); err != nil { + return nil, fmt.Errorf("decoding batcher debug response: %w", err) + } + + return &debugInfo, nil +} + +// DebugNodeStore fetches the NodeStore data from the debug endpoint. +func (t *HeadscaleInContainer) DebugNodeStore() (map[types.NodeID]types.Node, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/nodestore", + } + + result, err := t.Execute(command) + if err != nil { + return nil, fmt.Errorf("fetching nodestore debug info: %w", err) + } + + var nodeStore map[types.NodeID]types.Node + if err := json.Unmarshal([]byte(result), &nodeStore); err != nil { + return nil, fmt.Errorf("decoding nodestore debug response: %w", err) + } + + return nodeStore, nil +} + +// DebugFilter fetches the current filter rules from the debug endpoint. +func (t *HeadscaleInContainer) DebugFilter() ([]tailcfg.FilterRule, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "-H", "Accept: application/json", "http://localhost:9090/debug/filter", + } + + result, err := t.Execute(command) + if err != nil { + return nil, fmt.Errorf("fetching filter from debug endpoint: %w", err) + } + + var filterRules []tailcfg.FilterRule + if err := json.Unmarshal([]byte(result), &filterRules); err != nil { + return nil, fmt.Errorf("decoding filter response: %w", err) + } + + return filterRules, nil +} + +// DebugPolicy fetches the current policy from the debug endpoint. +func (t *HeadscaleInContainer) DebugPolicy() (string, error) { + // Execute curl inside the container to access the debug endpoint locally + command := []string{ + "curl", "-s", "http://localhost:9090/debug/policy", + } + + result, err := t.Execute(command) + if err != nil { + return "", fmt.Errorf("fetching policy from debug endpoint: %w", err) + } + + return result, nil } diff --git a/integration/integrationutil/util.go b/integration/integrationutil/util.go index 59eeeb17..4ddc7ae9 100644 --- a/integration/integrationutil/util.go +++ b/integration/integrationutil/util.go @@ -3,15 +3,39 @@ package integrationutil import ( "archive/tar" "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" "fmt" "io" + "math/big" "path/filepath" + "time" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "tailscale.com/tailcfg" ) +// PeerSyncTimeout returns the timeout for peer synchronization based on environment: +// 60s for dev, 120s for CI. +func PeerSyncTimeout() time.Duration { + if util.IsCI() { + return 120 * time.Second + } + return 60 * time.Second +} + +// PeerSyncRetryInterval returns the retry interval for peer synchronization checks. +func PeerSyncRetryInterval() time.Duration { + return 100 * time.Millisecond +} + func WriteFileToContainer( pool *dockertest.Pool, container *dockertest.Resource, @@ -93,3 +117,113 @@ func FetchPathFromContainer( return buf.Bytes(), nil } + +// nolint +func CreateCertificate(hostname string) ([]byte, []byte, error) { + // From: + // https://shaneutt.com/blog/golang-ca-and-signed-cert-go/ + + ca := &x509.Certificate{ + SerialNumber: big.NewInt(2019), + Subject: pkix.Name{ + Organization: []string{"Headscale testing INC"}, + Country: []string{"NL"}, + Locality: []string{"Leiden"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(60 * time.Hour), + IsCA: true, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageClientAuth, + x509.ExtKeyUsageServerAuth, + }, + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign, + BasicConstraintsValid: true, + } + + caPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + + cert := &x509.Certificate{ + SerialNumber: big.NewInt(1658), + Subject: pkix.Name{ + CommonName: hostname, + Organization: []string{"Headscale testing INC"}, + Country: []string{"NL"}, + Locality: []string{"Leiden"}, + }, + NotBefore: time.Now(), + NotAfter: time.Now().Add(60 * time.Minute), + SubjectKeyId: []byte{1, 2, 3, 4, 6}, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth, x509.ExtKeyUsageServerAuth}, + KeyUsage: x509.KeyUsageDigitalSignature, + DNSNames: []string{hostname}, + } + + certPrivKey, err := rsa.GenerateKey(rand.Reader, 4096) + if err != nil { + return nil, nil, err + } + + certBytes, err := x509.CreateCertificate( + rand.Reader, + cert, + ca, + &certPrivKey.PublicKey, + caPrivKey, + ) + if err != nil { + return nil, nil, err + } + + certPEM := new(bytes.Buffer) + + err = pem.Encode(certPEM, &pem.Block{ + Type: "CERTIFICATE", + Bytes: certBytes, + }) + if err != nil { + return nil, nil, err + } + + certPrivKeyPEM := new(bytes.Buffer) + + err = pem.Encode(certPrivKeyPEM, &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(certPrivKey), + }) + if err != nil { + return nil, nil, err + } + + return certPEM.Bytes(), certPrivKeyPEM.Bytes(), nil +} + +func BuildExpectedOnlineMap(all map[types.NodeID][]tailcfg.MapResponse) map[types.NodeID]map[types.NodeID]bool { + res := make(map[types.NodeID]map[types.NodeID]bool) + for nid, mrs := range all { + res[nid] = make(map[types.NodeID]bool) + for _, mr := range mrs { + for _, peer := range mr.Peers { + if peer.Online != nil { + res[nid][types.NodeID(peer.ID)] = *peer.Online + } + } + + for _, peer := range mr.PeersChanged { + if peer.Online != nil { + res[nid][types.NodeID(peer.ID)] = *peer.Online + } + } + + for _, peer := range mr.PeersChangedPatch { + if peer.Online != nil { + res[nid][types.NodeID(peer.NodeID)] = *peer.Online + } + } + } + } + return res +} diff --git a/integration/route_test.go b/integration/route_test.go index 489165a8..0460b5ef 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -1,47 +1,70 @@ package integration import ( + "cmp" + "encoding/json" "fmt" - "log" + "maps" "net/netip" + "slices" "sort" "strconv" + "strings" "testing" "time" + cmpdiff "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/routes" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + xmaps "golang.org/x/exp/maps" + "tailscale.com/ipn/ipnstate" + "tailscale.com/net/tsaddr" + "tailscale.com/tailcfg" + "tailscale.com/types/ipproto" + "tailscale.com/types/views" + "tailscale.com/util/must" + "tailscale.com/util/slicesx" + "tailscale.com/wgengine/filter" ) +var allPorts = filter.PortRange{First: 0, Last: 0xffff} + // This test is both testing the routes command and the propagation of // routes. func TestEnablingRoutes(t *testing.T) { IntegrationSkip(t) - t.Parallel() - user := "enable-routing" - - scenario, err := NewScenario() - assertNoErrf(t, "failed to create scenario: %s", err) - defer scenario.Shutdown() - - spec := map[string]int{ - user: 3, + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1"}, } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{tsic.WithAcceptRoutes()}, + hsic.WithTestName("clienableroute")) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) expectedRoutes := map[string]string{ "1": "10.0.0.0/24", @@ -51,730 +74,2997 @@ func TestEnablingRoutes(t *testing.T) { // advertise routes using the up command for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) - + status := client.MustStatus() command := []string{ "tailscale", "set", "--advertise-routes=" + expectedRoutes[string(status.Self.ID)], } _, _, err = client.Execute(command) - assertNoErrf(t, "failed to advertise route: %s", err) + require.NoErrorf(t, err, "failed to advertise route: %s", err) } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) - var routes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routes, - ) + var nodes []*v1.Node + // Wait for route advertisements to propagate to NodeStore + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(ct, err) - assertNoErr(t, err) - assert.Len(t, routes, 3) - - for _, route := range routes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) - } + for _, node := range nodes { + assert.Len(ct, node.GetAvailableRoutes(), 1) + assert.Empty(ct, node.GetApprovedRoutes()) + assert.Empty(ct, node.GetSubnetRoutes()) + } + }, 10*time.Second, 100*time.Millisecond, "route advertisements should propagate to all nodes") // Verify that no routes has been sent to the client, // they are not yet enabled. for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - assert.Nil(t, peerStatus.PrimaryRoutes) + assert.Nil(c, peerStatus.PrimaryRoutes) + } + }, 5*time.Second, 200*time.Millisecond, "Verifying no routes are active before approval") + } + + for _, node := range nodes { + _, err := headscale.ApproveRoutes( + node.GetId(), + util.MustStringsToPrefixes(node.GetAvailableRoutes()), + ) + require.NoError(t, err) + } + + // Wait for route approvals to propagate to NodeStore + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(ct, err) + + for _, node := range nodes { + assert.Len(ct, node.GetAvailableRoutes(), 1) + assert.Len(ct, node.GetApprovedRoutes(), 1) + assert.Len(ct, node.GetSubnetRoutes(), 1) } - } + }, 10*time.Second, 100*time.Millisecond, "route approvals should propagate to all nodes") - // Enable all routes - for _, route := range routes { - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "enable", - "--route", - strconv.Itoa(int(route.GetId())), - }) - assertNoErr(t, err) - } + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) - var enablingRoutes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &enablingRoutes, + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.NotNil(c, peerStatus.PrimaryRoutes) + assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 3) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes[string(peerStatus.ID)])}) + } + } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") + + _, err = headscale.ApproveRoutes( + 1, + []netip.Prefix{netip.MustParsePrefix("10.0.1.0/24")}, ) - assertNoErr(t, err) - assert.Len(t, enablingRoutes, 3) + require.NoError(t, err) - for _, route := range enablingRoutes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) - } + _, err = headscale.ApproveRoutes( + 2, + []netip.Prefix{}, + ) + require.NoError(t, err) - time.Sleep(5 * time.Second) + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + + for _, node := range nodes { + if node.GetId() == 1 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.0.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetSubnetRoutes()) + } else if node.GetId() == 2 { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.1.0/24 + assert.Empty(c, node.GetApprovedRoutes()) + assert.Empty(c, node.GetSubnetRoutes()) + } else { + assert.Len(c, node.GetAvailableRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetApprovedRoutes(), 1) // 10.0.2.0/24 + assert.Len(c, node.GetSubnetRoutes(), 1) // 10.0.2.0/24 + } + } + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") // Verify that the clients can see the new routes for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] - assert.NotNil(t, peerStatus.PrimaryRoutes) - if peerStatus.PrimaryRoutes == nil { - continue + switch peerStatus.ID { + case "1": + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + case "2": + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + default: + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{netip.MustParsePrefix("10.0.2.0/24")}) + } } - - pRoutes := peerStatus.PrimaryRoutes.AsSlice() - - assert.Len(t, pRoutes, 1) - - if len(pRoutes) > 0 { - peerRoute := peerStatus.PrimaryRoutes.AsSlice()[0] - - // id starts at 1, we created routes with 0 index - assert.Equalf( - t, - expectedRoutes[string(peerStatus.ID)], - peerRoute.String(), - "expected route %s to be present on peer %s (%s) in %s (%s) status", - expectedRoutes[string(peerStatus.ID)], - peerStatus.HostName, - peerStatus.ID, - client.Hostname(), - client.ID(), - ) - } - } - } - - routeToBeDisabled := enablingRoutes[0] - log.Printf("preparing to disable %v", routeToBeDisabled) - - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "disable", - "--route", - strconv.Itoa(int(routeToBeDisabled.GetId())), - }) - assertNoErr(t, err) - - var disablingRoutes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &disablingRoutes, - ) - assertNoErr(t, err) - - for _, route := range disablingRoutes { - assert.Equal(t, true, route.GetAdvertised()) - - if route.GetId() == routeToBeDisabled.GetId() { - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) - } else { - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) - } - } - - time.Sleep(5 * time.Second) - - // Verify that the clients can see the new routes - for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) - - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - if string(peerStatus.ID) == fmt.Sprintf("%d", routeToBeDisabled.GetNode().GetId()) { - assert.Nilf( - t, - peerStatus.PrimaryRoutes, - "expected node %s to have no routes, got primary route (%v)", - peerStatus.HostName, - peerStatus.PrimaryRoutes, - ) - } - } + }, 5*time.Second, 200*time.Millisecond, "Verifying final route state visible to clients") } } func TestHASubnetRouterFailover(t *testing.T) { IntegrationSkip(t) - t.Parallel() - user := "enable-routing" + propagationTime := 60 * time.Second - scenario, err := NewScenario() - assertNoErrf(t, "failed to create scenario: %s", err) - defer scenario.Shutdown() + // Helper function to validate primary routes table state + validatePrimaryRoutes := func(t *testing.T, headscale ControlServer, expectedRoutes *routes.DebugRoutes, message string) { + t.Helper() + assert.EventuallyWithT(t, func(c *assert.CollectT) { + primaryRoutesState, err := headscale.PrimaryRoutes() + assert.NoError(c, err) - spec := map[string]int{ - user: 3, + if diff := cmpdiff.Diff(expectedRoutes, primaryRoutesState, util.PrefixComparer); diff != "" { + t.Log(message) + t.Errorf("validatePrimaryRoutes mismatch (-want +got):\n%s", diff) + } + }, propagationTime, 200*time.Millisecond, "Validating primary routes table") } - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) + spec := ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + // defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{tsic.WithAcceptRoutes()}, + hsic.WithTestName("clienableroute"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) + requireNoErrGetHeadscale(t, err) - expectedRoutes := map[string]string{ - "1": "10.0.0.0/24", - "2": "10.0.0.0/24", - } + prefp, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + pref := *prefp + t.Logf("usernet1 prefix: %s", pref.String()) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + weburl := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("webservice: %s, %s", webip.String(), weburl) // Sort nodes by ID sort.SliceStable(allClients, func(i, j int) bool { - statusI, err := allClients[i].Status() - if err != nil { - return false - } - - statusJ, err := allClients[j].Status() - if err != nil { - return false - } + statusI := allClients[i].MustStatus() + statusJ := allClients[j].MustStatus() return statusI.Self.ID < statusJ.Self.ID }) + // This is ok because the scenario makes users in order, so the three first + // nodes, which are subnet routes, will be created first, and the last user + // will be created with the second. subRouter1 := allClients[0] subRouter2 := allClients[1] + subRouter3 := allClients[2] - client := allClients[2] + client := allClients[3] - // advertise HA route on node 1 and 2 - // ID 1 will be primary - // ID 2 will be secondary - for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) - - if route, ok := expectedRoutes[string(status.Self.ID)]; ok { - command := []string{ - "tailscale", - "set", - "--advertise-routes=" + route, - } - _, _, err = client.Execute(command) - assertNoErrf(t, "failed to advertise route: %s", err) + t.Logf("%s (%s) picked as client", client.Hostname(), client.MustID()) + t.Logf("=== Initial Route Advertisement - Setting up HA configuration with 3 routers ===") + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" - Router 1 (%s): Advertising route %s - will become PRIMARY when approved", subRouter1.Hostname(), pref.String()) + t.Logf(" - Router 2 (%s): Advertising route %s - will be STANDBY when approved", subRouter2.Hostname(), pref.String()) + t.Logf(" - Router 3 (%s): Advertising route %s - will be STANDBY when approved", subRouter3.Hostname(), pref.String()) + t.Logf(" Expected: All 3 routers advertise the same route for redundancy, but only one will be primary at a time") + for _, client := range allClients[:3] { + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + pref.String(), } + _, _, err = client.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) } err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) - var routes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routes, - ) - - assertNoErr(t, err) - assert.Len(t, routes, 2) - - for _, route := range routes { - assert.Equal(t, true, route.GetAdvertised()) - assert.Equal(t, false, route.GetEnabled()) - assert.Equal(t, false, route.GetIsPrimary()) - } + // Wait for route configuration changes after advertising routes + var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + require.GreaterOrEqual(t, len(nodes), 3, "need at least 3 nodes to avoid panic") + requireNodeRouteCountWithCollect(c, nodes[0], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, propagationTime, 200*time.Millisecond, "Waiting for route advertisements: All 3 routers should have advertised routes (available=1) but none approved yet (approved=0, subnet=0)") // Verify that no routes has been sent to the client, // they are not yet enabled. for _, client := range allClients { - status, err := client.Status() - assertNoErr(t, err) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(c, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + }, propagationTime, 200*time.Millisecond, "Verifying no routes are active before approval") + } + + // Declare variables that will be used across multiple EventuallyWithT blocks + var ( + srs1, srs2, srs3 *ipnstate.Status + clientStatus *ipnstate.Status + srs1PeerStatus *ipnstate.PeerStatus + srs2PeerStatus *ipnstate.PeerStatus + srs3PeerStatus *ipnstate.PeerStatus + ) + + // Helper function to check test failure and print route map if needed + checkFailureAndPrintRoutes := func(t *testing.T, client TailscaleClient) { + if t.Failed() { + t.Logf("[%s] Test failed at this checkpoint", time.Now().Format(TimestampFormat)) + status, err := client.Status() + if err == nil { + printCurrentRouteMap(t, xmaps.Values(status.Peer)...) + } + t.FailNow() + } + } + + // Validate primary routes table state - no routes approved yet + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{}, + PrimaryRoutes: map[string]types.NodeID{}, // No primary routes yet + }, "Primary routes table should be empty (no approved routes yet)") + + checkFailureAndPrintRoutes(t, client) + + // Enable route on node 1 + t.Logf("=== Approving route on router 1 (%s) - Single router mode (no HA yet) ===", subRouter1.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Expected: Router 1 becomes PRIMARY with route %s active", pref.String()) + t.Logf(" Expected: Routers 2 & 3 remain with advertised but unapproved routes") + t.Logf(" Expected: Client can access webservice through router 1 only") + _, err = headscale.ApproveRoutes( + MustFindNode(subRouter1.Hostname(), nodes).GetId(), + []netip.Prefix{pref}, + ) + require.NoError(t, err) + + // Wait for route approval on first subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + require.GreaterOrEqual(t, len(nodes), 3, "need at least 3 nodes to avoid panic") + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + }, propagationTime, 200*time.Millisecond, "Router 1 approval verification: Should be PRIMARY (available=1, approved=1, subnet=1), others still unapproved (available=1, approved=0, subnet=0)") + + // Verify that the client has routes from the primary machine and can access + // the webservice. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs1 = subRouter1.MustStatus() + srs2 = subRouter2.MustStatus() + srs3 = subRouter3.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.True(c, srs1PeerStatus.Online, "Router 1 should be online and serving as PRIMARY") + assert.True(c, srs2PeerStatus.Online, "Router 2 should be online but NOT serving routes (unapproved)") + assert.True(c, srs3PeerStatus.Online, "Router 3 should be online but NOT serving routes (unapproved)") + + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs1PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, nil) + + if srs1PeerStatus.PrimaryRoutes != nil { + t.Logf("got list: %v, want in: %v", srs1PeerStatus.PrimaryRoutes.AsSlice(), pref) + assert.Contains(c, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Verifying Router 1 is PRIMARY with routes after approval") + + t.Logf("=== Validating connectivity through PRIMARY router 1 (%s) to webservice at %s ===", must.Get(subRouter1.IPv4()).String(), webip.String()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Expected: Traffic flows through router 1 as it's the only approved route") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 1") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter1.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1") + + // Validate primary routes table state - router 1 is primary + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + // Note: Router 2 and 3 are available but not approved + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()), + }, + }, "Router 1 should be primary for route "+pref.String()) + + checkFailureAndPrintRoutes(t, client) + + // Enable route on node 2, now we will have a HA subnet router + t.Logf("=== Enabling High Availability by approving route on router 2 (%s) ===", subRouter2.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 is PRIMARY and actively serving traffic") + t.Logf(" Expected: Router 2 becomes STANDBY (approved but not primary)") + t.Logf(" Expected: Router 1 remains PRIMARY (no flapping - stability preferred)") + t.Logf(" Expected: HA is now active - if router 1 fails, router 2 can take over") + _, err = headscale.ApproveRoutes( + MustFindNode(subRouter2.Hostname(), nodes).GetId(), + []netip.Prefix{pref}, + ) + require.NoError(t, err) + + // Wait for route approval on second subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + if len(nodes) >= 3 { + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 0, 0) + } + }, 3*time.Second, 200*time.Millisecond, "HA setup verification: Router 2 approved as STANDBY (available=1, approved=1, subnet=0), Router 1 stays PRIMARY (subnet=1)") + + // Verify that the client has routes from the primary machine + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs1 = subRouter1.MustStatus() + srs2 = subRouter2.MustStatus() + srs3 = subRouter3.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.True(c, srs1PeerStatus.Online, "Router 1 should be online and remain PRIMARY") + assert.True(c, srs2PeerStatus.Online, "Router 2 should be online and now approved as STANDBY") + assert.True(c, srs3PeerStatus.Online, "Router 3 should be online but still unapproved") + + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs1PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, nil) + + if srs1PeerStatus.PrimaryRoutes != nil { + t.Logf("got list: %v, want in: %v", srs1PeerStatus.PrimaryRoutes.AsSlice(), pref) + assert.Contains(c, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Verifying Router 1 remains PRIMARY after Router 2 approval") + + // Validate primary routes table state - router 1 still primary, router 2 approved but standby + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + // Note: Router 3 is available but not approved + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()), + }, + }, "Router 1 should remain primary after router 2 approval") + + checkFailureAndPrintRoutes(t, client) + + t.Logf("=== Validating HA configuration - Router 1 PRIMARY, Router 2 STANDBY ===") + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current routing: Traffic through router 1 (%s) to %s", must.Get(subRouter1.IPv4()), webip.String()) + t.Logf(" Expected: Router 1 continues to handle all traffic (no change from before)") + t.Logf(" Expected: Router 2 is ready to take over if router 1 fails") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 1 in HA mode") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter1.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 1 in HA mode") + + // Validate primary routes table state - router 1 primary, router 2 approved (standby) + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + // Note: Router 3 is available but not approved + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()), + }, + }, "Router 1 primary with router 2 as standby") + + checkFailureAndPrintRoutes(t, client) + + // Enable route on node 3, now we will have a second standby and all will + // be enabled. + t.Logf("=== Adding second STANDBY router by approving route on router 3 (%s) ===", subRouter3.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 PRIMARY, Router 2 STANDBY") + t.Logf(" Expected: Router 3 becomes second STANDBY (approved but not primary)") + t.Logf(" Expected: Router 1 remains PRIMARY, Router 2 remains first STANDBY") + t.Logf(" Expected: Full HA configuration with 1 PRIMARY + 2 STANDBY routers") + _, err = headscale.ApproveRoutes( + MustFindNode(subRouter3.Hostname(), nodes).GetId(), + []netip.Prefix{pref}, + ) + require.NoError(t, err) + + // Wait for route approval on third subnet router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + require.GreaterOrEqual(t, len(nodes), 3, "need at least 3 nodes to avoid panic") + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0) + }, 3*time.Second, 200*time.Millisecond, "Full HA verification: Router 3 approved as second STANDBY (available=1, approved=1, subnet=0), Router 1 PRIMARY, Router 2 first STANDBY") + + // Verify that the client has routes from the primary machine + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs1 = subRouter1.MustStatus() + srs2 = subRouter2.MustStatus() + srs3 = subRouter3.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.True(c, srs1PeerStatus.Online, "Router 1 should be online and remain PRIMARY") + assert.True(c, srs2PeerStatus.Online, "Router 2 should be online as first STANDBY") + assert.True(c, srs3PeerStatus.Online, "Router 3 should be online as second STANDBY") + + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs1PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, nil) + + if srs1PeerStatus.PrimaryRoutes != nil { + t.Logf("got list: %v, want in: %v", srs1PeerStatus.PrimaryRoutes.AsSlice(), pref) + assert.Contains(c, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Verifying full HA with 3 routers: Router 1 PRIMARY, Routers 2 & 3 STANDBY") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 1 with full HA") + + // Wait for traceroute to work correctly through the expected router + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + + // Get the expected router IP - use a more robust approach to handle temporary disconnections + ips, err := subRouter1.IPs() + assert.NoError(c, err) + assert.NotEmpty(c, ips, "subRouter1 should have IP addresses") + + var expectedIP netip.Addr + for _, ip := range ips { + if ip.Is4() { + expectedIP = ip + break + } + } + assert.True(c, expectedIP.IsValid(), "subRouter1 should have a valid IPv4 address") + + assertTracerouteViaIPWithCollect(c, tr, expectedIP) + }, propagationTime, 200*time.Millisecond, "Verifying traffic still flows through PRIMARY router 1 with full HA setup active") + + // Validate primary routes table state - all 3 routers approved, router 1 still primary + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()): {pref}, + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()), + }, + }, "Router 1 primary with all 3 routers approved") + + checkFailureAndPrintRoutes(t, client) + + // Take down the current primary + t.Logf("=== FAILOVER TEST: Taking down PRIMARY router 1 (%s) ===", subRouter1.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 PRIMARY (serving traffic), Router 2 & 3 STANDBY") + t.Logf(" Action: Shutting down router 1 to simulate failure") + t.Logf(" Expected: Router 2 (%s) should automatically become new PRIMARY", subRouter2.Hostname()) + t.Logf(" Expected: Router 3 remains STANDBY") + t.Logf(" Expected: Traffic seamlessly fails over to router 2") + err = subRouter1.Down() + require.NoError(t, err) + + // Wait for router status changes after r1 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs2 = subRouter2.MustStatus() + clientStatus = client.MustStatus() + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.False(c, srs1PeerStatus.Online, "r1 should be offline") + assert.True(c, srs2PeerStatus.Online, "r2 should be online") + assert.True(c, srs3PeerStatus.Online, "r3 should be online") + + assert.Nil(c, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, nil) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains(c, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Failover verification: Router 1 offline, Router 2 should be new PRIMARY with routes, Router 3 still STANDBY") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 2 after failover") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter2.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after failover") + + // Validate primary routes table state - router 2 is now primary after router 1 failure + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + // Router 1 is disconnected, so not in AvailableRoutes + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()): {pref}, + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()), + }, + }, "Router 2 should be primary after router 1 failure") + + checkFailureAndPrintRoutes(t, client) + + // Take down subnet router 2, leaving none available + t.Logf("=== FAILOVER TEST: Taking down NEW PRIMARY router 2 (%s) ===", subRouter2.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 OFFLINE, Router 2 PRIMARY (serving traffic), Router 3 STANDBY") + t.Logf(" Action: Shutting down router 2 to simulate cascading failure") + t.Logf(" Expected: Router 3 (%s) should become new PRIMARY (last remaining router)", subRouter3.Hostname()) + t.Logf(" Expected: With only 1 router left, HA is effectively disabled") + t.Logf(" Expected: Traffic continues through router 3") + err = subRouter2.Down() + require.NoError(t, err) + + // Wait for router status changes after r2 goes down + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.False(c, srs1PeerStatus.Online, "Router 1 should still be offline") + assert.False(c, srs2PeerStatus.Online, "Router 2 should now be offline after failure") + assert.True(c, srs3PeerStatus.Online, "Router 3 should be online and taking over as PRIMARY") + + assert.Nil(c, srs1PeerStatus.PrimaryRoutes) + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, []netip.Prefix{pref}) + }, propagationTime, 200*time.Millisecond, "Second failover verification: Router 1 & 2 offline, Router 3 should be new PRIMARY (last router standing) with routes") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 3 after second failover") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter3.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after second failover") + + // Validate primary routes table state - router 3 is now primary after router 2 failure + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + // Routers 1 and 2 are disconnected, so not in AvailableRoutes + types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()): {pref}, + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()), + }, + }, "Router 3 should be primary after router 2 failure") + + checkFailureAndPrintRoutes(t, client) + + // Bring up subnet router 1, making the route available from there. + t.Logf("=== RECOVERY TEST: Bringing router 1 (%s) back online ===", subRouter1.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 OFFLINE, Router 2 OFFLINE, Router 3 PRIMARY (only router)") + t.Logf(" Action: Starting router 1 to restore HA capability") + t.Logf(" Expected: Router 3 remains PRIMARY (stability - no unnecessary failover)") + t.Logf(" Expected: Router 1 becomes STANDBY (ready for HA)") + t.Logf(" Expected: HA is restored with 2 routers available") + err = subRouter1.Up() + require.NoError(t, err) + + // Wait for router status changes after r1 comes back up + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.True(c, srs1PeerStatus.Online, "Router 1 should be back online as STANDBY") + assert.False(c, srs2PeerStatus.Online, "Router 2 should still be offline") + assert.True(c, srs3PeerStatus.Online, "Router 3 should remain online as PRIMARY") + + assert.Nil(c, srs1PeerStatus.PrimaryRoutes) + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, []netip.Prefix{pref}) + + if srs3PeerStatus.PrimaryRoutes != nil { + assert.Contains(c, + srs3PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Recovery verification: Router 1 back online as STANDBY, Router 3 remains PRIMARY (no flapping) with routes") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can still reach webservice through router 3 after router 1 recovery") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter3.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 3 after router 1 recovery") + + // Validate primary routes table state - router 3 remains primary after router 1 comes back + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + // Router 2 is still disconnected + types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()): {pref}, + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()), + }, + }, "Router 3 should remain primary after router 1 recovery") + + checkFailureAndPrintRoutes(t, client) + + // Bring up subnet router 2, should result in no change. + t.Logf("=== FULL RECOVERY TEST: Bringing router 2 (%s) back online ===", subRouter2.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 STANDBY, Router 2 OFFLINE, Router 3 PRIMARY") + t.Logf(" Action: Starting router 2 to restore full HA (3 routers)") + t.Logf(" Expected: Router 3 (%s) remains PRIMARY (stability - avoid unnecessary failovers)", subRouter3.Hostname()) + t.Logf(" Expected: Router 1 (%s) remains first STANDBY", subRouter1.Hostname()) + t.Logf(" Expected: Router 2 (%s) becomes second STANDBY", subRouter2.Hostname()) + t.Logf(" Expected: Full HA restored with all 3 routers online") + err = subRouter2.Up() + require.NoError(t, err) + + // Wait for nodestore batch processing to complete and online status to be updated + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for all routers to be online + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.True(c, srs1PeerStatus.Online, "Router 1 should be online as STANDBY") + assert.True(c, srs2PeerStatus.Online, "Router 2 should be back online as STANDBY") + assert.True(c, srs3PeerStatus.Online, "Router 3 should remain online as PRIMARY") + + assert.Nil(c, srs1PeerStatus.PrimaryRoutes) + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, []netip.Prefix{pref}) + + if srs3PeerStatus.PrimaryRoutes != nil { + assert.Contains(c, + srs3PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, 10*time.Second, 500*time.Millisecond, "Full recovery verification: All 3 routers online, Router 3 remains PRIMARY (no flapping) with routes") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 3 after full recovery") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter3.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter3") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 3 after full recovery") + + // Validate primary routes table state - router 3 remains primary after all routers back online + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()): {pref}, + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()), + }, + }, "Router 3 should remain primary after full recovery") + + checkFailureAndPrintRoutes(t, client) + + t.Logf("=== ROUTE DISABLE TEST: Removing approved route from PRIMARY router 3 (%s) ===", subRouter3.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 STANDBY, Router 2 STANDBY, Router 3 PRIMARY") + t.Logf(" Action: Disabling route approval on router 3 (route still advertised but not approved)") + t.Logf(" Expected: Router 1 (%s) should become new PRIMARY (lowest ID with approved route)", subRouter1.Hostname()) + t.Logf(" Expected: Router 2 (%s) remains STANDBY", subRouter2.Hostname()) + t.Logf(" Expected: Router 3 (%s) goes to advertised-only state (no longer serving)", subRouter3.Hostname()) + _, err = headscale.ApproveRoutes(MustFindNode(subRouter3.Hostname(), nodes).GetId(), []netip.Prefix{}) + + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + + // After disabling route on r3, r1 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "Route disable verification: Router 3 route disabled, Router 1 should be new PRIMARY, Router 2 STANDBY") + + // Verify that the route is announced from subnet router 1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.NotNil(c, srs1PeerStatus.PrimaryRoutes) + assert.Nil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, nil) + + if srs1PeerStatus.PrimaryRoutes != nil { + assert.Contains(c, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Verifying Router 1 becomes PRIMARY after Router 3 route disabled") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 1 after route disable") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter1.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter1") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 1 after route disable") + + // Validate primary routes table state - router 1 is primary after router 3 route disabled + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + // Router 3's route is no longer approved, so not in AvailableRoutes + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()), + }, + }, "Router 1 should be primary after router 3 route disabled") + + checkFailureAndPrintRoutes(t, client) + + // Disable the route of subnet router 1, making it failover to 2 + t.Logf("=== ROUTE DISABLE TEST: Removing approved route from NEW PRIMARY router 1 (%s) ===", subRouter1.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 PRIMARY, Router 2 STANDBY, Router 3 advertised-only") + t.Logf(" Action: Disabling route approval on router 1") + t.Logf(" Expected: Router 2 (%s) should become new PRIMARY (only remaining approved route)", subRouter2.Hostname()) + t.Logf(" Expected: Router 1 (%s) goes to advertised-only state", subRouter1.Hostname()) + t.Logf(" Expected: Router 3 (%s) remains advertised-only", subRouter3.Hostname()) + _, err = headscale.ApproveRoutes(MustFindNode(subRouter1.Hostname(), nodes).GetId(), []netip.Prefix{}) + + // Wait for nodestore batch processing and route state changes to complete + // NodeStore batching timeout is 500ms, so we wait up to 10 seconds for route failover + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + + // After disabling route on r1, r2 should become primary with 1 subnet route + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 0, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "Second route disable verification: Router 1 route disabled, Router 2 should be new PRIMARY") + + // Verify that the route is announced from subnet router 1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.Nil(c, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, nil) + requirePeerSubnetRoutesWithCollect(c, srs2PeerStatus, []netip.Prefix{pref}) + requirePeerSubnetRoutesWithCollect(c, srs3PeerStatus, nil) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains(c, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Verifying Router 2 becomes PRIMARY after Router 1 route disabled") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 2 after second route disable") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter2.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute goes through router 2 after second route disable") + + // Validate primary routes table state - router 2 is primary after router 1 route disabled + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + // Router 1's route is no longer approved, so not in AvailableRoutes + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + // Router 3's route is still not approved + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()), + }, + }, "Router 2 should be primary after router 1 route disabled") + + checkFailureAndPrintRoutes(t, client) + + // enable the route of subnet router 1, no change expected + t.Logf("=== ROUTE RE-ENABLE TEST: Re-approving route on router 1 (%s) ===", subRouter1.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 advertised-only, Router 2 PRIMARY, Router 3 advertised-only") + t.Logf(" Action: Re-enabling route approval on router 1") + t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability - no unnecessary flapping)", subRouter2.Hostname()) + t.Logf(" Expected: Router 1 (%s) becomes STANDBY (approved but not primary)", subRouter1.Hostname()) + t.Logf(" Expected: HA fully restored with Router 2 PRIMARY and Router 1 STANDBY") + r1Node := MustFindNode(subRouter1.Hostname(), nodes) + _, err = headscale.ApproveRoutes( + r1Node.GetId(), + util.MustStringsToPrefixes(r1Node.GetAvailableRoutes()), + ) + + // Wait for route state changes after re-enabling r1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter1.Hostname(), nodes), 1, 1, 0) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter2.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, MustFindNode(subRouter3.Hostname(), nodes), 1, 0, 0) + }, propagationTime, 200*time.Millisecond, "Re-enable verification: Router 1 approved as STANDBY, Router 2 remains PRIMARY (no flapping), full HA restored") + + // Verify that the route is announced from subnet router 1 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientStatus, err = client.Status() + assert.NoError(c, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + srs3PeerStatus = clientStatus.Peer[srs3.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + assert.NotNil(c, srs2PeerStatus, "Router 2 peer should exist") + assert.NotNil(c, srs3PeerStatus, "Router 3 peer should exist") + + if srs1PeerStatus == nil || srs2PeerStatus == nil || srs3PeerStatus == nil { + return + } + + assert.Nil(c, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(c, srs2PeerStatus.PrimaryRoutes) + assert.Nil(c, srs3PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains(c, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + pref, + ) + } + }, propagationTime, 200*time.Millisecond, "Verifying Router 2 remains PRIMARY after Router 1 route re-enabled") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, propagationTime, 200*time.Millisecond, "Verifying client can reach webservice through router 2 after route re-enable") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := subRouter2.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for subRouter2") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, propagationTime, 200*time.Millisecond, "Verifying traceroute still goes through router 2 after route re-enable") + + // Validate primary routes table state after router 1 re-approval + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + // Router 3 route is still not approved + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()), + }, + }, "Router 2 should remain primary after router 1 re-approval") + + checkFailureAndPrintRoutes(t, client) + + // Enable route on node 3, we now have all routes re-enabled + t.Logf("=== ROUTE RE-ENABLE TEST: Re-approving route on router 3 (%s) - Full HA Restoration ===", subRouter3.Hostname()) + t.Logf("[%s] Starting test section", time.Now().Format(TimestampFormat)) + t.Logf(" Current state: Router 1 STANDBY, Router 2 PRIMARY, Router 3 advertised-only") + t.Logf(" Action: Re-enabling route approval on router 3") + t.Logf(" Expected: Router 2 (%s) remains PRIMARY (stability preferred)", subRouter2.Hostname()) + t.Logf(" Expected: Routers 1 & 3 are both STANDBY") + t.Logf(" Expected: Full HA restored with all 3 routers available") + r3Node := MustFindNode(subRouter3.Hostname(), nodes) + _, err = headscale.ApproveRoutes( + r3Node.GetId(), + util.MustStringsToPrefixes(r3Node.GetAvailableRoutes()), + ) + + // Wait for route state changes after re-enabling r3 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 6) + require.GreaterOrEqual(t, len(nodes), 3, "need at least 3 nodes to avoid panic") + // After router 3 re-approval: Router 2 remains PRIMARY, Routers 1&3 are STANDBY + // SubnetRoutes should only show routes for PRIMARY node (actively serving) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 0) // Router 1: STANDBY (available, approved, but not serving) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 1) // Router 2: PRIMARY (available, approved, and serving) + requireNodeRouteCountWithCollect(c, nodes[2], 1, 1, 0) // Router 3: STANDBY (available, approved, but not serving) + }, propagationTime, 200*time.Millisecond, "Waiting for route state after router 3 re-approval") + + // Validate primary routes table state after router 3 re-approval + validatePrimaryRoutes(t, headscale, &routes.DebugRoutes{ + AvailableRoutes: map[types.NodeID][]netip.Prefix{ + types.NodeID(MustFindNode(subRouter1.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()): {pref}, + types.NodeID(MustFindNode(subRouter3.Hostname(), nodes).GetId()): {pref}, + }, + PrimaryRoutes: map[string]types.NodeID{ + pref.String(): types.NodeID(MustFindNode(subRouter2.Hostname(), nodes).GetId()), + }, + }, "Router 2 should remain primary after router 3 re-approval") + + checkFailureAndPrintRoutes(t, client) +} + +// TestSubnetRouteACL verifies that Subnet routes are distributed +// as expected when ACLs are activated. +// It implements the issue from +// https://github.com/juanfont/headscale/issues/1604 +func TestSubnetRouteACL(t *testing.T) { + IntegrationSkip(t) + + user := "user4" + + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{user}, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithAcceptRoutes(), + }, hsic.WithTestName("clienableroute"), hsic.WithACLPolicy( + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:admins"): []policyv2.Username{policyv2.Username(user + "@")}, + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:admins")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(groupp("group:admins"), tailcfg.PortRangeAny), + }, + }, + { + Action: "accept", + Sources: []policyv2.Alias{groupp("group:admins")}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp("10.33.0.0/16"), tailcfg.PortRangeAny), + }, + }, + }, + }, + )) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + expectedRoutes := map[string]string{ + "1": "10.33.0.0/16", + } + + // Sort nodes by ID + sort.SliceStable(allClients, func(i, j int) bool { + statusI := allClients[i].MustStatus() + statusJ := allClients[j].MustStatus() + return statusI.Self.ID < statusJ.Self.ID + }) + + subRouter1 := allClients[0] + + client := allClients[1] + + for _, client := range allClients { + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + if route, ok := expectedRoutes[string(status.Self.ID)]; ok { + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + route, + } + _, _, err = client.Execute(command) + assert.NoErrorf(c, err, "failed to advertise route: %s", err) + } + }, 5*time.Second, 200*time.Millisecond, "Configuring route advertisements") + } + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // Wait for route advertisements to propagate to the server + var nodes []*v1.Node + require.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + // Find the node that should have the route by checking node IDs + var routeNode *v1.Node + var otherNode *v1.Node + for _, node := range nodes { + nodeIDStr := strconv.FormatUint(node.GetId(), 10) + if _, shouldHaveRoute := expectedRoutes[nodeIDStr]; shouldHaveRoute { + routeNode = node + } else { + otherNode = node + } + } + + assert.NotNil(c, routeNode, "could not find node that should have route") + assert.NotNil(c, otherNode, "could not find node that should not have route") + + // After NodeStore fix: routes are properly tracked in route manager + // This test uses a policy with NO auto-approvers, so routes should be: + // announced=1, approved=0, subnet=0 (routes announced but not approved) + requireNodeRouteCountWithCollect(c, routeNode, 1, 0, 0) + requireNodeRouteCountWithCollect(c, otherNode, 0, 0, 0) + }, 10*time.Second, 100*time.Millisecond, "route advertisements should propagate to server") + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(c, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + }, 5*time.Second, 200*time.Millisecond, "Verifying no routes are active before approval") + } + + _, err = headscale.ApproveRoutes( + 1, + []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])}, + ) + require.NoError(t, err) + + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 0, 0, 0) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") + + // Verify that the client has routes from the primary machine + assert.EventuallyWithT(t, func(c *assert.CollectT) { + srs1, err := subRouter1.Status() + assert.NoError(c, err) + + clientStatus, err := client.Status() + assert.NoError(c, err) + + srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] + + assert.NotNil(c, srs1PeerStatus, "Router 1 peer should exist") + if srs1PeerStatus == nil { + return + } + + requirePeerSubnetRoutesWithCollect(c, srs1PeerStatus, []netip.Prefix{netip.MustParsePrefix(expectedRoutes["1"])}) + }, 5*time.Second, 200*time.Millisecond, "Verifying client can see subnet routes from router") + + // Wait for packet filter updates to propagate to client netmap + wantClientFilter := []filter.Match{ + { + IPProto: views.SliceOf([]ipproto.Proto{ + ipproto.TCP, ipproto.UDP, + }), + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.64.0.2/32"), + Ports: allPorts, + }, + { + Net: netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + Ports: allPorts, + }, + }, + Caps: []filter.CapMatch{}, + }, + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + clientNm, err := client.Netmap() + assert.NoError(c, err) + + if diff := cmpdiff.Diff(wantClientFilter, clientNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { + assert.Fail(c, fmt.Sprintf("Client (%s) filter, unexpected result (-want +got):\n%s", client.Hostname(), diff)) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for client packet filter to update") + + // Wait for packet filter updates to propagate to subnet router netmap + wantSubnetFilter := []filter.Match{ + { + IPProto: views.SliceOf([]ipproto.Proto{ + ipproto.TCP, ipproto.UDP, + }), + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("100.64.0.1/32"), + Ports: allPorts, + }, + { + Net: netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + Ports: allPorts, + }, + }, + Caps: []filter.CapMatch{}, + }, + { + IPProto: views.SliceOf([]ipproto.Proto{ + ipproto.TCP, ipproto.UDP, + }), + Srcs: []netip.Prefix{ + netip.MustParsePrefix("100.64.0.1/32"), + netip.MustParsePrefix("100.64.0.2/32"), + netip.MustParsePrefix("fd7a:115c:a1e0::1/128"), + netip.MustParsePrefix("fd7a:115c:a1e0::2/128"), + }, + Dsts: []filter.NetPortRange{ + { + Net: netip.MustParsePrefix("10.33.0.0/16"), + Ports: allPorts, + }, + }, + Caps: []filter.CapMatch{}, + }, + } + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + subnetNm, err := subRouter1.Netmap() + assert.NoError(c, err) + + if diff := cmpdiff.Diff(wantSubnetFilter, subnetNm.PacketFilter, util.ViewSliceIPProtoComparer, util.PrefixComparer); diff != "" { + assert.Fail(c, fmt.Sprintf("Subnet (%s) filter, unexpected result (-want +got):\n%s", subRouter1.Hostname(), diff)) + } + }, 10*time.Second, 200*time.Millisecond, "Waiting for subnet router packet filter to update") +} + +// TestEnablingExitRoutes tests enabling exit routes for clients. +// Its more or less the same as TestEnablingRoutes, but with the --advertise-exit-node flag +// set during login instead of set. +func TestEnablingExitRoutes(t *testing.T) { + IntegrationSkip(t) + + user := "user2" + + spec := ScenarioSpec{ + NodesPerUser: 2, + Users: []string{user}, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario") + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithExtraLoginArgs([]string{"--advertise-exit-node"}), + }, hsic.WithTestName("clienableroute")) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + var nodes []*v1.Node + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + requireNodeRouteCountWithCollect(c, nodes[0], 2, 0, 0) + requireNodeRouteCountWithCollect(c, nodes[1], 2, 0, 0) + }, 10*time.Second, 200*time.Millisecond, "Waiting for route advertisements to propagate") + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(c, peerStatus.PrimaryRoutes) + } + }, 5*time.Second, 200*time.Millisecond, "Verifying no exit routes are active before approval") + } + + // Enable all routes, but do v4 on one and v6 on other to ensure they + // are both added since they are exit routes. + _, err = headscale.ApproveRoutes( + nodes[0].GetId(), + []netip.Prefix{tsaddr.AllIPv4()}, + ) + require.NoError(t, err) + _, err = headscale.ApproveRoutes( + nodes[1].GetId(), + []netip.Prefix{tsaddr.AllIPv6()}, + ) + require.NoError(t, err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) + requireNodeRouteCountWithCollect(c, nodes[1], 2, 2, 2) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to both nodes") + + // Wait for route state changes to propagate to clients + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.NotNil(c, peerStatus.AllowedIPs) + if peerStatus.AllowedIPs != nil { + assert.Len(c, peerStatus.AllowedIPs.AsSlice(), 4) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv4()) + assert.Contains(c, peerStatus.AllowedIPs.AsSlice(), tsaddr.AllIPv6()) + } + } + } + }, 10*time.Second, 500*time.Millisecond, "clients should see new routes") +} + +// TestSubnetRouterMultiNetwork is an evolution of the subnet router test. +// This test will set up multiple docker networks and use two isolated tailscale +// clients and a service available in one of the networks to validate that a +// subnet router is working as expected. +func TestSubnetRouterMultiNetwork(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{tsic.WithAcceptRoutes()}, + hsic.WithTestName("clienableroute"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + assert.NotNil(t, headscale) + + pref, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + + var user1c, user2c TailscaleClient + + for _, c := range allClients { + s := c.MustStatus() + if s.User[s.Self.UserID].LoginName == "user1@test.no" { + user1c = c + } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { + user2c = c + } + } + require.NotNil(t, user1c) + require.NotNil(t, user2c) + + // Advertise the route for the dockersubnet of user1 + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + pref.String(), + } + _, _, err = user1c.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + var nodes []*v1.Node + // Wait for route advertisements to propagate to NodeStore + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 2) + requireNodeRouteCountWithCollect(ct, nodes[0], 1, 0, 0) + }, 10*time.Second, 100*time.Millisecond, "route advertisements should propagate") + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := user1c.Status() + assert.NoError(c, err) for _, peerKey := range status.Peers() { peerStatus := status.Peer[peerKey] - assert.Nil(t, peerStatus.PrimaryRoutes) + assert.Nil(c, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + }, 5*time.Second, 200*time.Millisecond, "Verifying no routes are active before approval") + + // Enable route + _, err = headscale.ApproveRoutes( + nodes[0].GetId(), + []netip.Prefix{*pref}, + ) + require.NoError(t, err) + + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 1, 1, 1) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") + + // Verify that the routes have been sent to the client + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := user2c.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *pref) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*pref}) + } + }, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + + url := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("url from %s to %s", user2c.Hostname(), url) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := user2c.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 5*time.Second, 200*time.Millisecond, "Verifying client can reach webservice through subnet route") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := user2c.Traceroute(webip) + assert.NoError(c, err) + ip, err := user1c.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for user1c") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, 5*time.Second, 200*time.Millisecond, "Verifying traceroute goes through subnet router") +} + +func TestSubnetRouterMultiNetworkExitNode(t *testing.T) { + IntegrationSkip(t) + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{}, + hsic.WithTestName("clienableroute"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + assert.NotNil(t, headscale) + + var user1c, user2c TailscaleClient + + for _, c := range allClients { + s := c.MustStatus() + if s.User[s.Self.UserID].LoginName == "user1@test.no" { + user1c = c + } + if s.User[s.Self.UserID].LoginName == "user2@test.no" { + user2c = c } } + require.NotNil(t, user1c) + require.NotNil(t, user2c) - // Enable all routes - for _, route := range routes { - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "enable", - "--route", - strconv.Itoa(int(route.GetId())), - }) - assertNoErr(t, err) + // Advertise the exit nodes for the dockersubnet of user1 + command := []string{ + "tailscale", + "set", + "--advertise-exit-node", + } + _, _, err = user1c.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) - time.Sleep(time.Second) + var nodes []*v1.Node + // Wait for route advertisements to propagate to NodeStore + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + var err error + nodes, err = headscale.ListNodes() + assert.NoError(ct, err) + assert.Len(ct, nodes, 2) + requireNodeRouteCountWithCollect(ct, nodes[0], 2, 0, 0) + }, 10*time.Second, 100*time.Millisecond, "route advertisements should propagate") + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := user1c.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(c, peerStatus.PrimaryRoutes) + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + }, 5*time.Second, 200*time.Millisecond, "Verifying no routes sent to client before approval") + + // Enable route + _, err = headscale.ApproveRoutes(nodes[0].GetId(), []netip.Prefix{tsaddr.AllIPv4()}) + require.NoError(t, err) + + // Wait for route state changes to propagate to nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + requireNodeRouteCountWithCollect(c, nodes[0], 2, 2, 2) + }, 10*time.Second, 500*time.Millisecond, "route state changes should propagate to nodes") + + // Verify that the routes have been sent to the client + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := user2c.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } + }, 10*time.Second, 500*time.Millisecond, "routes should be visible to client") + + // Tell user2c to use user1c as an exit node. + command = []string{ + "tailscale", + "set", + "--exit-node", + user1c.Hostname(), + } + _, _, err = user2c.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + + // We can't mess to much with ip forwarding in containers so + // we settle for a simple ping here. + // Direct is false since we use internal DERP which means we + // can't discover a direct path between docker networks. + err = user2c.Ping(webip.String(), + tsic.WithPingUntilDirect(false), + tsic.WithPingCount(1), + tsic.WithPingTimeout(7*time.Second), + ) + require.NoError(t, err) +} + +func MustFindNode(hostname string, nodes []*v1.Node) *v1.Node { + for _, node := range nodes { + if node.GetName() == hostname { + return node + } + } + panic("node not found") +} + +// TestAutoApproveMultiNetwork tests auto approving of routes +// by setting up two networks where network1 has three subnet +// routers: +// - routerUsernet1: advertising the docker network +// - routerSubRoute: advertising a subroute, a /24 inside a auto approved /16 +// - routeExitNode: advertising an exit node +// +// Each router is tested step by step through the following scenarios +// - Policy is set to auto approve the nodes route +// - Node advertises route and it is verified that it is auto approved and sent to nodes +// - Policy is changed to _not_ auto approve the route +// - Verify that peers can still see the node +// - Disable route, making it unavailable +// - Verify that peers can no longer use node +// - Policy is changed back to auto approve route, check that routes already existing is approved. +// - Verify that routes can now be seen by peers. +func TestAutoApproveMultiNetwork(t *testing.T) { + IntegrationSkip(t) + + // Timeout for EventuallyWithT assertions. + // Set generously to account for CI infrastructure variability. + assertTimeout := 60 * time.Second + + bigRoute := netip.MustParsePrefix("10.42.0.0/16") + subRoute := netip.MustParsePrefix("10.42.7.0/24") + notApprovedRoute := netip.MustParsePrefix("192.168.0.0/24") + + tests := []struct { + name string + pol *policyv2.Policy + approver string + spec ScenarioSpec + withURL bool + }{ + { + name: "authkey-tag", + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:approve"): policyv2.Owners{usernameOwner("user1@")}, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {tagApprover("tag:approve")}, + }, + ExitNode: policyv2.AutoApprovers{tagApprover("tag:approve")}, + }, + }, + approver: "tag:approve", + spec: ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + }, + }, + { + name: "authkey-user", + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {usernameApprover("user1@")}, + }, + ExitNode: policyv2.AutoApprovers{usernameApprover("user1@")}, + }, + }, + approver: "user1@", + spec: ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + }, + }, + { + name: "authkey-group", + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + Groups: policyv2.Groups{ + policyv2.Group("group:approve"): []policyv2.Username{policyv2.Username("user1@")}, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {groupApprover("group:approve")}, + }, + ExitNode: policyv2.AutoApprovers{groupApprover("group:approve")}, + }, + }, + approver: "group:approve", + spec: ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + }, + }, + { + name: "webauth-user", + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {usernameApprover("user1@")}, + }, + ExitNode: policyv2.AutoApprovers{usernameApprover("user1@")}, + }, + }, + approver: "user1@", + spec: ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + }, + withURL: true, + }, + { + name: "webauth-tag", + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + TagOwners: policyv2.TagOwners{ + policyv2.Tag("tag:approve"): policyv2.Owners{usernameOwner("user1@")}, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {tagApprover("tag:approve")}, + }, + ExitNode: policyv2.AutoApprovers{tagApprover("tag:approve")}, + }, + }, + approver: "tag:approve", + spec: ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + }, + withURL: true, + }, + { + name: "webauth-group", + pol: &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + Groups: policyv2.Groups{ + policyv2.Group("group:approve"): []policyv2.Username{policyv2.Username("user1@")}, + }, + AutoApprovers: policyv2.AutoApproverPolicy{ + Routes: map[netip.Prefix]policyv2.AutoApprovers{ + bigRoute: {groupApprover("group:approve")}, + }, + ExitNode: policyv2.AutoApprovers{groupApprover("group:approve")}, + }, + }, + approver: "group:approve", + spec: ScenarioSpec{ + NodesPerUser: 3, + Users: []string{"user1", "user2"}, + Networks: map[string][]string{ + "usernet1": {"user1"}, + "usernet2": {"user2"}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + }, + withURL: true, + }, } - var enablingRoutes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &enablingRoutes, - ) - assertNoErr(t, err) - assert.Len(t, enablingRoutes, 2) - - // Node 1 is primary - assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) - assert.Equal(t, true, enablingRoutes[0].GetEnabled()) - assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) - - // Node 2 is not primary - assert.Equal(t, true, enablingRoutes[1].GetAdvertised()) - assert.Equal(t, true, enablingRoutes[1].GetEnabled()) - assert.Equal(t, false, enablingRoutes[1].GetIsPrimary()) - - // Verify that the client has routes from the primary machine - srs1, err := subRouter1.Status() - srs2, err := subRouter2.Status() - - clientStatus, err := client.Status() - assertNoErr(t, err) - - srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey] - - assertNotNil(t, srs1PeerStatus.PrimaryRoutes) - assert.Nil(t, srs2PeerStatus.PrimaryRoutes) - - assert.Contains( - t, - srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), - ) - - // Take down the current primary - t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname()) - err = subRouter1.Down() - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfterMove []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfterMove, - ) - assertNoErr(t, err) - assert.Len(t, routesAfterMove, 2) - - // Node 1 is not primary - assert.Equal(t, true, routesAfterMove[0].GetAdvertised()) - assert.Equal(t, true, routesAfterMove[0].GetEnabled()) - assert.Equal(t, false, routesAfterMove[0].GetIsPrimary()) - - // Node 2 is primary - assert.Equal(t, true, routesAfterMove[1].GetAdvertised()) - assert.Equal(t, true, routesAfterMove[1].GetEnabled()) - assert.Equal(t, true, routesAfterMove[1].GetIsPrimary()) - - // TODO(kradalby): Check client status - // Route is expected to be on SR2 - - srs2, err = subRouter2.Status() - - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assert.Nil(t, srs1PeerStatus.PrimaryRoutes) - assertNotNil(t, srs2PeerStatus.PrimaryRoutes) - - if srs2PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), - ) - } - - // Take down subnet router 2, leaving none available - t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname()) - err = subRouter2.Down() - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfterBothDown []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfterBothDown, - ) - assertNoErr(t, err) - assert.Len(t, routesAfterBothDown, 2) - - // Node 1 is not primary - assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised()) - assert.Equal(t, true, routesAfterBothDown[0].GetEnabled()) - assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary()) - - // Node 2 is primary - // if the node goes down, but no other suitable route is - // available, keep the last known good route. - assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised()) - assert.Equal(t, true, routesAfterBothDown[1].GetEnabled()) - assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary()) - - // TODO(kradalby): Check client status - // Both are expected to be down - - // Verify that the route is not presented from either router - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assert.Nil(t, srs1PeerStatus.PrimaryRoutes) - assertNotNil(t, srs2PeerStatus.PrimaryRoutes) - - if srs2PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), - ) - } - - // Bring up subnet router 1, making the route available from there. - t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname()) - err = subRouter1.Up() - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfter1Up []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfter1Up, - ) - assertNoErr(t, err) - assert.Len(t, routesAfter1Up, 2) - - // Node 1 is primary - assert.Equal(t, true, routesAfter1Up[0].GetAdvertised()) - assert.Equal(t, true, routesAfter1Up[0].GetEnabled()) - assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary()) - - // Node 2 is not primary - assert.Equal(t, true, routesAfter1Up[1].GetAdvertised()) - assert.Equal(t, true, routesAfter1Up[1].GetEnabled()) - assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary()) - - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) - assert.Nil(t, srs2PeerStatus.PrimaryRoutes) - - if srs1PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), - ) - } - - // Bring up subnet router 2, should result in no change. - t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname()) - err = subRouter2.Up() - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfter2Up []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfter2Up, - ) - assertNoErr(t, err) - assert.Len(t, routesAfter2Up, 2) - - // Node 1 is not primary - assert.Equal(t, true, routesAfter2Up[0].GetAdvertised()) - assert.Equal(t, true, routesAfter2Up[0].GetEnabled()) - assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary()) - - // Node 2 is primary - assert.Equal(t, true, routesAfter2Up[1].GetAdvertised()) - assert.Equal(t, true, routesAfter2Up[1].GetEnabled()) - assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary()) - - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) - assert.Nil(t, srs2PeerStatus.PrimaryRoutes) - - if srs1PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), - ) - } - - // Disable the route of subnet router 1, making it failover to 2 - t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname()) - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "disable", - "--route", - fmt.Sprintf("%d", routesAfter2Up[0].GetId()), - }) - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfterDisabling1 []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfterDisabling1, - ) - assertNoErr(t, err) - assert.Len(t, routesAfterDisabling1, 2) - - // Node 1 is not primary - assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) - assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) - assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary()) - - // Node 2 is primary - assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised()) - assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled()) - assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary()) - - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assert.Nil(t, srs1PeerStatus.PrimaryRoutes) - assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) - - if srs2PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), - ) - } - - // enable the route of subnet router 1, no change expected - t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname()) - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "enable", - "--route", - fmt.Sprintf("%d", routesAfter2Up[0].GetId()), - }) - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfterEnabling1 []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfterEnabling1, - ) - assertNoErr(t, err) - assert.Len(t, routesAfterEnabling1, 2) - - // Node 1 is not primary - assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised()) - assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled()) - assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary()) - - // Node 2 is primary - assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised()) - assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled()) - assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary()) - - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assert.Nil(t, srs1PeerStatus.PrimaryRoutes) - assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) - - if srs2PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs2PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), - ) - } - - // delete the route of subnet router 2, failover to one expected - t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname()) - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "delete", - "--route", - fmt.Sprintf("%d", routesAfterEnabling1[1].GetId()), - }) - assertNoErr(t, err) - - time.Sleep(5 * time.Second) - - var routesAfterDeleting2 []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routesAfterDeleting2, - ) - assertNoErr(t, err) - assert.Len(t, routesAfterDeleting2, 1) - - t.Logf("routes after deleting2 %#v", routesAfterDeleting2) - - // Node 1 is primary - assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised()) - assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled()) - assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary()) - - // Verify that the route is announced from subnet router 1 - clientStatus, err = client.Status() - assertNoErr(t, err) - - srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] - srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] - - assertNotNil(t, srs1PeerStatus.PrimaryRoutes) - assert.Nil(t, srs2PeerStatus.PrimaryRoutes) - - if srs1PeerStatus.PrimaryRoutes != nil { - assert.Contains( - t, - srs1PeerStatus.PrimaryRoutes.AsSlice(), - netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), - ) + for _, tt := range tests { + for _, polMode := range []types.PolicyMode{types.PolicyModeDB, types.PolicyModeFile} { + for _, advertiseDuringUp := range []bool{false, true} { + name := fmt.Sprintf("%s-advertiseduringup-%t-pol-%s", tt.name, advertiseDuringUp, polMode) + t.Run(name, func(t *testing.T) { + // Create a deep copy of the policy to avoid mutating the shared test case. + // Each subtest modifies AutoApprovers.Routes (add then delete), so we need + // an isolated copy to prevent state leakage between sequential test runs. + pol := &policyv2.Policy{ + ACLs: slices.Clone(tt.pol.ACLs), + Groups: maps.Clone(tt.pol.Groups), + TagOwners: maps.Clone(tt.pol.TagOwners), + AutoApprovers: policyv2.AutoApproverPolicy{ + ExitNode: slices.Clone(tt.pol.AutoApprovers.ExitNode), + Routes: maps.Clone(tt.pol.AutoApprovers.Routes), + }, + } + + scenario, err := NewScenario(tt.spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + var nodes []*v1.Node + opts := []hsic.Option{ + hsic.WithTestName("autoapprovemulti"), + hsic.WithEmbeddedDERPServerOnly(), + hsic.WithTLS(), + hsic.WithACLPolicy(pol), + hsic.WithPolicyMode(polMode), + } + + tsOpts := []tsic.Option{ + tsic.WithAcceptRoutes(), + } + + route, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + + // For tag-based approvers, nodes must be tagged with that tag + // (tags-as-identity model: tagged nodes are identified by their tags) + var ( + preAuthKeyTags []string + webauthTagUser string + ) + + if strings.HasPrefix(tt.approver, "tag:") { + preAuthKeyTags = []string{tt.approver} + if tt.withURL { + // For webauth, only user1 can request tags (per tagOwners policy) + webauthTagUser = "user1" + } + } + + err = scenario.createHeadscaleEnvWithTags(tt.withURL, tsOpts, preAuthKeyTags, webauthTagUser, + opts..., + ) + requireNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + assert.NotNil(t, headscale) + + // Add the Docker network route to the auto-approvers + // Keep existing auto-approvers (like bigRoute) in place + var approvers policyv2.AutoApprovers + switch { + case strings.HasPrefix(tt.approver, "tag:"): + approvers = append(approvers, tagApprover(tt.approver)) + case strings.HasPrefix(tt.approver, "group:"): + approvers = append(approvers, groupApprover(tt.approver)) + default: + approvers = append(approvers, usernameApprover(tt.approver)) + } + // pol.AutoApprovers.Routes is already initialized in the deep copy above + prefix := *route + pol.AutoApprovers.Routes[prefix] = approvers + err = headscale.SetPolicy(pol) + require.NoError(t, err) + + if advertiseDuringUp { + tsOpts = append(tsOpts, + tsic.WithExtraLoginArgs([]string{"--advertise-routes=" + route.String()}), + ) + } + + // For webauth with tag approver, the node needs to advertise the tag during registration + // (tags-as-identity model: webauth nodes can use --advertise-tags if authorized by tagOwners) + if tt.withURL && strings.HasPrefix(tt.approver, "tag:") { + tsOpts = append(tsOpts, tsic.WithTags([]string{tt.approver})) + } + + tsOpts = append(tsOpts, tsic.WithNetwork(usernet1)) + + // This whole dance is to add a node _after_ all the other nodes + // with an additional tsOpt which advertises the route as part + // of the `tailscale up` command. If we do this as part of the + // scenario creation, it will be added to all nodes and turn + // into a HA node, which isn't something we are testing here. + routerUsernet1, err := scenario.CreateTailscaleNode("head", tsOpts...) + require.NoError(t, err) + + defer func() { + _, _, err := routerUsernet1.Shutdown() + require.NoError(t, err) + }() + + if tt.withURL { + u, err := routerUsernet1.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + body, err := doLoginURL(routerUsernet1.Hostname(), u) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister("user1", body) + require.NoError(t, err) + + // Wait for the client to sync with the server after webauth registration. + // Unlike authkey login which blocks until complete, webauth registration + // happens on the server side and the client needs time to receive the network map. + err = routerUsernet1.WaitForRunning(integrationutil.PeerSyncTimeout()) + require.NoError(t, err, "webauth client failed to reach Running state") + } else { + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + // If the approver is a tag, create a tagged PreAuthKey + // (tags-as-identity model: tags come from PreAuthKey, not --advertise-tags) + var pak *v1.PreAuthKey + if strings.HasPrefix(tt.approver, "tag:") { + pak, err = scenario.CreatePreAuthKeyWithTags(userMap["user1"].GetId(), false, false, []string{tt.approver}) + } else { + pak, err = scenario.CreatePreAuthKey(userMap["user1"].GetId(), false, false) + } + require.NoError(t, err) + + err = routerUsernet1.Login(headscale.GetEndpoint(), pak.GetKey()) + require.NoError(t, err) + } + // extra creation end. + + // Wait for the node to be fully running before getting its ID + // This is especially important for webauth flow where login is asynchronous + err = routerUsernet1.WaitForRunning(30 * time.Second) + require.NoError(t, err) + + // Wait for bidirectional peer synchronization. + // Both the router and all existing clients must see each other. + // This is critical for connectivity - without this, the WireGuard + // tunnels may not be established despite peers appearing in netmaps. + + // Router waits for all existing clients + err = routerUsernet1.WaitForPeers(len(allClients), 60*time.Second, 1*time.Second) + require.NoError(t, err, "router failed to see all peers") + + // All clients wait for the router (they should see 6 peers including the router) + for _, existingClient := range allClients { + err = existingClient.WaitForPeers(len(allClients), 60*time.Second, 1*time.Second) + require.NoErrorf(t, err, "client %s failed to see all peers including router", existingClient.Hostname()) + } + + routerUsernet1ID := routerUsernet1.MustID() + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + weburl := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("webservice: %s, %s", webip.String(), weburl) + + // Sort nodes by ID + sort.SliceStable(allClients, func(i, j int) bool { + statusI := allClients[i].MustStatus() + statusJ := allClients[j].MustStatus() + + return statusI.Self.ID < statusJ.Self.ID + }) + + // This is ok because the scenario makes users in order, so the three first + // nodes, which are subnet routes, will be created first, and the last user + // will be created with the second. + routerSubRoute := allClients[1] + routerExitNode := allClients[2] + + client := allClients[3] + + if !advertiseDuringUp { + // Advertise the route for the dockersubnet of user1 + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + route.String(), + } + _, _, err = routerUsernet1.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + } + + // Wait for route state changes to propagate. + // Use a longer timeout (30s) to account for CI infrastructure variability - + // when advertiseDuringUp=true, routes are sent during registration and may + // take longer to propagate through the server's auto-approval logic in slow + // environments. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + + routerNode := MustFindNode(routerUsernet1.Hostname(), nodes) + t.Logf("Initial auto-approval check - Router node %s: announced=%v, approved=%v, subnet=%v", + routerNode.GetName(), + routerNode.GetAvailableRoutes(), + routerNode.GetApprovedRoutes(), + routerNode.GetSubnetRoutes()) + + requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1) + }, assertTimeout, 500*time.Millisecond, "Initial route auto-approval: Route should be approved via policy") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + // Debug output to understand peer visibility + t.Logf("Client %s sees %d peers", client.Hostname(), len(status.Peers())) + + routerPeerFound := false + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.ID == routerUsernet1ID.StableID() { + routerPeerFound = true + t.Logf("Client sees router peer %s (ID=%s): AllowedIPs=%v, PrimaryRoutes=%v", + peerStatus.HostName, + peerStatus.ID, + peerStatus.AllowedIPs, + peerStatus.PrimaryRoutes) + + assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) + } else { + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + } + + assert.True(c, routerPeerFound, "Client should see the router peer") + }, assertTimeout, 200*time.Millisecond, "Verifying routes sent to client after auto-approval") + + // Verify WireGuard tunnel connectivity to the router before testing route. + // The client may have the route in its netmap but the actual tunnel may not + // be established yet, especially in CI environments with higher latency. + routerIPv4, err := routerUsernet1.IPv4() + require.NoError(t, err, "failed to get router IPv4") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + err := client.Ping( + routerIPv4.String(), + tsic.WithPingUntilDirect(false), // DERP relay is fine + tsic.WithPingCount(1), + tsic.WithPingTimeout(5*time.Second), + ) + assert.NoError(c, err, "ping to router should succeed") + }, assertTimeout, 200*time.Millisecond, "Verifying WireGuard tunnel to router is established") + + url := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("url from %s to %s", client.Hostname(), url) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, assertTimeout, 200*time.Millisecond, "Verifying client can reach webservice through auto-approved route") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through auto-approved router") + + // Remove the auto approval from the policy, any routes already enabled should be allowed. + prefix = *route + delete(pol.AutoApprovers.Routes, prefix) + err = headscale.SetPolicy(pol) + require.NoError(t, err) + t.Logf("Policy updated: removed auto-approver for route %s", prefix) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // Routes already approved should remain approved even after policy change + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + + routerNode := MustFindNode(routerUsernet1.Hostname(), nodes) + t.Logf("After policy removal - Router node %s: announced=%v, approved=%v, subnet=%v", + routerNode.GetName(), + routerNode.GetAvailableRoutes(), + routerNode.GetApprovedRoutes(), + routerNode.GetSubnetRoutes()) + + requireNodeRouteCountWithCollect(c, routerNode, 1, 1, 1) + }, assertTimeout, 500*time.Millisecond, "Routes should remain approved after auto-approver removal") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.ID == routerUsernet1ID.StableID() { + assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) + } else { + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + } + }, assertTimeout, 200*time.Millisecond, "Verifying routes remain after policy change") + + url = fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("url from %s to %s", client.Hostname(), url) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, assertTimeout, 200*time.Millisecond, "Verifying client can still reach webservice after policy change") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, assertTimeout, 200*time.Millisecond, "Verifying traceroute still goes through router after policy change") + + // Disable the route, making it unavailable since it is no longer auto-approved + _, err = headscale.ApproveRoutes( + MustFindNode(routerUsernet1.Hostname(), nodes).GetId(), + []netip.Prefix{}, + ) + require.NoError(t, err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 0, 0) + }, assertTimeout, 500*time.Millisecond, "route state changes should propagate") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + }, assertTimeout, 200*time.Millisecond, "Verifying routes disabled after route removal") + + // Add the route back to the auto approver in the policy, the route should + // now become available again. + var newApprovers policyv2.AutoApprovers + switch { + case strings.HasPrefix(tt.approver, "tag:"): + newApprovers = append(newApprovers, tagApprover(tt.approver)) + case strings.HasPrefix(tt.approver, "group:"): + newApprovers = append(newApprovers, groupApprover(tt.approver)) + default: + newApprovers = append(newApprovers, usernameApprover(tt.approver)) + } + // pol.AutoApprovers.Routes is already initialized in the deep copy above + prefix = *route + pol.AutoApprovers.Routes[prefix] = newApprovers + err = headscale.SetPolicy(pol) + require.NoError(t, err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + }, assertTimeout, 500*time.Millisecond, "route state changes should propagate") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.ID == routerUsernet1ID.StableID() { + assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) + } else { + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + } + }, assertTimeout, 200*time.Millisecond, "Verifying routes re-enabled after policy re-approval") + + url = fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("url from %s to %s", client.Hostname(), url) + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := client.Curl(url) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, assertTimeout, 200*time.Millisecond, "Verifying client can reach webservice after route re-approval") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := client.Traceroute(webip) + assert.NoError(c, err) + ip, err := routerUsernet1.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for routerUsernet1") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, assertTimeout, 200*time.Millisecond, "Verifying traceroute goes through router after re-approval") + + // Advertise and validate a subnet of an auto approved route, /24 inside the + // auto approved /16. + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + subRoute.String(), + } + _, _, err = routerSubRoute.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 1) + }, assertTimeout, 500*time.Millisecond, "route state changes should propagate") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.ID == routerUsernet1ID.StableID() { + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) + } else if peerStatus.ID == "2" { + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), subRoute) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{subRoute}) + } else { + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + } + }, assertTimeout, 200*time.Millisecond, "Verifying sub-route propagated to client") + + // Advertise a not approved route will not end up anywhere + command = []string{ + "tailscale", + "set", + "--advertise-routes=" + notApprovedRoute.String(), + } + _, _, err = routerSubRoute.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // These route should auto approve, so the node is expected to have a route + // for all counts. + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 0, 0, 0) + }, assertTimeout, 500*time.Millisecond, "route state changes should propagate") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.ID == routerUsernet1ID.StableID() { + assert.NotNil(c, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) + } else { + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + } + }, assertTimeout, 200*time.Millisecond, "Verifying unapproved route not propagated") + + // Exit routes are also automatically approved + command = []string{ + "tailscale", + "set", + "--advertise-exit-node", + } + _, _, err = routerExitNode.Execute(command) + require.NoErrorf(t, err, "failed to advertise route: %s", err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err = headscale.ListNodes() + assert.NoError(c, err) + requireNodeRouteCountWithCollect(c, MustFindNode(routerUsernet1.Hostname(), nodes), 1, 1, 1) + requireNodeRouteCountWithCollect(c, nodes[1], 1, 1, 0) + requireNodeRouteCountWithCollect(c, nodes[2], 2, 2, 2) + }, assertTimeout, 500*time.Millisecond, "route state changes should propagate") + + // Verify that the routes have been sent to the client. + assert.EventuallyWithT(t, func(c *assert.CollectT) { + status, err := client.Status() + assert.NoError(c, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if peerStatus.ID == routerUsernet1ID.StableID() { + if peerStatus.PrimaryRoutes != nil { + assert.Contains(c, peerStatus.PrimaryRoutes.AsSlice(), *route) + } + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{*route}) + } else if peerStatus.ID == "3" { + requirePeerSubnetRoutesWithCollect(c, peerStatus, []netip.Prefix{tsaddr.AllIPv4(), tsaddr.AllIPv6()}) + } else { + requirePeerSubnetRoutesWithCollect(c, peerStatus, nil) + } + } + }, assertTimeout, 200*time.Millisecond, "Verifying exit node routes propagated to client") + }) + } + } } } + +// assertTracerouteViaIPWithCollect is a version of assertTracerouteViaIP that works with assert.CollectT. +func assertTracerouteViaIPWithCollect(c *assert.CollectT, tr util.Traceroute, ip netip.Addr) { + assert.NotNil(c, tr) + assert.True(c, tr.Success) + assert.NoError(c, tr.Err) + assert.NotEmpty(c, tr.Route) + // Since we're inside EventuallyWithT, we can't use require.Greater with t + // but assert.NotEmpty above ensures len(tr.Route) > 0 + if len(tr.Route) > 0 { + assert.Equal(c, tr.Route[0].IP.String(), ip.String()) + } +} + +func SortPeerStatus(a, b *ipnstate.PeerStatus) int { + return cmp.Compare(a.ID, b.ID) +} + +func printCurrentRouteMap(t *testing.T, routers ...*ipnstate.PeerStatus) { + t.Logf("== Current routing map ==") + slices.SortFunc(routers, SortPeerStatus) + for _, router := range routers { + got := filterNonRoutes(router) + t.Logf(" Router %s (%s) is serving:", router.HostName, router.ID) + t.Logf(" AllowedIPs: %v", got) + if router.PrimaryRoutes != nil { + t.Logf(" PrimaryRoutes: %v", router.PrimaryRoutes.AsSlice()) + } + } +} + +// filterNonRoutes returns the list of routes that a [ipnstate.PeerStatus] is serving. +func filterNonRoutes(status *ipnstate.PeerStatus) []netip.Prefix { + return slicesx.Filter(nil, status.AllowedIPs.AsSlice(), func(p netip.Prefix) bool { + if tsaddr.IsExitRoute(p) { + return true + } + return !slices.ContainsFunc(status.TailscaleIPs, p.Contains) + }) +} + +func requirePeerSubnetRoutesWithCollect(c *assert.CollectT, status *ipnstate.PeerStatus, expected []netip.Prefix) { + if status.AllowedIPs.Len() <= 2 && len(expected) != 0 { + assert.Fail(c, fmt.Sprintf("peer %s (%s) has no subnet routes, expected %v", status.HostName, status.ID, expected)) + return + } + + if len(expected) == 0 { + expected = []netip.Prefix{} + } + + got := filterNonRoutes(status) + + if diff := cmpdiff.Diff(expected, got, util.PrefixComparer, cmpopts.EquateEmpty()); diff != "" { + assert.Fail(c, fmt.Sprintf("peer %s (%s) subnet routes, unexpected result (-want +got):\n%s", status.HostName, status.ID, diff)) + } +} + +func requireNodeRouteCountWithCollect(c *assert.CollectT, node *v1.Node, announced, approved, subnet int) { + assert.Lenf(c, node.GetAvailableRoutes(), announced, "expected %q announced routes(%v) to have %d route, had %d", node.GetName(), node.GetAvailableRoutes(), announced, len(node.GetAvailableRoutes())) + assert.Lenf(c, node.GetApprovedRoutes(), approved, "expected %q approved routes(%v) to have %d route, had %d", node.GetName(), node.GetApprovedRoutes(), approved, len(node.GetApprovedRoutes())) + assert.Lenf(c, node.GetSubnetRoutes(), subnet, "expected %q subnet routes(%v) to have %d route, had %d", node.GetName(), node.GetSubnetRoutes(), subnet, len(node.GetSubnetRoutes())) +} + +// TestSubnetRouteACLFiltering tests that a node can only access subnet routes +// that are explicitly allowed in the ACL. +func TestSubnetRouteACLFiltering(t *testing.T) { + IntegrationSkip(t) + + // Use router and node users for better clarity + routerUser := "router" + nodeUser := "node" + + spec := ScenarioSpec{ + NodesPerUser: 1, + Users: []string{routerUser, nodeUser}, + Networks: map[string][]string{ + "usernet1": {routerUser, nodeUser}, + }, + ExtraService: map[string][]extraServiceFunc{ + "usernet1": {Webservice}, + }, + // We build the head image with curl and traceroute, so only use + // that for this test. + Versions: []string{"head"}, + } + + scenario, err := NewScenario(spec) + require.NoErrorf(t, err, "failed to create scenario: %s", err) + defer scenario.ShutdownAssertNoPanics(t) + + // Set up the ACL policy that allows the node to access only one of the subnet routes (10.10.10.0/24) + aclPolicyStr := `{ + "hosts": { + "router": "100.64.0.1/32", + "node": "100.64.0.2/32" + }, + "acls": [ + { + "action": "accept", + "src": [ + "*" + ], + "dst": [ + "router:8000" + ] + }, + { + "action": "accept", + "src": [ + "node" + ], + "dst": [ + "*:*" + ] + } + ] + }` + + route, err := scenario.SubnetOfNetwork("usernet1") + require.NoError(t, err) + + services, err := scenario.Services("usernet1") + require.NoError(t, err) + require.Len(t, services, 1) + + usernet1, err := scenario.Network("usernet1") + require.NoError(t, err) + + web := services[0] + webip := netip.MustParseAddr(web.GetIPInNetwork(usernet1)) + weburl := fmt.Sprintf("http://%s/etc/hostname", webip) + t.Logf("webservice: %s, %s", webip.String(), weburl) + + aclPolicy := &policyv2.Policy{} + err = json.Unmarshal([]byte(aclPolicyStr), aclPolicy) + require.NoError(t, err) + + err = scenario.CreateHeadscaleEnv([]tsic.Option{ + tsic.WithAcceptRoutes(), + }, hsic.WithTestName("routeaclfilter"), + hsic.WithACLPolicy(aclPolicy), + hsic.WithPolicyMode(types.PolicyModeDB), + ) + requireNoErrHeadscaleEnv(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Get the router and node clients by user + routerClients, err := scenario.ListTailscaleClients(routerUser) + require.NoError(t, err) + require.Len(t, routerClients, 1) + routerClient := routerClients[0] + + nodeClients, err := scenario.ListTailscaleClients(nodeUser) + require.NoError(t, err) + require.Len(t, nodeClients, 1) + nodeClient := nodeClients[0] + + routerIP, err := routerClient.IPv4() + require.NoError(t, err, "failed to get router IPv4") + nodeIP, err := nodeClient.IPv4() + require.NoError(t, err, "failed to get node IPv4") + + aclPolicy.Hosts = policyv2.Hosts{ + policyv2.Host(routerUser): policyv2.Prefix(must.Get(routerIP.Prefix(32))), + policyv2.Host(nodeUser): policyv2.Prefix(must.Get(nodeIP.Prefix(32))), + } + aclPolicy.ACLs[1].Destinations = []policyv2.AliasWithPorts{ + aliasWithPorts(prefixp(route.String()), tailcfg.PortRangeAny), + } + require.NoError(t, headscale.SetPolicy(aclPolicy)) + + // Set up the subnet routes for the router + routes := []netip.Prefix{ + *route, // This should be accessible by the client + netip.MustParsePrefix("10.10.11.0/24"), // These should NOT be accessible + netip.MustParsePrefix("10.10.12.0/24"), + } + + routeArg := "--advertise-routes=" + routes[0].String() + "," + routes[1].String() + "," + routes[2].String() + command := []string{ + "tailscale", + "set", + routeArg, + } + + _, _, err = routerClient.Execute(command) + require.NoErrorf(t, err, "failed to advertise routes: %s", err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + var routerNode, nodeNode *v1.Node + // Wait for route advertisements to propagate to NodeStore + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + // List nodes and verify the router has 3 available routes + nodes, err := headscale.NodesByUser() + assert.NoError(ct, err) + assert.Len(ct, nodes, 2) + + // Find the router node + routerNode = nodes[routerUser][0] + nodeNode = nodes[nodeUser][0] + + assert.NotNil(ct, routerNode, "Router node not found") + assert.NotNil(ct, nodeNode, "Client node not found") + + // Check that the router has 3 routes available but not approved yet + requireNodeRouteCountWithCollect(ct, routerNode, 3, 0, 0) + requireNodeRouteCountWithCollect(ct, nodeNode, 0, 0, 0) + }, 10*time.Second, 100*time.Millisecond, "route advertisements should propagate to router node") + + // Approve all routes for the router + _, err = headscale.ApproveRoutes( + routerNode.GetId(), + util.MustStringsToPrefixes(routerNode.GetAvailableRoutes()), + ) + require.NoError(t, err) + + // Wait for route state changes to propagate + assert.EventuallyWithT(t, func(c *assert.CollectT) { + // List nodes and verify the router has 3 available routes + var err error + nodes, err := headscale.NodesByUser() + assert.NoError(c, err) + assert.Len(c, nodes, 2) + + // Find the router node + routerNode = nodes[routerUser][0] + + // Check that the router has 3 routes now approved and available + requireNodeRouteCountWithCollect(c, routerNode, 3, 3, 3) + }, 15*time.Second, 500*time.Millisecond, "route state changes should propagate") + + // Now check the client node status + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodeStatus, err := nodeClient.Status() + assert.NoError(c, err) + + routerStatus, err := routerClient.Status() + assert.NoError(c, err) + + // Check that the node can see the subnet routes from the router + routerPeerStatus := nodeStatus.Peer[routerStatus.Self.PublicKey] + + // The node should only have 1 subnet route + requirePeerSubnetRoutesWithCollect(c, routerPeerStatus, []netip.Prefix{*route}) + }, 5*time.Second, 200*time.Millisecond, "Verifying node sees filtered subnet routes") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + result, err := nodeClient.Curl(weburl) + assert.NoError(c, err) + assert.Len(c, result, 13) + }, 60*time.Second, 200*time.Millisecond, "Verifying node can reach webservice through allowed route") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + tr, err := nodeClient.Traceroute(webip) + assert.NoError(c, err) + ip, err := routerClient.IPv4() + if !assert.NoError(c, err, "failed to get IPv4 for routerClient") { + return + } + assertTracerouteViaIPWithCollect(c, tr, ip) + }, 60*time.Second, 200*time.Millisecond, "Verifying traceroute goes through router") +} diff --git a/integration/run.sh b/integration/run.sh index 8c1fb016..137bcfb7 100755 --- a/integration/run.sh +++ b/integration/run.sh @@ -13,8 +13,10 @@ run_tests() { for ((i = 1; i <= num_tests; i++)); do docker network prune -f >/dev/null 2>&1 - docker rm headscale-test-suite || true - docker kill "$(docker ps -q)" || true + docker rm headscale-test-suite >/dev/null 2>&1 || true + docker kill "$(docker ps -q)" >/dev/null 2>&1 || true + + echo "Run $i" start=$(date +%s) docker run \ @@ -24,6 +26,7 @@ run_tests() { --volume "$PWD:$PWD" -w "$PWD"/integration \ --volume /var/run/docker.sock:/var/run/docker.sock \ --volume "$PWD"/control_logs:/tmp/control \ + -e "HEADSCALE_INTEGRATION_POSTGRES" \ golang:1 \ go test ./... \ -failfast \ diff --git a/integration/scenario.go b/integration/scenario.go index 6bcd5852..35fee73e 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -1,80 +1,60 @@ package integration import ( + "context" + "crypto/tls" + "encoding/json" "errors" "fmt" + "io" "log" + "net" + "net/http" + "net/http/cookiejar" "net/netip" + "net/url" "os" + "slices" + "strconv" + "strings" "sync" + "testing" + "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/hscontrol/capver" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" + "github.com/juanfont/headscale/integration/dsic" "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/integrationutil" "github.com/juanfont/headscale/integration/tsic" + "github.com/oauth2-proxy/mockoidc" "github.com/ory/dockertest/v3" - "github.com/puzpuzpuz/xsync/v3" + "github.com/ory/dockertest/v3/docker" + "github.com/puzpuzpuz/xsync/v4" "github.com/samber/lo" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + xmaps "golang.org/x/exp/maps" "golang.org/x/sync/errgroup" + "tailscale.com/envknob" + "tailscale.com/util/mak" + "tailscale.com/util/multierr" ) const ( scenarioHashLength = 6 ) -func enabledVersions(vs map[string]bool) []string { - var ret []string - for version, enabled := range vs { - if enabled { - ret = append(ret, version) - } - } - - return ret -} +var usePostgresForTest = envknob.Bool("HEADSCALE_INTEGRATION_POSTGRES") var ( errNoHeadscaleAvailable = errors.New("no headscale available") errNoUserAvailable = errors.New("no user available") errNoClientFound = errors.New("client not found") - // Tailscale started adding TS2021 support in CapabilityVersion>=28 (v1.24.0), but - // proper support in Headscale was only added for CapabilityVersion>=39 clients (v1.30.0). - tailscaleVersions2021 = map[string]bool{ - "head": true, - "unstable": true, - "1.52": true, // CapVer: - "1.50": true, // CapVer: 74 - "1.48": true, // CapVer: 68 - "1.46": true, // CapVer: 65 - "1.44": true, // CapVer: 63 - "1.42": true, // CapVer: 61 - "1.40": true, // CapVer: 61 - "1.38": true, // CapVer: 58 - "1.36": true, // CapVer: 56 - "1.34": true, // CapVer: 51 - "1.32": true, // Oldest supported version, CapVer: 46 - "1.30": false, - } - - tailscaleVersions2019 = map[string]bool{ - "1.28": false, - "1.26": false, - "1.24": false, // Tailscale SSH - "1.22": false, - "1.20": false, - "1.18": false, - } - - // tailscaleVersionsUnavailable = []string{ - // // These versions seem to fail when fetching from apt. - // "1.14.6", - // "1.12.4", - // "1.10.2", - // "1.8.7", - // }. - // AllVersions represents a list of Tailscale versions the suite // uses to test compatibility with the ControlServer. // @@ -84,10 +64,7 @@ var ( // // The rest of the version represents Tailscale versions that can be // found in Tailscale's apt repository. - AllVersions = append( - enabledVersions(tailscaleVersions2021), - enabledVersions(tailscaleVersions2019)..., - ) + AllVersions = append([]string{"head", "unstable"}, capver.TailscaleLatestMajorMinor(capver.SupportedMajorMinorVersions, true)...) // MustTestVersions is the minimum set of versions we should test. // At the moment, this is arbitrarily chosen as: @@ -120,64 +97,222 @@ type Scenario struct { // TODO(kradalby): support multiple headcales for later, currently only // use one. controlServers *xsync.MapOf[string, ControlServer] + derpServers []*dsic.DERPServerInContainer users map[string]*User - pool *dockertest.Pool - network *dockertest.Network + pool *dockertest.Pool + networks map[string]*dockertest.Network + mockOIDC scenarioOIDC + extraServices map[string][]*dockertest.Resource mu sync.Mutex + + spec ScenarioSpec + userToNetwork map[string]*dockertest.Network + + testHashPrefix string + testDefaultNetwork string +} + +// ScenarioSpec describes the users, nodes, and network topology to +// set up for a given scenario. +type ScenarioSpec struct { + // Users is a list of usernames that will be created. + // Each created user will get nodes equivalent to NodesPerUser + Users []string + + // NodesPerUser is how many nodes should be attached to each user. + NodesPerUser int + + // Networks, if set, is the separate Docker networks that should be + // created and a list of the users that should be placed in those networks. + // If not set, a single network will be created and all users+nodes will be + // added there. + // Please note that Docker networks are not necessarily routable and + // connections between them might fall back to DERP. + Networks map[string][]string + + // ExtraService, if set, is additional a map of network to additional + // container services that should be set up. These container services + // typically dont run Tailscale, e.g. web service to test subnet router. + ExtraService map[string][]extraServiceFunc + + // Versions is specific list of versions to use for the test. + Versions []string + + // OIDCUsers, if populated, will start a Mock OIDC server and populate + // the user login stack with the given users. + // If the NodesPerUser is set, it should align with this list to ensure + // the correct users are logged in. + // This is because the MockOIDC server can only serve login + // requests based on a queue it has been given on startup. + // We currently only populates it with one login request per user. + OIDCUsers []mockoidc.MockUser + OIDCAccessTTL time.Duration + + MaxWait time.Duration +} + +func (s *Scenario) prefixedNetworkName(name string) string { + return s.testHashPrefix + "-" + name } // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // a set of Users and TailscaleClients. -func NewScenario() (*Scenario, error) { - hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) - if err != nil { - return nil, err - } - +func NewScenario(spec ScenarioSpec) (*Scenario, error) { pool, err := dockertest.NewPool("") if err != nil { return nil, fmt.Errorf("could not connect to docker: %w", err) } - pool.MaxWait = dockertestMaxWait() + // Opportunity to clean up unreferenced networks. + // This might be a no op, but it is worth a try as we sometime + // dont clean up nicely after ourselves. + dockertestutil.CleanUnreferencedNetworks(pool) + dockertestutil.CleanImagesInCI(pool) - networkName := fmt.Sprintf("hs-%s", hash) - if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { - networkName = overrideNetworkName + if spec.MaxWait == 0 { + pool.MaxWait = dockertestMaxWait() + } else { + pool.MaxWait = spec.MaxWait } - network, err := dockertestutil.GetFirstOrCreateNetwork(pool, networkName) + testHashPrefix := "hs-" + util.MustGenerateRandomStringDNSSafe(scenarioHashLength) + s := &Scenario{ + controlServers: xsync.NewMapOf[string, ControlServer](), + users: make(map[string]*User), + + pool: pool, + spec: spec, + + testHashPrefix: testHashPrefix, + testDefaultNetwork: testHashPrefix + "-default", + } + + var userToNetwork map[string]*dockertest.Network + if spec.Networks != nil || len(spec.Networks) != 0 { + for name, users := range s.spec.Networks { + networkName := testHashPrefix + "-" + name + network, err := s.AddNetwork(networkName) + if err != nil { + return nil, err + } + + for _, user := range users { + if n2, ok := userToNetwork[user]; ok { + return nil, fmt.Errorf("users can only have nodes placed in one network: %s into %s but already in %s", user, network.Network.Name, n2.Network.Name) + } + mak.Set(&userToNetwork, user, network) + } + } + } else { + _, err := s.AddNetwork(s.testDefaultNetwork) + if err != nil { + return nil, err + } + } + + for network, extras := range spec.ExtraService { + for _, extra := range extras { + svc, err := extra(s, network) + if err != nil { + return nil, err + } + mak.Set(&s.extraServices, s.prefixedNetworkName(network), append(s.extraServices[s.prefixedNetworkName(network)], svc)) + } + } + + s.userToNetwork = userToNetwork + + if len(spec.OIDCUsers) != 0 { + ttl := defaultAccessTTL + if spec.OIDCAccessTTL != 0 { + ttl = spec.OIDCAccessTTL + } + err = s.runMockOIDC(ttl, spec.OIDCUsers) + if err != nil { + return nil, err + } + } + + return s, nil +} + +func (s *Scenario) AddNetwork(name string) (*dockertest.Network, error) { + network, err := dockertestutil.GetFirstOrCreateNetwork(s.pool, name) if err != nil { return nil, fmt.Errorf("failed to create or get network: %w", err) } // We run the test suite in a docker container that calls a couple of endpoints for // readiness checks, this ensures that we can run the tests with individual networks - // and have the client reach the different containers - err = dockertestutil.AddContainerToNetwork(pool, network, "headscale-test-suite") + // and have the client reach the different containers. + // The container name includes the run ID to support multiple concurrent test runs. + testSuiteName := "headscale-test-suite" + if runID := dockertestutil.GetIntegrationRunID(); runID != "" { + testSuiteName = "headscale-test-suite-" + runID + } + + err = dockertestutil.AddContainerToNetwork(s.pool, network, testSuiteName) if err != nil { return nil, fmt.Errorf("failed to add test suite container to network: %w", err) } - return &Scenario{ - controlServers: xsync.NewMapOf[string, ControlServer](), - users: make(map[string]*User), + mak.Set(&s.networks, name, network) - pool: pool, - network: network, - }, nil + return network, nil } -// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient) -// and networks associated with it. -// In addition, it will save the logs of the ControlServer to `/tmp/control` in the -// environment running the tests. -func (s *Scenario) Shutdown() { +func (s *Scenario) Networks() []*dockertest.Network { + if len(s.networks) == 0 { + panic("Scenario.Networks called with empty network list") + } + return xmaps.Values(s.networks) +} + +func (s *Scenario) Network(name string) (*dockertest.Network, error) { + net, ok := s.networks[s.prefixedNetworkName(name)] + if !ok { + return nil, fmt.Errorf("no network named: %s", name) + } + + return net, nil +} + +func (s *Scenario) SubnetOfNetwork(name string) (*netip.Prefix, error) { + net, ok := s.networks[s.prefixedNetworkName(name)] + if !ok { + return nil, fmt.Errorf("no network named: %s", name) + } + + if len(net.Network.IPAM.Config) == 0 { + return nil, fmt.Errorf("no IPAM config found in network: %s", name) + } + + pref, err := netip.ParsePrefix(net.Network.IPAM.Config[0].Subnet) + if err != nil { + return nil, err + } + + return &pref, nil +} + +func (s *Scenario) Services(name string) ([]*dockertest.Resource, error) { + res, ok := s.extraServices[s.prefixedNetworkName(name)] + if !ok { + return nil, fmt.Errorf("no network named: %s", name) + } + + return res, nil +} + +func (s *Scenario) ShutdownAssertNoPanics(t *testing.T) { + defer dockertestutil.CleanUnreferencedNetworks(s.pool) + defer dockertestutil.CleanImagesInCI(s.pool) + s.controlServers.Range(func(_ string, control ControlServer) bool { - err := control.Shutdown() + stdoutPath, stderrPath, err := control.Shutdown() if err != nil { log.Printf( "Failed to shut down control: %s", @@ -185,27 +320,77 @@ func (s *Scenario) Shutdown() { ) } + if t != nil { + stdout, err := os.ReadFile(stdoutPath) + require.NoError(t, err) + assert.NotContains(t, string(stdout), "panic") + + stderr, err := os.ReadFile(stderrPath) + require.NoError(t, err) + assert.NotContains(t, string(stderr), "panic") + } + return true }) + s.mu.Lock() for userName, user := range s.users { for _, client := range user.Clients { log.Printf("removing client %s in user %s", client.Hostname(), userName) - err := client.Shutdown() + stdoutPath, stderrPath, err := client.Shutdown() if err != nil { log.Printf("failed to tear down client: %s", err) } + + if t != nil { + stdout, err := os.ReadFile(stdoutPath) + require.NoError(t, err) + assert.NotContains(t, string(stdout), "panic") + + stderr, err := os.ReadFile(stderrPath) + require.NoError(t, err) + assert.NotContains(t, string(stderr), "panic") + } + } + } + s.mu.Unlock() + + for _, derp := range s.derpServers { + err := derp.Shutdown() + if err != nil { + log.Printf("failed to tear down derp server: %s", err) + } + } + + for _, svcs := range s.extraServices { + for _, svc := range svcs { + err := svc.Close() + if err != nil { + log.Printf("failed to tear down service %q: %s", svc.Container.Name, err) + } } } - if err := s.pool.RemoveNetwork(s.network); err != nil { - log.Printf("failed to remove network: %s", err) + if s.mockOIDC.r != nil { + s.mockOIDC.r.Close() + if err := s.mockOIDC.r.Close(); err != nil { + log.Printf("failed to tear down oidc server: %s", err) + } } - // TODO(kradalby): This seem redundant to the previous call - // if err := s.network.Close(); err != nil { - // return fmt.Errorf("failed to tear down network: %w", err) - // } + for _, network := range s.networks { + if err := network.Close(); err != nil { + log.Printf("failed to tear down network: %s", err) + } + } +} + +// Shutdown shuts down and cleans up all the containers (ControlServer, TailscaleClient) +// and networks associated with it. +// In addition, it will save the logs of the ControlServer to `/tmp/control` in the +// environment running the tests. +func (s *Scenario) Shutdown() { + s.ShutdownAssertNoPanics(nil) } // Users returns the name of all users associated with the Scenario. @@ -233,7 +418,11 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { return headscale, nil } - headscale, err := hsic.New(s.pool, s.network, opts...) + if usePostgresForTest { + opts = append(opts, hsic.WithPostgres()) + } + + headscale, err := hsic.New(s.pool, s.Networks(), opts...) if err != nil { return nil, fmt.Errorf("failed to create headscale container: %w", err) } @@ -248,10 +437,32 @@ func (s *Scenario) Headscale(opts ...hsic.Option) (ControlServer, error) { return headscale, nil } +// Pool returns the dockertest pool for the scenario. +func (s *Scenario) Pool() *dockertest.Pool { + return s.pool +} + +// GetOrCreateUser gets or creates a user in the scenario. +func (s *Scenario) GetOrCreateUser(userStr string) *User { + s.mu.Lock() + defer s.mu.Unlock() + + if user, ok := s.users[userStr]; ok { + return user + } + + user := &User{ + Clients: make(map[string]TailscaleClient), + } + s.users[userStr] = user + + return user +} + // CreatePreAuthKey creates a "pre authentorised key" to be created in the // Headscale instance on behalf of the Scenario. func (s *Scenario) CreatePreAuthKey( - user string, + user uint64, reusable bool, ephemeral bool, ) (*v1.PreAuthKey, error) { @@ -267,27 +478,109 @@ func (s *Scenario) CreatePreAuthKey( return nil, fmt.Errorf("failed to create user: %w", errNoHeadscaleAvailable) } +// CreatePreAuthKeyWithOptions creates a "pre authorised key" with the specified options +// to be created in the Headscale instance on behalf of the Scenario. +func (s *Scenario) CreatePreAuthKeyWithOptions(opts hsic.AuthKeyOptions) (*v1.PreAuthKey, error) { + headscale, err := s.Headscale() + if err != nil { + return nil, fmt.Errorf("failed to create preauth key with options: %w", errNoHeadscaleAvailable) + } + + key, err := headscale.CreateAuthKeyWithOptions(opts) + if err != nil { + return nil, fmt.Errorf("failed to create preauth key with options: %w", err) + } + + return key, nil +} + +// CreatePreAuthKeyWithTags creates a "pre authorised key" with the specified tags +// to be created in the Headscale instance on behalf of the Scenario. +func (s *Scenario) CreatePreAuthKeyWithTags( + user uint64, + reusable bool, + ephemeral bool, + tags []string, +) (*v1.PreAuthKey, error) { + headscale, err := s.Headscale() + if err != nil { + return nil, fmt.Errorf("failed to create preauth key with tags: %w", errNoHeadscaleAvailable) + } + + key, err := headscale.CreateAuthKeyWithTags(user, reusable, ephemeral, tags) + if err != nil { + return nil, fmt.Errorf("failed to create preauth key with tags: %w", err) + } + + return key, nil +} + // CreateUser creates a User to be created in the // Headscale instance on behalf of the Scenario. -func (s *Scenario) CreateUser(user string) error { +func (s *Scenario) CreateUser(user string) (*v1.User, error) { if headscale, err := s.Headscale(); err == nil { - err := headscale.CreateUser(user) + u, err := headscale.CreateUser(user) if err != nil { - return fmt.Errorf("failed to create user: %w", err) + return nil, fmt.Errorf("failed to create user: %w", err) } + s.mu.Lock() s.users[user] = &User{ Clients: make(map[string]TailscaleClient), } + s.mu.Unlock() - return nil + return u, nil } - return fmt.Errorf("failed to create user: %w", errNoHeadscaleAvailable) + return nil, fmt.Errorf("failed to create user: %w", errNoHeadscaleAvailable) } /// Client related stuff +func (s *Scenario) CreateTailscaleNode( + version string, + opts ...tsic.Option, +) (TailscaleClient, error) { + headscale, err := s.Headscale() + if err != nil { + return nil, fmt.Errorf("failed to create tailscale node (version: %s): %w", version, err) + } + + cert := headscale.GetCert() + hostname := headscale.GetHostname() + + s.mu.Lock() + defer s.mu.Unlock() + opts = append(opts, + tsic.WithCACert(cert), + tsic.WithHeadscaleName(hostname), + ) + + tsClient, err := tsic.New( + s.pool, + version, + opts..., + ) + if err != nil { + return nil, fmt.Errorf( + "failed to create tailscale node: %w", + err, + ) + } + + err = tsClient.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) + if err != nil { + return nil, fmt.Errorf( + "failed to wait for tailscaled (%s) to need login: %w", + tsClient.Hostname(), + err, + ) + } + + return tsClient, nil +} + // CreateTailscaleNodesInUser creates and adds a new TailscaleClient to a // User in the Scenario. func (s *Scenario) CreateTailscaleNodesInUser( @@ -298,10 +591,14 @@ func (s *Scenario) CreateTailscaleNodesInUser( ) error { if user, ok := s.users[userStr]; ok { var versions []string - for i := 0; i < count; i++ { + for i := range count { version := requestedVersion if requestedVersion == "all" { - version = MustTestVersions[i%len(MustTestVersions)] + if s.spec.Versions != nil { + version = s.spec.Versions[i%len(s.spec.Versions)] + } else { + version = MustTestVersions[i%len(MustTestVersions)] + } } versions = append(versions, version) @@ -313,27 +610,43 @@ func (s *Scenario) CreateTailscaleNodesInUser( cert := headscale.GetCert() hostname := headscale.GetHostname() + // Determine which network this tailscale client will be in + var network *dockertest.Network + if s.userToNetwork != nil && s.userToNetwork[userStr] != nil { + network = s.userToNetwork[userStr] + } else { + network = s.networks[s.testDefaultNetwork] + } + + // Get headscale IP in this network for /etc/hosts fallback DNS + headscaleIP := headscale.GetIPInNetwork(network) + extraHosts := []string{hostname + ":" + headscaleIP} + + s.mu.Lock() opts = append(opts, - tsic.WithHeadscaleTLS(cert), + tsic.WithCACert(cert), tsic.WithHeadscaleName(hostname), + tsic.WithExtraHosts(extraHosts), ) + s.mu.Unlock() + user.createWaitGroup.Go(func() error { + s.mu.Lock() tsClient, err := tsic.New( s.pool, version, - s.network, opts..., ) + s.mu.Unlock() if err != nil { return fmt.Errorf( - "failed to create tailscale (%s) node: %w", - tsClient.Hostname(), + "failed to create tailscale node: %w", err, ) } - err = tsClient.WaitForNeedsLogin() + err = tsClient.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) if err != nil { return fmt.Errorf( "failed to wait for tailscaled (%s) to need login: %w", @@ -353,7 +666,7 @@ func (s *Scenario) CreateTailscaleNodesInUser( return err } - log.Printf("testing versions %v", lo.Uniq(versions)) + log.Printf("testing versions %v, MustTestVersions %v", lo.Uniq(versions), MustTestVersions) return nil } @@ -379,7 +692,7 @@ func (s *Scenario) RunTailscaleUp( } for _, client := range user.Clients { - err := client.WaitForRunning() + err := client.WaitForRunning(integrationutil.PeerSyncTimeout()) if err != nil { return fmt.Errorf("%s failed to up tailscale node: %w", client.Hostname(), err) } @@ -408,12 +721,14 @@ func (s *Scenario) CountTailscale() int { func (s *Scenario) WaitForTailscaleSync() error { tsCount := s.CountTailscale() - err := s.WaitForTailscaleSyncWithPeerCount(tsCount - 1) + err := s.WaitForTailscaleSyncWithPeerCount(tsCount-1, integrationutil.PeerSyncTimeout(), integrationutil.PeerSyncRetryInterval()) if err != nil { for _, user := range s.users { for _, client := range user.Clients { - peers, _ := client.PrettyPeers() - log.Println(peers) + peers, allOnline, _ := client.FailingPeersAsString() + if !allOnline { + log.Println(peers) + } } } } @@ -421,62 +736,439 @@ func (s *Scenario) WaitForTailscaleSync() error { return err } -// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports -// to have all other TailscaleClients present in their netmap.NetworkMap. -func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int) error { +// WaitForTailscaleSyncPerUser blocks execution until each TailscaleClient has the expected +// number of peers for its user. This is useful for policies like autogroup:self where nodes +// only see same-user peers, not all nodes in the network. +func (s *Scenario) WaitForTailscaleSyncPerUser(timeout, retryInterval time.Duration) error { + var allErrors []error + for _, user := range s.users { + // Calculate expected peer count: number of nodes in this user minus 1 (self) + expectedPeers := len(user.Clients) - 1 + for _, client := range user.Clients { c := client + expectedCount := expectedPeers user.syncWaitGroup.Go(func() error { - return c.WaitForPeers(peerCount) + return c.WaitForPeers(expectedCount, timeout, retryInterval) }) } if err := user.syncWaitGroup.Wait(); err != nil { - return err + allErrors = append(allErrors, err) } } + if len(allErrors) > 0 { + return multierr.New(allErrors...) + } + return nil } -// CreateHeadscaleEnv is a conventient method returning a complete Headcale -// test environment with nodes of all versions, joined to the server with X -// users. -func (s *Scenario) CreateHeadscaleEnv( - users map[string]int, +// WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports +// to have all other TailscaleClients present in their netmap.NetworkMap. +func (s *Scenario) WaitForTailscaleSyncWithPeerCount(peerCount int, timeout, retryInterval time.Duration) error { + var allErrors []error + + for _, user := range s.users { + for _, client := range user.Clients { + c := client + user.syncWaitGroup.Go(func() error { + return c.WaitForPeers(peerCount, timeout, retryInterval) + }) + } + if err := user.syncWaitGroup.Wait(); err != nil { + allErrors = append(allErrors, err) + } + } + + if len(allErrors) > 0 { + return multierr.New(allErrors...) + } + + return nil +} + +func (s *Scenario) CreateHeadscaleEnvWithLoginURL( tsOpts []tsic.Option, opts ...hsic.Option, +) error { + return s.createHeadscaleEnv(true, tsOpts, opts...) +} + +func (s *Scenario) CreateHeadscaleEnv( + tsOpts []tsic.Option, + opts ...hsic.Option, +) error { + return s.createHeadscaleEnv(false, tsOpts, opts...) +} + +// CreateHeadscaleEnv starts the headscale environment and the clients +// according to the ScenarioSpec passed to the Scenario. +func (s *Scenario) createHeadscaleEnv( + withURL bool, + tsOpts []tsic.Option, + opts ...hsic.Option, +) error { + return s.createHeadscaleEnvWithTags(withURL, tsOpts, nil, "", opts...) +} + +// createHeadscaleEnvWithTags starts the headscale environment and the clients +// according to the ScenarioSpec passed to the Scenario. If preAuthKeyTags is +// non-empty and withURL is false, the tags will be applied to the PreAuthKey +// (tags-as-identity model). +// +// For webauth (withURL=true), if webauthTagUser is non-empty and preAuthKeyTags +// is non-empty, only nodes belonging to that user will request tags via +// --advertise-tags. This is necessary because tagOwners ACL controls which +// users can request specific tags. +func (s *Scenario) createHeadscaleEnvWithTags( + withURL bool, + tsOpts []tsic.Option, + preAuthKeyTags []string, + webauthTagUser string, + opts ...hsic.Option, ) error { headscale, err := s.Headscale(opts...) if err != nil { return err } - for userName, clientCount := range users { - err = s.CreateUser(userName) + for _, user := range s.spec.Users { + u, err := s.CreateUser(user) if err != nil { return err } - err = s.CreateTailscaleNodesInUser(userName, "all", clientCount, tsOpts...) + var userOpts []tsic.Option + if s.userToNetwork != nil { + userOpts = append(tsOpts, tsic.WithNetwork(s.userToNetwork[user])) + } else { + userOpts = append(tsOpts, tsic.WithNetwork(s.networks[s.testDefaultNetwork])) + } + + // For webauth with tags, only apply tags to the specified webauthTagUser + // (other users may not be authorized via tagOwners) + if withURL && webauthTagUser != "" && len(preAuthKeyTags) > 0 && user == webauthTagUser { + userOpts = append(userOpts, tsic.WithTags(preAuthKeyTags)) + } + + err = s.CreateTailscaleNodesInUser(user, "all", s.spec.NodesPerUser, userOpts...) if err != nil { return err } - key, err := s.CreatePreAuthKey(userName, true, false) - if err != nil { - return err - } + if withURL { + err = s.RunTailscaleUpWithURL(user, headscale.GetEndpoint()) + if err != nil { + return err + } + } else { + // Use tagged PreAuthKey if tags are provided (tags-as-identity model) + var key *v1.PreAuthKey + if len(preAuthKeyTags) > 0 { + key, err = s.CreatePreAuthKeyWithTags(u.GetId(), true, false, preAuthKeyTags) + } else { + key, err = s.CreatePreAuthKey(u.GetId(), true, false) + } + if err != nil { + return err + } - err = s.RunTailscaleUp(userName, headscale.GetEndpoint(), key.GetKey()) - if err != nil { - return err + err = s.RunTailscaleUp(user, headscale.GetEndpoint(), key.GetKey()) + if err != nil { + return err + } } } return nil } +func (s *Scenario) RunTailscaleUpWithURL(userStr, loginServer string) error { + log.Printf("running tailscale up for user %s", userStr) + if user, ok := s.users[userStr]; ok { + for _, client := range user.Clients { + tsc := client + user.joinWaitGroup.Go(func() error { + loginURL, err := tsc.LoginWithURL(loginServer) + if err != nil { + log.Printf("%s failed to run tailscale up: %s", tsc.Hostname(), err) + } + + body, err := doLoginURL(tsc.Hostname(), loginURL) + if err != nil { + return err + } + + // If the URL is not a OIDC URL, then we need to + // run the register command to fully log in the client. + if !strings.Contains(loginURL.String(), "/oidc/") { + s.runHeadscaleRegister(userStr, body) + } + + return nil + }) + + log.Printf("client %s is ready", client.Hostname()) + } + + if err := user.joinWaitGroup.Wait(); err != nil { + return err + } + + for _, client := range user.Clients { + err := client.WaitForRunning(integrationutil.PeerSyncTimeout()) + if err != nil { + return fmt.Errorf( + "%s tailscale node has not reached running: %w", + client.Hostname(), + err, + ) + } + } + + return nil + } + + return fmt.Errorf("failed to up tailscale node: %w", errNoUserAvailable) +} + +type debugJar struct { + inner *cookiejar.Jar + mu sync.RWMutex + store map[string]map[string]map[string]*http.Cookie // domain -> path -> name -> cookie +} + +func newDebugJar() (*debugJar, error) { + jar, err := cookiejar.New(nil) + if err != nil { + return nil, err + } + return &debugJar{ + inner: jar, + store: make(map[string]map[string]map[string]*http.Cookie), + }, nil +} + +func (j *debugJar) SetCookies(u *url.URL, cookies []*http.Cookie) { + j.inner.SetCookies(u, cookies) + + j.mu.Lock() + defer j.mu.Unlock() + + for _, c := range cookies { + if c == nil || c.Name == "" { + continue + } + domain := c.Domain + if domain == "" { + domain = u.Hostname() + } + path := c.Path + if path == "" { + path = "/" + } + if _, ok := j.store[domain]; !ok { + j.store[domain] = make(map[string]map[string]*http.Cookie) + } + if _, ok := j.store[domain][path]; !ok { + j.store[domain][path] = make(map[string]*http.Cookie) + } + j.store[domain][path][c.Name] = copyCookie(c) + } +} + +func (j *debugJar) Cookies(u *url.URL) []*http.Cookie { + return j.inner.Cookies(u) +} + +func (j *debugJar) Dump(w io.Writer) { + j.mu.RLock() + defer j.mu.RUnlock() + + for domain, paths := range j.store { + fmt.Fprintf(w, "Domain: %s\n", domain) + for path, byName := range paths { + fmt.Fprintf(w, " Path: %s\n", path) + for _, c := range byName { + fmt.Fprintf( + w, " %s=%s; Expires=%v; Secure=%v; HttpOnly=%v; SameSite=%v\n", + c.Name, c.Value, c.Expires, c.Secure, c.HttpOnly, c.SameSite, + ) + } + } + } +} + +func copyCookie(c *http.Cookie) *http.Cookie { + cc := *c + return &cc +} + +func newLoginHTTPClient(hostname string) (*http.Client, error) { + hc := &http.Client{ + Transport: LoggingRoundTripper{Hostname: hostname}, + } + + jar, err := newDebugJar() + if err != nil { + return nil, fmt.Errorf("%s failed to create cookiejar: %w", hostname, err) + } + + hc.Jar = jar + + return hc, nil +} + +// doLoginURL visits the given login URL and returns the body as a string. +func doLoginURL(hostname string, loginURL *url.URL) (string, error) { + log.Printf("%s login url: %s\n", hostname, loginURL.String()) + + hc, err := newLoginHTTPClient(hostname) + if err != nil { + return "", err + } + + body, _, err := doLoginURLWithClient(hostname, loginURL, hc, true) + if err != nil { + return "", err + } + + return body, nil +} + +// doLoginURLWithClient performs the login request using the provided HTTP client. +// When followRedirects is false, it will return the first redirect without following it. +func doLoginURLWithClient(hostname string, loginURL *url.URL, hc *http.Client, followRedirects bool) ( + string, + *url.URL, + error, +) { + if hc == nil { + return "", nil, fmt.Errorf("%s http client is nil", hostname) + } + + if loginURL == nil { + return "", nil, fmt.Errorf("%s login url is nil", hostname) + } + + log.Printf("%s logging in with url: %s", hostname, loginURL.String()) + ctx := context.Background() + req, err := http.NewRequestWithContext(ctx, http.MethodGet, loginURL.String(), nil) + if err != nil { + return "", nil, fmt.Errorf("%s failed to create http request: %w", hostname, err) + } + + originalRedirect := hc.CheckRedirect + if !followRedirects { + hc.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + } + defer func() { + hc.CheckRedirect = originalRedirect + }() + + resp, err := hc.Do(req) + if err != nil { + return "", nil, fmt.Errorf("%s failed to send http request: %w", hostname, err) + } + defer resp.Body.Close() + + bodyBytes, err := io.ReadAll(resp.Body) + if err != nil { + return "", nil, fmt.Errorf("%s failed to read response body: %w", hostname, err) + } + body := string(bodyBytes) + + var redirectURL *url.URL + if resp.StatusCode >= http.StatusMultipleChoices && resp.StatusCode < http.StatusBadRequest { + redirectURL, err = resp.Location() + if err != nil { + return body, nil, fmt.Errorf("%s failed to resolve redirect location: %w", hostname, err) + } + } + + if followRedirects && resp.StatusCode != http.StatusOK { + log.Printf("body: %s", body) + + return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + } + + if resp.StatusCode >= http.StatusBadRequest { + log.Printf("body: %s", body) + + return body, redirectURL, fmt.Errorf("%s unexpected status code %d", hostname, resp.StatusCode) + } + + if hc.Jar != nil { + if jar, ok := hc.Jar.(*debugJar); ok { + jar.Dump(os.Stdout) + } else { + log.Printf("cookies: %+v", hc.Jar.Cookies(loginURL)) + } + } + + return body, redirectURL, nil +} + +var errParseAuthPage = errors.New("failed to parse auth page") + +func (s *Scenario) runHeadscaleRegister(userStr string, body string) error { + // see api.go HTML template + codeSep := strings.Split(string(body), "</code>") + if len(codeSep) != 2 { + return errParseAuthPage + } + + keySep := strings.Split(codeSep[0], "key ") + if len(keySep) != 2 { + return errParseAuthPage + } + key := keySep[1] + key = strings.SplitN(key, " ", 2)[0] + log.Printf("registering node %s", key) + + if headscale, err := s.Headscale(); err == nil { + _, err = headscale.Execute( + []string{"headscale", "nodes", "register", "--user", userStr, "--key", key}, + ) + if err != nil { + log.Printf("failed to register node: %s", err) + + return err + } + + return nil + } + + return fmt.Errorf("failed to find headscale: %w", errNoHeadscaleAvailable) +} + +type LoggingRoundTripper struct { + Hostname string +} + +func (t LoggingRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + noTls := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, // nolint + } + resp, err := noTls.RoundTrip(req) + if err != nil { + return nil, err + } + + log.Printf(` +--- +%s - method: %s | url: %s +%s - status: %d | cookies: %+v +--- +`, t.Hostname, req.Method, req.URL.String(), t.Hostname, resp.StatusCode, resp.Cookies()) + + return resp, nil +} + // GetIPs returns all netip.Addr of TailscaleClients associated with a User // in a Scenario. func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { @@ -496,7 +1188,7 @@ func (s *Scenario) GetIPs(user string) ([]netip.Addr, error) { return ips, fmt.Errorf("failed to get ips: %w", errNoUserAvailable) } -// GetIPs returns all TailscaleClients associated with a User in a Scenario. +// GetClients returns all TailscaleClients associated with a User in a Scenario. func (s *Scenario) GetClients(user string) ([]TailscaleClient, error) { var clients []TailscaleClient if ns, ok := s.users[user]; ok { @@ -541,10 +1233,8 @@ func (s *Scenario) FindTailscaleClientByIP(ip netip.Addr) (TailscaleClient, erro for _, client := range clients { ips, _ := client.IPs() - for _, ip2 := range ips { - if ip == ip2 { - return client, nil - } + if slices.Contains(ips, ip) { + return client, nil } } @@ -572,7 +1262,7 @@ func (s *Scenario) ListTailscaleClientsIPs(users ...string) ([]netip.Addr, error return allIps, nil } -// ListTailscaleClientsIPs returns a list of FQDN based on Users +// ListTailscaleClientsFQDNs returns a list of FQDN based on Users // passed as parameters. func (s *Scenario) ListTailscaleClientsFQDNs(users ...string) ([]string, error) { allFQDNs := make([]string, 0) @@ -601,7 +1291,7 @@ func (s *Scenario) WaitForTailscaleLogout() error { for _, client := range user.Clients { c := client user.syncWaitGroup.Go(func() error { - return c.WaitForNeedsLogin() + return c.WaitForNeedsLogin(integrationutil.PeerSyncTimeout()) }) } if err := user.syncWaitGroup.Wait(); err != nil { @@ -611,3 +1301,239 @@ func (s *Scenario) WaitForTailscaleLogout() error { return nil } + +// CreateDERPServer creates a new DERP server in a container. +func (s *Scenario) CreateDERPServer(version string, opts ...dsic.Option) (*dsic.DERPServerInContainer, error) { + derp, err := dsic.New(s.pool, version, s.Networks(), opts...) + if err != nil { + return nil, fmt.Errorf("failed to create DERP server: %w", err) + } + + err = derp.WaitForRunning() + if err != nil { + return nil, fmt.Errorf("failed to reach DERP server: %w", err) + } + + s.derpServers = append(s.derpServers, derp) + + return derp, nil +} + +type scenarioOIDC struct { + r *dockertest.Resource + cfg *types.OIDCConfig +} + +func (o *scenarioOIDC) Issuer() string { + if o.cfg == nil { + panic("OIDC has not been created") + } + + return o.cfg.Issuer +} + +func (o *scenarioOIDC) ClientSecret() string { + if o.cfg == nil { + panic("OIDC has not been created") + } + + return o.cfg.ClientSecret +} + +func (o *scenarioOIDC) ClientID() string { + if o.cfg == nil { + panic("OIDC has not been created") + } + + return o.cfg.ClientID +} + +const ( + dockerContextPath = "../." + hsicOIDCMockHashLength = 6 + defaultAccessTTL = 10 * time.Minute +) + +var errStatusCodeNotOK = errors.New("status code not OK") + +func (s *Scenario) runMockOIDC(accessTTL time.Duration, users []mockoidc.MockUser) error { + port, err := dockertestutil.RandomFreeHostPort() + if err != nil { + log.Fatalf("could not find an open port: %s", err) + } + portNotation := fmt.Sprintf("%d/tcp", port) + + hash, _ := util.GenerateRandomStringDNSSafe(hsicOIDCMockHashLength) + + hostname := "hs-oidcmock-" + hash + + usersJSON, err := json.Marshal(users) + if err != nil { + return err + } + + mockOidcOptions := &dockertest.RunOptions{ + Name: hostname, + Cmd: []string{"headscale", "mockoidc"}, + ExposedPorts: []string{portNotation}, + PortBindings: map[docker.Port][]docker.PortBinding{ + docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, + }, + Networks: s.Networks(), + Env: []string{ + "MOCKOIDC_ADDR=" + hostname, + fmt.Sprintf("MOCKOIDC_PORT=%d", port), + "MOCKOIDC_CLIENT_ID=superclient", + "MOCKOIDC_CLIENT_SECRET=supersecret", + "MOCKOIDC_ACCESS_TTL=" + accessTTL.String(), + "MOCKOIDC_USERS=" + string(usersJSON), + }, + } + + headscaleBuildOptions := &dockertest.BuildOptions{ + Dockerfile: hsic.IntegrationTestDockerFileName, + ContextDir: dockerContextPath, + } + + err = s.pool.RemoveContainerByName(hostname) + if err != nil { + return err + } + + s.mockOIDC = scenarioOIDC{} + + // Add integration test labels if running under hi tool + dockertestutil.DockerAddIntegrationLabels(mockOidcOptions, "oidc") + + if pmockoidc, err := s.pool.BuildAndRunWithBuildOptions( + headscaleBuildOptions, + mockOidcOptions, + dockertestutil.DockerRestartPolicy); err == nil { + s.mockOIDC.r = pmockoidc + } else { + return err + } + + // headscale needs to set up the provider with a specific + // IP addr to ensure we get the correct config from the well-known + // endpoint. + network := s.Networks()[0] + ipAddr := s.mockOIDC.r.GetIPInNetwork(network) + + log.Println("Waiting for headscale mock oidc to be ready for tests") + hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) + + if err := s.pool.Retry(func() error { + oidcConfigURL := fmt.Sprintf("http://%s/oidc/.well-known/openid-configuration", hostEndpoint) + httpClient := &http.Client{} + ctx := context.Background() + req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) + resp, err := httpClient.Do(req) + if err != nil { + log.Printf("headscale mock OIDC tests is not ready: %s\n", err) + + return err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errStatusCodeNotOK + } + + return nil + }); err != nil { + return err + } + + s.mockOIDC.cfg = &types.OIDCConfig{ + Issuer: fmt.Sprintf( + "http://%s/oidc", + hostEndpoint, + ), + ClientID: "superclient", + ClientSecret: "supersecret", + OnlyStartIfOIDCIsAvailable: true, + } + + log.Printf("headscale mock oidc is ready for tests at %s", hostEndpoint) + + return nil +} + +type extraServiceFunc func(*Scenario, string) (*dockertest.Resource, error) + +func Webservice(s *Scenario, networkName string) (*dockertest.Resource, error) { + // port, err := dockertestutil.RandomFreeHostPort() + // if err != nil { + // log.Fatalf("could not find an open port: %s", err) + // } + // portNotation := fmt.Sprintf("%d/tcp", port) + + hash := util.MustGenerateRandomStringDNSSafe(hsicOIDCMockHashLength) + + hostname := "hs-webservice-" + hash + + network, ok := s.networks[s.prefixedNetworkName(networkName)] + if !ok { + return nil, fmt.Errorf("network does not exist: %s", networkName) + } + + webOpts := &dockertest.RunOptions{ + Name: hostname, + Cmd: []string{"/bin/sh", "-c", "cd / ; python3 -m http.server --bind :: 80"}, + // ExposedPorts: []string{portNotation}, + // PortBindings: map[docker.Port][]docker.PortBinding{ + // docker.Port(portNotation): {{HostPort: strconv.Itoa(port)}}, + // }, + Networks: []*dockertest.Network{network}, + Env: []string{}, + } + + // Add integration test labels if running under hi tool + dockertestutil.DockerAddIntegrationLabels(webOpts, "web") + + webBOpts := &dockertest.BuildOptions{ + Dockerfile: hsic.IntegrationTestDockerFileName, + ContextDir: dockerContextPath, + } + + web, err := s.pool.BuildAndRunWithBuildOptions( + webBOpts, + webOpts, + dockertestutil.DockerRestartPolicy) + if err != nil { + return nil, err + } + + // headscale needs to set up the provider with a specific + // IP addr to ensure we get the correct config from the well-known + // endpoint. + // ipAddr := web.GetIPInNetwork(network) + + // log.Println("Waiting for headscale mock oidc to be ready for tests") + // hostEndpoint := net.JoinHostPort(ipAddr, strconv.Itoa(port)) + + // if err := s.pool.Retry(func() error { + // oidcConfigURL := fmt.Sprintf("http://%s/etc/hostname", hostEndpoint) + // httpClient := &http.Client{} + // ctx := context.Background() + // req, _ := http.NewRequestWithContext(ctx, http.MethodGet, oidcConfigURL, nil) + // resp, err := httpClient.Do(req) + // if err != nil { + // log.Printf("headscale mock OIDC tests is not ready: %s\n", err) + + // return err + // } + // defer resp.Body.Close() + + // if resp.StatusCode != http.StatusOK { + // return errStatusCodeNotOK + // } + + // return nil + // }); err != nil { + // return err + // } + + return web, nil +} diff --git a/integration/scenario_test.go b/integration/scenario_test.go index 59b6a33c..1e2a151a 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -4,10 +4,12 @@ import ( "testing" "github.com/juanfont/headscale/integration/dockertestutil" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" ) // This file is intended to "test the test framework", by proxy it will also test -// some Headcsale/Tailscale stuff, but mostly in very simple ways. +// some Headscale/Tailscale stuff, but mostly in very simple ways. func IntegrationSkip(t *testing.T) { t.Helper() @@ -27,15 +29,14 @@ func IntegrationSkip(t *testing.T) { // nolint:tparallel func TestHeadscale(t *testing.T) { IntegrationSkip(t) - t.Parallel() var err error user := "test-space" - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() + scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) t.Run("start-headscale", func(t *testing.T) { headscale, err := scenario.Headscale() @@ -50,7 +51,7 @@ func TestHeadscale(t *testing.T) { }) t.Run("create-user", func(t *testing.T) { - err := scenario.CreateUser(user) + _, err := scenario.CreateUser(user) if err != nil { t.Fatalf("failed to create user: %s", err) } @@ -61,52 +62,19 @@ func TestHeadscale(t *testing.T) { }) t.Run("create-auth-key", func(t *testing.T) { - _, err := scenario.CreatePreAuthKey(user, true, false) + _, err := scenario.CreatePreAuthKey(1, true, false) if err != nil { t.Fatalf("failed to create preauthkey: %s", err) } }) } -// If subtests are parallel, then they will start before setup is run. -// This might mean we approach setup slightly wrong, but for now, ignore -// the linter -// nolint:tparallel -func TestCreateTailscale(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - user := "only-create-containers" - - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() - - scenario.users[user] = &User{ - Clients: make(map[string]TailscaleClient), - } - - t.Run("create-tailscale", func(t *testing.T) { - err := scenario.CreateTailscaleNodesInUser(user, "all", 3) - if err != nil { - t.Fatalf("failed to add tailscale nodes: %s", err) - } - - if clients := len(scenario.users[user].Clients); clients != 3 { - t.Fatalf("wrong number of tailscale clients: %d != %d", clients, 3) - } - - // TODO(kradalby): Test "all" version logic - }) -} - // If subtests are parallel, then they will start before setup is run. // This might mean we approach setup slightly wrong, but for now, ignore // the linter // nolint:tparallel func TestTailscaleNodesJoiningHeadcale(t *testing.T) { IntegrationSkip(t) - t.Parallel() var err error @@ -114,9 +82,9 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 - scenario, err := NewScenario() - assertNoErr(t, err) - defer scenario.Shutdown() + scenario, err := NewScenario(ScenarioSpec{}) + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) t.Run("start-headscale", func(t *testing.T) { headscale, err := scenario.Headscale() @@ -131,7 +99,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { }) t.Run("create-user", func(t *testing.T) { - err := scenario.CreateUser(user) + _, err := scenario.CreateUser(user) if err != nil { t.Fatalf("failed to create user: %s", err) } @@ -142,7 +110,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { }) t.Run("create-tailscale", func(t *testing.T) { - err := scenario.CreateTailscaleNodesInUser(user, "1.30.2", count) + err := scenario.CreateTailscaleNodesInUser(user, "unstable", count, tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork])) if err != nil { t.Fatalf("failed to add tailscale nodes: %s", err) } @@ -153,7 +121,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { }) t.Run("join-headscale", func(t *testing.T) { - key, err := scenario.CreatePreAuthKey(user, true, false) + key, err := scenario.CreatePreAuthKey(1, true, false) if err != nil { t.Fatalf("failed to create preauthkey: %s", err) } diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 587190e4..2986bcea 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -7,52 +7,34 @@ import ( "testing" "time" - "github.com/juanfont/headscale/hscontrol/policy" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/integration/hsic" "github.com/juanfont/headscale/integration/tsic" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" ) -var retry = func(times int, sleepInterval time.Duration, - doWork func() (string, string, error), -) (string, string, error) { - var result string - var stderr string - var err error - - for attempts := 0; attempts < times; attempts++ { - tempResult, tempStderr, err := doWork() - - result += tempResult - stderr += tempStderr - - if err == nil { - return result, stderr, nil - } - - // If we get a permission denied error, we can fail immediately - // since that is something we wont recover from by retrying. - if err != nil && strings.Contains(stderr, "Permission denied (tailscale)") { - return result, stderr, err - } - - time.Sleep(sleepInterval) - } - - return result, stderr, err +func isSSHNoAccessStdError(stderr string) bool { + return strings.Contains(stderr, "Permission denied (tailscale)") || + // Since https://github.com/tailscale/tailscale/pull/14853 + strings.Contains(stderr, "failed to evaluate SSH policy") || + // Since https://github.com/tailscale/tailscale/pull/16127 + strings.Contains(stderr, "tailnet policy does not permit you to SSH to this node") } -func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { +func sshScenario(t *testing.T, policy *policyv2.Policy, clientsPerUser int) *Scenario { t.Helper() - scenario, err := NewScenario() - assertNoErr(t, err) - spec := map[string]int{ - "user1": clientsPerUser, - "user2": clientsPerUser, + spec := ScenarioSpec{ + NodesPerUser: clientsPerUser, + Users: []string{"user1", "user2"}, } + scenario, err := NewScenario(spec) + require.NoError(t, err) - err = scenario.CreateHeadscaleEnv(spec, + err = scenario.CreateHeadscaleEnv( []tsic.Option{ tsic.WithSSH(), @@ -60,73 +42,74 @@ func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Sc // tailscaled to stop configuring the wgengine, causing it // to not configure DNS. tsic.WithNetfilter("off"), - tsic.WithDockerEntrypoint([]string{ - "/bin/sh", - "-c", - "/bin/sleep 3 ; apk add openssh ; adduser ssh-it-user ; update-ca-certificates ; tailscaled --tun=tsdev", - }), + tsic.WithPackages("openssh"), + tsic.WithExtraCommands("adduser ssh-it-user"), tsic.WithDockerWorkdir("/"), }, hsic.WithACLPolicy(policy), hsic.WithTestName("ssh"), - hsic.WithConfigEnv(map[string]string{ - "HEADSCALE_EXPERIMENTAL_FEATURE_SSH": "1", - }), ) - assertNoErr(t, err) + require.NoError(t, err) err = scenario.WaitForTailscaleSync() - assertNoErr(t, err) + require.NoError(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErr(t, err) + require.NoError(t, err) return scenario } func TestSSHOneUserToAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, - &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policy.SSH{ + SSHs: []policyv2.SSH{ { - Action: "accept", - Sources: []string{"group:integration-test"}, - Destinations: []string{"*"}, - Users: []string{"ssh-it-user"}, + Action: "accept", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + // Use autogroup:member and autogroup:tagged instead of wildcard + // since wildcard (*) is no longer supported for SSH destinations + Destinations: policyv2.SSHDstAliases{ + ptr.To(policyv2.AutoGroupMember), + ptr.To(policyv2.AutoGroupTagged), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, len(MustTestVersions), ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) user1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) user2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range user1Clients { for _, peer := range allClients { @@ -149,89 +132,125 @@ func TestSSHOneUserToAll(t *testing.T) { } } +// TestSSHMultipleUsersAllToAll tests that users in a group can SSH to each other's devices +// using autogroup:self as the destination, which allows same-user SSH access. func TestSSHMultipleUsersAllToAll(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, - &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1", "user2"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@"), policyv2.Username("user2@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policy.SSH{ + SSHs: []policyv2.SSH{ { - Action: "accept", - Sources: []string{"group:integration-test"}, - Destinations: []string{"group:integration-test"}, - Users: []string{"ssh-it-user"}, + Action: "accept", + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + // Use autogroup:self to allow users to SSH to their own devices. + // Username destinations (e.g., "user1@") now require the source + // to be that exact same user only. For group-to-group SSH access, + // use autogroup:self instead. + Destinations: policyv2.SSHDstAliases{ptr.To(policyv2.AutoGroupSelf)}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, len(MustTestVersions), ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) nsOneClients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) nsTwoClients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) - testInterUserSSH := func(sourceClients []TailscaleClient, targetClients []TailscaleClient) { - for _, client := range sourceClients { - for _, peer := range targetClients { - assertSSHHostname(t, client, peer) + // With autogroup:self, users can SSH to their own devices, but not to other users' devices. + // Test that user1's devices can SSH to each other + for _, client := range nsOneClients { + for _, peer := range nsOneClients { + if client.Hostname() == peer.Hostname() { + continue } + + assertSSHHostname(t, client, peer) } } - testInterUserSSH(nsOneClients, nsTwoClients) - testInterUserSSH(nsTwoClients, nsOneClients) + // Test that user2's devices can SSH to each other + for _, client := range nsTwoClients { + for _, peer := range nsTwoClients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHHostname(t, client, peer) + } + } + + // Test that user1 cannot SSH to user2's devices (autogroup:self only allows same-user) + for _, client := range nsOneClients { + for _, peer := range nsTwoClients { + assertSSHPermissionDenied(t, client, peer) + } + } + + // Test that user2 cannot SSH to user1's devices (autogroup:self only allows same-user) + for _, client := range nsTwoClients { + for _, peer := range nsOneClients { + assertSSHPermissionDenied(t, client, peer) + } + } } func TestSSHNoSSHConfigured(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, - &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policy.SSH{}, + SSHs: []policyv2.SSH{}, }, len(MustTestVersions), ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -246,41 +265,43 @@ func TestSSHNoSSHConfigured(t *testing.T) { func TestSSHIsBlockedInACL(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, - &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:integration-test": {"user1"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:integration-test"): []policyv2.Username{policyv2.Username("user1@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:80"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRange{First: 80, Last: 80}), + }, }, }, - SSHs: []policy.SSH{ + SSHs: []policyv2.SSH{ { Action: "accept", - Sources: []string{"group:integration-test"}, - Destinations: []string{"group:integration-test"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:integration-test")}, + Destinations: policyv2.SSHDstAliases{ptr.To(policyv2.AutoGroupSelf)}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, len(MustTestVersions), ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range allClients { for _, peer := range allClients { @@ -295,51 +316,56 @@ func TestSSHIsBlockedInACL(t *testing.T) { func TestSSHUserOnlyIsolation(t *testing.T) { IntegrationSkip(t) - t.Parallel() scenario := sshScenario(t, - &policy.ACLPolicy{ - Groups: map[string][]string{ - "group:ssh1": {"user1"}, - "group:ssh2": {"user2"}, + &policyv2.Policy{ + Groups: policyv2.Groups{ + policyv2.Group("group:ssh1"): []policyv2.Username{policyv2.Username("user1@")}, + policyv2.Group("group:ssh2"): []policyv2.Username{policyv2.Username("user2@")}, }, - ACLs: []policy.ACL{ + ACLs: []policyv2.ACL{ { - Action: "accept", - Sources: []string{"*"}, - Destinations: []string{"*:*"}, + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, }, }, - SSHs: []policy.SSH{ + SSHs: []policyv2.SSH{ + // Use autogroup:self to allow users in each group to SSH to their own devices. + // Username destinations (e.g., "user1@") require the source to be that + // exact same user only, not a group containing that user. { Action: "accept", - Sources: []string{"group:ssh1"}, - Destinations: []string{"group:ssh1"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:ssh1")}, + Destinations: policyv2.SSHDstAliases{ptr.To(policyv2.AutoGroupSelf)}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, { Action: "accept", - Sources: []string{"group:ssh2"}, - Destinations: []string{"group:ssh2"}, - Users: []string{"ssh-it-user"}, + Sources: policyv2.SSHSrcAliases{groupp("group:ssh2")}, + Destinations: policyv2.SSHDstAliases{ptr.To(policyv2.AutoGroupSelf)}, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, }, }, }, len(MustTestVersions), ) - defer scenario.Shutdown() + defer scenario.ShutdownAssertNoPanics(t) ssh1Clients, err := scenario.ListTailscaleClients("user1") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) ssh2Clients, err := scenario.ListTailscaleClients("user2") - assertNoErrListClients(t, err) + requireNoErrListClients(t, err) err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) + requireNoErrSync(t, err) _, err = scenario.ListTailscaleClientsFQDNs() - assertNoErrListFQDN(t, err) + requireNoErrListFQDN(t, err) for _, client := range ssh1Clients { for _, peer := range ssh2Clients { @@ -384,6 +410,16 @@ func TestSSHUserOnlyIsolation(t *testing.T) { func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { t.Helper() + return doSSHWithRetry(t, client, peer, true) +} + +func doSSHWithoutRetry(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, string, error) { + t.Helper() + return doSSHWithRetry(t, client, peer, false) +} + +func doSSHWithRetry(t *testing.T, client TailscaleClient, peer TailscaleClient, retry bool) (string, string, error) { + t.Helper() peerFQDN, _ := peer.FQDN() @@ -396,34 +432,56 @@ func doSSH(t *testing.T, client TailscaleClient, peer TailscaleClient) (string, log.Printf("Running from %s to %s", client.Hostname(), peer.Hostname()) log.Printf("Command: %s", strings.Join(command, " ")) - return retry(10, 1*time.Second, func() (string, string, error) { - return client.Execute(command) - }) + var ( + result, stderr string + err error + ) + + if retry { + // Use assert.EventuallyWithT to retry SSH connections for success cases + assert.EventuallyWithT(t, func(ct *assert.CollectT) { + result, stderr, err = client.Execute(command) + + // If we get a permission denied error, we can fail immediately + // since that is something we won't recover from by retrying. + if err != nil && isSSHNoAccessStdError(stderr) { + return // Don't retry permission denied errors + } + + // For all other errors, assert no error to trigger retry + assert.NoError(ct, err) + }, 10*time.Second, 200*time.Millisecond) + } else { + // For failure cases, just execute once + result, stderr, err = client.Execute(command) + } + + return result, stderr, err } func assertSSHHostname(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() result, _, err := doSSH(t, client, peer) - assertNoErr(t, err) + require.NoError(t, err) - assertContains(t, peer.ID(), strings.ReplaceAll(result, "\n", "")) + require.Contains(t, peer.ContainerID(), strings.ReplaceAll(result, "\n", "")) } func assertSSHPermissionDenied(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() - result, stderr, _ := doSSH(t, client, peer) + result, stderr, err := doSSHWithoutRetry(t, client, peer) assert.Empty(t, result) - assertContains(t, stderr, "Permission denied (tailscale)") + assertSSHNoAccessStdError(t, err, stderr) } func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient) { t.Helper() - result, stderr, _ := doSSH(t, client, peer) + result, stderr, _ := doSSHWithoutRetry(t, client, peer) assert.Empty(t, result) @@ -432,3 +490,93 @@ func assertSSHTimeout(t *testing.T, client TailscaleClient, peer TailscaleClient t.Fatalf("connection did not time out") } } + +func assertSSHNoAccessStdError(t *testing.T, err error, stderr string) { + t.Helper() + assert.Error(t, err) + + if !isSSHNoAccessStdError(stderr) { + t.Errorf("expected stderr output suggesting access denied, got: %s", stderr) + } +} + +// TestSSHAutogroupSelf tests that SSH with autogroup:self works correctly: +// - Users can SSH to their own devices +// - Users cannot SSH to other users' devices. +func TestSSHAutogroupSelf(t *testing.T) { + IntegrationSkip(t) + + scenario := sshScenario(t, + &policyv2.Policy{ + ACLs: []policyv2.ACL{ + { + Action: "accept", + Protocol: "tcp", + Sources: []policyv2.Alias{wildcard()}, + Destinations: []policyv2.AliasWithPorts{ + aliasWithPorts(wildcard(), tailcfg.PortRangeAny), + }, + }, + }, + SSHs: []policyv2.SSH{ + { + Action: "accept", + Sources: policyv2.SSHSrcAliases{ + ptr.To(policyv2.AutoGroupMember), + }, + Destinations: policyv2.SSHDstAliases{ + ptr.To(policyv2.AutoGroupSelf), + }, + Users: []policyv2.SSHUser{policyv2.SSHUser("ssh-it-user")}, + }, + }, + }, + 2, // 2 clients per user + ) + defer scenario.ShutdownAssertNoPanics(t) + + user1Clients, err := scenario.ListTailscaleClients("user1") + requireNoErrListClients(t, err) + + user2Clients, err := scenario.ListTailscaleClients("user2") + requireNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + requireNoErrSync(t, err) + + // Test that user1's devices can SSH to each other + for _, client := range user1Clients { + for _, peer := range user1Clients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHHostname(t, client, peer) + } + } + + // Test that user2's devices can SSH to each other + for _, client := range user2Clients { + for _, peer := range user2Clients { + if client.Hostname() == peer.Hostname() { + continue + } + + assertSSHHostname(t, client, peer) + } + } + + // Test that user1 cannot SSH to user2's devices + for _, client := range user1Clients { + for _, peer := range user2Clients { + assertSSHPermissionDenied(t, client, peer) + } + } + + // Test that user2 cannot SSH to user1's devices + for _, client := range user2Clients { + for _, peer := range user1Clients { + assertSSHPermissionDenied(t, client, peer) + } + } +} diff --git a/integration/tags_test.go b/integration/tags_test.go new file mode 100644 index 00000000..5dad36e5 --- /dev/null +++ b/integration/tags_test.go @@ -0,0 +1,3118 @@ +package integration + +import ( + "sort" + "testing" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" + "github.com/juanfont/headscale/hscontrol/util" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "tailscale.com/tailcfg" + "tailscale.com/types/ptr" +) + +const tagTestUser = "taguser" + +// ============================================================================= +// Helper Functions +// ============================================================================= + +// tagsTestPolicy creates a policy for tag tests with: +// - tag:valid-owned: owned by the specified user +// - tag:second: owned by the specified user +// - tag:valid-unowned: owned by "other-user" (not the test user) +// - tag:nonexistent is deliberately NOT defined. +func tagsTestPolicy() *policyv2.Policy { + return &policyv2.Policy{ + TagOwners: policyv2.TagOwners{ + "tag:valid-owned": policyv2.Owners{ptr.To(policyv2.Username(tagTestUser + "@"))}, + "tag:second": policyv2.Owners{ptr.To(policyv2.Username(tagTestUser + "@"))}, + "tag:valid-unowned": policyv2.Owners{ptr.To(policyv2.Username("other-user@"))}, + // Note: tag:nonexistent deliberately NOT defined + }, + ACLs: []policyv2.ACL{ + { + Action: "accept", + Sources: []policyv2.Alias{policyv2.Wildcard}, + Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}}, + }, + }, + } +} + +// tagsEqual compares two tag slices as unordered sets. +func tagsEqual(actual, expected []string) bool { + if len(actual) != len(expected) { + return false + } + + sortedActual := append([]string{}, actual...) + sortedExpected := append([]string{}, expected...) + + sort.Strings(sortedActual) + sort.Strings(sortedExpected) + + for i := range sortedActual { + if sortedActual[i] != sortedExpected[i] { + return false + } + } + + return true +} + +// assertNodeHasTagsWithCollect asserts that a node has exactly the expected tags (order-independent). +func assertNodeHasTagsWithCollect(c *assert.CollectT, node *v1.Node, expectedTags []string) { + actualTags := node.GetTags() + sortedActual := append([]string{}, actualTags...) + sortedExpected := append([]string{}, expectedTags...) + + sort.Strings(sortedActual) + sort.Strings(sortedExpected) + assert.Equal(c, sortedExpected, sortedActual, "Node %s tags mismatch", node.GetName()) +} + +// assertNodeHasNoTagsWithCollect asserts that a node has no tags. +func assertNodeHasNoTagsWithCollect(c *assert.CollectT, node *v1.Node) { + assert.Empty(c, node.GetTags(), "Node %s should have no tags, but has: %v", node.GetName(), node.GetTags()) +} + +// assertNodeSelfHasTagsWithCollect asserts that a client's self view has exactly the expected tags. +// This validates that tag updates have propagated to the node's own status (issue #2978). +func assertNodeSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient, expectedTags []string) { + status, err := client.Status() + //nolint:testifylint // must use assert with CollectT in EventuallyWithT + assert.NoError(c, err, "failed to get client status") + + if status == nil || status.Self == nil { + assert.Fail(c, "client status or self is nil") + return + } + + var actualTagsSlice []string + + if status.Self.Tags != nil { + for _, tag := range status.Self.Tags.All() { + actualTagsSlice = append(actualTagsSlice, tag) + } + } + + sortedActual := append([]string{}, actualTagsSlice...) + sortedExpected := append([]string{}, expectedTags...) + + sort.Strings(sortedActual) + sort.Strings(sortedExpected) + assert.Equal(c, sortedExpected, sortedActual, "Client %s self tags mismatch", client.Hostname()) +} + +// ============================================================================= +// Test Suite 2: Auth Key WITH Pre-assigned Tags +// ============================================================================= + +// TestTagsAuthKeyWithTagRequestDifferentTag tests that requesting a different tag +// than what the auth key provides results in registration failure. +// +// Test 2.1: Request different tag than key provides +// Setup: Run `tailscale up --advertise-tags="tag:second" --auth-key AUTH_KEY_WITH_TAG` +// Expected: Registration fails with error containing "requested tags [tag:second] are invalid or not permitted". +func TestTagsAuthKeyWithTagRequestDifferentTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, // We'll create the node manually + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-diff"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + t.Logf("Created tagged PreAuthKey with tags: %v", authKey.GetAclTags()) + + // Create a tailscale client that will try to use --advertise-tags with a DIFFERENT tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:second"}), + ) + require.NoError(t, err) + + // Login should fail because the advertised tags don't match the auth key's tags + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + + // Document actual behavior - we expect this to fail + if err != nil { + t.Logf("Test 2.1 PASS: Registration correctly rejected with error: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + // If it succeeded, document this unexpected behavior + t.Logf("Test 2.1 UNEXPECTED: Registration succeeded when it should have failed") + + // Check what tags the node actually has + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v (expected rejection)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node state") + + t.Fail() + } +} + +// TestTagsAuthKeyWithTagNoAdvertiseFlag tests that registering with a tagged auth key +// but no --advertise-tags flag results in the node inheriting the key's tags. +// +// Test 2.2: Register with no advertise-tags flag +// Setup: Run `tailscale up --auth-key AUTH_KEY_WITH_TAG` (no --advertise-tags) +// Expected: Registration succeeds, node has ["tag:valid-owned"] (inherited from key). +func TestTagsAuthKeyWithTagNoAdvertiseFlag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-inherit"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + t.Logf("Created tagged PreAuthKey with tags: %v", authKey.GetAclTags()) + + // Create a tailscale client WITHOUT --advertise-tags + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + // Note: NO WithExtraLoginArgs for --advertise-tags + ) + require.NoError(t, err) + + // Login with the tagged PreAuthKey + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for node to be registered and verify it has the key's tags + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + node := nodes[0] + t.Logf("Node registered with tags: %v", node.GetTags()) + assertNodeHasTagsWithCollect(c, node, []string{"tag:valid-owned"}) + } + }, 30*time.Second, 500*time.Millisecond, "verifying node inherited tags from auth key") + + t.Logf("Test 2.2 completed - node inherited tags from auth key") +} + +// TestTagsAuthKeyWithTagCannotAddViaCLI tests that nodes registered with a tagged auth key +// cannot add additional tags via the client CLI. +// +// Test 2.3: Cannot add tags via CLI after registration +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITH_TAG +// 2. Run `tailscale up --advertise-tags="tag:valid-owned,tag:second" --auth-key AUTH_KEY_WITH_TAG` +// +// Expected: Command fails with error containing "requested tags [tag:second] are invalid or not permitted". +func TestTagsAuthKeyWithTagCannotAddViaCLI(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-noadd"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + t.Logf("Node registered with tag:valid-owned, now attempting to add tag:second via CLI") + + // Attempt to add additional tags via tailscale up + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--advertise-tags=tag:valid-owned,tag:second", + } + _, stderr, err := client.Execute(command) + + // Document actual behavior + if err != nil { + t.Logf("Test 2.3 PASS: CLI correctly rejected adding tags: %v, stderr: %s", err, stderr) + } else { + t.Logf("Test 2.3: CLI command succeeded, checking if tags actually changed") + + // Check if tags actually changed + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + // If still only has original tag, that's the expected behavior + if tagsEqual(nodes[0].GetTags(), []string{"tag:valid-owned"}) { + t.Logf("Test 2.3 PASS: Tags unchanged after CLI attempt: %v", nodes[0].GetTags()) + } else { + t.Logf("Test 2.3 FAIL: Tags changed unexpectedly to: %v", nodes[0].GetTags()) + assert.Fail(c, "Tags should not have changed") + } + } + }, 10*time.Second, 500*time.Millisecond, "verifying tags unchanged") + } +} + +// TestTagsAuthKeyWithTagCannotChangeViaCLI tests that nodes registered with a tagged auth key +// cannot change to a completely different tag set via the client CLI. +// +// Test 2.4: Cannot change to different tag set via CLI +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITH_TAG +// 2. Run `tailscale up --advertise-tags="tag:second" --auth-key AUTH_KEY_WITH_TAG` +// +// Expected: Command fails, tags remain ["tag:valid-owned"]. +func TestTagsAuthKeyWithTagCannotChangeViaCLI(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-nochange"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + t.Logf("Node registered, now attempting to change to different tag via CLI") + + // Attempt to change to a different tag via tailscale up + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--advertise-tags=tag:second", + } + _, stderr, err := client.Execute(command) + + // Document actual behavior + if err != nil { + t.Logf("Test 2.4 PASS: CLI correctly rejected changing tags: %v, stderr: %s", err, stderr) + } else { + t.Logf("Test 2.4: CLI command succeeded, checking if tags actually changed") + + // Check if tags remain unchanged + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + if tagsEqual(nodes[0].GetTags(), []string{"tag:valid-owned"}) { + t.Logf("Test 2.4 PASS: Tags unchanged: %v", nodes[0].GetTags()) + } else { + t.Logf("Test 2.4 FAIL: Tags changed unexpectedly to: %v", nodes[0].GetTags()) + assert.Fail(c, "Tags should not have changed") + } + } + }, 10*time.Second, 500*time.Millisecond, "verifying tags unchanged") + } +} + +// TestTagsAuthKeyWithTagAdminOverrideReauthPreserves tests that admin-assigned tags +// are preserved even after reauthentication - admin decisions are authoritative. +// +// Test 2.5: Admin assignment is preserved through reauth +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITH_TAG +// 2. Assign ["tag:second"] via headscale CLI +// 3. Run `tailscale up --auth-key AUTH_KEY_WITH_TAG --force-reauth` +// +// Expected: After step 2 tags are ["tag:second"], after step 3 tags remain ["tag:second"]. +func TestTagsAuthKeyWithTagAdminOverrideReauthPreserves(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-admin"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, true, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + t.Logf("Step 1 complete: Node %d registered with tag:valid-owned", nodeID) + + // Step 2: Admin assigns different tags via headscale CLI + err = headscale.SetNodeTags(nodeID, []string{"tag:second"}) + require.NoError(t, err) + + // Verify admin assignment took effect (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("After admin assignment, server tags are: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin tag assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin tag assignment propagated to node self") + + t.Logf("Step 2 complete: Admin assigned tag:second (verified on both server and node self)") + + // Step 3: Force reauthentication + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--force-reauth", + } + //nolint:errcheck // Intentionally ignoring error - we check results below + client.Execute(command) + + // Verify admin tags are preserved even after reauth - admin decisions are authoritative (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node") + + if len(nodes) >= 1 { + // Find the most recently updated node (in case a new one was created) + node := nodes[len(nodes)-1] + t.Logf("After reauth, server tags are: %v", node.GetTags()) + + // Expected: admin-assigned tags are preserved through reauth + assertNodeHasTagsWithCollect(c, node, []string{"tag:second"}) + } + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved after reauth on server") + + // Verify admin tags are preserved in node's self view after reauth (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved after reauth in node self") + + t.Logf("Test 2.5 PASS: Admin tags preserved through reauth (admin decisions are authoritative)") +} + +// TestTagsAuthKeyWithTagCLICannotModifyAdminTags tests that the client CLI +// cannot modify admin-assigned tags. +// +// Test 2.6: Client CLI cannot modify admin-assigned tags +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITH_TAG +// 2. Assign ["tag:valid-owned", "tag:second"] via headscale CLI +// 3. Run `tailscale up --advertise-tags="tag:valid-owned" --auth-key AUTH_KEY_WITH_TAG` +// +// Expected: Command either fails or is no-op, tags remain ["tag:valid-owned", "tag:second"]. +func TestTagsAuthKeyWithTagCLICannotModifyAdminTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-noadmin"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, true, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Step 2: Admin assigns multiple tags via headscale CLI + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-owned", "tag:second"}) + require.NoError(t, err) + + // Verify admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned", "tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin tag assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned", "tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin tag assignment propagated to node self") + + t.Logf("Admin assigned both tags, now attempting to reduce via CLI") + + // Step 3: Attempt to reduce tags via CLI + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--advertise-tags=tag:valid-owned", + } + _, stderr, err := client.Execute(command) + + t.Logf("CLI command result: err=%v, stderr=%s", err, stderr) + + // Verify admin tags are preserved - CLI should not be able to reduce them (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("After CLI attempt, server tags are: %v", nodes[0].GetTags()) + + // Expected: tags should remain unchanged (admin wins) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned", "tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "admin tags should be preserved after CLI attempt on server") + + // Verify admin tags are preserved in node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned", "tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved after CLI attempt in node self") + + t.Logf("Test 2.6 PASS: Admin tags preserved - CLI cannot modify admin-assigned tags") +} + +// ============================================================================= +// Test Suite 3: Auth Key WITHOUT Tags +// ============================================================================= + +// TestTagsAuthKeyWithoutTagCannotRequestTags tests that nodes cannot request tags +// when using an auth key that has no tags. +// +// Test 3.1: Cannot request tags with tagless key +// Setup: Run `tailscale up --advertise-tags="tag:valid-owned" --auth-key AUTH_KEY_WITHOUT_TAG` +// Expected: Registration fails with error containing "requested tags [tag:valid-owned] are invalid or not permitted". +func TestTagsAuthKeyWithoutTagCannotRequestTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-req"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, false, false) + require.NoError(t, err) + t.Logf("Created PreAuthKey without tags") + + // Create a tailscale client that will try to request tags + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + ) + require.NoError(t, err) + + // Login should fail because the auth key has no tags + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + t.Logf("Test 3.1 PASS: Registration correctly rejected: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + // If it succeeded, document this unexpected behavior + t.Logf("Test 3.1 UNEXPECTED: Registration succeeded when it should have failed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v (expected rejection)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node state") + + t.Fail() + } +} + +// TestTagsAuthKeyWithoutTagRegisterNoTags tests that registering with a tagless auth key +// and no --advertise-tags results in a node with no tags. +// +// Test 3.2: Register with no tags +// Setup: Run `tailscale up --auth-key AUTH_KEY_WITHOUT_TAG` (no --advertise-tags) +// Expected: Registration succeeds, node has no tags (empty tag set). +func TestTagsAuthKeyWithoutTagRegisterNoTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-noreg"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, false, false) + require.NoError(t, err) + + // Create a tailscale client without --advertise-tags + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Login should succeed + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Verify node has no tags + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v", nodes[0].GetTags()) + assertNodeHasNoTagsWithCollect(c, nodes[0]) + } + }, 30*time.Second, 500*time.Millisecond, "verifying node has no tags") + + t.Logf("Test 3.2 completed - node registered without tags") +} + +// TestTagsAuthKeyWithoutTagCannotAddViaCLI tests that nodes registered with a tagless +// auth key cannot add tags via the client CLI. +// +// Test 3.3: Cannot add tags via CLI after registration +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITHOUT_TAG +// 2. Run `tailscale up --advertise-tags="tag:valid-owned" --auth-key AUTH_KEY_WITHOUT_TAG` +// +// Expected: Command fails, node remains with no tags. +func TestTagsAuthKeyWithoutTagCannotAddViaCLI(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-noadd"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, true, false) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + assertNodeHasNoTagsWithCollect(c, nodes[0]) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + t.Logf("Node registered without tags, attempting to add via CLI") + + // Attempt to add tags via tailscale up + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--advertise-tags=tag:valid-owned", + } + _, stderr, err := client.Execute(command) + + // Document actual behavior + if err != nil { + t.Logf("Test 3.3 PASS: CLI correctly rejected adding tags: %v, stderr: %s", err, stderr) + } else { + t.Logf("Test 3.3: CLI command succeeded, checking if tags actually changed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + if len(nodes[0].GetTags()) == 0 { + t.Logf("Test 3.3 PASS: Tags still empty after CLI attempt") + } else { + t.Logf("Test 3.3 FAIL: Tags changed to: %v", nodes[0].GetTags()) + assert.Fail(c, "Tags should not have changed") + } + } + }, 10*time.Second, 500*time.Millisecond, "verifying tags unchanged") + } +} + +// TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithReset tests that the client CLI +// is a no-op after admin tag assignment, even with --reset flag. +// +// Test 3.4: CLI no-op after admin tag assignment (with --reset) +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITHOUT_TAG +// 2. Assign ["tag:valid-owned"] via headscale CLI +// 3. Run `tailscale up --auth-key AUTH_KEY_WITHOUT_TAG --reset` +// +// Expected: Command is no-op, tags remain ["tag:valid-owned"]. +func TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithReset(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-reset"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, true, false) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + assertNodeHasNoTagsWithCollect(c, nodes[0]) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Step 2: Admin assigns tags + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Verify admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin tag assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin tag assignment propagated to node self") + + t.Logf("Admin assigned tag, now running CLI with --reset") + + // Step 3: Run tailscale up with --reset + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--reset", + } + _, stderr, err := client.Execute(command) + t.Logf("CLI --reset result: err=%v, stderr=%s", err, stderr) + + // Verify admin tags are preserved - --reset should not remove them (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("After --reset, server tags are: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 10*time.Second, 500*time.Millisecond, "admin tags should be preserved after --reset on server") + + // Verify admin tags are preserved in node's self view after --reset (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved after --reset in node self") + + t.Logf("Test 3.4 PASS: Admin tags preserved after --reset") +} + +// TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithEmptyAdvertise tests that the client CLI +// is a no-op after admin tag assignment, even with empty --advertise-tags. +// +// Test 3.5: CLI no-op after admin tag assignment (with empty advertise-tags) +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITHOUT_TAG +// 2. Assign ["tag:valid-owned"] via headscale CLI +// 3. Run `tailscale up --auth-key AUTH_KEY_WITHOUT_TAG --advertise-tags=""` +// +// Expected: Command is no-op, tags remain ["tag:valid-owned"]. +func TestTagsAuthKeyWithoutTagCLINoOpAfterAdminWithEmptyAdvertise(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-empty"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, true, false) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Step 2: Admin assigns tags + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Verify admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin tag assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin tag assignment propagated to node self") + + t.Logf("Admin assigned tag, now running CLI with empty --advertise-tags") + + // Step 3: Run tailscale up with empty --advertise-tags + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--advertise-tags=", + } + _, stderr, err := client.Execute(command) + t.Logf("CLI empty advertise-tags result: err=%v, stderr=%s", err, stderr) + + // Verify admin tags are preserved - empty --advertise-tags should not remove them (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("After empty --advertise-tags, server tags are: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 10*time.Second, 500*time.Millisecond, "admin tags should be preserved after empty --advertise-tags on server") + + // Verify admin tags are preserved in node's self view after empty --advertise-tags (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved after empty --advertise-tags in node self") + + t.Logf("Test 3.5 PASS: Admin tags preserved after empty --advertise-tags") +} + +// TestTagsAuthKeyWithoutTagCLICannotReduceAdminMultiTag tests that the client CLI +// cannot reduce an admin-assigned multi-tag set. +// +// Test 3.6: Client CLI cannot reduce admin-assigned multi-tag set +// Setup: +// 1. Register with --auth-key AUTH_KEY_WITHOUT_TAG +// 2. Assign ["tag:valid-owned", "tag:second"] via headscale CLI +// 3. Run `tailscale up --advertise-tags="tag:valid-owned" --auth-key AUTH_KEY_WITHOUT_TAG` +// +// Expected: Command is no-op (or fails), tags remain ["tag:valid-owned", "tag:second"]. +func TestTagsAuthKeyWithoutTagCLICannotReduceAdminMultiTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-reduce"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, true, false) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + // Initial login + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for initial registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Step 2: Admin assigns multiple tags + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-owned", "tag:second"}) + require.NoError(t, err) + + // Verify admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned", "tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin tag assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned", "tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin tag assignment propagated to node self") + + t.Logf("Admin assigned both tags, now attempting to reduce via CLI") + + // Step 3: Attempt to reduce tags via CLI + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--authkey=" + authKey.GetKey(), + "--advertise-tags=tag:valid-owned", + } + _, stderr, err := client.Execute(command) + t.Logf("CLI reduce result: err=%v, stderr=%s", err, stderr) + + // Verify admin tags are preserved - CLI should not be able to reduce them (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("After CLI reduce attempt, server tags are: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned", "tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "admin tags should be preserved after CLI reduce attempt on server") + + // Verify admin tags are preserved in node's self view after CLI reduce attempt (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned", "tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved after CLI reduce attempt in node self") + + t.Logf("Test 3.6 PASS: Admin tags preserved - CLI cannot reduce admin-assigned multi-tag set") +} + +// ============================================================================= +// Test Suite 1: User Login Authentication (Web Auth Flow) +// ============================================================================= + +// TestTagsUserLoginOwnedTagAtRegistration tests that a user can advertise an owned tag +// during web auth registration. +// +// Test 1.1: Advertise owned tag at registration +// Setup: Web auth login with --advertise-tags="tag:valid-owned" +// Expected: Node has ["tag:valid-owned"]. +func TestTagsUserLoginOwnedTagAtRegistration(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, // We'll create the node manually + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-owned"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Create a tailscale client with --advertise-tags + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + ) + require.NoError(t, err) + + // Login via web auth flow + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // Complete the web auth by visiting the login URL + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + // Register the node via headscale CLI + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + // Wait for client to be running + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Verify node has the advertised tag + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 30*time.Second, 500*time.Millisecond, "verifying node has advertised tag") + + t.Logf("Test 1.1 completed - web auth with owned tag succeeded") +} + +// TestTagsUserLoginNonExistentTagAtRegistration tests that advertising a non-existent tag +// during web auth registration fails. +// +// Test 1.2: Advertise non-existent tag at registration +// Setup: Web auth login with --advertise-tags="tag:nonexistent" +// Expected: Registration fails - node should not be registered OR should have no tags. +func TestTagsUserLoginNonExistentTagAtRegistration(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-nonexist"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Create a tailscale client with non-existent tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:nonexistent"}), + ) + require.NoError(t, err) + + // Login via web auth flow + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // Complete the web auth by visiting the login URL + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + // Register the node via headscale CLI - this should fail due to non-existent tag + err = scenario.runHeadscaleRegister(tagTestUser, body) + + // We expect registration to fail with an error about invalid/unauthorized tags + if err != nil { + t.Logf("Test 1.2 PASS: Registration correctly rejected with error: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + // Check the result - if registration succeeded, the node should not have the invalid tag + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err, "Should be able to list nodes") + + if len(nodes) == 0 { + t.Logf("Test 1.2 PASS: Registration rejected - no nodes registered") + } else { + // If a node was registered, it should NOT have the non-existent tag + assert.NotContains(c, nodes[0].GetTags(), "tag:nonexistent", + "Non-existent tag should not be applied to node") + t.Logf("Test 1.2: Node registered with tags: %v (non-existent tag correctly rejected)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node registration result") + } +} + +// TestTagsUserLoginUnownedTagAtRegistration tests that advertising an unowned tag +// during web auth registration is rejected. +// +// Test 1.3: Advertise unowned tag at registration +// Setup: Web auth login with --advertise-tags="tag:valid-unowned" +// Expected: Registration fails - node should not be registered OR should have no tags. +func TestTagsUserLoginUnownedTagAtRegistration(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-unowned"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Create a tailscale client with unowned tag (tag:valid-unowned is owned by "other-user", not "taguser") + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-unowned"}), + ) + require.NoError(t, err) + + // Login via web auth flow + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // Complete the web auth + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + // Register the node - should fail or reject the unowned tag + _ = scenario.runHeadscaleRegister(tagTestUser, body) + + // Check the result - user should NOT be able to claim an unowned tag + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err, "Should be able to list nodes") + + // Either: no nodes registered (ideal), or node registered without the unowned tag + if len(nodes) == 0 { + t.Logf("Test 1.3 PASS: Registration rejected - no nodes registered") + } else { + // If a node was registered, it should NOT have the unowned tag + assert.NotContains(c, nodes[0].GetTags(), "tag:valid-unowned", + "Unowned tag should not be applied to node (tag:valid-unowned is owned by other-user)") + t.Logf("Test 1.3: Node registered with tags: %v (unowned tag correctly rejected)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node registration result") +} + +// TestTagsUserLoginAddTagViaCLIReauth tests that a user can add tags via CLI reauthentication. +// +// Test 1.4: Add tag via CLI reauthentication +// Setup: +// 1. Register with --advertise-tags="tag:valid-owned" +// 2. Run tailscale up --advertise-tags="tag:valid-owned,tag:second" +// +// Expected: Triggers full reauthentication, node has both tags. +func TestTagsUserLoginAddTagViaCLIReauth(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-addtag"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Step 1: Create and register with one tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + ) + require.NoError(t, err) + + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Verify initial tag + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Initial tags: %v", nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "checking initial tags") + + // Step 2: Try to add second tag via CLI + t.Logf("Attempting to add second tag via CLI reauth") + + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--advertise-tags=tag:valid-owned,tag:second", + } + _, stderr, err := client.Execute(command) + t.Logf("CLI result: err=%v, stderr=%s", err, stderr) + + // Check final state - EventuallyWithT handles waiting for propagation + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) >= 1 { + t.Logf("Test 1.4: After CLI, tags are: %v", nodes[0].GetTags()) + + if tagsEqual(nodes[0].GetTags(), []string{"tag:valid-owned", "tag:second"}) { + t.Logf("Test 1.4 PASS: Both tags present after reauth") + } else { + t.Logf("Test 1.4: Tags are %v (may require manual reauth completion)", nodes[0].GetTags()) + } + } + }, 30*time.Second, 500*time.Millisecond, "checking tags after CLI") +} + +// TestTagsUserLoginRemoveTagViaCLIReauth tests that a user can remove tags via CLI reauthentication. +// +// Test 1.5: Remove tag via CLI reauthentication +// Setup: +// 1. Register with --advertise-tags="tag:valid-owned,tag:second" +// 2. Run tailscale up --advertise-tags="tag:valid-owned" +// +// Expected: Triggers full reauthentication, node has only ["tag:valid-owned"]. +func TestTagsUserLoginRemoveTagViaCLIReauth(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-rmtag"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Step 1: Create and register with two tags + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned,tag:second"}), + ) + require.NoError(t, err) + + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Verify initial tags + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Initial tags: %v", nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "checking initial tags") + + // Step 2: Try to remove second tag via CLI + t.Logf("Attempting to remove tag via CLI reauth") + + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--advertise-tags=tag:valid-owned", + } + _, stderr, err := client.Execute(command) + t.Logf("CLI result: err=%v, stderr=%s", err, stderr) + + // Check final state - EventuallyWithT handles waiting for propagation + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) >= 1 { + t.Logf("Test 1.5: After CLI, tags are: %v", nodes[0].GetTags()) + + if tagsEqual(nodes[0].GetTags(), []string{"tag:valid-owned"}) { + t.Logf("Test 1.5 PASS: Only one tag after removal") + } + } + }, 30*time.Second, 500*time.Millisecond, "checking tags after CLI") +} + +// TestTagsUserLoginCLINoOpAfterAdminAssignment tests that CLI advertise-tags becomes +// a no-op after admin tag assignment. +// +// Test 1.6: CLI advertise-tags becomes no-op after admin tag assignment +// Setup: +// 1. Register with --advertise-tags="tag:valid-owned" +// 2. Assign ["tag:second"] via headscale CLI +// 3. Run tailscale up --advertise-tags="tag:valid-owned" +// +// Expected: Step 3 does NOT trigger reauthentication, tags remain ["tag:second"]. +func TestTagsUserLoginCLINoOpAfterAdminAssignment(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-adminwin"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Step 1: Register with one tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + ) + require.NoError(t, err) + + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + t.Logf("Step 1: Node %d registered with tags: %v", nodeID, nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Step 2: Admin assigns different tag + err = headscale.SetNodeTags(nodeID, []string{"tag:second"}) + require.NoError(t, err) + + // Verify admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Step 2: After admin assignment, server tags: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin assignment propagated to node self") + + // Step 3: Try to change tags via CLI + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--advertise-tags=tag:valid-owned", + } + _, stderr, err := client.Execute(command) + t.Logf("Step 3 CLI result: err=%v, stderr=%s", err, stderr) + + // Verify admin tags are preserved - CLI advertise-tags should be a no-op after admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("Step 3: After CLI, server tags are: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "admin tags should be preserved - CLI advertise-tags should be no-op on server") + + // Verify admin tags are preserved in node's self view after CLI attempt (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved - CLI advertise-tags should be no-op in node self") + + t.Logf("Test 1.6 PASS: Admin tags preserved (CLI was no-op)") +} + +// TestTagsUserLoginCLICannotRemoveAdminTags tests that CLI cannot remove admin-assigned tags. +// +// Test 1.7: CLI cannot remove admin-assigned tags +// Setup: +// 1. Register with --advertise-tags="tag:valid-owned" +// 2. Assign ["tag:valid-owned", "tag:second"] via headscale CLI +// 3. Run tailscale up --advertise-tags="tag:valid-owned" +// +// Expected: Command is no-op, tags remain ["tag:valid-owned", "tag:second"]. +func TestTagsUserLoginCLICannotRemoveAdminTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-webauth-norem"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Step 1: Register with one tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + ) + require.NoError(t, err) + + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Step 2: Admin assigns both tags + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-owned", "tag:second"}) + require.NoError(t, err) + + // Verify admin assignment (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("After admin assignment, server tags: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned", "tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying admin assignment on server") + + // Verify admin assignment propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned", "tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "verifying admin assignment propagated to node self") + + // Step 3: Try to reduce tags via CLI + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--advertise-tags=tag:valid-owned", + } + _, stderr, err := client.Execute(command) + t.Logf("CLI result: err=%v, stderr=%s", err, stderr) + + // Verify admin tags are preserved - CLI should not be able to remove admin-assigned tags (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + t.Logf("Test 1.7: After CLI, server tags are: %v", nodes[0].GetTags()) + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned", "tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "admin tags should be preserved - CLI cannot remove them on server") + + // Verify admin tags are preserved in node's self view after CLI attempt (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned", "tag:second"}) + }, 30*time.Second, 500*time.Millisecond, "admin tags should be preserved - CLI cannot remove them in node self") + + t.Logf("Test 1.7 PASS: Admin tags preserved (CLI cannot remove)") +} + +// ============================================================================= +// Test Suite 2 (continued): Additional Auth Key WITH Tags Tests +// ============================================================================= + +// TestTagsAuthKeyWithTagRequestNonExistentTag tests that requesting a non-existent tag +// with a tagged auth key results in registration failure. +// +// Test 2.7: Request non-existent tag with tagged key +// Setup: Run `tailscale up --advertise-tags="tag:nonexistent" --auth-key AUTH_KEY_WITH_TAG` +// Expected: Registration fails with error containing "requested tags". +func TestTagsAuthKeyWithTagRequestNonExistentTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-nonexist"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + t.Logf("Created tagged PreAuthKey with tags: %v", authKey.GetAclTags()) + + // Create a tailscale client that will try to use --advertise-tags with a NON-EXISTENT tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:nonexistent"}), + ) + require.NoError(t, err) + + // Login should fail because ANY advertise-tags is rejected for PreAuthKey registrations + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + t.Logf("Test 2.7 PASS: Registration correctly rejected with error: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + t.Logf("Test 2.7 UNEXPECTED: Registration succeeded when it should have failed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v (expected rejection)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node state") + + t.Fail() + } +} + +// TestTagsAuthKeyWithTagRequestUnownedTag tests that requesting an unowned tag +// with a tagged auth key results in registration failure. +// +// Test 2.8: Request unowned tag with tagged key +// Setup: Run `tailscale up --advertise-tags="tag:valid-unowned" --auth-key AUTH_KEY_WITH_TAG` +// Expected: Registration fails with error containing "requested tags". +func TestTagsAuthKeyWithTagRequestUnownedTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-unowned"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey with tag:valid-owned + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + t.Logf("Created tagged PreAuthKey with tags: %v", authKey.GetAclTags()) + + // Create a tailscale client that will try to use --advertise-tags with an UNOWNED tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-unowned"}), + ) + require.NoError(t, err) + + // Login should fail because ANY advertise-tags is rejected for PreAuthKey registrations + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + t.Logf("Test 2.8 PASS: Registration correctly rejected with error: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + t.Logf("Test 2.8 UNEXPECTED: Registration succeeded when it should have failed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v (expected rejection)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node state") + + t.Fail() + } +} + +// ============================================================================= +// Test Suite 3 (continued): Additional Auth Key WITHOUT Tags Tests +// ============================================================================= + +// TestTagsAuthKeyWithoutTagRequestNonExistentTag tests that requesting a non-existent tag +// with a tagless auth key results in registration failure. +// +// Test 3.7: Request non-existent tag with tagless key +// Setup: Run `tailscale up --advertise-tags="tag:nonexistent" --auth-key AUTH_KEY_WITHOUT_TAG` +// Expected: Registration fails with error containing "requested tags". +func TestTagsAuthKeyWithoutTagRequestNonExistentTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-nonexist"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, false, false) + require.NoError(t, err) + t.Logf("Created PreAuthKey without tags") + + // Create a tailscale client that will try to request a NON-EXISTENT tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:nonexistent"}), + ) + require.NoError(t, err) + + // Login should fail because ANY advertise-tags is rejected for PreAuthKey registrations + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + t.Logf("Test 3.7 PASS: Registration correctly rejected: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + t.Logf("Test 3.7 UNEXPECTED: Registration succeeded when it should have failed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v (expected rejection)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node state") + + t.Fail() + } +} + +// TestTagsAuthKeyWithoutTagRequestUnownedTag tests that requesting an unowned tag +// with a tagless auth key results in registration failure. +// +// Test 3.8: Request unowned tag with tagless key +// Setup: Run `tailscale up --advertise-tags="tag:valid-unowned" --auth-key AUTH_KEY_WITHOUT_TAG` +// Expected: Registration fails with error containing "requested tags". +func TestTagsAuthKeyWithoutTagRequestUnownedTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-nokey-unowned"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create an auth key WITHOUT tags + authKey, err := scenario.CreatePreAuthKey(userID, false, false) + require.NoError(t, err) + t.Logf("Created PreAuthKey without tags") + + // Create a tailscale client that will try to request an UNOWNED tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-unowned"}), + ) + require.NoError(t, err) + + // Login should fail because ANY advertise-tags is rejected for PreAuthKey registrations + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + t.Logf("Test 3.8 PASS: Registration correctly rejected: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + t.Logf("Test 3.8 UNEXPECTED: Registration succeeded when it should have failed") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + t.Logf("Node registered with tags: %v (expected rejection)", nodes[0].GetTags()) + } + }, 10*time.Second, 500*time.Millisecond, "checking node state") + + t.Fail() + } +} + +// ============================================================================= +// Test Suite 4: Admin API (SetNodeTags) Validation Tests +// ============================================================================= + +// TestTagsAdminAPICannotSetNonExistentTag tests that the admin API rejects +// setting a tag that doesn't exist in the policy. +// +// Test 4.1: Admin cannot set non-existent tag +// Setup: Create node, then call SetNodeTags with ["tag:nonexistent"] +// Expected: SetNodeTags returns error. +func TestTagsAdminAPICannotSetNonExistentTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-admin-nonexist"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey to register a node + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + t.Logf("Node %d registered with tags: %v", nodeID, nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for registration") + + // Try to set a non-existent tag via admin API - should fail + err = headscale.SetNodeTags(nodeID, []string{"tag:nonexistent"}) + + require.Error(t, err, "SetNodeTags should fail for non-existent tag") + t.Logf("Test 4.1 PASS: Admin API correctly rejected non-existent tag: %v", err) +} + +// TestTagsAdminAPICanSetUnownedTag tests that the admin API CAN set a tag +// that exists in policy but is owned by a different user. +// Admin has full authority over tags - ownership only matters for client requests. +// +// Test 4.2: Admin CAN set unowned tag (admin has full authority) +// Setup: Create node, then call SetNodeTags with ["tag:valid-unowned"] +// Expected: SetNodeTags succeeds (admin can assign any existing tag). +func TestTagsAdminAPICanSetUnownedTag(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-admin-unowned"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey to register a node + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + t.Logf("Node %d registered with tags: %v", nodeID, nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for registration") + + // Admin sets an "unowned" tag - should SUCCEED because admin has full authority + // (tag:valid-unowned is owned by other-user, but admin can assign it) + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-unowned"}) + require.NoError(t, err, "SetNodeTags should succeed for admin setting any existing tag") + + // Verify the tag was applied (server-side) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-unowned"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying unowned tag was applied on server") + + // Verify the tag was propagated to node's self view (issue #2978) + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-unowned"}) + }, 30*time.Second, 500*time.Millisecond, "verifying unowned tag propagated to node self") + + t.Logf("Test 4.2 PASS: Admin API correctly allowed setting unowned tag") +} + +// TestTagsAdminAPICannotRemoveAllTags tests that the admin API rejects +// removing all tags from a node (would orphan the node). +// +// Test 4.3: Admin cannot remove all tags +// Setup: Create tagged node, then call SetNodeTags with [] +// Expected: SetNodeTags returns error. +func TestTagsAdminAPICannotRemoveAllTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-admin-empty"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey to register a node + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + t.Logf("Node %d registered with tags: %v", nodeID, nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for registration") + + // Try to remove all tags - should fail + err = headscale.SetNodeTags(nodeID, []string{}) + + require.Error(t, err, "SetNodeTags should fail when trying to remove all tags") + t.Logf("Test 4.3 PASS: Admin API correctly rejected removing all tags: %v", err) + + // Verify original tags are preserved + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying original tags preserved") +} + +// assertNetmapSelfHasTagsWithCollect asserts that the client's netmap self node has expected tags. +// This validates at a deeper level than status - directly from tailscale debug netmap. +func assertNetmapSelfHasTagsWithCollect(c *assert.CollectT, client TailscaleClient, expectedTags []string) { + nm, err := client.Netmap() + //nolint:testifylint // must use assert with CollectT in EventuallyWithT + assert.NoError(c, err, "failed to get client netmap") + + if nm == nil { + assert.Fail(c, "client netmap is nil") + return + } + + var actualTagsSlice []string + + if nm.SelfNode.Valid() { + for _, tag := range nm.SelfNode.Tags().All() { + actualTagsSlice = append(actualTagsSlice, tag) + } + } + + sortedActual := append([]string{}, actualTagsSlice...) + sortedExpected := append([]string{}, expectedTags...) + + sort.Strings(sortedActual) + sort.Strings(sortedExpected) + assert.Equal(c, sortedExpected, sortedActual, "Client %s netmap self tags mismatch", client.Hostname()) +} + +// TestTagsIssue2978ReproTagReplacement specifically tests issue #2978: +// When tags are changed on the server, the node's self view should update. +// This test performs multiple tag replacements and checks for immediate propagation. +// +// Issue scenario (from nblock's report): +// 1. Node registers via CLI auth with --advertise-tags=tag:foo +// 2. Admin changes tag to tag:bar via headscale CLI/API +// 3. Node's self view should show tag:bar (not tag:foo). +// +// This test uses web auth with --advertise-tags to match the reporter's flow. +func TestTagsIssue2978ReproTagReplacement(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + // Use CreateHeadscaleEnvWithLoginURL for web auth flow + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{ + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + }, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-issue-2978"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Create a tailscale client with --advertise-tags (matching nblock's "cli auth with --advertise-tags=tag:foo") + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned"}), + ) + require.NoError(t, err) + + // Login via web auth flow (this is "cli auth" - tailscale up triggers web auth) + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + // Complete the web auth by visiting the login URL + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + // Register the node via headscale CLI + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + // Wait for client to be running + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Wait for initial registration with tag:valid-owned + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for initial registration") + + // Verify client initially sees tag:valid-owned + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-owned"}) + }, 30*time.Second, 500*time.Millisecond, "client should see initial tag") + + t.Logf("Step 1: Node %d registered via web auth with --advertise-tags=tag:valid-owned, client sees it", nodeID) + + // Step 2: Admin changes tag to tag:second (FIRST CALL - this is "tag:bar" in issue terms) + // According to issue #2978, the first SetNodeTags call updates the server but + // the client's self view does NOT update until a SECOND call with the same tag. + t.Log("Step 2: Calling SetNodeTags FIRST time with tag:second") + + err = headscale.SetNodeTags(nodeID, []string{"tag:second"}) + require.NoError(t, err) + + // Verify server-side update happened + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:second"}) + } + }, 10*time.Second, 500*time.Millisecond, "server should show tag:second after first call") + + t.Log("Step 2a: Server shows tag:second after first call") + + // CRITICAL BUG CHECK: According to nblock, after the first SetNodeTags call, + // the client's self view does NOT update even after waiting ~1 minute. + // We wait 10 seconds and check - if the client STILL shows the OLD tag, + // that demonstrates the bug. If the client shows the NEW tag, the bug is fixed. + t.Log("Step 2b: Waiting 10 seconds to see if client self view updates (bug: it should NOT)") + //nolint:forbidigo // intentional sleep to demonstrate bug timing - client should get update immediately, not after waiting + time.Sleep(10 * time.Second) + + // Check client status after waiting + status, err := client.Status() + require.NoError(t, err) + + var selfTagsAfterFirstCall []string + + if status.Self != nil && status.Self.Tags != nil { + for _, tag := range status.Self.Tags.All() { + selfTagsAfterFirstCall = append(selfTagsAfterFirstCall, tag) + } + } + + t.Logf("Step 2c: Client self tags after FIRST SetNodeTags + 10s wait: %v", selfTagsAfterFirstCall) + + // Also check netmap + nm, nmErr := client.Netmap() + + var netmapTagsAfterFirstCall []string + + if nmErr == nil && nm != nil && nm.SelfNode.Valid() { + for _, tag := range nm.SelfNode.Tags().All() { + netmapTagsAfterFirstCall = append(netmapTagsAfterFirstCall, tag) + } + } + + t.Logf("Step 2d: Client netmap self tags after FIRST SetNodeTags + 10s wait: %v", netmapTagsAfterFirstCall) + + // Step 3: Call SetNodeTags AGAIN with the SAME tag (SECOND CALL) + // According to nblock, this second call with the same tag triggers the update. + t.Log("Step 3: Calling SetNodeTags SECOND time with SAME tag:second") + + err = headscale.SetNodeTags(nodeID, []string{"tag:second"}) + require.NoError(t, err) + + // Now the client should see the update quickly (within a few seconds) + t.Log("Step 3a: Verifying client self view updates after SECOND call") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:second"}) + }, 10*time.Second, 500*time.Millisecond, "client status.Self should update to tag:second after SECOND call") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNetmapSelfHasTagsWithCollect(c, client, []string{"tag:second"}) + }, 10*time.Second, 500*time.Millisecond, "client netmap.SelfNode should update to tag:second after SECOND call") + + t.Log("Step 3b: Client self view updated to tag:second after SECOND call") + + // Step 4: Do another tag change to verify the pattern repeats + t.Log("Step 4: Calling SetNodeTags FIRST time with tag:valid-unowned") + + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-unowned"}) + require.NoError(t, err) + + // Verify server-side update + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-unowned"}) + } + }, 10*time.Second, 500*time.Millisecond, "server should show tag:valid-unowned") + + t.Log("Step 4a: Server shows tag:valid-unowned after first call") + + // Wait and check - bug means client still shows old tag + t.Log("Step 4b: Waiting 10 seconds to see if client self view updates (bug: it should NOT)") + //nolint:forbidigo // intentional sleep to demonstrate bug timing - client should get update immediately, not after waiting + time.Sleep(10 * time.Second) + + status, err = client.Status() + require.NoError(t, err) + + var selfTagsAfterSecondChange []string + + if status.Self != nil && status.Self.Tags != nil { + for _, tag := range status.Self.Tags.All() { + selfTagsAfterSecondChange = append(selfTagsAfterSecondChange, tag) + } + } + + t.Logf("Step 4c: Client self tags after FIRST SetNodeTags(tag:valid-unowned) + 10s wait: %v", selfTagsAfterSecondChange) + + // Step 5: Call SetNodeTags AGAIN with the SAME tag + t.Log("Step 5: Calling SetNodeTags SECOND time with SAME tag:valid-unowned") + + err = headscale.SetNodeTags(nodeID, []string{"tag:valid-unowned"}) + require.NoError(t, err) + + // Now the client should see the update quickly + t.Log("Step 5a: Verifying client self view updates after SECOND call") + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNodeSelfHasTagsWithCollect(c, client, []string{"tag:valid-unowned"}) + }, 10*time.Second, 500*time.Millisecond, "client status.Self should update to tag:valid-unowned after SECOND call") + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + assertNetmapSelfHasTagsWithCollect(c, client, []string{"tag:valid-unowned"}) + }, 10*time.Second, 500*time.Millisecond, "client netmap.SelfNode should update to tag:valid-unowned after SECOND call") + + t.Log("Test complete - see logs for bug reproduction details") +} + +// TestTagsAdminAPICannotSetInvalidFormat tests that the admin API rejects +// tags that don't have the correct format (must start with "tag:"). +// +// Test 4.4: Admin cannot set invalid format tag +// Setup: Create node, then call SetNodeTags with ["invalid-no-prefix"] +// Expected: SetNodeTags returns error. +func TestTagsAdminAPICannotSetInvalidFormat(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-admin-invalid"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + userMap, err := headscale.MapUsers() + require.NoError(t, err) + + userID := userMap[tagTestUser].GetId() + + // Create a tagged PreAuthKey to register a node + authKey, err := scenario.CreatePreAuthKeyWithTags(userID, false, false, []string{"tag:valid-owned"}) + require.NoError(t, err) + + // Create and register a tailscale client + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + ) + require.NoError(t, err) + + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for registration and get node ID + var nodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + nodeID = nodes[0].GetId() + t.Logf("Node %d registered with tags: %v", nodeID, nodes[0].GetTags()) + } + }, 30*time.Second, 500*time.Millisecond, "waiting for registration") + + // Try to set a tag without the "tag:" prefix - should fail + err = headscale.SetNodeTags(nodeID, []string{"invalid-no-prefix"}) + + require.Error(t, err, "SetNodeTags should fail for invalid tag format") + t.Logf("Test 4.4 PASS: Admin API correctly rejected invalid tag format: %v", err) + + // Verify original tags are preserved + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1) + + if len(nodes) == 1 { + assertNodeHasTagsWithCollect(c, nodes[0], []string{"tag:valid-owned"}) + } + }, 10*time.Second, 500*time.Millisecond, "verifying original tags preserved") +} + +// ============================================================================= +// Test for Issue #2979: Reauth to untag a device +// ============================================================================= + +// TestTagsUserLoginReauthWithEmptyTagsRemovesAllTags tests that reauthenticating +// with an empty tag list (--advertise-tags= --force-reauth) removes all tags +// and returns ownership to the user. +// +// Bug #2979: Reauth to untag a device keeps it tagged +// Setup: Register a node with tags via user login, then reauth with --advertise-tags= --force-reauth +// Expected: Node should have no tags and ownership should return to the user. +// +// Note: This only works with --force-reauth because without it, the Tailscale +// client doesn't trigger a full reauth to the server - it only updates local state. +func TestTagsUserLoginReauthWithEmptyTagsRemovesAllTags(t *testing.T) { + IntegrationSkip(t) + + t.Run("with force-reauth", func(t *testing.T) { + tc := struct { + name string + testName string + forceReauth bool + }{ + name: "with force-reauth", + testName: "with-force-reauth", + forceReauth: true, + } + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnvWithLoginURL( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-reauth-untag-2979-"+tc.testName), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Step 1: Create and register a node with tags + t.Logf("Step 1: Registering node with tags") + + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:valid-owned,tag:second"}), + ) + require.NoError(t, err) + + loginURL, err := client.LoginWithURL(headscale.GetEndpoint()) + require.NoError(t, err) + + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + + // Verify initial tags + var initialNodeID uint64 + + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Expected exactly one node") + + if len(nodes) == 1 { + node := nodes[0] + initialNodeID = node.GetId() + t.Logf("Initial state - Node ID: %d, Tags: %v, User: %s", + node.GetId(), node.GetTags(), node.GetUser().GetName()) + + // Verify node has the expected tags + assertNodeHasTagsWithCollect(c, node, []string{"tag:valid-owned", "tag:second"}) + } + }, 30*time.Second, 500*time.Millisecond, "checking initial tags") + + // Step 2: Reauth with empty tags to remove all tags + t.Logf("Step 2: Reauthenticating with empty tag list to untag device (%s)", tc.name) + + if tc.forceReauth { + // Manually run tailscale up with --force-reauth and empty tags + // This will output a login URL that we need to complete + // Include --hostname to match the initial login command + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--hostname=" + client.Hostname(), + "--advertise-tags=", + "--force-reauth", + } + + stdout, stderr, _ := client.Execute(command) + t.Logf("Reauth command stderr: %s", stderr) + + // Parse the login URL from the command output + loginURL, err := util.ParseLoginURLFromCLILogin(stdout + stderr) + require.NoError(t, err, "Failed to parse login URL from reauth command") + t.Logf("Reauth login URL: %s", loginURL) + + body, err := doLoginURL(client.Hostname(), loginURL) + require.NoError(t, err) + + err = scenario.runHeadscaleRegister(tagTestUser, body) + require.NoError(t, err) + + err = client.WaitForRunning(120 * time.Second) + require.NoError(t, err) + t.Logf("Completed reauth with empty tags") + } else { + // Without force-reauth, just try tailscale up + // Include --hostname to match the initial login command + command := []string{ + "tailscale", "up", + "--login-server=" + headscale.GetEndpoint(), + "--hostname=" + client.Hostname(), + "--advertise-tags=", + } + stdout, stderr, err := client.Execute(command) + t.Logf("CLI reauth result: err=%v, stdout=%s, stderr=%s", err, stdout, stderr) + } + + // Step 3: Verify tags are removed and ownership is returned to user + // This is the key assertion for bug #2979 + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes(tagTestUser) + assert.NoError(c, err) + + if len(nodes) >= 1 { + node := nodes[0] + t.Logf("After reauth - Node ID: %d, Tags: %v, User: %s", + node.GetId(), node.GetTags(), node.GetUser().GetName()) + + // Assert: Node should have NO tags + assertNodeHasNoTagsWithCollect(c, node) + + // Assert: Node should be owned by the user (not tagged-devices) + assert.Equal(c, tagTestUser, node.GetUser().GetName(), + "Node ownership should return to user %s after untagging", tagTestUser) + + // Verify the node ID is still the same (not a new registration) + assert.Equal(c, initialNodeID, node.GetId(), + "Node ID should remain the same after reauth") + + if len(node.GetTags()) == 0 && node.GetUser().GetName() == tagTestUser { + t.Logf("Test #2979 (%s) PASS: Node successfully untagged and ownership returned to user", tc.name) + } else { + t.Logf("Test #2979 (%s) FAIL: Expected no tags and user=%s, got tags=%v user=%s", + tc.name, tagTestUser, node.GetTags(), node.GetUser().GetName()) + } + } + }, 60*time.Second, 1*time.Second, "verifying tags removed and ownership returned") + }) +} + +// ============================================================================= +// Test Suite 5: Auth Key WITHOUT User (Tags-Only Ownership) +// ============================================================================= + +// TestTagsAuthKeyWithoutUserInheritsTags tests that when an auth key without a user +// (tags-only) is used without --advertise-tags, the node inherits the key's tags. +// +// Test 5.1: Auth key without user, no --advertise-tags flag +// Setup: Run `tailscale up --auth-key AUTH_KEY_WITH_TAGS_NO_USER` +// Expected: Node registers with the tags from the auth key. +func TestTagsAuthKeyWithoutUserInheritsTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-no-user-inherit"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Create an auth key with tags but WITHOUT a user + authKey, err := scenario.CreatePreAuthKeyWithOptions(hsic.AuthKeyOptions{ + User: nil, + Reusable: false, + Ephemeral: false, + Tags: []string{"tag:valid-owned"}, + }) + require.NoError(t, err) + t.Logf("Created tags-only PreAuthKey with tags: %v", authKey.GetAclTags()) + + // Create a tailscale client WITHOUT --advertise-tags + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + // Note: NO WithExtraLoginArgs for --advertise-tags + ) + require.NoError(t, err) + + // Login with the tags-only auth key + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + require.NoError(t, err) + + // Wait for node to be registered and verify it has the key's tags + // Note: Tags-only nodes don't have a user, so we list all nodes + assert.EventuallyWithT(t, func(c *assert.CollectT) { + nodes, err := headscale.ListNodes() + assert.NoError(c, err) + assert.Len(c, nodes, 1, "Should have exactly 1 node") + + if len(nodes) == 1 { + node := nodes[0] + t.Logf("Node registered with tags: %v", node.GetTags()) + assertNodeHasTagsWithCollect(c, node, []string{"tag:valid-owned"}) + } + }, 30*time.Second, 500*time.Millisecond, "verifying node inherited tags from auth key") + + t.Logf("Test 5.1 PASS: Node inherited tags from tags-only auth key") +} + +// TestTagsAuthKeyWithoutUserRejectsAdvertisedTags tests that when an auth key without +// a user (tags-only) is used WITH --advertise-tags, the registration is rejected. +// PreAuthKey registrations do not allow client-requested tags. +// +// Test 5.2: Auth key without user, with --advertise-tags (should be rejected) +// Setup: Run `tailscale up --advertise-tags="tag:second" --auth-key AUTH_KEY_WITH_TAGS_NO_USER` +// Expected: Registration fails with error containing "requested tags". +func TestTagsAuthKeyWithoutUserRejectsAdvertisedTags(t *testing.T) { + IntegrationSkip(t) + + policy := tagsTestPolicy() + + spec := ScenarioSpec{ + NodesPerUser: 0, + Users: []string{tagTestUser}, + } + + scenario, err := NewScenario(spec) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + err = scenario.CreateHeadscaleEnv( + []tsic.Option{}, + hsic.WithACLPolicy(policy), + hsic.WithTestName("tags-authkey-no-user-reject-advertise"), + hsic.WithTLS(), + ) + requireNoErrHeadscaleEnv(t, err) + + headscale, err := scenario.Headscale() + requireNoErrGetHeadscale(t, err) + + // Create an auth key with tags but WITHOUT a user + authKey, err := scenario.CreatePreAuthKeyWithOptions(hsic.AuthKeyOptions{ + User: nil, + Reusable: false, + Ephemeral: false, + Tags: []string{"tag:valid-owned"}, + }) + require.NoError(t, err) + t.Logf("Created tags-only PreAuthKey with tags: %v", authKey.GetAclTags()) + + // Create a tailscale client WITH --advertise-tags for a DIFFERENT tag + client, err := scenario.CreateTailscaleNode( + "head", + tsic.WithNetwork(scenario.networks[scenario.testDefaultNetwork]), + tsic.WithExtraLoginArgs([]string{"--advertise-tags=tag:second"}), + ) + require.NoError(t, err) + + // Login should fail because ANY advertise-tags is rejected for PreAuthKey registrations + err = client.Login(headscale.GetEndpoint(), authKey.GetKey()) + if err != nil { + t.Logf("Test 5.2 PASS: Registration correctly rejected with error: %v", err) + assert.ErrorContains(t, err, "requested tags") + } else { + t.Logf("Test 5.2 UNEXPECTED: Registration succeeded when it should have failed") + t.Fail() + } +} diff --git a/integration/tailscale.go b/integration/tailscale.go index e7bf71b9..f397133e 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -1,18 +1,26 @@ package integration import ( + "io" "net/netip" "net/url" + "time" + "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/tsic" "tailscale.com/ipn/ipnstate" + "tailscale.com/net/netcheck" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/wgengine/filter" ) // nolint type TailscaleClient interface { Hostname() string - Shutdown() error + Shutdown() (string, string, error) Version() string Execute( command []string, @@ -21,16 +29,37 @@ type TailscaleClient interface { Login(loginServer, authKey string) error LoginWithURL(loginServer string) (*url.URL, error) Logout() error + Restart() error Up() error Down() error IPs() ([]netip.Addr, error) + MustIPs() []netip.Addr + IPv4() (netip.Addr, error) + MustIPv4() netip.Addr + MustIPv6() netip.Addr FQDN() (string, error) - Status() (*ipnstate.Status, error) - WaitForNeedsLogin() error - WaitForRunning() error - WaitForPeers(expected int) error + MustFQDN() string + Status(...bool) (*ipnstate.Status, error) + MustStatus() *ipnstate.Status + Netmap() (*netmap.NetworkMap, error) + DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) + GetNodePrivateKey() (*key.NodePrivate, error) + Netcheck() (*netcheck.Report, error) + WaitForNeedsLogin(timeout time.Duration) error + WaitForRunning(timeout time.Duration) error + WaitForPeers(expected int, timeout, retryInterval time.Duration) error Ping(hostnameOrIP string, opts ...tsic.PingOption) error Curl(url string, opts ...tsic.CurlOption) (string, error) - ID() string - PrettyPeers() (string, error) + CurlFailFast(url string) (string, error) + Traceroute(netip.Addr) (util.Traceroute, error) + ContainerID() string + MustID() types.NodeID + ReadFile(path string) ([]byte, error) + PacketFilter() ([]filter.Match, error) + + // FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client + // and a bool indicating if the clients online count and peer count is equal. + FailingPeersAsString() (string, bool, error) + + WriteLogs(stdout, stderr io.Writer) error } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index 7404f6ea..fb07896b 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -1,30 +1,48 @@ package tsic import ( + "archive/tar" + "bytes" + "context" "encoding/json" "errors" "fmt" + "io" "log" "net/netip" "net/url" + "os" + "reflect" + "runtime/debug" + "slices" "strconv" "strings" "time" + "github.com/cenkalti/backoff/v5" + "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/integration/dockertestutil" "github.com/juanfont/headscale/integration/integrationutil" "github.com/ory/dockertest/v3" "github.com/ory/dockertest/v3/docker" + "tailscale.com/ipn" "tailscale.com/ipn/ipnstate" + "tailscale.com/ipn/store/mem" + "tailscale.com/net/netcheck" + "tailscale.com/paths" + "tailscale.com/types/key" + "tailscale.com/types/netmap" + "tailscale.com/util/multierr" + "tailscale.com/wgengine/filter" ) const ( tsicHashLength = 6 - defaultPingTimeout = 300 * time.Millisecond - defaultPingCount = 10 + defaultPingTimeout = 200 * time.Millisecond + defaultPingCount = 5 dockerContextPath = "../." - headscaleCertPath = "/usr/local/share/ca-certificates/headscale.crt" + caCertRoot = "/usr/local/share/ca-certificates" dockerExecuteTimeout = 60 * time.Second ) @@ -36,6 +54,15 @@ var ( errTailscaleCannotUpWithoutAuthkey = errors.New("cannot up without authkey") errTailscaleNotConnected = errors.New("tailscale not connected") errTailscaledNotReadyForLogin = errors.New("tailscaled not ready for login") + errInvalidClientConfig = errors.New("verifiably invalid client config requested") + errInvalidTailscaleImageFormat = errors.New("invalid HEADSCALE_INTEGRATION_TAILSCALE_IMAGE format, expected repository:tag") + errTailscaleImageRequiredInCI = errors.New("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE must be set in CI for HEAD version") + errContainerNotInitialized = errors.New("container not initialized") + errFQDNNotYetAvailable = errors.New("FQDN not yet available") +) + +const ( + VersionHead = "head" ) func errTailscaleStatus(hostname string, err error) error { @@ -57,54 +84,50 @@ type TailscaleInContainer struct { fqdn string // optional config - headscaleCert []byte + caCerts [][]byte headscaleHostname string + withWebsocketDERP bool withSSH bool withTags []string withEntrypoint []string withExtraHosts []string workdir string netfilter string + extraLoginArgs []string + withAcceptRoutes bool + withPackages []string // Alpine packages to install at container start + withWebserverPort int // Port for built-in HTTP server (0 = disabled) + withExtraCommands []string // Extra shell commands to run before tailscaled + + // build options, solely for HEAD + buildConfig TailscaleInContainerBuildConfig +} + +type TailscaleInContainerBuildConfig struct { + tags []string } // Option represent optional settings that can be given to a // Tailscale instance. type Option = func(c *TailscaleInContainer) -// WithHeadscaleTLS takes the certificate of the Headscale instance -// and adds it to the trusted surtificate of the Tailscale container. -func WithHeadscaleTLS(cert []byte) Option { +// WithCACert adds it to the trusted surtificate of the Tailscale container. +func WithCACert(cert []byte) Option { return func(tsic *TailscaleInContainer) { - tsic.headscaleCert = cert + tsic.caCerts = append(tsic.caCerts, cert) } } -// WithOrCreateNetwork sets the Docker container network to use with -// the Tailscale instance, if the parameter is nil, a new network, -// isolating the TailscaleClient, will be created. If a network is -// passed, the Tailscale instance will join the given network. -func WithOrCreateNetwork(network *dockertest.Network) Option { +// WithNetwork sets the Docker container network to use with +// the Tailscale instance. +func WithNetwork(network *dockertest.Network) Option { return func(tsic *TailscaleInContainer) { - if network != nil { - tsic.network = network - - return - } - - network, err := dockertestutil.GetFirstOrCreateNetwork( - tsic.pool, - fmt.Sprintf("%s-network", tsic.hostname), - ) - if err != nil { - log.Fatalf("failed to create network: %s", err) - } - tsic.network = network } } // WithHeadscaleName set the name of the headscale instance, -// mostly useful in combination with TLS and WithHeadscaleTLS. +// mostly useful in combination with TLS and WithCACert. func WithHeadscaleName(hsName string) Option { return func(tsic *TailscaleInContainer) { tsic.headscaleHostname = hsName @@ -118,6 +141,14 @@ func WithTags(tags []string) Option { } } +// WithWebsocketDERP toggles a development knob to +// force enable DERP connection through the new websocket protocol. +func WithWebsocketDERP(enabled bool) Option { + return func(tsic *TailscaleInContainer) { + tsic.withWebsocketDERP = enabled + } +} + // WithSSH enables SSH for the Tailscale instance. func WithSSH() Option { return func(tsic *TailscaleInContainer) { @@ -158,11 +189,117 @@ func WithNetfilter(state string) Option { } } +// WithBuildTag adds an additional value to the `-tags=` parameter +// of the Go compiler, allowing callers to customize the Tailscale client build. +// This option is only meaningful when invoked on **HEAD** versions of the client. +// Attempts to use it with any other version is a bug in the calling code. +func WithBuildTag(tag string) Option { + return func(tsic *TailscaleInContainer) { + if tsic.version != VersionHead { + panic(errInvalidClientConfig) + } + + tsic.buildConfig.tags = append( + tsic.buildConfig.tags, tag, + ) + } +} + +// WithExtraLoginArgs adds additional arguments to the `tailscale up` command +// as part of the Login function. +func WithExtraLoginArgs(args []string) Option { + return func(tsic *TailscaleInContainer) { + tsic.extraLoginArgs = append(tsic.extraLoginArgs, args...) + } +} + +// WithAcceptRoutes tells the node to accept incoming routes. +func WithAcceptRoutes() Option { + return func(tsic *TailscaleInContainer) { + tsic.withAcceptRoutes = true + } +} + +// WithPackages specifies Alpine packages to install when the container starts. +// This requires internet access and uses `apk add`. Common packages: +// - "python3" for HTTP server +// - "curl" for HTTP client +// - "bind-tools" for dig command +// - "iptables", "ip6tables" for firewall rules +// Note: Tests using this option require internet access and cannot use +// the built-in DERP server in offline mode. +func WithPackages(packages ...string) Option { + return func(tsic *TailscaleInContainer) { + tsic.withPackages = append(tsic.withPackages, packages...) + } +} + +// WithWebserver starts a Python HTTP server on the specified port +// alongside tailscaled. This is useful for testing subnet routing +// and ACL connectivity. Automatically adds "python3" to packages if needed. +// The server serves files from the root directory (/). +func WithWebserver(port int) Option { + return func(tsic *TailscaleInContainer) { + tsic.withWebserverPort = port + } +} + +// WithExtraCommands adds extra shell commands to run before tailscaled starts. +// Commands are run after package installation and CA certificate updates. +func WithExtraCommands(commands ...string) Option { + return func(tsic *TailscaleInContainer) { + tsic.withExtraCommands = append(tsic.withExtraCommands, commands...) + } +} + +// buildEntrypoint constructs the container entrypoint command based on +// configured options (packages, webserver, etc.). +func (t *TailscaleInContainer) buildEntrypoint() []string { + var commands []string + + // Wait for network to be ready + commands = append(commands, "while ! ip route show default >/dev/null 2>&1; do sleep 0.1; done") + + // If CA certs are configured, wait for them to be written by the Go code + // (certs are written after container start via tsic.WriteFile) + if len(t.caCerts) > 0 { + commands = append(commands, + fmt.Sprintf("while [ ! -f %s/user-0.crt ]; do sleep 0.1; done", caCertRoot)) + } + + // Install packages if requested (requires internet access) + packages := t.withPackages + if t.withWebserverPort > 0 && !slices.Contains(packages, "python3") { + packages = append(packages, "python3") + } + + if len(packages) > 0 { + commands = append(commands, "apk add --no-cache "+strings.Join(packages, " ")) + } + + // Update CA certificates + commands = append(commands, "update-ca-certificates") + + // Run extra commands if any + commands = append(commands, t.withExtraCommands...) + + // Start webserver in background if requested + // Use subshell to avoid & interfering with command joining + if t.withWebserverPort > 0 { + commands = append(commands, + fmt.Sprintf("(python3 -m http.server --bind :: %d &)", t.withWebserverPort)) + } + + // Start tailscaled (must be last as it's the foreground process) + commands = append(commands, "tailscaled --tun=tsdev --verbose=10") + + return []string{"/bin/sh", "-c", strings.Join(commands, " ; ")} +} + // New returns a new TailscaleInContainer instance. func New( pool *dockertest.Pool, version string, - network *dockertest.Network, opts ...Option, ) (*TailscaleInContainer, error) { hash, err := util.GenerateRandomStringDNSSafe(tsicHashLength) @@ -170,71 +307,208 @@ func New( return nil, err } - hostname := fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) + // Include run ID in hostname for easier identification of which test run owns this container + runID := dockertestutil.GetIntegrationRunID() + + var hostname string + + if runID != "" { + // Use last 6 chars of run ID (the random hash part) for brevity + runIDShort := runID[len(runID)-6:] + hostname = fmt.Sprintf("ts-%s-%s-%s", runIDShort, strings.ReplaceAll(version, ".", "-"), hash) + } else { + hostname = fmt.Sprintf("ts-%s-%s", strings.ReplaceAll(version, ".", "-"), hash) + } tsic := &TailscaleInContainer{ version: version, hostname: hostname, - pool: pool, - network: network, - - withEntrypoint: []string{ - "/bin/sh", - "-c", - "/bin/sleep 3 ; update-ca-certificates ; tailscaled --tun=tsdev --verbose=10", - }, + pool: pool, } for _, opt := range opts { opt(tsic) } - tailscaleOptions := &dockertest.RunOptions{ - Name: hostname, - Networks: []*dockertest.Network{tsic.network}, - // Cmd: []string{ - // "tailscaled", "--tun=tsdev", - // }, - Entrypoint: tsic.withEntrypoint, - ExtraHosts: tsic.withExtraHosts, + // Build the entrypoint command dynamically based on options. + // Only build if no custom entrypoint was provided via WithDockerEntrypoint. + if len(tsic.withEntrypoint) == 0 { + tsic.withEntrypoint = tsic.buildEntrypoint() } - if tsic.headscaleHostname != "" { - tailscaleOptions.ExtraHosts = []string{ - "host.docker.internal:host-gateway", - fmt.Sprintf("%s:host-gateway", tsic.headscaleHostname), - } + if tsic.network == nil { + return nil, fmt.Errorf("no network set, called from: \n%s", string(debug.Stack())) } + tailscaleOptions := &dockertest.RunOptions{ + Name: hostname, + Networks: []*dockertest.Network{tsic.network}, + Entrypoint: tsic.withEntrypoint, + ExtraHosts: tsic.withExtraHosts, + Env: []string{}, + } + + if tsic.withWebsocketDERP { + if version != VersionHead { + return tsic, errInvalidClientConfig + } + + WithBuildTag("ts_debug_websockets")(tsic) + + tailscaleOptions.Env = append( + tailscaleOptions.Env, + fmt.Sprintf("TS_DEBUG_DERP_WS_CLIENT=%t", tsic.withWebsocketDERP), + ) + } + + tailscaleOptions.ExtraHosts = append(tailscaleOptions.ExtraHosts, + "host.docker.internal:host-gateway") + if tsic.workdir != "" { tailscaleOptions.WorkingDir = tsic.workdir } - // dockertest isnt very good at handling containers that has already - // been created, this is an attempt to make sure this container isnt + // dockertest isn't very good at handling containers that has already + // been created, this is an attempt to make sure this container isn't // present. err = pool.RemoveContainerByName(hostname) if err != nil { return nil, err } + // Add integration test labels if running under hi tool + dockertestutil.DockerAddIntegrationLabels(tailscaleOptions, "tailscale") + var container *dockertest.Resource + + if version != VersionHead { + // build options are not meaningful with pre-existing images, + // let's not lead anyone astray by pretending otherwise. + defaultBuildConfig := TailscaleInContainerBuildConfig{} + + hasBuildConfig := !reflect.DeepEqual(defaultBuildConfig, tsic.buildConfig) + if hasBuildConfig { + return tsic, errInvalidClientConfig + } + } + switch version { - case "head": - buildOptions := &dockertest.BuildOptions{ - Dockerfile: "Dockerfile.tailscale-HEAD", - ContextDir: dockerContextPath, - BuildArgs: []docker.BuildArg{}, + case VersionHead: + // Check if a pre-built image is available via environment variable + prebuiltImage := os.Getenv("HEADSCALE_INTEGRATION_TAILSCALE_IMAGE") + + // If custom build tags are required (e.g., for websocket DERP), we cannot use + // the pre-built image as it won't have the necessary code compiled in. + hasBuildTags := len(tsic.buildConfig.tags) > 0 + if hasBuildTags && prebuiltImage != "" { + log.Printf("Ignoring pre-built image %s because custom build tags are required: %v", + prebuiltImage, tsic.buildConfig.tags) + prebuiltImage = "" } - container, err = pool.BuildAndRunWithBuildOptions( - buildOptions, - tailscaleOptions, - dockertestutil.DockerRestartPolicy, - dockertestutil.DockerAllowLocalIPv6, - dockertestutil.DockerAllowNetworkAdministration, - ) + if prebuiltImage != "" { + log.Printf("Using pre-built tailscale image: %s", prebuiltImage) + + // Parse image into repository and tag + repo, tag, ok := strings.Cut(prebuiltImage, ":") + if !ok { + return nil, errInvalidTailscaleImageFormat + } + + tailscaleOptions.Repository = repo + tailscaleOptions.Tag = tag + + container, err = pool.RunWithOptions( + tailscaleOptions, + dockertestutil.DockerRestartPolicy, + dockertestutil.DockerAllowLocalIPv6, + dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, + ) + if err != nil { + return nil, fmt.Errorf("could not run pre-built tailscale container %q: %w", prebuiltImage, err) + } + } else if util.IsCI() && !hasBuildTags { + // In CI, we require a pre-built image unless custom build tags are needed + return nil, errTailscaleImageRequiredInCI + } else { + buildOptions := &dockertest.BuildOptions{ + Dockerfile: "Dockerfile.tailscale-HEAD", + ContextDir: dockerContextPath, + BuildArgs: []docker.BuildArg{}, + } + + buildTags := strings.Join(tsic.buildConfig.tags, ",") + if len(buildTags) > 0 { + buildOptions.BuildArgs = append( + buildOptions.BuildArgs, + docker.BuildArg{ + Name: "BUILD_TAGS", + Value: buildTags, + }, + ) + } + + container, err = pool.BuildAndRunWithBuildOptions( + buildOptions, + tailscaleOptions, + dockertestutil.DockerRestartPolicy, + dockertestutil.DockerAllowLocalIPv6, + dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, + ) + if err != nil { + // Try to get more detailed build output + log.Printf("Docker build failed for %s, attempting to get detailed output...", hostname) + + buildOutput, buildErr := dockertestutil.RunDockerBuildForDiagnostics(dockerContextPath, "Dockerfile.tailscale-HEAD") + + // Show the last 100 lines of build output to avoid overwhelming the logs + lines := strings.Split(buildOutput, "\n") + + const maxLines = 100 + + startLine := 0 + if len(lines) > maxLines { + startLine = len(lines) - maxLines + } + + relevantOutput := strings.Join(lines[startLine:], "\n") + + if buildErr != nil { + // The diagnostic build also failed - this is the real error + return nil, fmt.Errorf( + "%s could not start tailscale container (version: %s): %w\n\nDocker build failed. Last %d lines of output:\n%s", + hostname, + version, + err, + maxLines, + relevantOutput, + ) + } + + if buildOutput != "" { + // Build succeeded on retry but container creation still failed + return nil, fmt.Errorf( + "%s could not start tailscale container (version: %s): %w\n\nDocker build succeeded on retry, but container creation failed. Last %d lines of build output:\n%s", + hostname, + version, + err, + maxLines, + relevantOutput, + ) + } + + // No output at all - diagnostic build command may have failed + return nil, fmt.Errorf( + "%s could not start tailscale container (version: %s): %w\n\nUnable to get diagnostic build output (command may have failed silently)", + hostname, + version, + err, + ) + } + } case "unstable": tailscaleOptions.Repository = "tailscale/tailscale" tailscaleOptions.Tag = version @@ -244,7 +518,11 @@ func New( dockertestutil.DockerRestartPolicy, dockertestutil.DockerAllowLocalIPv6, dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, ) + if err != nil { + log.Printf("Docker run failed for %s (unstable), error: %v", hostname, err) + } default: tailscaleOptions.Repository = "tailscale/tailscale" tailscaleOptions.Tag = "v" + version @@ -254,7 +532,11 @@ func New( dockertestutil.DockerRestartPolicy, dockertestutil.DockerAllowLocalIPv6, dockertestutil.DockerAllowNetworkAdministration, + dockertestutil.DockerMemoryLimit, ) + if err != nil { + log.Printf("Docker run failed for %s (version: v%s), error: %v", hostname, version, err) + } } if err != nil { @@ -265,12 +547,13 @@ func New( err, ) } + log.Printf("Created %s container\n", hostname) tsic.container = container - if tsic.hasTLS() { - err = tsic.WriteFile(headscaleCertPath, tsic.headscaleCert) + for i, cert := range tsic.caCerts { + err = tsic.WriteFile(fmt.Sprintf("%s/user-%d.crt", caCertRoot, i), cert) if err != nil { return nil, fmt.Errorf("failed to write TLS certificate to container: %w", err) } @@ -279,13 +562,9 @@ func New( return tsic, nil } -func (t *TailscaleInContainer) hasTLS() bool { - return len(t.headscaleCert) != 0 -} - // Shutdown stops and cleans up the Tailscale container. -func (t *TailscaleInContainer) Shutdown() error { - err := t.SaveLog("/tmp/control") +func (t *TailscaleInContainer) Shutdown() (string, string, error) { + stdoutPath, stderrPath, err := t.SaveLog("/tmp/control") if err != nil { log.Printf( "Failed to save log from %s: %s", @@ -294,7 +573,7 @@ func (t *TailscaleInContainer) Shutdown() error { ) } - return t.pool.Purge(t.container) + return stdoutPath, stderrPath, t.pool.Purge(t.container) } // Hostname returns the hostname of the Tailscale instance. @@ -309,7 +588,7 @@ func (t *TailscaleInContainer) Version() string { // ID returns the Docker container ID of the TailscaleInContainer // instance. -func (t *TailscaleInContainer) ID() string { +func (t *TailscaleInContainer) ContainerID() string { return t.container.Container.ID } @@ -328,7 +607,6 @@ func (t *TailscaleInContainer) Execute( if err != nil { // log.Printf("command issued: %s", strings.Join(command, " ")) // log.Printf("command stderr: %s\n", stderr) - if stdout != "" { log.Printf("command stdout: %s\n", stdout) } @@ -343,18 +621,32 @@ func (t *TailscaleInContainer) Execute( return stdout, stderr, nil } -// Up runs the login routine on the given Tailscale instance. -// This login mechanism uses the authorised key for authentication. -func (t *TailscaleInContainer) Login( +// Retrieve container logs. +func (t *TailscaleInContainer) Logs(stdout, stderr io.Writer) error { + return dockertestutil.WriteLog( + t.pool, + t.container, + stdout, stderr, + ) +} + +func (t *TailscaleInContainer) buildLoginCommand( loginServer, authKey string, -) error { +) []string { command := []string{ "tailscale", "up", "--login-server=" + loginServer, - "--authkey=" + authKey, "--hostname=" + t.hostname, - "--accept-routes=false", + fmt.Sprintf("--accept-routes=%t", t.withAcceptRoutes), + } + + if authKey != "" { + command = append(command, "--authkey="+authKey) + } + + if t.extraLoginArgs != nil { + command = append(command, t.extraLoginArgs...) } if t.withSSH { @@ -367,10 +659,20 @@ func (t *TailscaleInContainer) Login( if len(t.withTags) > 0 { command = append(command, - fmt.Sprintf(`--advertise-tags=%s`, strings.Join(t.withTags, ",")), + "--advertise-tags="+strings.Join(t.withTags, ","), ) } + return command +} + +// Login runs the login routine on the given Tailscale instance. +// This login mechanism uses the authorised key for authentication. +func (t *TailscaleInContainer) Login( + loginServer, authKey string, +) error { + command := t.buildLoginCommand(loginServer, authKey) + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { return fmt.Errorf( "%s failed to join tailscale client (%s): %w", @@ -387,29 +689,22 @@ func (t *TailscaleInContainer) Login( // This login mechanism uses web + command line flow for authentication. func (t *TailscaleInContainer) LoginWithURL( loginServer string, -) (*url.URL, error) { - command := []string{ - "tailscale", - "up", - "--login-server=" + loginServer, - "--hostname=" + t.hostname, - "--accept-routes=false", - } +) (loginURL *url.URL, err error) { + command := t.buildLoginCommand(loginServer, "") - _, stderr, err := t.Execute(command) + stdout, stderr, err := t.Execute(command) if errors.Is(err, errTailscaleNotLoggedIn) { return nil, errTailscaleCannotUpWithoutAuthkey } - urlStr := strings.ReplaceAll(stderr, "\nTo authenticate, visit:\n\n\t", "") - urlStr = strings.TrimSpace(urlStr) + defer func() { + if err != nil { + log.Printf("join command: %q", strings.Join(command, " ")) + } + }() - // parse URL - loginURL, err := url.Parse(urlStr) + loginURL, err = util.ParseLoginURLFromCLILogin(stdout + stderr) if err != nil { - log.Printf("Could not parse login URL: %s", err) - log.Printf("Original join command result: %s", stderr) - return nil, err } @@ -418,11 +713,49 @@ func (t *TailscaleInContainer) LoginWithURL( // Logout runs the logout routine on the given Tailscale instance. func (t *TailscaleInContainer) Logout() error { - _, _, err := t.Execute([]string{"tailscale", "logout"}) + stdout, stderr, err := t.Execute([]string{"tailscale", "logout"}) if err != nil { return err } + stdout, stderr, _ = t.Execute([]string{"tailscale", "status"}) + if !strings.Contains(stdout+stderr, "Logged out.") { + return fmt.Errorf("failed to logout, stdout: %s, stderr: %s", stdout, stderr) + } + + return t.waitForBackendState("NeedsLogin", integrationutil.PeerSyncTimeout()) +} + +// Restart restarts the Tailscale container using Docker API. +// This simulates a container restart (e.g., docker restart or Kubernetes pod restart). +// The container's entrypoint will re-execute, which typically includes running +// "tailscale up" with any auth keys stored in environment variables. +func (t *TailscaleInContainer) Restart() error { + if t.container == nil { + return errContainerNotInitialized + } + + // Use Docker API to restart the container + err := t.pool.Client.RestartContainer(t.container.Container.ID, 30) + if err != nil { + return fmt.Errorf("failed to restart container %s: %w", t.hostname, err) + } + + // Wait for the container to be back up and tailscaled to be ready + // We use exponential backoff to poll until we can successfully execute a command + _, err = backoff.Retry(context.Background(), func() (struct{}, error) { + // Try to execute a simple command to verify the container is responsive + _, _, err := t.Execute([]string{"tailscale", "version"}, dockertestutil.ExecuteCommandTimeout(5*time.Second)) + if err != nil { + return struct{}{}, fmt.Errorf("container not ready: %w", err) + } + + return struct{}{}, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(30*time.Second)) + if err != nil { + return fmt.Errorf("timeout waiting for container %s to restart and become ready: %w", t.hostname, err) + } + return nil } @@ -466,39 +799,97 @@ func (t *TailscaleInContainer) Down() error { // IPs returns the netip.Addr of the Tailscale instance. func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { - if t.ips != nil && len(t.ips) != 0 { + if len(t.ips) != 0 { return t.ips, nil } - ips := make([]netip.Addr, 0) - - command := []string{ - "tailscale", - "ip", - } - - result, _, err := t.Execute(command) - if err != nil { - return []netip.Addr{}, fmt.Errorf("%s failed to join tailscale client: %w", t.hostname, err) - } - - for _, address := range strings.Split(result, "\n") { - address = strings.TrimSuffix(address, "\n") - if len(address) < 1 { - continue + // Retry with exponential backoff to handle eventual consistency + ips, err := backoff.Retry(context.Background(), func() ([]netip.Addr, error) { + command := []string{ + "tailscale", + "ip", } - ip, err := netip.ParseAddr(address) + + result, _, err := t.Execute(command) if err != nil { - return nil, err + return nil, fmt.Errorf("%s failed to get IPs: %w", t.hostname, err) } - ips = append(ips, ip) + + ips := make([]netip.Addr, 0) + + for address := range strings.SplitSeq(result, "\n") { + address = strings.TrimSuffix(address, "\n") + if len(address) < 1 { + continue + } + + ip, err := netip.ParseAddr(address) + if err != nil { + return nil, fmt.Errorf("failed to parse IP %s: %w", address, err) + } + + ips = append(ips, ip) + } + + if len(ips) == 0 { + return nil, fmt.Errorf("no IPs returned yet for %s", t.hostname) + } + + return ips, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) + if err != nil { + return nil, fmt.Errorf("failed to get IPs for %s after retries: %w", t.hostname, err) } return ips, nil } +func (t *TailscaleInContainer) MustIPs() []netip.Addr { + ips, err := t.IPs() + if err != nil { + panic(err) + } + + return ips +} + +// IPv4 returns the IPv4 address of the Tailscale instance. +func (t *TailscaleInContainer) IPv4() (netip.Addr, error) { + ips, err := t.IPs() + if err != nil { + return netip.Addr{}, err + } + + for _, ip := range ips { + if ip.Is4() { + return ip, nil + } + } + + return netip.Addr{}, fmt.Errorf("no IPv4 address found for %s", t.hostname) +} + +func (t *TailscaleInContainer) MustIPv4() netip.Addr { + ip, err := t.IPv4() + if err != nil { + panic(err) + } + + return ip +} + +func (t *TailscaleInContainer) MustIPv6() netip.Addr { + for _, ip := range t.MustIPs() { + if ip.Is6() { + return ip + } + } + + panic("no ipv6 found") +} + // Status returns the ipnstate.Status of the Tailscale instance. -func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) { +func (t *TailscaleInContainer) Status(save ...bool) (*ipnstate.Status, error) { command := []string{ "tailscale", "status", @@ -511,37 +902,274 @@ func (t *TailscaleInContainer) Status() (*ipnstate.Status, error) { } var status ipnstate.Status + err = json.Unmarshal([]byte(result), &status) if err != nil { return nil, fmt.Errorf("failed to unmarshal tailscale status: %w", err) } + err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_status.json", t.hostname), []byte(result), 0o755) + if err != nil { + return nil, fmt.Errorf("status netmap to /tmp/control: %w", err) + } + return &status, err } +// MustStatus returns the ipnstate.Status of the Tailscale instance. +func (t *TailscaleInContainer) MustStatus() *ipnstate.Status { + status, err := t.Status() + if err != nil { + panic(err) + } + + return status +} + +// MustID returns the ID of the Tailscale instance. +func (t *TailscaleInContainer) MustID() types.NodeID { + status, err := t.Status() + if err != nil { + panic(err) + } + + id, err := strconv.ParseUint(string(status.Self.ID), 10, 64) + if err != nil { + panic(fmt.Sprintf("failed to parse ID: %s", err)) + } + + return types.NodeID(id) +} + +// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. +// Only works with Tailscale 1.56 and newer. +// Panics if version is lower then minimum. +func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { + if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { + panic("tsic.Netmap() called with unsupported version: " + t.version) + } + + command := []string{ + "tailscale", + "debug", + "netmap", + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) + return nil, fmt.Errorf("failed to execute tailscale debug netmap command: %w", err) + } + + var nm netmap.NetworkMap + + err = json.Unmarshal([]byte(result), &nm) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale netmap: %w", err) + } + + err = os.WriteFile(fmt.Sprintf("/tmp/control/%s_netmap.json", t.hostname), []byte(result), 0o755) + if err != nil { + return nil, fmt.Errorf("saving netmap to /tmp/control: %w", err) + } + + return &nm, err +} + +// Netmap returns the current Netmap (netmap.NetworkMap) of the Tailscale instance. +// This implementation is based on getting the netmap from `tailscale debug watch-ipn` +// as there seem to be some weirdness omitting endpoint and DERP info if we use +// Patch updates. +// This implementation works on all supported versions. +// func (t *TailscaleInContainer) Netmap() (*netmap.NetworkMap, error) { +// // watch-ipn will only give an update if something is happening, +// // since we send keep alives, the worst case for this should be +// // 1 minute, but set a slightly more conservative time. +// ctx, _ := context.WithTimeout(context.Background(), 3*time.Minute) + +// notify, err := t.watchIPN(ctx) +// if err != nil { +// return nil, err +// } + +// if notify.NetMap == nil { +// return nil, fmt.Errorf("no netmap present in ipn.Notify") +// } + +// return notify.NetMap, nil +// } + +// watchIPN watches `tailscale debug watch-ipn` for a ipn.Notify object until +// it gets one that has a netmap.NetworkMap. +func (t *TailscaleInContainer) watchIPN(ctx context.Context) (*ipn.Notify, error) { + pr, pw := io.Pipe() + + type result struct { + notify *ipn.Notify + err error + } + + resultChan := make(chan result, 1) + + // There is no good way to kill the goroutine with watch-ipn, + // so make a nice func to send a kill command to issue when + // we are done. + killWatcher := func() { + stdout, stderr, err := t.Execute([]string{ + "/bin/sh", "-c", `kill $(ps aux | grep "tailscale debug watch-ipn" | grep -v grep | awk '{print $1}') || true`, + }) + if err != nil { + log.Printf("failed to kill tailscale watcher, \nstdout: %s\nstderr: %s\nerr: %s", stdout, stderr, err) + } + } + + go func() { + _, _ = t.container.Exec( + // Prior to 1.56, the initial "Connected." message was printed to stdout, + // filter out with grep. + []string{"/bin/sh", "-c", `tailscale debug watch-ipn | grep -v "Connected."`}, + dockertest.ExecOptions{ + // The interesting output is sent to stdout, so ignore stderr. + StdOut: pw, + // StdErr: pw, + }, + ) + }() + + go func() { + decoder := json.NewDecoder(pr) + for decoder.More() { + var notify ipn.Notify + + err := decoder.Decode(¬ify) + if err != nil { + resultChan <- result{nil, fmt.Errorf("parse notify: %w", err)} + } + + if notify.NetMap != nil { + resultChan <- result{¬ify, nil} + } + } + }() + + select { + case <-ctx.Done(): + killWatcher() + + return nil, ctx.Err() + + case result := <-resultChan: + killWatcher() + + if result.err != nil { + return nil, result.err + } + + return result.notify, nil + } +} + +func (t *TailscaleInContainer) DebugDERPRegion(region string) (*ipnstate.DebugDERPRegionReport, error) { + if !util.TailscaleVersionNewerOrEqual("1.34", t.version) { + panic("tsic.DebugDERPRegion() called with unsupported version: " + t.version) + } + + command := []string{ + "tailscale", + "debug", + "derp", + region, + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) // nolint + + return nil, fmt.Errorf("failed to execute tailscale debug derp command: %w", err) + } + + var report ipnstate.DebugDERPRegionReport + + err = json.Unmarshal([]byte(result), &report) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale derp region report: %w", err) + } + + return &report, err +} + +// Netcheck returns the current Netcheck Report (netcheck.Report) of the Tailscale instance. +func (t *TailscaleInContainer) Netcheck() (*netcheck.Report, error) { + command := []string{ + "tailscale", + "netcheck", + "--format=json", + } + + result, stderr, err := t.Execute(command) + if err != nil { + fmt.Printf("stderr: %s\n", stderr) + return nil, fmt.Errorf("failed to execute tailscale debug netcheck command: %w", err) + } + + var nm netcheck.Report + + err = json.Unmarshal([]byte(result), &nm) + if err != nil { + return nil, fmt.Errorf("failed to unmarshal tailscale netcheck: %w", err) + } + + return &nm, err +} + // FQDN returns the FQDN as a string of the Tailscale instance. func (t *TailscaleInContainer) FQDN() (string, error) { if t.fqdn != "" { return t.fqdn, nil } - status, err := t.Status() + // Retry with exponential backoff to handle eventual consistency + fqdn, err := backoff.Retry(context.Background(), func() (string, error) { + status, err := t.Status() + if err != nil { + return "", fmt.Errorf("failed to get status: %w", err) + } + + if status.Self.DNSName == "" { + return "", errFQDNNotYetAvailable + } + + return status.Self.DNSName, nil + }, backoff.WithBackOff(backoff.NewExponentialBackOff()), backoff.WithMaxElapsedTime(10*time.Second)) if err != nil { - return "", fmt.Errorf("failed to get FQDN: %w", err) + return "", fmt.Errorf("failed to get FQDN for %s after retries: %w", t.hostname, err) } - return status.Self.DNSName, nil + return fqdn, nil } -// PrettyPeers returns a formatted-ish table of peers in the client. -func (t *TailscaleInContainer) PrettyPeers() (string, error) { - status, err := t.Status() +// MustFQDN returns the FQDN as a string of the Tailscale instance, panicking on error. +func (t *TailscaleInContainer) MustFQDN() string { + fqdn, err := t.FQDN() if err != nil { - return "", fmt.Errorf("failed to get FQDN: %w", err) + panic(err) } - str := fmt.Sprintf("Peers of %s\n", t.hostname) - str += "Hostname\tOnline\tLastSeen\n" + return fqdn +} + +// FailingPeersAsString returns a formatted-ish multi-line-string of peers in the client +// and a bool indicating if the clients online count and peer count is equal. +func (t *TailscaleInContainer) FailingPeersAsString() (string, bool, error) { + status, err := t.Status() + if err != nil { + return "", false, fmt.Errorf("failed to get FQDN: %w", err) + } + + var b strings.Builder + + fmt.Fprintf(&b, "Peers of %s\n", t.hostname) + fmt.Fprint(&b, "Hostname\tOnline\tLastSeen\n") peerCount := len(status.Peers()) onlineCount := 0 @@ -553,87 +1181,129 @@ func (t *TailscaleInContainer) PrettyPeers() (string, error) { onlineCount++ } - str += fmt.Sprintf("%s\t%t\t%s\n", peer.HostName, peer.Online, peer.LastSeen) + fmt.Fprintf(&b, "%s\t%t\t%s\n", peer.HostName, peer.Online, peer.LastSeen) } - str += fmt.Sprintf("Peer Count: %d, Online Count: %d\n\n", peerCount, onlineCount) + fmt.Fprintf(&b, "Peer Count: %d, Online Count: %d\n\n", peerCount, onlineCount) - return str, nil + return b.String(), peerCount == onlineCount, nil } // WaitForNeedsLogin blocks until the Tailscale (tailscaled) instance has // started and needs to be logged into. -func (t *TailscaleInContainer) WaitForNeedsLogin() error { - return t.pool.Retry(func() error { - status, err := t.Status() - if err != nil { - return errTailscaleStatus(t.hostname, err) - } - - // ipnstate.Status.CurrentTailnet was added in Tailscale 1.22.0 - // https://github.com/tailscale/tailscale/pull/3865 - // - // Before that, we can check the BackendState to see if the - // tailscaled daemon is connected to the control system. - if status.BackendState == "NeedsLogin" { - return nil - } - - return errTailscaledNotReadyForLogin - }) +func (t *TailscaleInContainer) WaitForNeedsLogin(timeout time.Duration) error { + return t.waitForBackendState("NeedsLogin", timeout) } // WaitForRunning blocks until the Tailscale (tailscaled) instance is logged in // and ready to be used. -func (t *TailscaleInContainer) WaitForRunning() error { - return t.pool.Retry(func() error { - status, err := t.Status() - if err != nil { - return errTailscaleStatus(t.hostname, err) - } +func (t *TailscaleInContainer) WaitForRunning(timeout time.Duration) error { + return t.waitForBackendState("Running", timeout) +} - // ipnstate.Status.CurrentTailnet was added in Tailscale 1.22.0 - // https://github.com/tailscale/tailscale/pull/3865 - // - // Before that, we can check the BackendState to see if the - // tailscaled daemon is connected to the control system. - if status.BackendState == "Running" { - return nil - } +func (t *TailscaleInContainer) waitForBackendState(state string, timeout time.Duration) error { + ticker := time.NewTicker(integrationutil.PeerSyncRetryInterval()) + defer ticker.Stop() - return errTailscaleNotConnected - }) + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + for { + select { + case <-ctx.Done(): + return fmt.Errorf("timeout waiting for backend state %s on %s after %v", state, t.hostname, timeout) + case <-ticker.C: + status, err := t.Status() + if err != nil { + continue // Keep retrying on status errors + } + + // ipnstate.Status.CurrentTailnet was added in Tailscale 1.22.0 + // https://github.com/tailscale/tailscale/pull/3865 + // + // Before that, we can check the BackendState to see if the + // tailscaled daemon is connected to the control system. + if status.BackendState == state { + return nil + } + } + } } // WaitForPeers blocks until N number of peers is present in the // Peer list of the Tailscale instance and is reporting Online. -func (t *TailscaleInContainer) WaitForPeers(expected int) error { - return t.pool.Retry(func() error { - status, err := t.Status() - if err != nil { - return errTailscaleStatus(t.hostname, err) - } +// +// The method verifies that each peer: +// - Has the expected peer count +// - All peers are Online +// - All peers have a hostname +// - All peers have a DERP relay assigned +// +// Uses multierr to collect all validation errors. +func (t *TailscaleInContainer) WaitForPeers(expected int, timeout, retryInterval time.Duration) error { + ticker := time.NewTicker(retryInterval) + defer ticker.Stop() - if peers := status.Peers(); len(peers) != expected { - return fmt.Errorf( - "%s err: %w expected %d, got %d", - t.hostname, - errTailscaleWrongPeerCount, - expected, - len(peers), - ) - } else { - for _, peerKey := range peers { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + var lastErrs []error + + for { + select { + case <-ctx.Done(): + if len(lastErrs) > 0 { + return fmt.Errorf("timeout waiting for %d peers on %s after %v, errors: %w", expected, t.hostname, timeout, multierr.New(lastErrs...)) + } + + return fmt.Errorf("timeout waiting for %d peers on %s after %v", expected, t.hostname, timeout) + case <-ticker.C: + status, err := t.Status() + if err != nil { + lastErrs = []error{errTailscaleStatus(t.hostname, err)} + continue // Keep retrying on status errors + } + + if peers := status.Peers(); len(peers) != expected { + lastErrs = []error{fmt.Errorf( + "%s err: %w expected %d, got %d", + t.hostname, + errTailscaleWrongPeerCount, + expected, + len(peers), + )} + + continue + } + + // Verify that the peers of a given node is Online + // has a hostname and a DERP relay. + var peerErrors []error + + for _, peerKey := range status.Peers() { peer := status.Peer[peerKey] if !peer.Online { - return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName) + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName)) + } + + if peer.HostName == "" { + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a Hostname", t.hostname, peer.HostName)) + } + + if peer.Relay == "" { + peerErrors = append(peerErrors, fmt.Errorf("[%s] peer count correct, but %s does not have a DERP", t.hostname, peer.HostName)) } } - } - return nil - }) + if len(peerErrors) > 0 { + lastErrs = peerErrors + continue + } + + return nil + } + } } type ( @@ -689,7 +1359,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err "tailscale", "ping", fmt.Sprintf("--timeout=%s", args.timeout), fmt.Sprintf("--c=%d", args.count), - fmt.Sprintf("--until-direct=%s", strconv.FormatBool(args.direct)), + "--until-direct=" + strconv.FormatBool(args.direct), } command = append(command, hostnameOrIP) @@ -701,6 +1371,7 @@ func (t *TailscaleInContainer) Ping(hostnameOrIP string, opts ...PingOption) err ), ) if err != nil { + log.Printf("command: %v", command) log.Printf( "failed to run ping command from %s to %s, err: %s", t.Hostname(), @@ -768,11 +1439,11 @@ func WithCurlRetry(ret int) CurlOption { } const ( - defaultConnectionTimeout = 3 * time.Second - defaultMaxTime = 10 * time.Second - defaultRetry = 5 - defaultRetryDelay = 0 * time.Second - defaultRetryMaxTime = 50 * time.Second + defaultConnectionTimeout = 1 * time.Second + defaultMaxTime = 3 * time.Second + defaultRetry = 3 + defaultRetryDelay = 200 * time.Millisecond + defaultRetryMaxTime = 5 * time.Second ) // Curl executes the Tailscale curl command and curls a hostname @@ -793,15 +1464,16 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err command := []string{ "curl", "--silent", - "--connect-timeout", fmt.Sprintf("%d", int(args.connectionTimeout.Seconds())), - "--max-time", fmt.Sprintf("%d", int(args.maxTime.Seconds())), - "--retry", fmt.Sprintf("%d", args.retry), - "--retry-delay", fmt.Sprintf("%d", int(args.retryDelay.Seconds())), - "--retry-max-time", fmt.Sprintf("%d", int(args.retryMaxTime.Seconds())), + "--connect-timeout", strconv.Itoa(int(args.connectionTimeout.Seconds())), + "--max-time", strconv.Itoa(int(args.maxTime.Seconds())), + "--retry", strconv.Itoa(args.retry), + "--retry-delay", strconv.Itoa(int(args.retryDelay.Seconds())), + "--retry-max-time", strconv.Itoa(int(args.retryMaxTime.Seconds())), url, } var result string + result, _, err := t.Execute(command) if err != nil { log.Printf( @@ -817,6 +1489,38 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err return result, nil } +// CurlFailFast executes the Tailscale curl command with aggressive timeouts +// optimized for testing expected connection failures. It uses minimal timeouts +// to quickly detect blocked connections without waiting for multiple retries. +func (t *TailscaleInContainer) CurlFailFast(url string) (string, error) { + // Use aggressive timeouts for fast failure detection + return t.Curl(url, + WithCurlConnectionTimeout(1*time.Second), + WithCurlMaxTime(2*time.Second), + WithCurlRetry(1)) +} + +func (t *TailscaleInContainer) Traceroute(ip netip.Addr) (util.Traceroute, error) { + command := []string{ + "traceroute", + ip.String(), + } + + var result util.Traceroute + + stdout, stderr, err := t.Execute(command) + if err != nil { + return result, err + } + + result, err = util.ParseTraceroute(stdout + stderr) + if err != nil { + return result, err + } + + return result, nil +} + // WriteFile save file inside the Tailscale container. func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) @@ -824,6 +1528,102 @@ func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { // SaveLog saves the current stdout log of the container to a path // on the host system. -func (t *TailscaleInContainer) SaveLog(path string) error { +func (t *TailscaleInContainer) SaveLog(path string) (string, string, error) { + // TODO(kradalby): Assert if tailscale logs contains panics. + // NOTE(enoperm): `t.WriteLog | countMatchingLines` + // is probably most of what is for that, + // but I'd rather not change the behaviour here, + // as it may affect all the other tests + // I have not otherwise touched. return dockertestutil.SaveLog(t.pool, t.container, path) } + +// WriteLogs writes the current stdout/stderr log of the container to +// the given io.Writers. +func (t *TailscaleInContainer) WriteLogs(stdout, stderr io.Writer) error { + return dockertestutil.WriteLog(t.pool, t.container, stdout, stderr) +} + +// ReadFile reads a file from the Tailscale container. +// It returns the content of the file as a byte slice. +func (t *TailscaleInContainer) ReadFile(path string) ([]byte, error) { + tarBytes, err := integrationutil.FetchPathFromContainer(t.pool, t.container, path) + if err != nil { + return nil, fmt.Errorf("reading file from container: %w", err) + } + + var out bytes.Buffer + + tr := tar.NewReader(bytes.NewReader(tarBytes)) + for { + hdr, err := tr.Next() + if err == io.EOF { + break // End of archive + } + + if err != nil { + return nil, fmt.Errorf("reading tar header: %w", err) + } + + if !strings.Contains(path, hdr.Name) { + return nil, fmt.Errorf("file not found in tar archive, looking for: %s, header was: %s", path, hdr.Name) + } + + if _, err := io.Copy(&out, tr); err != nil { + return nil, fmt.Errorf("copying file to buffer: %w", err) + } + + // Only support reading the first tile + break + } + + if out.Len() == 0 { + return nil, errors.New("file is empty") + } + + return out.Bytes(), nil +} + +func (t *TailscaleInContainer) GetNodePrivateKey() (*key.NodePrivate, error) { + state, err := t.ReadFile(paths.DefaultTailscaledStateFile()) + if err != nil { + return nil, fmt.Errorf("failed to read state file: %w", err) + } + + store := &mem.Store{} + if err = store.LoadFromJSON(state); err != nil { + return nil, fmt.Errorf("failed to unmarshal state file: %w", err) + } + + currentProfileKey, err := store.ReadState(ipn.CurrentProfileStateKey) + if err != nil { + return nil, fmt.Errorf("failed to read current profile state key: %w", err) + } + + currentProfile, err := store.ReadState(ipn.StateKey(currentProfileKey)) + if err != nil { + return nil, fmt.Errorf("failed to read current profile state: %w", err) + } + + p := &ipn.Prefs{} + if err = json.Unmarshal(currentProfile, &p); err != nil { + return nil, fmt.Errorf("failed to unmarshal current profile state: %w", err) + } + + return &p.Persist.PrivateNodeKey, nil +} + +// PacketFilter returns the current packet filter rules from the client's network map. +// This is useful for verifying that policy changes have propagated to the client. +func (t *TailscaleInContainer) PacketFilter() ([]filter.Match, error) { + if !util.TailscaleVersionNewerOrEqual("1.56", t.version) { + return nil, fmt.Errorf("tsic.PacketFilter() requires Tailscale 1.56+, current version: %s", t.version) + } + + nm, err := t.Netmap() + if err != nil { + return nil, fmt.Errorf("failed to get netmap: %w", err) + } + + return nm.PacketFilter, nil +} diff --git a/integration/utils.go b/integration/utils.go deleted file mode 100644 index e17e18a2..00000000 --- a/integration/utils.go +++ /dev/null @@ -1,229 +0,0 @@ -package integration - -import ( - "os" - "strings" - "testing" - "time" - - "github.com/juanfont/headscale/integration/tsic" -) - -const ( - derpPingTimeout = 2 * time.Second - derpPingCount = 10 -) - -func assertNoErr(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "unexpected error: %s", err) -} - -func assertNoErrf(t *testing.T, msg string, err error) { - t.Helper() - if err != nil { - t.Fatalf(msg, err) - } -} - -func assertNotNil(t *testing.T, thing interface{}) { - t.Helper() - if thing == nil { - t.Fatal("got unexpected nil") - } -} - -func assertNoErrHeadscaleEnv(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to create headscale environment: %s", err) -} - -func assertNoErrGetHeadscale(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to get headscale: %s", err) -} - -func assertNoErrListClients(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to list clients: %s", err) -} - -func assertNoErrListClientIPs(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to get client IPs: %s", err) -} - -func assertNoErrSync(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to have all clients sync up: %s", err) -} - -func assertNoErrListFQDN(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to list FQDNs: %s", err) -} - -func assertNoErrLogout(t *testing.T, err error) { - t.Helper() - assertNoErrf(t, "failed to log out tailscale nodes: %s", err) -} - -func assertContains(t *testing.T, str, subStr string) { - t.Helper() - if !strings.Contains(str, subStr) { - t.Fatalf("%#v does not contain %#v", str, subStr) - } -} - -func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { - t.Helper() - success := 0 - - for _, client := range clients { - for _, addr := range addrs { - err := client.Ping(addr, opts...) - if err != nil { - t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) - } else { - success++ - } - } - } - - return success -} - -func pingDerpAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { - t.Helper() - success := 0 - - for _, client := range clients { - for _, addr := range addrs { - if isSelfClient(client, addr) { - continue - } - - err := client.Ping( - addr, - tsic.WithPingTimeout(derpPingTimeout), - tsic.WithPingCount(derpPingCount), - tsic.WithPingUntilDirect(false), - ) - if err != nil { - t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) - } else { - success++ - } - } - } - - return success -} - -func isSelfClient(client TailscaleClient, addr string) bool { - if addr == client.Hostname() { - return true - } - - ips, err := client.IPs() - if err != nil { - return false - } - - for _, ip := range ips { - if ip.String() == addr { - return true - } - } - - return false -} - -func isCI() bool { - if _, ok := os.LookupEnv("CI"); ok { - return true - } - - if _, ok := os.LookupEnv("GITHUB_RUN_ID"); ok { - return true - } - - return false -} - -func dockertestMaxWait() time.Duration { - wait := 60 * time.Second //nolint - - if isCI() { - wait = 300 * time.Second //nolint - } - - return wait -} - -// func dockertestCommandTimeout() time.Duration { -// timeout := 10 * time.Second //nolint -// -// if isCI() { -// timeout = 60 * time.Second //nolint -// } -// -// return timeout -// } - -// pingAllNegativeHelper is intended to have 1 or more nodes timeing out from the ping, -// it counts failures instead of successes. -// func pingAllNegativeHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { -// t.Helper() -// failures := 0 -// -// timeout := 100 -// count := 3 -// -// for _, client := range clients { -// for _, addr := range addrs { -// err := client.Ping( -// addr, -// tsic.WithPingTimeout(time.Duration(timeout)*time.Millisecond), -// tsic.WithPingCount(count), -// ) -// if err != nil { -// failures++ -// } -// } -// } -// -// return failures -// } - -// // findPeerByIP takes an IP and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus -// // if there is a peer with the given IP. If no peer is found, nil is returned. -// func findPeerByIP( -// ip netip.Addr, -// peers map[key.NodePublic]*ipnstate.PeerStatus, -// ) *ipnstate.PeerStatus { -// for _, peer := range peers { -// for _, peerIP := range peer.TailscaleIPs { -// if ip == peerIP { -// return peer -// } -// } -// } -// -// return nil -// } -// -// // findPeerByHostname takes a hostname and a map of peers from status.Peer, and returns a *ipnstate.PeerStatus -// // if there is a peer with the given hostname. If no peer is found, nil is returned. -// func findPeerByHostname( -// hostname string, -// peers map[key.NodePublic]*ipnstate.PeerStatus, -// ) *ipnstate.PeerStatus { -// for _, peer := range peers { -// if hostname == peer.HostName { -// return peer -// } -// } -// -// return nil -// } diff --git a/mkdocs.yml b/mkdocs.yml index 86a15469..c5d82eec 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,5 +1,6 @@ +--- site_name: Headscale -site_url: https://juanfont.github.io/headscale +site_url: https://juanfont.github.io/headscale/ edit_uri: blob/main/docs/ # Change the master branch to main as we are using main as a main branch site_author: Headscale authors site_description: >- @@ -10,7 +11,7 @@ repo_name: juanfont/headscale repo_url: https://github.com/juanfont/headscale # Copyright -copyright: Copyright © 2023 Headscale authors +copyright: Copyright © 2025 Headscale authors # Configuration theme: @@ -40,31 +41,64 @@ theme: - toc.follow # - toc.integrate palette: - - scheme: default + - media: "(prefers-color-scheme)" + toggle: + icon: material/brightness-auto + name: Switch to light mode + - media: "(prefers-color-scheme: light)" + scheme: default primary: white toggle: icon: material/brightness-7 name: Switch to dark mode - - scheme: slate + - media: "(prefers-color-scheme: dark)" + scheme: slate toggle: icon: material/brightness-4 - name: Switch to light mode + name: Switch to system preference font: text: Roboto code: Roboto Mono favicon: assets/favicon.png - logo: ./logo/headscale3-dots.svg + logo: assets/logo/headscale3-dots.svg + +# Excludes +exclude_docs: | + /requirements.txt # Plugins plugins: - search: separator: '[\s\-,:!=\[\]()"`/]+|\.(?!\d)|&[lg]t;|(?!\b)(?=[A-Z][a-z])' + - macros: + - include-markdown: - minify: minify_html: true + - mike: - social: {} + - redirects: + redirect_maps: + acls.md: ref/acls.md + android-client.md: usage/connect/android.md + apple-client.md: usage/connect/apple.md + dns-records.md: ref/dns.md + exit-node.md: ref/routes.md + faq.md: about/faq.md + iOS-client.md: usage/connect/apple.md#ios + oidc.md: ref/oidc.md + ref/exit-node.md: ref/routes.md + ref/remote-cli.md: ref/api.md#grpc + remote-cli.md: ref/api.md#grpc + reverse-proxy.md: ref/integration/reverse-proxy.md + tls.md: ref/tls.md + web-ui.md: ref/integration/web-ui.md + windows-client.md: usage/connect/windows.md # Customization extra: + version: + alias: true + provider: mike annotate: json: [.s2] social: @@ -76,6 +110,8 @@ extra: link: https://github.com/juanfont/headscale/pkgs/container/headscale - icon: fontawesome/brands/discord link: https://discord.gg/c84AZQhmpx + headscale: + version: 0.28.0-beta.1 # Extensions markdown_extensions: @@ -121,27 +157,41 @@ markdown_extensions: # Page tree nav: - - Home: index.md - - FAQ: faq.md - - Getting started: + - Welcome: index.md + - About: + - FAQ: about/faq.md + - Features: about/features.md + - Clients: about/clients.md + - Getting help: about/help.md + - Releases: about/releases.md + - Contributing: about/contributing.md + - Sponsor: about/sponsor.md + + - Setup: + - Requirements and Assumptions: setup/requirements.md - Installation: - - Linux: running-headscale-linux.md - - OpenBSD: running-headscale-openbsd.md - - Container: running-headscale-container.md - - Configuration: - - Web UI: web-ui.md - - OIDC authentication: oidc.md - - Exit node: exit-node.md - - Reverse proxy: reverse-proxy.md - - TLS: tls.md - - ACLs: acls.md - - Custom DNS records: dns-records.md - - Remote CLI: remote-cli.md - - Usage: - - Android: android-client.md - - Windows: windows-client.md - - iOS: iOS-client.md - - Proposals: - - ACLs: proposals/001-acls.md - - Better routing: proposals/002-better-routing.md - - Glossary: glossary.md + - Official releases: setup/install/official.md + - Community packages: setup/install/community.md + - Container: setup/install/container.md + - Build from source: setup/install/source.md + - Upgrade: setup/upgrade.md + - Usage: + - Getting started: usage/getting-started.md + - Connect a node: + - Android: usage/connect/android.md + - Apple: usage/connect/apple.md + - Windows: usage/connect/windows.md + - Reference: + - Configuration: ref/configuration.md + - OpenID Connect: ref/oidc.md + - Routes: ref/routes.md + - TLS: ref/tls.md + - ACLs: ref/acls.md + - DNS: ref/dns.md + - DERP: ref/derp.md + - API: ref/api.md + - Debug: ref/debug.md + - Integration: + - Reverse proxy: ref/integration/reverse-proxy.md + - Web UI: ref/integration/web-ui.md + - Tools: ref/integration/tools.md diff --git a/nix/README.md b/nix/README.md new file mode 100644 index 00000000..533e4b5e --- /dev/null +++ b/nix/README.md @@ -0,0 +1,41 @@ +# Headscale NixOS Module + +This directory contains the NixOS module for Headscale. + +## Rationale + +The module is maintained in this repository to keep the code and module +synchronized at the same commit. This allows faster iteration and ensures the +module stays compatible with the latest Headscale changes. All changes should +aim to be upstreamed to nixpkgs. + +## Files + +- **[`module.nix`](./module.nix)** - The NixOS module implementation +- **[`example-configuration.nix`](./example-configuration.nix)** - Example + configuration demonstrating all major features +- **[`tests/`](./tests/)** - NixOS integration tests + +## Usage + +Add to your flake inputs: + +```nix +inputs.headscale.url = "github:juanfont/headscale"; +``` + +Then import the module: + +```nix +imports = [ inputs.headscale.nixosModules.default ]; +``` + +See [`example-configuration.nix`](./example-configuration.nix) for configuration +options. + +## Upstream + +- [nixpkgs module](https://github.com/NixOS/nixpkgs/blob/master/nixos/modules/services/networking/headscale.nix) +- [nixpkgs package](https://github.com/NixOS/nixpkgs/blob/master/pkgs/by-name/he/headscale/package.nix) + +The module in this repository may be newer than the nixpkgs version. diff --git a/nix/example-configuration.nix b/nix/example-configuration.nix new file mode 100644 index 00000000..e1f6cec7 --- /dev/null +++ b/nix/example-configuration.nix @@ -0,0 +1,145 @@ +# Example NixOS configuration using the headscale module +# +# This file demonstrates how to use the headscale NixOS module from this flake. +# To use in your own configuration, add this to your flake.nix inputs: +# +# inputs.headscale.url = "github:juanfont/headscale"; +# +# Then import the module: +# +# imports = [ inputs.headscale.nixosModules.default ]; +# + +{ config, pkgs, ... }: + +{ + # Import the headscale module + # In a real configuration, this would come from the flake input + # imports = [ inputs.headscale.nixosModules.default ]; + + services.headscale = { + enable = true; + + # Optional: Use a specific package (defaults to pkgs.headscale) + # package = pkgs.headscale; + + # Listen on all interfaces (default is 127.0.0.1) + address = "0.0.0.0"; + port = 8080; + + settings = { + # The URL clients will connect to + server_url = "https://headscale.example.com"; + + # IP prefixes for the tailnet + # These use the freeform settings - you can set any headscale config option + prefixes = { + v4 = "100.64.0.0/10"; + v6 = "fd7a:115c:a1e0::/48"; + allocation = "sequential"; + }; + + # DNS configuration with MagicDNS + dns = { + magic_dns = true; + base_domain = "tailnet.example.com"; + + # Whether to override client's local DNS settings (default: true) + # When true, nameservers.global must be set + override_local_dns = true; + + nameservers = { + global = [ "1.1.1.1" "8.8.8.8" ]; + }; + }; + + # DERP (relay) configuration + derp = { + # Use default Tailscale DERP servers + urls = [ "https://controlplane.tailscale.com/derpmap/default" ]; + auto_update_enabled = true; + update_frequency = "24h"; + + # Optional: Run your own DERP server + # server = { + # enabled = true; + # region_id = 999; + # stun_listen_addr = "0.0.0.0:3478"; + # }; + }; + + # Database configuration (SQLite is recommended) + database = { + type = "sqlite"; + sqlite = { + path = "/var/lib/headscale/db.sqlite"; + write_ahead_log = true; + }; + + # PostgreSQL example (not recommended for new deployments) + # type = "postgres"; + # postgres = { + # host = "localhost"; + # port = 5432; + # name = "headscale"; + # user = "headscale"; + # password_file = "/run/secrets/headscale-db-password"; + # }; + }; + + # Logging configuration + log = { + level = "info"; + format = "text"; + }; + + # Optional: OIDC authentication + # oidc = { + # issuer = "https://accounts.google.com"; + # client_id = "your-client-id"; + # client_secret_path = "/run/secrets/oidc-client-secret"; + # scope = [ "openid" "profile" "email" ]; + # allowed_domains = [ "example.com" ]; + # }; + + # Optional: Let's Encrypt TLS certificates + # tls_letsencrypt_hostname = "headscale.example.com"; + # tls_letsencrypt_challenge_type = "HTTP-01"; + + # Optional: Provide your own TLS certificates + # tls_cert_path = "/path/to/cert.pem"; + # tls_key_path = "/path/to/key.pem"; + + # ACL policy configuration + policy = { + mode = "file"; + path = "/var/lib/headscale/policy.hujson"; + }; + + # You can add ANY headscale configuration option here thanks to freeform settings + # For example, experimental features or settings not explicitly defined above: + # experimental_feature = true; + # custom_setting = "value"; + }; + }; + + # Optional: Open firewall ports + networking.firewall = { + allowedTCPPorts = [ 8080 ]; + # If running a DERP server: + # allowedUDPPorts = [ 3478 ]; + }; + + # Optional: Use with nginx reverse proxy for TLS termination + # services.nginx = { + # enable = true; + # virtualHosts."headscale.example.com" = { + # enableACME = true; + # forceSSL = true; + # locations."/" = { + # proxyPass = "http://127.0.0.1:8080"; + # proxyWebsockets = true; + # }; + # }; + # }; +} diff --git a/nix/module.nix b/nix/module.nix new file mode 100644 index 00000000..a75398fb --- /dev/null +++ b/nix/module.nix @@ -0,0 +1,727 @@ +{ config +, lib +, pkgs +, ... +}: +let + cfg = config.services.headscale; + + dataDir = "/var/lib/headscale"; + runDir = "/run/headscale"; + + cliConfig = { + # Turn off update checks since the origin of our package + # is nixpkgs and not Github. + disable_check_updates = true; + + unix_socket = "${runDir}/headscale.sock"; + }; + + settingsFormat = pkgs.formats.yaml { }; + configFile = settingsFormat.generate "headscale.yaml" cfg.settings; + cliConfigFile = settingsFormat.generate "headscale.yaml" cliConfig; + + assertRemovedOption = option: message: { + assertion = !lib.hasAttrByPath option cfg; + message = + "The option `services.headscale.${lib.options.showOption option}` was removed. " + message; + }; +in +{ + # Disable the upstream NixOS module to prevent conflicts + disabledModules = [ "services/networking/headscale.nix" ]; + + options = { + services.headscale = { + enable = lib.mkEnableOption "headscale, Open Source coordination server for Tailscale"; + + package = lib.mkPackageOption pkgs "headscale" { }; + + user = lib.mkOption { + default = "headscale"; + type = lib.types.str; + description = '' + User account under which headscale runs. + + ::: {.note} + If left as the default value this user will automatically be created + on system activation, otherwise you are responsible for + ensuring the user exists before the headscale service starts. + ::: + ''; + }; + + group = lib.mkOption { + default = "headscale"; + type = lib.types.str; + description = '' + Group under which headscale runs. + + ::: {.note} + If left as the default value this group will automatically be created + on system activation, otherwise you are responsible for + ensuring the user exists before the headscale service starts. + ::: + ''; + }; + + address = lib.mkOption { + type = lib.types.str; + default = "127.0.0.1"; + description = '' + Listening address of headscale. + ''; + example = "0.0.0.0"; + }; + + port = lib.mkOption { + type = lib.types.port; + default = 8080; + description = '' + Listening port of headscale. + ''; + example = 443; + }; + + settings = lib.mkOption { + description = '' + Overrides to {file}`config.yaml` as a Nix attribute set. + Check the [example config](https://github.com/juanfont/headscale/blob/main/config-example.yaml) + for possible options. + ''; + type = lib.types.submodule { + freeformType = settingsFormat.type; + + options = { + server_url = lib.mkOption { + type = lib.types.str; + default = "http://127.0.0.1:8080"; + description = '' + The url clients will connect to. + ''; + example = "https://myheadscale.example.com:443"; + }; + + noise.private_key_path = lib.mkOption { + type = lib.types.path; + default = "${dataDir}/noise_private.key"; + description = '' + Path to noise private key file, generated automatically if it does not exist. + ''; + }; + + prefixes = + let + prefDesc = '' + Each prefix consists of either an IPv4 or IPv6 address, + and the associated prefix length, delimited by a slash. + It must be within IP ranges supported by the Tailscale + client - i.e., subnets of 100.64.0.0/10 and fd7a:115c:a1e0::/48. + ''; + in + { + v4 = lib.mkOption { + type = lib.types.str; + default = "100.64.0.0/10"; + description = prefDesc; + }; + + v6 = lib.mkOption { + type = lib.types.str; + default = "fd7a:115c:a1e0::/48"; + description = prefDesc; + }; + + allocation = lib.mkOption { + type = lib.types.enum [ + "sequential" + "random" + ]; + example = "random"; + default = "sequential"; + description = '' + Strategy used for allocation of IPs to nodes, available options: + - sequential (default): assigns the next free IP from the previous given IP. + - random: assigns the next free IP from a pseudo-random IP generator (crypto/rand). + ''; + }; + }; + + derp = { + urls = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ "https://controlplane.tailscale.com/derpmap/default" ]; + description = '' + List of urls containing DERP maps. + See [How Tailscale works](https://tailscale.com/blog/how-tailscale-works/) for more information on DERP maps. + ''; + }; + + paths = lib.mkOption { + type = lib.types.listOf lib.types.path; + default = [ ]; + description = '' + List of file paths containing DERP maps. + See [How Tailscale works](https://tailscale.com/blog/how-tailscale-works/) for more information on DERP maps. + ''; + }; + + auto_update_enabled = lib.mkOption { + type = lib.types.bool; + default = true; + description = '' + Whether to automatically update DERP maps on a set frequency. + ''; + example = false; + }; + + update_frequency = lib.mkOption { + type = lib.types.str; + default = "24h"; + description = '' + Frequency to update DERP maps. + ''; + example = "5m"; + }; + + server.private_key_path = lib.mkOption { + type = lib.types.path; + default = "${dataDir}/derp_server_private.key"; + description = '' + Path to derp private key file, generated automatically if it does not exist. + ''; + }; + }; + + ephemeral_node_inactivity_timeout = lib.mkOption { + type = lib.types.str; + default = "30m"; + description = '' + Time before an inactive ephemeral node is deleted. + ''; + example = "5m"; + }; + + database = { + type = lib.mkOption { + type = lib.types.enum [ + "sqlite" + "sqlite3" + "postgres" + ]; + example = "postgres"; + default = "sqlite"; + description = '' + Database engine to use. + Please note that using Postgres is highly discouraged as it is only supported for legacy reasons. + All new development, testing and optimisations are done with SQLite in mind. + ''; + }; + + sqlite = { + path = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = "${dataDir}/db.sqlite"; + description = "Path to the sqlite3 database file."; + }; + + write_ahead_log = lib.mkOption { + type = lib.types.bool; + default = true; + description = '' + Enable WAL mode for SQLite. This is recommended for production environments. + <https://www.sqlite.org/wal.html> + ''; + example = true; + }; + }; + + postgres = { + host = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = null; + example = "127.0.0.1"; + description = "Database host address."; + }; + + port = lib.mkOption { + type = lib.types.nullOr lib.types.port; + default = null; + example = 3306; + description = "Database host port."; + }; + + name = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = null; + example = "headscale"; + description = "Database name."; + }; + + user = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = null; + example = "headscale"; + description = "Database user."; + }; + + password_file = lib.mkOption { + type = lib.types.nullOr lib.types.path; + default = null; + example = "/run/keys/headscale-dbpassword"; + description = '' + A file containing the password corresponding to + {option}`database.user`. + ''; + }; + }; + }; + + log = { + level = lib.mkOption { + type = lib.types.str; + default = "info"; + description = '' + headscale log level. + ''; + example = "debug"; + }; + + format = lib.mkOption { + type = lib.types.str; + default = "text"; + description = '' + headscale log format. + ''; + example = "json"; + }; + }; + + dns = { + magic_dns = lib.mkOption { + type = lib.types.bool; + default = true; + description = '' + Whether to use [MagicDNS](https://tailscale.com/kb/1081/magicdns/). + ''; + example = false; + }; + + base_domain = lib.mkOption { + type = lib.types.str; + default = ""; + description = '' + Defines the base domain to create the hostnames for MagicDNS. + This domain must be different from the {option}`server_url` + domain. + {option}`base_domain` must be a FQDN, without the trailing dot. + The FQDN of the hosts will be `hostname.base_domain` (e.g. + `myhost.tailnet.example.com`). + ''; + example = "tailnet.example.com"; + }; + + override_local_dns = lib.mkOption { + type = lib.types.bool; + default = true; + description = '' + Whether to use the local DNS settings of a node or override + the local DNS settings and force the use of Headscale's DNS + configuration. + ''; + example = false; + }; + + nameservers = { + global = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ ]; + description = '' + List of nameservers to pass to Tailscale clients. + Required when {option}`override_local_dns` is true. + ''; + }; + }; + + search_domains = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ ]; + description = '' + Search domains to inject to Tailscale clients. + ''; + example = [ "mydomain.internal" ]; + }; + }; + + oidc = { + issuer = lib.mkOption { + type = lib.types.str; + default = ""; + description = '' + URL to OpenID issuer. + ''; + example = "https://openid.example.com"; + }; + + client_id = lib.mkOption { + type = lib.types.str; + default = ""; + description = '' + OpenID Connect client ID. + ''; + }; + + client_secret_path = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = null; + description = '' + Path to OpenID Connect client secret file. Expands environment variables in format ''${VAR}. + ''; + }; + + scope = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ + "openid" + "profile" + "email" + ]; + description = '' + Scopes used in the OIDC flow. + ''; + }; + + extra_params = lib.mkOption { + type = lib.types.attrsOf lib.types.str; + default = { }; + description = '' + Custom query parameters to send with the Authorize Endpoint request. + ''; + example = { + domain_hint = "example.com"; + }; + }; + + allowed_domains = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ ]; + description = '' + Allowed principal domains. if an authenticated user's domain + is not in this list authentication request will be rejected. + ''; + example = [ "example.com" ]; + }; + + allowed_users = lib.mkOption { + type = lib.types.listOf lib.types.str; + default = [ ]; + description = '' + Users allowed to authenticate even if not in allowedDomains. + ''; + example = [ "alice@example.com" ]; + }; + + pkce = { + enabled = lib.mkOption { + type = lib.types.bool; + default = false; + description = '' + Enable or disable PKCE (Proof Key for Code Exchange) support. + PKCE adds an additional layer of security to the OAuth 2.0 + authorization code flow by preventing authorization code + interception attacks + See https://datatracker.ietf.org/doc/html/rfc7636 + ''; + example = true; + }; + + method = lib.mkOption { + type = lib.types.str; + default = "S256"; + description = '' + PKCE method to use: + - plain: Use plain code verifier + - S256: Use SHA256 hashed code verifier (default, recommended) + ''; + }; + }; + }; + + tls_letsencrypt_hostname = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = ""; + description = '' + Domain name to request a TLS certificate for. + ''; + }; + + tls_letsencrypt_challenge_type = lib.mkOption { + type = lib.types.enum [ + "TLS-ALPN-01" + "HTTP-01" + ]; + default = "HTTP-01"; + description = '' + Type of ACME challenge to use, currently supported types: + `HTTP-01` or `TLS-ALPN-01`. + ''; + }; + + tls_letsencrypt_listen = lib.mkOption { + type = lib.types.nullOr lib.types.str; + default = ":http"; + description = '' + When HTTP-01 challenge is chosen, letsencrypt must set up a + verification endpoint, and it will be listening on: + `:http = port 80`. + ''; + }; + + tls_cert_path = lib.mkOption { + type = lib.types.nullOr lib.types.path; + default = null; + description = '' + Path to already created certificate. + ''; + }; + + tls_key_path = lib.mkOption { + type = lib.types.nullOr lib.types.path; + default = null; + description = '' + Path to key for already created certificate. + ''; + }; + + policy = { + mode = lib.mkOption { + type = lib.types.enum [ + "file" + "database" + ]; + default = "file"; + description = '' + The mode can be "file" or "database" that defines + where the ACL policies are stored and read from. + ''; + }; + + path = lib.mkOption { + type = lib.types.nullOr lib.types.path; + default = null; + description = '' + If the mode is set to "file", the path to a + HuJSON file containing ACL policies. + ''; + }; + }; + }; + }; + }; + }; + }; + + imports = with lib; [ + (mkRenamedOptionModule + [ "services" "headscale" "derp" "autoUpdate" ] + [ "services" "headscale" "settings" "derp" "auto_update_enabled" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "derp" "auto_update_enable" ] + [ "services" "headscale" "settings" "derp" "auto_update_enabled" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "derp" "paths" ] + [ "services" "headscale" "settings" "derp" "paths" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "derp" "updateFrequency" ] + [ "services" "headscale" "settings" "derp" "update_frequency" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "derp" "urls" ] + [ "services" "headscale" "settings" "derp" "urls" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "ephemeralNodeInactivityTimeout" ] + [ "services" "headscale" "settings" "ephemeral_node_inactivity_timeout" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "logLevel" ] + [ "services" "headscale" "settings" "log" "level" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "openIdConnect" "clientId" ] + [ "services" "headscale" "settings" "oidc" "client_id" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "openIdConnect" "clientSecretFile" ] + [ "services" "headscale" "settings" "oidc" "client_secret_path" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "openIdConnect" "issuer" ] + [ "services" "headscale" "settings" "oidc" "issuer" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "serverUrl" ] + [ "services" "headscale" "settings" "server_url" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "tls" "certFile" ] + [ "services" "headscale" "settings" "tls_cert_path" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "tls" "keyFile" ] + [ "services" "headscale" "settings" "tls_key_path" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "tls" "letsencrypt" "challengeType" ] + [ "services" "headscale" "settings" "tls_letsencrypt_challenge_type" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "tls" "letsencrypt" "hostname" ] + [ "services" "headscale" "settings" "tls_letsencrypt_hostname" ] + ) + (mkRenamedOptionModule + [ "services" "headscale" "tls" "letsencrypt" "httpListen" ] + [ "services" "headscale" "settings" "tls_letsencrypt_listen" ] + ) + + (mkRemovedOptionModule [ "services" "headscale" "openIdConnect" "domainMap" ] '' + Headscale no longer uses domain_map. If you're using an old version of headscale you can still set this option via services.headscale.settings.oidc.domain_map. + '') + ]; + + config = lib.mkIf cfg.enable { + assertions = [ + { + assertion = with cfg.settings; dns.magic_dns -> dns.base_domain != ""; + message = "dns.base_domain must be set when using MagicDNS"; + } + { + assertion = with cfg.settings; dns.override_local_dns -> (dns.nameservers.global != [ ]); + message = "dns.nameservers.global must be set when dns.override_local_dns is true"; + } + (assertRemovedOption [ "settings" "acl_policy_path" ] "Use `policy.path` instead.") + (assertRemovedOption [ "settings" "db_host" ] "Use `database.postgres.host` instead.") + (assertRemovedOption [ "settings" "db_name" ] "Use `database.postgres.name` instead.") + (assertRemovedOption [ + "settings" + "db_password_file" + ] "Use `database.postgres.password_file` instead.") + (assertRemovedOption [ "settings" "db_path" ] "Use `database.sqlite.path` instead.") + (assertRemovedOption [ "settings" "db_port" ] "Use `database.postgres.port` instead.") + (assertRemovedOption [ "settings" "db_type" ] "Use `database.type` instead.") + (assertRemovedOption [ "settings" "db_user" ] "Use `database.postgres.user` instead.") + (assertRemovedOption [ "settings" "dns_config" ] "Use `dns` instead.") + (assertRemovedOption [ "settings" "dns_config" "domains" ] "Use `dns.search_domains` instead.") + (assertRemovedOption [ + "settings" + "dns_config" + "nameservers" + ] "Use `dns.nameservers.global` instead.") + (assertRemovedOption [ + "settings" + "oidc" + "strip_email_domain" + ] "The strip_email_domain option got removed upstream") + ]; + + services.headscale.settings = lib.mkMerge [ + cliConfig + { + listen_addr = lib.mkDefault "${cfg.address}:${toString cfg.port}"; + + tls_letsencrypt_cache_dir = "${dataDir}/.cache"; + } + ]; + + environment = { + # Headscale CLI needs a minimal config to be able to locate the unix socket + # to talk to the server instance. + etc."headscale/config.yaml".source = cliConfigFile; + + systemPackages = [ cfg.package ]; + }; + + users.groups.headscale = lib.mkIf (cfg.group == "headscale") { }; + + users.users.headscale = lib.mkIf (cfg.user == "headscale") { + description = "headscale user"; + home = dataDir; + group = cfg.group; + isSystemUser = true; + }; + + systemd.services.headscale = { + description = "headscale coordination server for Tailscale"; + wants = [ "network-online.target" ]; + after = [ "network-online.target" ]; + wantedBy = [ "multi-user.target" ]; + + script = '' + ${lib.optionalString (cfg.settings.database.postgres.password_file != null) '' + export HEADSCALE_DATABASE_POSTGRES_PASS="$(head -n1 ${lib.escapeShellArg cfg.settings.database.postgres.password_file})" + ''} + + exec ${lib.getExe cfg.package} serve --config ${configFile} + ''; + + serviceConfig = + let + capabilityBoundingSet = [ "CAP_CHOWN" ] ++ lib.optional (cfg.port < 1024) "CAP_NET_BIND_SERVICE"; + in + { + Restart = "always"; + RestartSec = "5s"; + Type = "simple"; + User = cfg.user; + Group = cfg.group; + + # Hardening options + RuntimeDirectory = "headscale"; + # Allow headscale group access so users can be added and use the CLI. + RuntimeDirectoryMode = "0750"; + + StateDirectory = "headscale"; + StateDirectoryMode = "0750"; + + ProtectSystem = "strict"; + ProtectHome = true; + PrivateTmp = true; + PrivateDevices = true; + ProtectKernelTunables = true; + ProtectControlGroups = true; + RestrictSUIDSGID = true; + PrivateMounts = true; + ProtectKernelModules = true; + ProtectKernelLogs = true; + ProtectHostname = true; + ProtectClock = true; + ProtectProc = "invisible"; + ProcSubset = "pid"; + RestrictNamespaces = true; + RemoveIPC = true; + UMask = "0077"; + + CapabilityBoundingSet = capabilityBoundingSet; + AmbientCapabilities = capabilityBoundingSet; + NoNewPrivileges = true; + LockPersonality = true; + RestrictRealtime = true; + SystemCallFilter = [ + "@system-service" + "~@privileged" + "@chown" + ]; + SystemCallArchitectures = "native"; + RestrictAddressFamilies = "AF_INET AF_INET6 AF_UNIX"; + }; + }; + }; + + meta.maintainers = with lib.maintainers; [ + kradalby + misterio77 + ]; +} diff --git a/nix/tests/headscale.nix b/nix/tests/headscale.nix new file mode 100644 index 00000000..7dc93870 --- /dev/null +++ b/nix/tests/headscale.nix @@ -0,0 +1,102 @@ +{ pkgs, lib, ... }: +let + tls-cert = pkgs.runCommand "selfSignedCerts" { buildInputs = [ pkgs.openssl ]; } '' + openssl req \ + -x509 -newkey rsa:4096 -sha256 -days 365 \ + -nodes -out cert.pem -keyout key.pem \ + -subj '/CN=headscale' -addext "subjectAltName=DNS:headscale" + + mkdir -p $out + cp key.pem cert.pem $out + ''; +in +{ + name = "headscale"; + meta.maintainers = with lib.maintainers; [ + kradalby + misterio77 + ]; + + nodes = + let + headscalePort = 8080; + stunPort = 3478; + peer = { + services.tailscale.enable = true; + security.pki.certificateFiles = [ "${tls-cert}/cert.pem" ]; + }; + in + { + peer1 = peer; + peer2 = peer; + + headscale = { + services = { + headscale = { + enable = true; + port = headscalePort; + settings = { + server_url = "https://headscale"; + ip_prefixes = [ "100.64.0.0/10" ]; + derp.server = { + enabled = true; + region_id = 999; + stun_listen_addr = "0.0.0.0:${toString stunPort}"; + }; + dns = { + base_domain = "tailnet"; + extra_records = [ + { + name = "foo.bar"; + type = "A"; + value = "100.64.0.2"; + } + ]; + override_local_dns = false; + }; + }; + }; + nginx = { + enable = true; + virtualHosts.headscale = { + addSSL = true; + sslCertificate = "${tls-cert}/cert.pem"; + sslCertificateKey = "${tls-cert}/key.pem"; + locations."/" = { + proxyPass = "http://127.0.0.1:${toString headscalePort}"; + proxyWebsockets = true; + }; + }; + }; + }; + networking.firewall = { + allowedTCPPorts = [ + 80 + 443 + ]; + allowedUDPPorts = [ stunPort ]; + }; + environment.systemPackages = [ pkgs.headscale ]; + }; + }; + + testScript = '' + start_all() + headscale.wait_for_unit("headscale") + headscale.wait_for_open_port(443) + + # Create headscale user and preauth-key + headscale.succeed("headscale users create test") + authkey = headscale.succeed("headscale preauthkeys -u 1 create --reusable") + + # Connect peers + up_cmd = f"tailscale up --login-server 'https://headscale' --auth-key {authkey}" + peer1.execute(up_cmd) + peer2.execute(up_cmd) + + # Check that they are reachable from the tailnet + peer1.wait_until_succeeds("tailscale ping peer2") + peer2.wait_until_succeeds("tailscale ping peer1.tailnet") + assert (res := peer1.wait_until_succeeds("${lib.getExe pkgs.dig} +short foo.bar").strip()) == "100.64.0.2", f"Domain {res} did not match 100.64.0.2" + ''; +} diff --git a/packaging/README.md b/packaging/README.md new file mode 100644 index 00000000..b731d3f0 --- /dev/null +++ b/packaging/README.md @@ -0,0 +1,5 @@ +# Packaging + +We use [nFPM](https://nfpm.goreleaser.com/) for making `.deb` packages. + +This folder contains files we need to package with these releases. diff --git a/packaging/deb/postinst b/packaging/deb/postinst new file mode 100644 index 00000000..d249a432 --- /dev/null +++ b/packaging/deb/postinst @@ -0,0 +1,87 @@ +#!/bin/sh +# postinst script for headscale. + +set -e + +# Summary of how this script can be called: +# * <postinst> 'configure' <most-recently-configured-version> +# * <old-postinst> 'abort-upgrade' <new version> +# * <conflictor's-postinst> 'abort-remove' 'in-favour' <package> +# <new-version> +# * <postinst> 'abort-remove' +# * <deconfigured's-postinst> 'abort-deconfigure' 'in-favour' +# <failed-install-package> <version> 'removing' +# <conflicting-package> <version> +# for details, see https://www.debian.org/doc/debian-policy/ or +# the debian-policy package. + +HEADSCALE_USER="headscale" +HEADSCALE_GROUP="headscale" +HEADSCALE_HOME_DIR="/var/lib/headscale" +HEADSCALE_SHELL="/usr/sbin/nologin" +HEADSCALE_SERVICE="headscale.service" + +case "$1" in + configure) + groupadd --force --system "$HEADSCALE_GROUP" + if ! id -u "$HEADSCALE_USER" >/dev/null 2>&1; then + useradd --system --shell "$HEADSCALE_SHELL" \ + --gid "$HEADSCALE_GROUP" --home-dir "$HEADSCALE_HOME_DIR" \ + --comment "headscale default user" "$HEADSCALE_USER" + fi + + if dpkg --compare-versions "$2" lt-nl "0.27"; then + # < 0.24.0-beta.1 used /home/headscale as home and /bin/sh as shell. + # The directory /home/headscale was not created by the package or + # useradd but the service always used /var/lib/headscale which was + # always shipped by the package as empty directory. Previous versions + # of the package did not update the user account properties. + usermod --home "$HEADSCALE_HOME_DIR" --shell "$HEADSCALE_SHELL" \ + "$HEADSCALE_USER" >/dev/null + fi + + if dpkg --compare-versions "$2" lt-nl "0.27" \ + && [ $(id --user "$HEADSCALE_USER") -ge 1000 ] \ + && [ $(id --group "$HEADSCALE_GROUP") -ge 1000 ]; then + # < 0.26.0-beta.1 created a regular user/group to run headscale. + # Previous versions of the package did not migrate to system uid/gid. + # Assume that the *default* uid/gid range is in use and only run this + # migration when the current uid/gid is allocated in the user range. + # Create a temporary system user/group to guarantee the allocation of a + # uid/gid in the system range. Assign this new uid/gid to the existing + # user and group and remove the temporary user/group afterwards. + tmp_name="headscaletmp" + useradd --system --no-log-init --no-create-home --shell "$HEADSCALE_SHELL" "$tmp_name" + tmp_uid="$(id --user "$tmp_name")" + tmp_gid="$(id --group "$tmp_name")" + usermod --non-unique --uid "$tmp_uid" --gid "$tmp_gid" "$HEADSCALE_USER" + groupmod --non-unique --gid "$tmp_gid" "$HEADSCALE_USER" + userdel --force "$tmp_name" + fi + + # Enable service and keep track of its state + if deb-systemd-helper --quiet was-enabled "$HEADSCALE_SERVICE"; then + deb-systemd-helper enable "$HEADSCALE_SERVICE" >/dev/null || true + else + deb-systemd-helper update-state "$HEADSCALE_SERVICE" >/dev/null || true + fi + + # Bounce service + if [ -d /run/systemd/system ]; then + systemctl --system daemon-reload >/dev/null || true + if [ -n "$2" ]; then + deb-systemd-invoke restart "$HEADSCALE_SERVICE" >/dev/null || true + else + deb-systemd-invoke start "$HEADSCALE_SERVICE" >/dev/null || true + fi + fi + ;; + + abort-upgrade|abort-remove|abort-deconfigure) + ;; + + *) + echo "postinst called with unknown argument '$1'" >&2 + exit 1 + ;; +esac diff --git a/packaging/deb/postrm b/packaging/deb/postrm new file mode 100644 index 00000000..664bc51e --- /dev/null +++ b/packaging/deb/postrm @@ -0,0 +1,42 @@ +#!/bin/sh +# postrm script for headscale. + +set -e + +# Summary of how this script can be called: +# * <postrm> 'remove' +# * <postrm> 'purge' +# * <old-postrm> 'upgrade' <new-version> +# * <new-postrm> 'failed-upgrade' <old-version> +# * <new-postrm> 'abort-install' +# * <new-postrm> 'abort-install' <old-version> +# * <new-postrm> 'abort-upgrade' <old-version> +# * <disappearer's-postrm> 'disappear' <overwriter> +# <overwriter-version> +# for details, see https://www.debian.org/doc/debian-policy/ or +# the debian-policy package. + + +case "$1" in + remove) + if [ -d /run/systemd/system ]; then + systemctl --system daemon-reload >/dev/null || true + fi + ;; + + purge) + userdel headscale + rm -rf /var/lib/headscale + if [ -x "/usr/bin/deb-systemd-helper" ]; then + deb-systemd-helper purge headscale.service >/dev/null || true + fi + ;; + + upgrade|failed-upgrade|abort-install|abort-upgrade|disappear) + ;; + + *) + echo "postrm called with unknown argument '$1'" >&2 + exit 1 + ;; +esac diff --git a/packaging/deb/prerm b/packaging/deb/prerm new file mode 100644 index 00000000..2cee63a2 --- /dev/null +++ b/packaging/deb/prerm @@ -0,0 +1,34 @@ +#!/bin/sh +# prerm script for headscale. + +set -e + +# Summary of how this script can be called: +# * <prerm> 'remove' +# * <old-prerm> 'upgrade' <new-version> +# * <new-prerm> 'failed-upgrade' <old-version> +# * <conflictor's-prerm> 'remove' 'in-favour' <package> <new-version> +# * <deconfigured's-prerm> 'deconfigure' 'in-favour' +# <package-being-installed> <version> 'removing' +# <conflicting-package> <version> +# for details, see https://www.debian.org/doc/debian-policy/ or +# the debian-policy package. + + +case "$1" in + remove) + if [ -d /run/systemd/system ]; then + deb-systemd-invoke stop headscale.service >/dev/null || true + fi + ;; + upgrade|deconfigure) + ;; + + failed-upgrade) + ;; + + *) + echo "prerm called with unknown argument '$1'" >&2 + exit 1 + ;; +esac diff --git a/docs/packaging/headscale.systemd.service b/packaging/systemd/headscale.service similarity index 93% rename from docs/packaging/headscale.systemd.service rename to packaging/systemd/headscale.service index 14e31618..7d20444f 100644 --- a/docs/packaging/headscale.systemd.service +++ b/packaging/systemd/headscale.service @@ -1,5 +1,4 @@ [Unit] -After=syslog.target After=network.target Description=headscale coordination server for Tailscale X-Restart-Triggers=/etc/headscale/config.yaml @@ -9,11 +8,12 @@ Type=simple User=headscale Group=headscale ExecStart=/usr/bin/headscale serve +ExecReload=/usr/bin/kill -HUP $MAINPID Restart=always RestartSec=5 WorkingDirectory=/var/lib/headscale -ReadWritePaths=/var/lib/headscale /var/run +ReadWritePaths=/var/lib/headscale AmbientCapabilities=CAP_NET_BIND_SERVICE CAP_CHOWN CapabilityBoundingSet=CAP_NET_BIND_SERVICE CAP_CHOWN diff --git a/proto/buf.lock b/proto/buf.lock index 7e075d76..31cd0644 100644 --- a/proto/buf.lock +++ b/proto/buf.lock @@ -4,12 +4,15 @@ deps: - remote: buf.build owner: googleapis repository: googleapis - commit: 62f35d8aed1149c291d606d958a7ce32 + commit: 61b203b9a9164be9a834f58c37be6f62 + digest: shake256:e619113001d6e284ee8a92b1561e5d4ea89a47b28bf0410815cb2fa23914df8be9f1a6a98dcf069f5bc2d829a2cfb1ac614863be45cd4f8a5ad8606c5f200224 - remote: buf.build owner: grpc-ecosystem repository: grpc-gateway - commit: bc28b723cd774c32b6fbc77621518765 + commit: 4c5ba75caaf84e928b7137ae5c18c26a + digest: shake256:e174ad9408f3e608f6157907153ffec8d310783ee354f821f57178ffbeeb8faa6bb70b41b61099c1783c82fe16210ebd1279bc9c9ee6da5cffba9f0e675b8b99 - remote: buf.build owner: ufoundit-dev repository: protoc-gen-gorm commit: e2ecbaa0d37843298104bd29fd866df8 + digest: shake256:088347669906bc49513b40d58fd7ae877769668928fca038e070732ce0f9855c03f21885b0099e0d27acf9475feca0a34dbcedac22bb374bf2cd7c1e352de56c diff --git a/proto/headscale/v1/apikey.proto b/proto/headscale/v1/apikey.proto index 749e5c22..6ea0d669 100644 --- a/proto/headscale/v1/apikey.proto +++ b/proto/headscale/v1/apikey.proto @@ -1,35 +1,35 @@ syntax = "proto3"; package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; import "google/protobuf/timestamp.proto"; message ApiKey { - uint64 id = 1; - string prefix = 2; - google.protobuf.Timestamp expiration = 3; - google.protobuf.Timestamp created_at = 4; - google.protobuf.Timestamp last_seen = 5; + uint64 id = 1; + string prefix = 2; + google.protobuf.Timestamp expiration = 3; + google.protobuf.Timestamp created_at = 4; + google.protobuf.Timestamp last_seen = 5; } -message CreateApiKeyRequest { - google.protobuf.Timestamp expiration = 1; -} +message CreateApiKeyRequest { google.protobuf.Timestamp expiration = 1; } -message CreateApiKeyResponse { - string api_key = 1; -} +message CreateApiKeyResponse { string api_key = 1; } message ExpireApiKeyRequest { - string prefix = 1; + string prefix = 1; + uint64 id = 2; } -message ExpireApiKeyResponse { +message ExpireApiKeyResponse {} + +message ListApiKeysRequest {} + +message ListApiKeysResponse { repeated ApiKey api_keys = 1; } + +message DeleteApiKeyRequest { + string prefix = 1; + uint64 id = 2; } -message ListApiKeysRequest { -} - -message ListApiKeysResponse { - repeated ApiKey api_keys = 1; -} +message DeleteApiKeyResponse {} diff --git a/proto/headscale/v1/device.proto b/proto/headscale/v1/device.proto index 207ff374..6c75df88 100644 --- a/proto/headscale/v1/device.proto +++ b/proto/headscale/v1/device.proto @@ -1,6 +1,6 @@ syntax = "proto3"; package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; import "google/protobuf/timestamp.proto"; @@ -8,76 +8,69 @@ import "google/protobuf/timestamp.proto"; // https://github.com/tailscale/tailscale/blob/main/api.md message Latency { - float latency_ms = 1; - bool preferred = 2; + float latency_ms = 1; + bool preferred = 2; } message ClientSupports { - bool hair_pinning = 1; - bool ipv6 = 2; - bool pcp = 3; - bool pmp = 4; - bool udp = 5; - bool upnp = 6; + bool hair_pinning = 1; + bool ipv6 = 2; + bool pcp = 3; + bool pmp = 4; + bool udp = 5; + bool upnp = 6; } message ClientConnectivity { - repeated string endpoints = 1; - string derp = 2; - bool mapping_varies_by_dest_ip = 3; - map<string, Latency> latency = 4; - ClientSupports client_supports = 5; + repeated string endpoints = 1; + string derp = 2; + bool mapping_varies_by_dest_ip = 3; + map<string, Latency> latency = 4; + ClientSupports client_supports = 5; } -message GetDeviceRequest { - string id = 1; -} +message GetDeviceRequest { string id = 1; } message GetDeviceResponse { - repeated string addresses = 1; - string id = 2; - string user = 3; - string name = 4; - string hostname = 5; - string client_version = 6; - bool update_available = 7; - string os = 8; - google.protobuf.Timestamp created = 9; - google.protobuf.Timestamp last_seen = 10; - bool key_expiry_disabled = 11; - google.protobuf.Timestamp expires = 12; - bool authorized = 13; - bool is_external = 14; - string machine_key = 15; - string node_key = 16; - bool blocks_incoming_connections = 17; - repeated string enabled_routes = 18; - repeated string advertised_routes = 19; - ClientConnectivity client_connectivity = 20; + repeated string addresses = 1; + string id = 2; + string user = 3; + string name = 4; + string hostname = 5; + string client_version = 6; + bool update_available = 7; + string os = 8; + google.protobuf.Timestamp created = 9; + google.protobuf.Timestamp last_seen = 10; + bool key_expiry_disabled = 11; + google.protobuf.Timestamp expires = 12; + bool authorized = 13; + bool is_external = 14; + string machine_key = 15; + string node_key = 16; + bool blocks_incoming_connections = 17; + repeated string enabled_routes = 18; + repeated string advertised_routes = 19; + ClientConnectivity client_connectivity = 20; } -message DeleteDeviceRequest { - string id = 1; -} +message DeleteDeviceRequest { string id = 1; } -message DeleteDeviceResponse { -} +message DeleteDeviceResponse {} -message GetDeviceRoutesRequest { - string id = 1; -} +message GetDeviceRoutesRequest { string id = 1; } message GetDeviceRoutesResponse { - repeated string enabled_routes = 1; - repeated string advertised_routes = 2; + repeated string enabled_routes = 1; + repeated string advertised_routes = 2; } message EnableDeviceRoutesRequest { - string id = 1; - repeated string routes = 2; + string id = 1; + repeated string routes = 2; } message EnableDeviceRoutesResponse { - repeated string enabled_routes = 1; - repeated string advertised_routes = 2; + repeated string enabled_routes = 1; + repeated string advertised_routes = 2; } diff --git a/proto/headscale/v1/headscale.proto b/proto/headscale/v1/headscale.proto index e113c192..5e556255 100644 --- a/proto/headscale/v1/headscale.proto +++ b/proto/headscale/v1/headscale.proto @@ -1,207 +1,225 @@ syntax = "proto3"; package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; import "google/api/annotations.proto"; import "headscale/v1/user.proto"; import "headscale/v1/preauthkey.proto"; import "headscale/v1/node.proto"; -import "headscale/v1/routes.proto"; import "headscale/v1/apikey.proto"; -// import "headscale/v1/device.proto"; +import "headscale/v1/policy.proto"; service HeadscaleService { - // --- User start --- - rpc GetUser(GetUserRequest) returns(GetUserResponse) { - option(google.api.http) = { - get : "/api/v1/user/{name}" - }; - } + // --- User start --- + rpc CreateUser(CreateUserRequest) returns (CreateUserResponse) { + option (google.api.http) = { + post : "/api/v1/user" + body : "*" + }; + } - rpc CreateUser(CreateUserRequest) returns(CreateUserResponse) { - option(google.api.http) = { - post : "/api/v1/user" - body : "*" - }; - } + rpc RenameUser(RenameUserRequest) returns (RenameUserResponse) { + option (google.api.http) = { + post : "/api/v1/user/{old_id}/rename/{new_name}" + }; + } - rpc RenameUser(RenameUserRequest) returns(RenameUserResponse) { - option(google.api.http) = { - post : "/api/v1/user/{old_name}/rename/{new_name}" - }; - } + rpc DeleteUser(DeleteUserRequest) returns (DeleteUserResponse) { + option (google.api.http) = { + delete : "/api/v1/user/{id}" + }; + } - rpc DeleteUser(DeleteUserRequest) returns(DeleteUserResponse) { - option(google.api.http) = { - delete : "/api/v1/user/{name}" - }; - } + rpc ListUsers(ListUsersRequest) returns (ListUsersResponse) { + option (google.api.http) = { + get : "/api/v1/user" + }; + } + // --- User end --- - rpc ListUsers(ListUsersRequest) returns(ListUsersResponse) { - option(google.api.http) = { - get : "/api/v1/user" - }; - } - // --- User end --- + // --- PreAuthKeys start --- + rpc CreatePreAuthKey(CreatePreAuthKeyRequest) + returns (CreatePreAuthKeyResponse) { + option (google.api.http) = { + post : "/api/v1/preauthkey" + body : "*" + }; + } - // --- PreAuthKeys start --- - rpc CreatePreAuthKey(CreatePreAuthKeyRequest) returns(CreatePreAuthKeyResponse) { - option(google.api.http) = { - post : "/api/v1/preauthkey" - body : "*" - }; - } + rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) + returns (ExpirePreAuthKeyResponse) { + option (google.api.http) = { + post : "/api/v1/preauthkey/expire" + body : "*" + }; + } - rpc ExpirePreAuthKey(ExpirePreAuthKeyRequest) returns(ExpirePreAuthKeyResponse) { - option(google.api.http) = { - post : "/api/v1/preauthkey/expire" - body : "*" - }; - } + rpc DeletePreAuthKey(DeletePreAuthKeyRequest) + returns (DeletePreAuthKeyResponse) { + option (google.api.http) = { + delete : "/api/v1/preauthkey" + }; + } - rpc ListPreAuthKeys(ListPreAuthKeysRequest) returns(ListPreAuthKeysResponse) { - option(google.api.http) = { - get : "/api/v1/preauthkey" - }; - } - // --- PreAuthKeys end --- + rpc ListPreAuthKeys(ListPreAuthKeysRequest) + returns (ListPreAuthKeysResponse) { + option (google.api.http) = { + get : "/api/v1/preauthkey" + }; + } + // --- PreAuthKeys end --- - // --- Node start --- - rpc DebugCreateNode(DebugCreateNodeRequest) returns(DebugCreateNodeResponse) { - option(google.api.http) = { - post : "/api/v1/debug/node" - body : "*" - }; - } + // --- Node start --- + rpc DebugCreateNode(DebugCreateNodeRequest) + returns (DebugCreateNodeResponse) { + option (google.api.http) = { + post : "/api/v1/debug/node" + body : "*" + }; + } - rpc GetNode(GetNodeRequest) returns(GetNodeResponse) { - option(google.api.http) = { - get : "/api/v1/node/{node_id}" - }; - } + rpc GetNode(GetNodeRequest) returns (GetNodeResponse) { + option (google.api.http) = { + get : "/api/v1/node/{node_id}" + }; + } - rpc SetTags(SetTagsRequest) returns(SetTagsResponse) { - option(google.api.http) = { - post : "/api/v1/node/{node_id}/tags" - body : "*" - }; - } + rpc SetTags(SetTagsRequest) returns (SetTagsResponse) { + option (google.api.http) = { + post : "/api/v1/node/{node_id}/tags" + body : "*" + }; + } - rpc RegisterNode(RegisterNodeRequest) returns(RegisterNodeResponse) { - option(google.api.http) = { - post : "/api/v1/node/register" - }; - } + rpc SetApprovedRoutes(SetApprovedRoutesRequest) + returns (SetApprovedRoutesResponse) { + option (google.api.http) = { + post : "/api/v1/node/{node_id}/approve_routes" + body : "*" + }; + } - rpc DeleteNode(DeleteNodeRequest) returns(DeleteNodeResponse) { - option(google.api.http) = { - delete : "/api/v1/node/{node_id}" - }; - } + rpc RegisterNode(RegisterNodeRequest) returns (RegisterNodeResponse) { + option (google.api.http) = { + post : "/api/v1/node/register" + }; + } - rpc ExpireNode(ExpireNodeRequest) returns(ExpireNodeResponse) { - option(google.api.http) = { - post : "/api/v1/node/{node_id}/expire" - }; - } + rpc DeleteNode(DeleteNodeRequest) returns (DeleteNodeResponse) { + option (google.api.http) = { + delete : "/api/v1/node/{node_id}" + }; + } - rpc RenameNode(RenameNodeRequest) returns(RenameNodeResponse) { - option(google.api.http) = { - post : "/api/v1/node/{node_id}/rename/{new_name}" - }; - } + rpc ExpireNode(ExpireNodeRequest) returns (ExpireNodeResponse) { + option (google.api.http) = { + post : "/api/v1/node/{node_id}/expire" + }; + } - rpc ListNodes(ListNodesRequest) returns(ListNodesResponse) { - option(google.api.http) = { - get : "/api/v1/node" - }; - } + rpc RenameNode(RenameNodeRequest) returns (RenameNodeResponse) { + option (google.api.http) = { + post : "/api/v1/node/{node_id}/rename/{new_name}" + }; + } - rpc MoveNode(MoveNodeRequest) returns(MoveNodeResponse) { - option(google.api.http) = { - post : "/api/v1/node/{node_id}/user" - }; - } - // --- Node end --- + rpc ListNodes(ListNodesRequest) returns (ListNodesResponse) { + option (google.api.http) = { + get : "/api/v1/node" + }; + } - // --- Route start --- - rpc GetRoutes(GetRoutesRequest) returns(GetRoutesResponse) { - option(google.api.http) = { - get : "/api/v1/routes" - }; - } + rpc BackfillNodeIPs(BackfillNodeIPsRequest) + returns (BackfillNodeIPsResponse) { + option (google.api.http) = { + post : "/api/v1/node/backfillips" + }; + } - rpc EnableRoute(EnableRouteRequest) returns(EnableRouteResponse) { - option(google.api.http) = { - post : "/api/v1/routes/{route_id}/enable" - }; - } + // --- Node end --- - rpc DisableRoute(DisableRouteRequest) returns(DisableRouteResponse) { - option(google.api.http) = { - post : "/api/v1/routes/{route_id}/disable" - }; - } + // --- ApiKeys start --- + rpc CreateApiKey(CreateApiKeyRequest) returns (CreateApiKeyResponse) { + option (google.api.http) = { + post : "/api/v1/apikey" + body : "*" + }; + } - rpc GetNodeRoutes(GetNodeRoutesRequest) returns(GetNodeRoutesResponse) { - option(google.api.http) = { - get : "/api/v1/node/{node_id}/routes" - }; - } + rpc ExpireApiKey(ExpireApiKeyRequest) returns (ExpireApiKeyResponse) { + option (google.api.http) = { + post : "/api/v1/apikey/expire" + body : "*" + }; + } - rpc DeleteRoute(DeleteRouteRequest) returns(DeleteRouteResponse) { - option(google.api.http) = { - delete : "/api/v1/routes/{route_id}" - }; - } + rpc ListApiKeys(ListApiKeysRequest) returns (ListApiKeysResponse) { + option (google.api.http) = { + get : "/api/v1/apikey" + }; + } - // --- Route end --- + rpc DeleteApiKey(DeleteApiKeyRequest) returns (DeleteApiKeyResponse) { + option (google.api.http) = { + delete : "/api/v1/apikey/{prefix}" + }; + } + // --- ApiKeys end --- - // --- ApiKeys start --- - rpc CreateApiKey(CreateApiKeyRequest) returns(CreateApiKeyResponse) { - option(google.api.http) = { - post : "/api/v1/apikey" - body : "*" - }; - } + // --- Policy start --- + rpc GetPolicy(GetPolicyRequest) returns (GetPolicyResponse) { + option (google.api.http) = { + get : "/api/v1/policy" + }; + } - rpc ExpireApiKey(ExpireApiKeyRequest) returns(ExpireApiKeyResponse) { - option(google.api.http) = { - post : "/api/v1/apikey/expire" - body : "*" - }; - } + rpc SetPolicy(SetPolicyRequest) returns (SetPolicyResponse) { + option (google.api.http) = { + put : "/api/v1/policy" + body : "*" + }; + } + // --- Policy end --- - rpc ListApiKeys(ListApiKeysRequest) returns(ListApiKeysResponse) { - option(google.api.http) = { - get : "/api/v1/apikey" - }; - } - // --- ApiKeys end --- + // --- Health start --- + rpc Health(HealthRequest) returns (HealthResponse) { + option (google.api.http) = { + get : "/api/v1/health" + }; + } + // --- Health end --- - // Implement Tailscale API - // rpc GetDevice(GetDeviceRequest) returns(GetDeviceResponse) { - // option(google.api.http) = { - // get : "/api/v1/device/{id}" - // }; - // } + // Implement Tailscale API + // rpc GetDevice(GetDeviceRequest) returns(GetDeviceResponse) { + // option(google.api.http) = { + // get : "/api/v1/device/{id}" + // }; + // } - // rpc DeleteDevice(DeleteDeviceRequest) returns(DeleteDeviceResponse) { - // option(google.api.http) = { - // delete : "/api/v1/device/{id}" - // }; - // } + // rpc DeleteDevice(DeleteDeviceRequest) returns(DeleteDeviceResponse) { + // option(google.api.http) = { + // delete : "/api/v1/device/{id}" + // }; + // } - // rpc GetDeviceRoutes(GetDeviceRoutesRequest) returns(GetDeviceRoutesResponse) { - // option(google.api.http) = { - // get : "/api/v1/device/{id}/routes" - // }; - // } + // rpc GetDeviceRoutes(GetDeviceRoutesRequest) + // returns(GetDeviceRoutesResponse) { + // option(google.api.http) = { + // get : "/api/v1/device/{id}/routes" + // }; + // } - // rpc EnableDeviceRoutes(EnableDeviceRoutesRequest) returns(EnableDeviceRoutesResponse) { - // option(google.api.http) = { - // post : "/api/v1/device/{id}/routes" - // }; - // } + // rpc EnableDeviceRoutes(EnableDeviceRoutesRequest) + // returns(EnableDeviceRoutesResponse) { + // option(google.api.http) = { + // post : "/api/v1/device/{id}/routes" + // }; + // } +} + +message HealthRequest {} + +message HealthResponse { + bool database_connectivity = 1; } diff --git a/proto/headscale/v1/node.proto b/proto/headscale/v1/node.proto index 476aa59a..3ce83c4b 100644 --- a/proto/headscale/v1/node.proto +++ b/proto/headscale/v1/node.proto @@ -1,126 +1,142 @@ syntax = "proto3"; package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; import "google/protobuf/timestamp.proto"; -import "headscale/v1/user.proto"; import "headscale/v1/preauthkey.proto"; +import "headscale/v1/user.proto"; + +option go_package = "github.com/juanfont/headscale/gen/go/v1"; enum RegisterMethod { - REGISTER_METHOD_UNSPECIFIED = 0; - REGISTER_METHOD_AUTH_KEY = 1; - REGISTER_METHOD_CLI = 2; - REGISTER_METHOD_OIDC = 3; + REGISTER_METHOD_UNSPECIFIED = 0; + REGISTER_METHOD_AUTH_KEY = 1; + REGISTER_METHOD_CLI = 2; + REGISTER_METHOD_OIDC = 3; } message Node { - uint64 id = 1; - string machine_key = 2; - string node_key = 3; - string disco_key = 4; - repeated string ip_addresses = 5; - string name = 6; - User user = 7; + // 9: removal of last_successful_update + reserved 9; - google.protobuf.Timestamp last_seen = 8; - google.protobuf.Timestamp last_successful_update = 9; - google.protobuf.Timestamp expiry = 10; + uint64 id = 1; + string machine_key = 2; + string node_key = 3; + string disco_key = 4; + repeated string ip_addresses = 5; + string name = 6; + User user = 7; - PreAuthKey pre_auth_key = 11; + google.protobuf.Timestamp last_seen = 8; + google.protobuf.Timestamp expiry = 10; - google.protobuf.Timestamp created_at = 12; + PreAuthKey pre_auth_key = 11; - RegisterMethod register_method = 13; + google.protobuf.Timestamp created_at = 12; - reserved 14 to 17; - // google.protobuf.Timestamp updated_at = 14; - // google.protobuf.Timestamp deleted_at = 15; + RegisterMethod register_method = 13; - // bytes host_info = 15; - // bytes endpoints = 16; - // bytes enabled_routes = 17; + reserved 14 to 20; + // google.protobuf.Timestamp updated_at = 14; + // google.protobuf.Timestamp deleted_at = 15; - repeated string forced_tags = 18; - repeated string invalid_tags = 19; - repeated string valid_tags = 20; - string given_name = 21; - bool online = 22; + // bytes host_info = 15; + // bytes endpoints = 16; + // bytes enabled_routes = 17; + + // Deprecated + // repeated string forced_tags = 18; + // repeated string invalid_tags = 19; + // repeated string valid_tags = 20; + string given_name = 21; + bool online = 22; + repeated string approved_routes = 23; + repeated string available_routes = 24; + repeated string subnet_routes = 25; + repeated string tags = 26; } message RegisterNodeRequest { - string user = 1; - string key = 2; + string user = 1; + string key = 2; } message RegisterNodeResponse { - Node node = 1; + Node node = 1; } message GetNodeRequest { - uint64 node_id = 1; + uint64 node_id = 1; } message GetNodeResponse { - Node node = 1; + Node node = 1; } message SetTagsRequest { - uint64 node_id = 1; - repeated string tags = 2; + uint64 node_id = 1; + repeated string tags = 2; } message SetTagsResponse { - Node node = 1; + Node node = 1; +} + +message SetApprovedRoutesRequest { + uint64 node_id = 1; + repeated string routes = 2; +} + +message SetApprovedRoutesResponse { + Node node = 1; } message DeleteNodeRequest { - uint64 node_id = 1; + uint64 node_id = 1; } -message DeleteNodeResponse { -} +message DeleteNodeResponse {} message ExpireNodeRequest { - uint64 node_id = 1; + uint64 node_id = 1; + google.protobuf.Timestamp expiry = 2; } message ExpireNodeResponse { - Node node = 1; + Node node = 1; } message RenameNodeRequest { - uint64 node_id = 1; - string new_name = 2; + uint64 node_id = 1; + string new_name = 2; } message RenameNodeResponse { - Node node = 1; + Node node = 1; } message ListNodesRequest { - string user = 1; + string user = 1; } message ListNodesResponse { - repeated Node nodes = 1; -} - -message MoveNodeRequest { - uint64 node_id = 1; - string user = 2; -} - -message MoveNodeResponse { - Node node = 1; + repeated Node nodes = 1; } message DebugCreateNodeRequest { - string user = 1; - string key = 2; - string name = 3; - repeated string routes = 4; + string user = 1; + string key = 2; + string name = 3; + repeated string routes = 4; } message DebugCreateNodeResponse { - Node node = 1; + Node node = 1; +} + +message BackfillNodeIPsRequest { + bool confirmed = 1; +} + +message BackfillNodeIPsResponse { + repeated string changes = 1; } diff --git a/proto/headscale/v1/policy.proto b/proto/headscale/v1/policy.proto new file mode 100644 index 00000000..6c52c01f --- /dev/null +++ b/proto/headscale/v1/policy.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; +package headscale.v1; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; + +import "google/protobuf/timestamp.proto"; + +message SetPolicyRequest { string policy = 1; } + +message SetPolicyResponse { + string policy = 1; + google.protobuf.Timestamp updated_at = 2; +} + +message GetPolicyRequest {} + +message GetPolicyResponse { + string policy = 1; + google.protobuf.Timestamp updated_at = 2; +} diff --git a/proto/headscale/v1/preauthkey.proto b/proto/headscale/v1/preauthkey.proto index 7d0de294..04e88821 100644 --- a/proto/headscale/v1/preauthkey.proto +++ b/proto/headscale/v1/preauthkey.proto @@ -1,45 +1,49 @@ syntax = "proto3"; package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; import "google/protobuf/timestamp.proto"; +import "headscale/v1/user.proto"; + +option go_package = "github.com/juanfont/headscale/gen/go/v1"; message PreAuthKey { - string user = 1; - string id = 2; - string key = 3; - bool reusable = 4; - bool ephemeral = 5; - bool used = 6; - google.protobuf.Timestamp expiration = 7; - google.protobuf.Timestamp created_at = 8; - repeated string acl_tags = 9; + User user = 1; + uint64 id = 2; + string key = 3; + bool reusable = 4; + bool ephemeral = 5; + bool used = 6; + google.protobuf.Timestamp expiration = 7; + google.protobuf.Timestamp created_at = 8; + repeated string acl_tags = 9; } message CreatePreAuthKeyRequest { - string user = 1; - bool reusable = 2; - bool ephemeral = 3; - google.protobuf.Timestamp expiration = 4; - repeated string acl_tags = 5; + uint64 user = 1; + bool reusable = 2; + bool ephemeral = 3; + google.protobuf.Timestamp expiration = 4; + repeated string acl_tags = 5; } message CreatePreAuthKeyResponse { - PreAuthKey pre_auth_key = 1; + PreAuthKey pre_auth_key = 1; } message ExpirePreAuthKeyRequest { - string user = 1; - string key = 2; + uint64 id = 1; } -message ExpirePreAuthKeyResponse { +message ExpirePreAuthKeyResponse {} + +message DeletePreAuthKeyRequest { + uint64 id = 1; } -message ListPreAuthKeysRequest { - string user = 1; -} +message DeletePreAuthKeyResponse {} + +message ListPreAuthKeysRequest {} message ListPreAuthKeysResponse { - repeated PreAuthKey pre_auth_keys = 1; + repeated PreAuthKey pre_auth_keys = 1; } diff --git a/proto/headscale/v1/routes.proto b/proto/headscale/v1/routes.proto deleted file mode 100644 index ea900259..00000000 --- a/proto/headscale/v1/routes.proto +++ /dev/null @@ -1,55 +0,0 @@ -syntax = "proto3"; -package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; - -import "google/protobuf/timestamp.proto"; -import "headscale/v1/node.proto"; - -message Route { - uint64 id = 1; - Node node = 2; - string prefix = 3; - bool advertised = 4; - bool enabled = 5; - bool is_primary = 6; - - google.protobuf.Timestamp created_at = 7; - google.protobuf.Timestamp updated_at = 8; - google.protobuf.Timestamp deleted_at = 9; -} - -message GetRoutesRequest { -} - -message GetRoutesResponse { - repeated Route routes = 1; -} - -message EnableRouteRequest { - uint64 route_id = 1; -} - -message EnableRouteResponse { -} - -message DisableRouteRequest { - uint64 route_id = 1; -} - -message DisableRouteResponse { -} - -message GetNodeRoutesRequest { - uint64 node_id = 1; -} - -message GetNodeRoutesResponse { - repeated Route routes = 1; -} - -message DeleteRouteRequest { - uint64 route_id = 1; -} - -message DeleteRouteResponse { -} diff --git a/proto/headscale/v1/user.proto b/proto/headscale/v1/user.proto index 4bc3c886..bd71bcb1 100644 --- a/proto/headscale/v1/user.proto +++ b/proto/headscale/v1/user.proto @@ -1,50 +1,44 @@ syntax = "proto3"; package headscale.v1; -option go_package = "github.com/juanfont/headscale/gen/go/v1"; +option go_package = "github.com/juanfont/headscale/gen/go/v1"; import "google/protobuf/timestamp.proto"; message User { - string id = 1; - string name = 2; - google.protobuf.Timestamp created_at = 3; -} - -message GetUserRequest { - string name = 1; -} - -message GetUserResponse { - User user = 1; + uint64 id = 1; + string name = 2; + google.protobuf.Timestamp created_at = 3; + string display_name = 4; + string email = 5; + string provider_id = 6; + string provider = 7; + string profile_pic_url = 8; } message CreateUserRequest { - string name = 1; + string name = 1; + string display_name = 2; + string email = 3; + string picture_url = 4; } -message CreateUserResponse { - User user = 1; -} +message CreateUserResponse { User user = 1; } message RenameUserRequest { - string old_name = 1; - string new_name = 2; + uint64 old_id = 1; + string new_name = 2; } -message RenameUserResponse { - User user = 1; -} +message RenameUserResponse { User user = 1; } -message DeleteUserRequest { - string name = 1; -} +message DeleteUserRequest { uint64 id = 1; } -message DeleteUserResponse { -} +message DeleteUserResponse {} message ListUsersRequest { + uint64 id = 1; + string name = 2; + string email = 3; } -message ListUsersResponse { - repeated User users = 1; -} +message ListUsersResponse { repeated User users = 1; } diff --git a/swagger.go b/swagger.go index 306fc1f6..fa764568 100644 --- a/swagger.go +++ b/swagger.go @@ -20,7 +20,7 @@ func SwaggerUI( <html> <head> <link rel="stylesheet" type="text/css" href="https://unpkg.com/swagger-ui-dist@3/swagger-ui.css"> - + <link rel="icon" href="/favicon.ico"> <script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-standalone-preset.js"></script> <script src="https://unpkg.com/swagger-ui-dist@3/swagger-ui-bundle.js" charset="UTF-8"></script> </head> @@ -57,6 +57,7 @@ func SwaggerUI( writer.Header().Set("Content-Type", "text/plain; charset=utf-8") writer.WriteHeader(http.StatusInternalServerError) + _, err := writer.Write([]byte("Could not render Swagger")) if err != nil { log.Error(). @@ -70,6 +71,7 @@ func SwaggerUI( writer.Header().Set("Content-Type", "text/html; charset=utf-8") writer.WriteHeader(http.StatusOK) + _, err := writer.Write(payload.Bytes()) if err != nil { log.Error(). @@ -85,6 +87,7 @@ func SwaggerAPIv1( ) { writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) + if _, err := writer.Write(apiV1JSON); err != nil { log.Error(). Caller(). diff --git a/tools/capver/main.go b/tools/capver/main.go new file mode 100644 index 00000000..80468c4a --- /dev/null +++ b/tools/capver/main.go @@ -0,0 +1,485 @@ +package main + +//go:generate go run main.go + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "go/format" + "io" + "log" + "net/http" + "os" + "regexp" + "slices" + "sort" + "strconv" + "strings" + + xmaps "golang.org/x/exp/maps" + "tailscale.com/tailcfg" +) + +const ( + ghcrTokenURL = "https://ghcr.io/token?service=ghcr.io&scope=repository:tailscale/tailscale:pull" //nolint:gosec + ghcrTagsURL = "https://ghcr.io/v2/tailscale/tailscale/tags/list?n=10000" + rawFileURL = "https://github.com/tailscale/tailscale/raw/refs/tags/%s/tailcfg/tailcfg.go" + outputFile = "../../hscontrol/capver/capver_generated.go" + testFile = "../../hscontrol/capver/capver_test_data.go" + fallbackCapVer = 90 + maxTestCases = 4 + supportedMajorMinorVersions = 10 + filePermissions = 0o600 + semverMatchGroups = 4 + latest3Count = 3 + latest2Count = 2 +) + +var errUnexpectedStatusCode = errors.New("unexpected status code") + +// GHCRTokenResponse represents the response from GHCR token endpoint. +type GHCRTokenResponse struct { + Token string `json:"token"` +} + +// GHCRTagsResponse represents the response from GHCR tags list endpoint. +type GHCRTagsResponse struct { + Name string `json:"name"` + Tags []string `json:"tags"` +} + +// getGHCRToken fetches an anonymous token from GHCR for accessing public container images. +func getGHCRToken(ctx context.Context) (string, error) { + client := &http.Client{} + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ghcrTokenURL, nil) + if err != nil { + return "", fmt.Errorf("error creating token request: %w", err) + } + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("error fetching GHCR token: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("%w: %d", errUnexpectedStatusCode, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("error reading token response: %w", err) + } + + var tokenResp GHCRTokenResponse + + err = json.Unmarshal(body, &tokenResp) + if err != nil { + return "", fmt.Errorf("error parsing token response: %w", err) + } + + return tokenResp.Token, nil +} + +// getGHCRTags fetches all available tags from GHCR for tailscale/tailscale. +func getGHCRTags(ctx context.Context) ([]string, error) { + token, err := getGHCRToken(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get GHCR token: %w", err) + } + + client := &http.Client{} + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, ghcrTagsURL, nil) + if err != nil { + return nil, fmt.Errorf("error creating tags request: %w", err) + } + + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + return nil, fmt.Errorf("error fetching tags: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("%w: %d", errUnexpectedStatusCode, resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("error reading tags response: %w", err) + } + + var tagsResp GHCRTagsResponse + + err = json.Unmarshal(body, &tagsResp) + if err != nil { + return nil, fmt.Errorf("error parsing tags response: %w", err) + } + + return tagsResp.Tags, nil +} + +// semverRegex matches semantic version tags like v1.90.0 or v1.90.1. +var semverRegex = regexp.MustCompile(`^v(\d+)\.(\d+)\.(\d+)$`) + +// parseSemver extracts major, minor, patch from a semver tag. +// Returns -1 for all values if not a valid semver. +func parseSemver(tag string) (int, int, int) { + matches := semverRegex.FindStringSubmatch(tag) + if len(matches) != semverMatchGroups { + return -1, -1, -1 + } + + major, _ := strconv.Atoi(matches[1]) + minor, _ := strconv.Atoi(matches[2]) + patch, _ := strconv.Atoi(matches[3]) + + return major, minor, patch +} + +// getMinorVersionsFromTags processes container tags and returns a map of minor versions +// to the first available patch version for each minor. +// For example: {"v1.90": "v1.90.0", "v1.92": "v1.92.0"}. +func getMinorVersionsFromTags(tags []string) map[string]string { + // Map minor version (e.g., "v1.90") to lowest patch version available + minorToLowestPatch := make(map[string]struct { + patch int + fullVer string + }) + + for _, tag := range tags { + major, minor, patch := parseSemver(tag) + if major < 0 { + continue // Not a semver tag + } + + minorKey := fmt.Sprintf("v%d.%d", major, minor) + + existing, exists := minorToLowestPatch[minorKey] + if !exists || patch < existing.patch { + minorToLowestPatch[minorKey] = struct { + patch int + fullVer string + }{ + patch: patch, + fullVer: tag, + } + } + } + + // Convert to simple map + result := make(map[string]string) + for minorVer, info := range minorToLowestPatch { + result[minorVer] = info.fullVer + } + + return result +} + +// getCapabilityVersions fetches container tags from GHCR, identifies minor versions, +// and fetches the capability version for each from the Tailscale source. +func getCapabilityVersions(ctx context.Context) (map[string]tailcfg.CapabilityVersion, error) { + // Fetch container tags from GHCR + tags, err := getGHCRTags(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get container tags: %w", err) + } + + log.Printf("Found %d container tags", len(tags)) + + // Get minor versions with their representative patch versions + minorVersions := getMinorVersionsFromTags(tags) + log.Printf("Found %d minor versions", len(minorVersions)) + + // Regular expression to find the CurrentCapabilityVersion line + re := regexp.MustCompile(`const CurrentCapabilityVersion CapabilityVersion = (\d+)`) + + versions := make(map[string]tailcfg.CapabilityVersion) + client := &http.Client{} + + for minorVer, patchVer := range minorVersions { + // Fetch the raw Go file for the patch version + rawURL := fmt.Sprintf(rawFileURL, patchVer) + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, rawURL, nil) //nolint:gosec + if err != nil { + log.Printf("Warning: failed to create request for %s: %v", patchVer, err) + continue + } + + resp, err := client.Do(req) + if err != nil { + log.Printf("Warning: failed to fetch %s: %v", patchVer, err) + continue + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + log.Printf("Warning: got status %d for %s", resp.StatusCode, patchVer) + continue + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Printf("Warning: failed to read response for %s: %v", patchVer, err) + continue + } + + // Find the CurrentCapabilityVersion + matches := re.FindStringSubmatch(string(body)) + if len(matches) > 1 { + capabilityVersionStr := matches[1] + capabilityVersion, _ := strconv.Atoi(capabilityVersionStr) + versions[minorVer] = tailcfg.CapabilityVersion(capabilityVersion) + log.Printf(" %s (from %s): capVer %d", minorVer, patchVer, capabilityVersion) + } + } + + return versions, nil +} + +func calculateMinSupportedCapabilityVersion(versions map[string]tailcfg.CapabilityVersion) tailcfg.CapabilityVersion { + // Since we now store minor versions directly, just sort and take the oldest of the latest N + minorVersions := xmaps.Keys(versions) + sort.Strings(minorVersions) + + supportedCount := min(len(minorVersions), supportedMajorMinorVersions) + + if supportedCount == 0 { + return fallbackCapVer + } + + // The minimum supported version is the oldest of the latest 10 + oldestSupportedMinor := minorVersions[len(minorVersions)-supportedCount] + + return versions[oldestSupportedMinor] +} + +func writeCapabilityVersionsToFile(versions map[string]tailcfg.CapabilityVersion, minSupportedCapVer tailcfg.CapabilityVersion) error { + // Generate the Go code as a string + var content strings.Builder + content.WriteString("package capver\n\n") + content.WriteString("// Generated DO NOT EDIT\n\n") + content.WriteString(`import "tailscale.com/tailcfg"`) + content.WriteString("\n\n") + content.WriteString("var tailscaleToCapVer = map[string]tailcfg.CapabilityVersion{\n") + + sortedVersions := xmaps.Keys(versions) + sort.Strings(sortedVersions) + + for _, version := range sortedVersions { + fmt.Fprintf(&content, "\t\"%s\": %d,\n", version, versions[version]) + } + + content.WriteString("}\n") + + content.WriteString("\n\n") + content.WriteString("var capVerToTailscaleVer = map[tailcfg.CapabilityVersion]string{\n") + + capVarToTailscaleVer := make(map[tailcfg.CapabilityVersion]string) + + for _, v := range sortedVersions { + capabilityVersion := versions[v] + + // If it is already set, skip and continue, + // we only want the first tailscale version per + // capability version. + if _, ok := capVarToTailscaleVer[capabilityVersion]; ok { + continue + } + + capVarToTailscaleVer[capabilityVersion] = v + } + + capsSorted := xmaps.Keys(capVarToTailscaleVer) + slices.Sort(capsSorted) + + for _, capVer := range capsSorted { + fmt.Fprintf(&content, "\t%d:\t\t\"%s\",\n", capVer, capVarToTailscaleVer[capVer]) + } + + content.WriteString("}\n\n") + + // Add the SupportedMajorMinorVersions constant + content.WriteString("// SupportedMajorMinorVersions is the number of major.minor Tailscale versions supported.\n") + fmt.Fprintf(&content, "const SupportedMajorMinorVersions = %d\n\n", supportedMajorMinorVersions) + + // Add the MinSupportedCapabilityVersion constant + content.WriteString("// MinSupportedCapabilityVersion represents the minimum capability version\n") + content.WriteString("// supported by this Headscale instance (latest 10 minor versions)\n") + fmt.Fprintf(&content, "const MinSupportedCapabilityVersion tailcfg.CapabilityVersion = %d\n", minSupportedCapVer) + + // Format the generated code + formatted, err := format.Source([]byte(content.String())) + if err != nil { + return fmt.Errorf("error formatting Go code: %w", err) + } + + // Write to file + err = os.WriteFile(outputFile, formatted, filePermissions) + if err != nil { + return fmt.Errorf("error writing file: %w", err) + } + + return nil +} + +func writeTestDataFile(versions map[string]tailcfg.CapabilityVersion, minSupportedCapVer tailcfg.CapabilityVersion) error { + // Sort minor versions + minorVersions := xmaps.Keys(versions) + sort.Strings(minorVersions) + + // Take latest N + supportedCount := min(len(minorVersions), supportedMajorMinorVersions) + + latest10 := minorVersions[len(minorVersions)-supportedCount:] + latest3 := minorVersions[len(minorVersions)-min(latest3Count, len(minorVersions)):] + latest2 := minorVersions[len(minorVersions)-min(latest2Count, len(minorVersions)):] + + // Generate test data file content + var content strings.Builder + content.WriteString("package capver\n\n") + content.WriteString("// Generated DO NOT EDIT\n\n") + content.WriteString("import \"tailscale.com/tailcfg\"\n\n") + + // Generate complete test struct for TailscaleLatestMajorMinor + content.WriteString("var tailscaleLatestMajorMinorTests = []struct {\n") + content.WriteString("\tn int\n") + content.WriteString("\tstripV bool\n") + content.WriteString("\texpected []string\n") + content.WriteString("}{\n") + + // Latest 3 with v prefix + content.WriteString("\t{3, false, []string{") + + for i, version := range latest3 { + content.WriteString(fmt.Sprintf("\"%s\"", version)) + + if i < len(latest3)-1 { + content.WriteString(", ") + } + } + + content.WriteString("}},\n") + + // Latest 2 without v prefix + content.WriteString("\t{2, true, []string{") + + for i, version := range latest2 { + // Strip v prefix for this test case + verNoV := strings.TrimPrefix(version, "v") + content.WriteString(fmt.Sprintf("\"%s\"", verNoV)) + + if i < len(latest2)-1 { + content.WriteString(", ") + } + } + + content.WriteString("}},\n") + + // Latest N without v prefix (all supported) + content.WriteString(fmt.Sprintf("\t{%d, true, []string{\n", supportedMajorMinorVersions)) + + for _, version := range latest10 { + verNoV := strings.TrimPrefix(version, "v") + content.WriteString(fmt.Sprintf("\t\t\"%s\",\n", verNoV)) + } + + content.WriteString("\t}},\n") + + // Empty case + content.WriteString("\t{0, false, nil},\n") + content.WriteString("}\n\n") + + // Build capVerToTailscaleVer for test data + capVerToTailscaleVer := make(map[tailcfg.CapabilityVersion]string) + sortedVersions := xmaps.Keys(versions) + sort.Strings(sortedVersions) + + for _, v := range sortedVersions { + capabilityVersion := versions[v] + if _, ok := capVerToTailscaleVer[capabilityVersion]; !ok { + capVerToTailscaleVer[capabilityVersion] = v + } + } + + // Generate complete test struct for CapVerMinimumTailscaleVersion + content.WriteString("var capVerMinimumTailscaleVersionTests = []struct {\n") + content.WriteString("\tinput tailcfg.CapabilityVersion\n") + content.WriteString("\texpected string\n") + content.WriteString("}{\n") + + // Add minimum supported version + minVersionString := capVerToTailscaleVer[minSupportedCapVer] + content.WriteString(fmt.Sprintf("\t{%d, \"%s\"},\n", minSupportedCapVer, minVersionString)) + + // Add a few more test cases + capsSorted := xmaps.Keys(capVerToTailscaleVer) + slices.Sort(capsSorted) + + testCount := 0 + for _, capVer := range capsSorted { + if testCount >= maxTestCases { + break + } + + if capVer != minSupportedCapVer { // Don't duplicate the min version test + version := capVerToTailscaleVer[capVer] + content.WriteString(fmt.Sprintf("\t{%d, \"%s\"},\n", capVer, version)) + + testCount++ + } + } + + // Edge cases + content.WriteString("\t{9001, \"\"}, // Test case for a version higher than any in the map\n") + content.WriteString("\t{60, \"\"}, // Test case for a version lower than any in the map\n") + content.WriteString("}\n") + + // Format the generated code + formatted, err := format.Source([]byte(content.String())) + if err != nil { + return fmt.Errorf("error formatting test data Go code: %w", err) + } + + // Write to file + err = os.WriteFile(testFile, formatted, filePermissions) + if err != nil { + return fmt.Errorf("error writing test data file: %w", err) + } + + return nil +} + +func main() { + ctx := context.Background() + + versions, err := getCapabilityVersions(ctx) + if err != nil { + log.Println("Error:", err) + return + } + + // Calculate the minimum supported capability version + minSupportedCapVer := calculateMinSupportedCapabilityVersion(versions) + + err = writeCapabilityVersionsToFile(versions, minSupportedCapVer) + if err != nil { + log.Println("Error writing to file:", err) + return + } + + err = writeTestDataFile(versions, minSupportedCapVer) + if err != nil { + log.Println("Error writing test data file:", err) + return + } + + log.Println("Capability versions written to", outputFile) +}