From 898bb72568a49a8b2b923dc8b8e7b3fea640c8d6 Mon Sep 17 00:00:00 2001 From: Janis Jansons Date: Mon, 12 Jan 2026 02:02:33 +0200 Subject: [PATCH] ACL testing (#1803) --- CHANGELOG.md | 4 + cmd/headscale/cli/policy.go | 356 ++++- docs/ref/acls.md | 159 +++ gen/go/headscale/v1/headscale.pb.go | 113 +- gen/go/headscale/v1/headscale.pb.gw.go | 63 + gen/go/headscale/v1/headscale_grpc.pb.go | 38 + gen/go/headscale/v1/policy.pb.go | 329 ++++- .../headscale/v1/headscale.swagger.json | 141 ++ hscontrol/grpcv1.go | 85 ++ hscontrol/policy/pm.go | 6 + hscontrol/policy/v2/policy.go | 10 + hscontrol/policy/v2/test.go | 385 ++++++ hscontrol/policy/v2/test_test.go | 1199 +++++++++++++++++ hscontrol/policy/v2/types.go | 1 + hscontrol/state/state.go | 11 + proto/headscale/v1/headscale.proto | 7 + proto/headscale/v1/policy.proto | 45 + 17 files changed, 2886 insertions(+), 66 deletions(-) create mode 100644 hscontrol/policy/v2/test.go create mode 100644 hscontrol/policy/v2/test_test.go diff --git a/CHANGELOG.md b/CHANGELOG.md index 0ef22ff2..2b7e6c2a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -73,6 +73,10 @@ sequentially through each stable release, selecting the latest patch version ava - Log ACME/autocert errors for easier debugging [#2933](https://github.com/juanfont/headscale/pull/2933) - Improve CLI list output formatting [#2951](https://github.com/juanfont/headscale/pull/2951) - Use Debian 13 distroless base images for containers [#2944](https://github.com/juanfont/headscale/pull/2944) +- Add ACL testing functionality via CLI, API, and embedded policy tests [#3005](https://github.com/juanfont/headscale/pull/3005) + - New `headscale policy test` command to verify ACL rules + - New `POST /api/v1/policy/test` endpoint for third-party UI integration + - Support for `tests` section in policy files with automatic validation on policy updates - Fix ACL policy not applied to new OIDC nodes until client restart [#2890](https://github.com/juanfont/headscale/pull/2890) - Fix autogroup:self preventing visibility of nodes matched by other ACL rules [#2882](https://github.com/juanfont/headscale/pull/2882) - Fix nodes being rejected after pre-authentication key expiration [#2917](https://github.com/juanfont/headscale/pull/2917) diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index f99d5390..73e2ff01 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -1,13 +1,16 @@ package cli import ( + "encoding/json" "fmt" "io" "os" + "strings" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/policy" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -16,7 +19,10 @@ import ( ) const ( - bypassFlag = "bypass-grpc-and-access-database-directly" + bypassFlag = "bypass-grpc-and-access-database-directly" //nolint:gosec // Not credentials + separatorWidth = 50 + outputFormatJSON = "json" + outputFormatJSONLine = "json-line" ) func init() { @@ -37,6 +43,17 @@ func init() { log.Fatal().Err(err).Msg("") } policyCmd.AddCommand(checkPolicy) + + // Test command flags + testPolicy.Flags().StringP("src", "s", "", "Source alias to test from (user, group, tag, host, or IP)") + testPolicy.Flags().StringSliceP("accept", "a", nil, "Destinations that should be allowed (repeatable, format: host:port)") + testPolicy.Flags().StringSliceP("deny", "d", nil, "Destinations that should be denied (repeatable, format: host:port)") + testPolicy.Flags().StringP("proto", "p", "", "Protocol to test (tcp, udp, icmp)") + testPolicy.Flags().StringP("file", "f", "", "Path to a JSON file with test definitions") + testPolicy.Flags().StringP("policy-file", "", "", "Test against a proposed policy file instead of current policy") + testPolicy.Flags().BoolP("embedded", "e", false, "Run tests embedded in the current policy") + testPolicy.Flags().BoolP(bypassFlag, "", false, "Uses the headscale config to directly access the database, bypassing gRPC and does not require the server to be running") + policyCmd.AddCommand(testPolicy) } var policyCmd = &cobra.Command{ @@ -210,3 +227,340 @@ var checkPolicy = &cobra.Command{ SuccessOutput(nil, "Policy is valid", "") }, } + +var testPolicy = &cobra.Command{ + Use: "test", + Short: "Test ACL rules", + Long: `Test ACL rules to verify access between sources and destinations. + +Examples: + # Test if user can access server + headscale policy test --src "alice@example.com" --accept "tag:server:22" + + # Test with deny rules + headscale policy test --src "alice@" --accept "10.0.0.1:80" --deny "10.0.0.2:443" + + # Run tests from a JSON file + headscale policy test --file tests.json + + # Run embedded tests from current policy + headscale policy test --embedded + + # Test against a proposed policy file + headscale policy test --src "alice@" --accept "10.0.0.1:22" --policy-file new-policy.json`, + Run: func(cmd *cobra.Command, args []string) { + output, _ := cmd.Flags().GetString("output") + + // Collect tests from various sources + var tests []policyv2.ACLTest + + // Get flags + src, _ := cmd.Flags().GetString("src") + accept, _ := cmd.Flags().GetStringSlice("accept") + deny, _ := cmd.Flags().GetStringSlice("deny") + proto, _ := cmd.Flags().GetString("proto") + testFile, _ := cmd.Flags().GetString("file") + policyFile, _ := cmd.Flags().GetString("policy-file") + embedded, _ := cmd.Flags().GetBool("embedded") + bypass, _ := cmd.Flags().GetBool(bypassFlag) + + // Build test from command line flags if src is provided + if src != "" { + tests = append(tests, policyv2.ACLTest{ + Src: src, + Proto: policyv2.Protocol(proto), + Accept: accept, + Deny: deny, + }) + } + + // Load tests from file if provided + if testFile != "" { + fileTests, err := loadTestsFromFile(testFile) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error loading tests from file: %s", err), output) + return + } + tests = append(tests, fileTests...) + } + + // Read policy file if provided (for testing against proposed policy) + var policyBytes []byte + if policyFile != "" { + f, err := os.Open(policyFile) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error opening policy file: %s", err), output) + return + } + defer f.Close() + + policyBytes, err = io.ReadAll(f) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Error reading policy file: %s", err), output) + return + } + } + + var results policyv2.ACLTestResults + + if bypass { + results = runTestsBypass(cmd, output, tests, policyBytes, embedded) + } else { + results = runTestsGRPC(cmd, output, tests, policyBytes, embedded) + } + + // Output results + if output == outputFormatJSON || output == outputFormatJSONLine { + SuccessOutput(results, "", output) + } else { + printHumanReadableResults(results) + } + }, +} + +func loadTestsFromFile(path string) ([]policyv2.ACLTest, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + var tests []policyv2.ACLTest + + decoder := json.NewDecoder(f) + + err = decoder.Decode(&tests) + if err != nil { + return nil, err + } + + return tests, nil +} + +func runTestsBypass(cmd *cobra.Command, output string, tests []policyv2.ACLTest, policyBytes []byte, embedded bool) policyv2.ACLTestResults { + confirm := false + + force, _ := cmd.Flags().GetBool("force") + if !force { + confirm = util.YesNo("DO NOT run this command if an instance of headscale is running, are you sure headscale is not running?") + } + + if !confirm && !force { + ErrorOutput(nil, "Aborting command", output) + return policyv2.ACLTestResults{} + } + + cfg, err := types.LoadServerConfig() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed loading config: %s", err), output) + return policyv2.ACLTestResults{} + } + + d, err := db.NewHeadscaleDatabase( + cfg.Database, + cfg.BaseDomain, + nil, + ) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to open database: %s", err), output) + return policyv2.ACLTestResults{} + } + + users, err := d.ListUsers() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to load users: %s", err), output) + return policyv2.ACLTestResults{} + } + + nodes, err := d.ListNodes() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to load nodes: %s", err), output) + return policyv2.ACLTestResults{} + } + + // Convert nodes to NodeView slice + nodeViews := make([]types.NodeView, len(nodes)) + for i, n := range nodes { + nodeViews[i] = n.View() + } + + // Determine which policy to test against + var polBytes []byte + if len(policyBytes) > 0 { + polBytes = policyBytes + } else { + pol, err := d.GetPolicy() + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to load policy: %s", err), output) + return policyv2.ACLTestResults{} + } + + polBytes = []byte(pol.Data) + } + + pm, err := policyv2.NewPolicyManager(polBytes, users, views.SliceOf(nodeViews)) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to parse policy: %s", err), output) + return policyv2.ACLTestResults{} + } + + // If embedded flag is set, get tests from the policy + if embedded { + pol := pm.Policy() + if pol != nil && len(pol.Tests) > 0 { + tests = append(tests, pol.Tests...) + } + } + + if len(tests) == 0 { + ErrorOutput(nil, "No tests to run. Use --src, --file, or --embedded to specify tests.", output) + return policyv2.ACLTestResults{} + } + + return pm.RunTests(tests) +} + +func runTestsGRPC(_ *cobra.Command, output string, tests []policyv2.ACLTest, policyBytes []byte, embedded bool) policyv2.ACLTestResults { + ctx, client, conn, cancel := newHeadscaleCLIWithConfig() + defer cancel() + defer conn.Close() + + // If embedded, get tests from current policy first + if embedded { + policyResp, err := client.GetPolicy(ctx, &v1.GetPolicyRequest{}) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to get current policy: %s", err), output) + return policyv2.ACLTestResults{} + } + + // Parse policy to extract embedded tests + pm, err := policyv2.NewPolicyManager([]byte(policyResp.GetPolicy()), nil, views.Slice[types.NodeView]{}) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to parse policy: %s", err), output) + return policyv2.ACLTestResults{} + } + + pol := pm.Policy() + if pol != nil && len(pol.Tests) > 0 { + tests = append(tests, pol.Tests...) + } + } + + if len(tests) == 0 { + ErrorOutput(nil, "No tests to run. Use --src, --file, or --embedded to specify tests.", output) + return policyv2.ACLTestResults{} + } + + // Convert tests to proto format + protoTests := make([]*v1.ACLTest, len(tests)) + for i, t := range tests { + protoTests[i] = &v1.ACLTest{ + Src: t.Src, + Proto: string(t.Proto), + Accept: t.Accept, + Deny: t.Deny, + } + } + + request := &v1.TestACLRequest{ + Tests: protoTests, + Policy: string(policyBytes), + } + + response, err := client.TestACL(ctx, request) + if err != nil { + ErrorOutput(err, fmt.Sprintf("Failed to run ACL tests: %s", err), output) + return policyv2.ACLTestResults{} + } + + // Convert proto response to internal format + results := policyv2.ACLTestResults{ + AllPassed: response.GetAllPassed(), + Results: make([]policyv2.ACLTestResult, len(response.GetResults())), + } + + for i, r := range response.GetResults() { + results.Results[i] = policyv2.ACLTestResult{ + Src: r.GetSrc(), + Passed: r.GetPassed(), + Errors: r.GetErrors(), + AcceptOK: r.GetAcceptOk(), + AcceptFail: r.GetAcceptFail(), + DenyOK: r.GetDenyOk(), + DenyFail: r.GetDenyFail(), + } + } + + return results +} + +func printHumanReadableResults(results policyv2.ACLTestResults) { + fmt.Println("ACL Test Results") + fmt.Println(strings.Repeat("=", separatorWidth)) + fmt.Println() + + passedCount := 0 + totalCount := len(results.Results) + + for _, result := range results.Results { + fmt.Printf("Source: %s\n", result.Src) + fmt.Println() + + if len(result.Errors) > 0 { + fmt.Println(" Errors:") + + for _, e := range result.Errors { + fmt.Printf(" ! %s\n", e) + } + + fmt.Println() + } + + if len(result.AcceptOK) > 0 || len(result.AcceptFail) > 0 { + fmt.Println(" Accept Tests:") + + for _, dest := range result.AcceptOK { + fmt.Printf(" [PASS] %s - ALLOWED (expected)\n", dest) + } + + for _, dest := range result.AcceptFail { + fmt.Printf(" [FAIL] %s - DENIED (expected ALLOWED)\n", dest) + } + + fmt.Println() + } + + if len(result.DenyOK) > 0 || len(result.DenyFail) > 0 { + fmt.Println(" Deny Tests:") + + for _, dest := range result.DenyOK { + fmt.Printf(" [PASS] %s - DENIED (expected)\n", dest) + } + + for _, dest := range result.DenyFail { + fmt.Printf(" [FAIL] %s - ALLOWED (expected DENIED)\n", dest) + } + + fmt.Println() + } + + if result.Passed { + passedCount++ + + fmt.Println(" Result: PASSED") + } else { + fmt.Println(" Result: FAILED") + } + + fmt.Println() + fmt.Println(strings.Repeat("-", separatorWidth)) + fmt.Println() + } + + // Summary + if results.AllPassed { + fmt.Printf("Overall: PASSED (%d/%d tests passed)\n", passedCount, totalCount) + } else { + fmt.Printf("Overall: FAILED (%d/%d tests passed)\n", passedCount, totalCount) + } +} diff --git a/docs/ref/acls.md b/docs/ref/acls.md index 3368ab61..3f3b87c8 100644 --- a/docs/ref/acls.md +++ b/docs/ref/acls.md @@ -285,3 +285,162 @@ Used in Tailscale SSH rules to allow access to any user except root. Can only be "users": ["autogroup:nonroot"] } ``` + +## Testing ACLs + +Headscale provides ACL testing functionality to verify that your policy rules work as expected. You can test ACLs using embedded tests in your policy file or via the CLI. + +### Embedded Tests in Policy + +You can include a `tests` section in your policy file to define test cases that are automatically validated when the policy is loaded or updated. **If any embedded test fails, the policy update will be rejected**, providing regression protection when modifying ACL rules. + +```json +{ + "groups": { + "group:dev": ["dev1@", "dev2@"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:dev"], + "dst": ["tag:dev-servers:*"] + } + ], + "tests": [ + { + "src": "dev1@", + "accept": ["tag:dev-servers:22", "tag:dev-servers:80"] + }, + { + "src": "dev1@", + "deny": ["tag:prod-servers:22"] + }, + { + "src": "group:dev", + "proto": "tcp", + "accept": ["tag:dev-servers:443"] + } + ] +} +``` + +Each test case supports the following fields: + +| Field | Description | +|----------|----------------------------------------------------------| +| `src` | Source alias to test from (user, group, tag, host, or IP) | +| `accept` | List of destinations that should be **allowed** (format: `host:port`) | +| `deny` | List of destinations that should be **denied** (format: `host:port`) | +| `proto` | Optional protocol filter (`tcp`, `udp`, `icmp`) | + +### CLI Testing + +The `headscale policy test` command allows you to test ACL rules without modifying your policy. + +#### Test Specific Access + +```bash +# Test if a user can access a server on port 22 +headscale policy test --src "alice@example.com" --accept "tag:server:22" + +# Test with multiple destinations +headscale policy test --src "group:dev" --accept "tag:dev:22" --accept "tag:dev:80" + +# Test both allowed and denied access +headscale policy test --src "alice@" --accept "10.0.0.1:80" --deny "10.0.0.2:443" + +# Test with protocol filter +headscale policy test --src "tag:monitoring" --proto tcp --accept "tag:servers:9090" +``` + +#### Run Embedded Tests + +```bash +# Run all tests defined in the current policy's tests section +headscale policy test --embedded +``` + +#### Test a Proposed Policy + +Before applying a new policy, you can test it without affecting the running configuration: + +```bash +# Test against a proposed policy file +headscale policy test --src "alice@" --accept "server:22" --policy-file new-acl.json + +# Run embedded tests from a proposed policy file +headscale policy test --embedded --policy-file new-acl.json +``` + +#### Test from a File + +You can define multiple tests in a JSON file: + +```bash +headscale policy test --file tests.json +``` + +Where `tests.json` contains: + +```json +[ + { + "src": "alice@example.com", + "accept": ["server1:22", "server2:80"], + "deny": ["database:5432"] + }, + { + "src": "tag:ci", + "accept": ["tag:staging:*"] + } +] +``` + +#### Output Formats + +By default, the CLI shows human-readable output. + +For programmatic use, JSON output is available: + +```bash +headscale policy test --src "alice@" --accept "server:22" --output json +``` + +### API Endpoint + +Third-party UIs can use the gRPC/HTTP API to test ACL rules: + +**Endpoint:** `POST /api/v1/policy/test` + +**Request:** + +```json +{ + "tests": [ + { + "src": "alice@example.com", + "accept": ["server1:22"], + "deny": ["database:5432"] + } + ], + "policy": "" +} +``` + +The optional `policy` field allows testing against a proposed policy instead of the current active policy. If empty, tests run against the current policy. + +**Response:** + +```json +{ + "all_passed": true, + "results": [ + { + "src": "alice@example.com", + "passed": true, + "accept_ok": ["server1:22"], + "deny_ok": ["database:5432"] + } + ] +} +``` diff --git a/gen/go/headscale/v1/headscale.pb.go b/gen/go/headscale/v1/headscale.pb.go index 3d16778c..4f4ab82a 100644 --- a/gen/go/headscale/v1/headscale.pb.go +++ b/gen/go/headscale/v1/headscale.pb.go @@ -109,7 +109,7 @@ const file_headscale_v1_headscale_proto_rawDesc = "" + "\x1cheadscale/v1/headscale.proto\x12\fheadscale.v1\x1a\x1cgoogle/api/annotations.proto\x1a\x17headscale/v1/user.proto\x1a\x1dheadscale/v1/preauthkey.proto\x1a\x17headscale/v1/node.proto\x1a\x19headscale/v1/apikey.proto\x1a\x19headscale/v1/policy.proto\"\x0f\n" + "\rHealthRequest\"E\n" + "\x0eHealthResponse\x123\n" + - "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\x8c\x17\n" + + "\x15database_connectivity\x18\x01 \x01(\bR\x14databaseConnectivity2\xf4\x17\n" + "\x10HeadscaleService\x12h\n" + "\n" + "CreateUser\x12\x1f.headscale.v1.CreateUserRequest\x1a .headscale.v1.CreateUserResponse\"\x17\x82\xd3\xe4\x93\x02\x11:\x01*\"\f/api/v1/user\x12\x80\x01\n" + @@ -140,7 +140,8 @@ const file_headscale_v1_headscale_proto_rawDesc = "" + "\vListApiKeys\x12 .headscale.v1.ListApiKeysRequest\x1a!.headscale.v1.ListApiKeysResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/apikey\x12v\n" + "\fDeleteApiKey\x12!.headscale.v1.DeleteApiKeyRequest\x1a\".headscale.v1.DeleteApiKeyResponse\"\x1f\x82\xd3\xe4\x93\x02\x19*\x17/api/v1/apikey/{prefix}\x12d\n" + "\tGetPolicy\x12\x1e.headscale.v1.GetPolicyRequest\x1a\x1f.headscale.v1.GetPolicyResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/policy\x12g\n" + - "\tSetPolicy\x12\x1e.headscale.v1.SetPolicyRequest\x1a\x1f.headscale.v1.SetPolicyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\x1a\x0e/api/v1/policy\x12[\n" + + "\tSetPolicy\x12\x1e.headscale.v1.SetPolicyRequest\x1a\x1f.headscale.v1.SetPolicyResponse\"\x19\x82\xd3\xe4\x93\x02\x13:\x01*\x1a\x0e/api/v1/policy\x12f\n" + + "\aTestACL\x12\x1c.headscale.v1.TestACLRequest\x1a\x1d.headscale.v1.TestACLResponse\"\x1e\x82\xd3\xe4\x93\x02\x18:\x01*\"\x13/api/v1/policy/test\x12[\n" + "\x06Health\x12\x1b.headscale.v1.HealthRequest\x1a\x1c.headscale.v1.HealthResponse\"\x16\x82\xd3\xe4\x93\x02\x10\x12\x0e/api/v1/healthB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( @@ -183,30 +184,32 @@ var file_headscale_v1_headscale_proto_goTypes = []any{ (*DeleteApiKeyRequest)(nil), // 23: headscale.v1.DeleteApiKeyRequest (*GetPolicyRequest)(nil), // 24: headscale.v1.GetPolicyRequest (*SetPolicyRequest)(nil), // 25: headscale.v1.SetPolicyRequest - (*CreateUserResponse)(nil), // 26: headscale.v1.CreateUserResponse - (*RenameUserResponse)(nil), // 27: headscale.v1.RenameUserResponse - (*DeleteUserResponse)(nil), // 28: headscale.v1.DeleteUserResponse - (*ListUsersResponse)(nil), // 29: headscale.v1.ListUsersResponse - (*CreatePreAuthKeyResponse)(nil), // 30: headscale.v1.CreatePreAuthKeyResponse - (*ExpirePreAuthKeyResponse)(nil), // 31: headscale.v1.ExpirePreAuthKeyResponse - (*DeletePreAuthKeyResponse)(nil), // 32: headscale.v1.DeletePreAuthKeyResponse - (*ListPreAuthKeysResponse)(nil), // 33: headscale.v1.ListPreAuthKeysResponse - (*DebugCreateNodeResponse)(nil), // 34: headscale.v1.DebugCreateNodeResponse - (*GetNodeResponse)(nil), // 35: headscale.v1.GetNodeResponse - (*SetTagsResponse)(nil), // 36: headscale.v1.SetTagsResponse - (*SetApprovedRoutesResponse)(nil), // 37: headscale.v1.SetApprovedRoutesResponse - (*RegisterNodeResponse)(nil), // 38: headscale.v1.RegisterNodeResponse - (*DeleteNodeResponse)(nil), // 39: headscale.v1.DeleteNodeResponse - (*ExpireNodeResponse)(nil), // 40: headscale.v1.ExpireNodeResponse - (*RenameNodeResponse)(nil), // 41: headscale.v1.RenameNodeResponse - (*ListNodesResponse)(nil), // 42: headscale.v1.ListNodesResponse - (*BackfillNodeIPsResponse)(nil), // 43: headscale.v1.BackfillNodeIPsResponse - (*CreateApiKeyResponse)(nil), // 44: headscale.v1.CreateApiKeyResponse - (*ExpireApiKeyResponse)(nil), // 45: headscale.v1.ExpireApiKeyResponse - (*ListApiKeysResponse)(nil), // 46: headscale.v1.ListApiKeysResponse - (*DeleteApiKeyResponse)(nil), // 47: headscale.v1.DeleteApiKeyResponse - (*GetPolicyResponse)(nil), // 48: headscale.v1.GetPolicyResponse - (*SetPolicyResponse)(nil), // 49: headscale.v1.SetPolicyResponse + (*TestACLRequest)(nil), // 26: headscale.v1.TestACLRequest + (*CreateUserResponse)(nil), // 27: headscale.v1.CreateUserResponse + (*RenameUserResponse)(nil), // 28: headscale.v1.RenameUserResponse + (*DeleteUserResponse)(nil), // 29: headscale.v1.DeleteUserResponse + (*ListUsersResponse)(nil), // 30: headscale.v1.ListUsersResponse + (*CreatePreAuthKeyResponse)(nil), // 31: headscale.v1.CreatePreAuthKeyResponse + (*ExpirePreAuthKeyResponse)(nil), // 32: headscale.v1.ExpirePreAuthKeyResponse + (*DeletePreAuthKeyResponse)(nil), // 33: headscale.v1.DeletePreAuthKeyResponse + (*ListPreAuthKeysResponse)(nil), // 34: headscale.v1.ListPreAuthKeysResponse + (*DebugCreateNodeResponse)(nil), // 35: headscale.v1.DebugCreateNodeResponse + (*GetNodeResponse)(nil), // 36: headscale.v1.GetNodeResponse + (*SetTagsResponse)(nil), // 37: headscale.v1.SetTagsResponse + (*SetApprovedRoutesResponse)(nil), // 38: headscale.v1.SetApprovedRoutesResponse + (*RegisterNodeResponse)(nil), // 39: headscale.v1.RegisterNodeResponse + (*DeleteNodeResponse)(nil), // 40: headscale.v1.DeleteNodeResponse + (*ExpireNodeResponse)(nil), // 41: headscale.v1.ExpireNodeResponse + (*RenameNodeResponse)(nil), // 42: headscale.v1.RenameNodeResponse + (*ListNodesResponse)(nil), // 43: headscale.v1.ListNodesResponse + (*BackfillNodeIPsResponse)(nil), // 44: headscale.v1.BackfillNodeIPsResponse + (*CreateApiKeyResponse)(nil), // 45: headscale.v1.CreateApiKeyResponse + (*ExpireApiKeyResponse)(nil), // 46: headscale.v1.ExpireApiKeyResponse + (*ListApiKeysResponse)(nil), // 47: headscale.v1.ListApiKeysResponse + (*DeleteApiKeyResponse)(nil), // 48: headscale.v1.DeleteApiKeyResponse + (*GetPolicyResponse)(nil), // 49: headscale.v1.GetPolicyResponse + (*SetPolicyResponse)(nil), // 50: headscale.v1.SetPolicyResponse + (*TestACLResponse)(nil), // 51: headscale.v1.TestACLResponse } var file_headscale_v1_headscale_proto_depIdxs = []int32{ 2, // 0: headscale.v1.HeadscaleService.CreateUser:input_type -> headscale.v1.CreateUserRequest @@ -233,34 +236,36 @@ var file_headscale_v1_headscale_proto_depIdxs = []int32{ 23, // 21: headscale.v1.HeadscaleService.DeleteApiKey:input_type -> headscale.v1.DeleteApiKeyRequest 24, // 22: headscale.v1.HeadscaleService.GetPolicy:input_type -> headscale.v1.GetPolicyRequest 25, // 23: headscale.v1.HeadscaleService.SetPolicy:input_type -> headscale.v1.SetPolicyRequest - 0, // 24: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest - 26, // 25: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse - 27, // 26: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse - 28, // 27: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse - 29, // 28: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse - 30, // 29: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse - 31, // 30: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse - 32, // 31: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse - 33, // 32: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse - 34, // 33: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse - 35, // 34: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse - 36, // 35: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse - 37, // 36: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse - 38, // 37: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse - 39, // 38: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse - 40, // 39: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse - 41, // 40: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse - 42, // 41: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse - 43, // 42: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse - 44, // 43: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse - 45, // 44: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse - 46, // 45: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse - 47, // 46: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse - 48, // 47: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse - 49, // 48: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse - 1, // 49: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse - 25, // [25:50] is the sub-list for method output_type - 0, // [0:25] is the sub-list for method input_type + 26, // 24: headscale.v1.HeadscaleService.TestACL:input_type -> headscale.v1.TestACLRequest + 0, // 25: headscale.v1.HeadscaleService.Health:input_type -> headscale.v1.HealthRequest + 27, // 26: headscale.v1.HeadscaleService.CreateUser:output_type -> headscale.v1.CreateUserResponse + 28, // 27: headscale.v1.HeadscaleService.RenameUser:output_type -> headscale.v1.RenameUserResponse + 29, // 28: headscale.v1.HeadscaleService.DeleteUser:output_type -> headscale.v1.DeleteUserResponse + 30, // 29: headscale.v1.HeadscaleService.ListUsers:output_type -> headscale.v1.ListUsersResponse + 31, // 30: headscale.v1.HeadscaleService.CreatePreAuthKey:output_type -> headscale.v1.CreatePreAuthKeyResponse + 32, // 31: headscale.v1.HeadscaleService.ExpirePreAuthKey:output_type -> headscale.v1.ExpirePreAuthKeyResponse + 33, // 32: headscale.v1.HeadscaleService.DeletePreAuthKey:output_type -> headscale.v1.DeletePreAuthKeyResponse + 34, // 33: headscale.v1.HeadscaleService.ListPreAuthKeys:output_type -> headscale.v1.ListPreAuthKeysResponse + 35, // 34: headscale.v1.HeadscaleService.DebugCreateNode:output_type -> headscale.v1.DebugCreateNodeResponse + 36, // 35: headscale.v1.HeadscaleService.GetNode:output_type -> headscale.v1.GetNodeResponse + 37, // 36: headscale.v1.HeadscaleService.SetTags:output_type -> headscale.v1.SetTagsResponse + 38, // 37: headscale.v1.HeadscaleService.SetApprovedRoutes:output_type -> headscale.v1.SetApprovedRoutesResponse + 39, // 38: headscale.v1.HeadscaleService.RegisterNode:output_type -> headscale.v1.RegisterNodeResponse + 40, // 39: headscale.v1.HeadscaleService.DeleteNode:output_type -> headscale.v1.DeleteNodeResponse + 41, // 40: headscale.v1.HeadscaleService.ExpireNode:output_type -> headscale.v1.ExpireNodeResponse + 42, // 41: headscale.v1.HeadscaleService.RenameNode:output_type -> headscale.v1.RenameNodeResponse + 43, // 42: headscale.v1.HeadscaleService.ListNodes:output_type -> headscale.v1.ListNodesResponse + 44, // 43: headscale.v1.HeadscaleService.BackfillNodeIPs:output_type -> headscale.v1.BackfillNodeIPsResponse + 45, // 44: headscale.v1.HeadscaleService.CreateApiKey:output_type -> headscale.v1.CreateApiKeyResponse + 46, // 45: headscale.v1.HeadscaleService.ExpireApiKey:output_type -> headscale.v1.ExpireApiKeyResponse + 47, // 46: headscale.v1.HeadscaleService.ListApiKeys:output_type -> headscale.v1.ListApiKeysResponse + 48, // 47: headscale.v1.HeadscaleService.DeleteApiKey:output_type -> headscale.v1.DeleteApiKeyResponse + 49, // 48: headscale.v1.HeadscaleService.GetPolicy:output_type -> headscale.v1.GetPolicyResponse + 50, // 49: headscale.v1.HeadscaleService.SetPolicy:output_type -> headscale.v1.SetPolicyResponse + 51, // 50: headscale.v1.HeadscaleService.TestACL:output_type -> headscale.v1.TestACLResponse + 1, // 51: headscale.v1.HeadscaleService.Health:output_type -> headscale.v1.HealthResponse + 26, // [26:52] is the sub-list for method output_type + 0, // [0:26] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name diff --git a/gen/go/headscale/v1/headscale.pb.gw.go b/gen/go/headscale/v1/headscale.pb.gw.go index ffa6964f..6e216f1e 100644 --- a/gen/go/headscale/v1/headscale.pb.gw.go +++ b/gen/go/headscale/v1/headscale.pb.gw.go @@ -813,6 +813,30 @@ func local_request_HeadscaleService_SetPolicy_0(ctx context.Context, marshaler r return msg, metadata, err } +func request_HeadscaleService_TestACL_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq TestACLRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := client.TestACL(ctx, &protoReq, grpc.Header(&metadata.HeaderMD), grpc.Trailer(&metadata.TrailerMD)) + return msg, metadata, err +} + +func local_request_HeadscaleService_TestACL_0(ctx context.Context, marshaler runtime.Marshaler, server HeadscaleServiceServer, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { + var ( + protoReq TestACLRequest + metadata runtime.ServerMetadata + ) + if err := marshaler.NewDecoder(req.Body).Decode(&protoReq); err != nil && !errors.Is(err, io.EOF) { + return nil, metadata, status.Errorf(codes.InvalidArgument, "%v", err) + } + msg, err := server.TestACL(ctx, &protoReq) + return msg, metadata, err +} + func request_HeadscaleService_Health_0(ctx context.Context, marshaler runtime.Marshaler, client HeadscaleServiceClient, req *http.Request, pathParams map[string]string) (proto.Message, runtime.ServerMetadata, error) { var ( protoReq HealthRequest @@ -1317,6 +1341,26 @@ func RegisterHeadscaleServiceHandlerServer(ctx context.Context, mux *runtime.Ser } forward_HeadscaleService_SetPolicy_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_TestACL_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + var stream runtime.ServerTransportStream + ctx = grpc.NewContextWithServerTransportStream(ctx, &stream) + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateIncomingContext(ctx, mux, req, "/headscale.v1.HeadscaleService/TestACL", runtime.WithHTTPPathPattern("/api/v1/policy/test")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := local_request_HeadscaleService_TestACL_0(annotatedContext, inboundMarshaler, server, req, pathParams) + md.HeaderMD, md.TrailerMD = metadata.Join(md.HeaderMD, stream.Header()), metadata.Join(md.TrailerMD, stream.Trailer()) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_TestACL_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) mux.Handle(http.MethodGet, pattern_HeadscaleService_Health_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1785,6 +1829,23 @@ func RegisterHeadscaleServiceHandlerClient(ctx context.Context, mux *runtime.Ser } forward_HeadscaleService_SetPolicy_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) }) + mux.Handle(http.MethodPost, pattern_HeadscaleService_TestACL_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { + ctx, cancel := context.WithCancel(req.Context()) + defer cancel() + inboundMarshaler, outboundMarshaler := runtime.MarshalerForRequest(mux, req) + annotatedContext, err := runtime.AnnotateContext(ctx, mux, req, "/headscale.v1.HeadscaleService/TestACL", runtime.WithHTTPPathPattern("/api/v1/policy/test")) + if err != nil { + runtime.HTTPError(ctx, mux, outboundMarshaler, w, req, err) + return + } + resp, md, err := request_HeadscaleService_TestACL_0(annotatedContext, inboundMarshaler, client, req, pathParams) + annotatedContext = runtime.NewServerMetadataContext(annotatedContext, md) + if err != nil { + runtime.HTTPError(annotatedContext, mux, outboundMarshaler, w, req, err) + return + } + forward_HeadscaleService_TestACL_0(annotatedContext, mux, outboundMarshaler, w, req, resp, mux.GetForwardResponseOptions()...) + }) mux.Handle(http.MethodGet, pattern_HeadscaleService_Health_0, func(w http.ResponseWriter, req *http.Request, pathParams map[string]string) { ctx, cancel := context.WithCancel(req.Context()) defer cancel() @@ -1830,6 +1891,7 @@ var ( pattern_HeadscaleService_DeleteApiKey_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 1, 0, 4, 1, 5, 3}, []string{"api", "v1", "apikey", "prefix"}, "")) pattern_HeadscaleService_GetPolicy_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "policy"}, "")) pattern_HeadscaleService_SetPolicy_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "policy"}, "")) + pattern_HeadscaleService_TestACL_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2, 2, 3}, []string{"api", "v1", "policy", "test"}, "")) pattern_HeadscaleService_Health_0 = runtime.MustPattern(runtime.NewPattern(1, []int{2, 0, 2, 1, 2, 2}, []string{"api", "v1", "health"}, "")) ) @@ -1858,5 +1920,6 @@ var ( forward_HeadscaleService_DeleteApiKey_0 = runtime.ForwardResponseMessage forward_HeadscaleService_GetPolicy_0 = runtime.ForwardResponseMessage forward_HeadscaleService_SetPolicy_0 = runtime.ForwardResponseMessage + forward_HeadscaleService_TestACL_0 = runtime.ForwardResponseMessage forward_HeadscaleService_Health_0 = runtime.ForwardResponseMessage ) diff --git a/gen/go/headscale/v1/headscale_grpc.pb.go b/gen/go/headscale/v1/headscale_grpc.pb.go index a3963935..b2b98fe3 100644 --- a/gen/go/headscale/v1/headscale_grpc.pb.go +++ b/gen/go/headscale/v1/headscale_grpc.pb.go @@ -43,6 +43,7 @@ const ( HeadscaleService_DeleteApiKey_FullMethodName = "/headscale.v1.HeadscaleService/DeleteApiKey" HeadscaleService_GetPolicy_FullMethodName = "/headscale.v1.HeadscaleService/GetPolicy" HeadscaleService_SetPolicy_FullMethodName = "/headscale.v1.HeadscaleService/SetPolicy" + HeadscaleService_TestACL_FullMethodName = "/headscale.v1.HeadscaleService/TestACL" HeadscaleService_Health_FullMethodName = "/headscale.v1.HeadscaleService/Health" ) @@ -79,6 +80,7 @@ type HeadscaleServiceClient interface { // --- Policy start --- GetPolicy(ctx context.Context, in *GetPolicyRequest, opts ...grpc.CallOption) (*GetPolicyResponse, error) SetPolicy(ctx context.Context, in *SetPolicyRequest, opts ...grpc.CallOption) (*SetPolicyResponse, error) + TestACL(ctx context.Context, in *TestACLRequest, opts ...grpc.CallOption) (*TestACLResponse, error) // --- Health start --- Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) } @@ -331,6 +333,16 @@ func (c *headscaleServiceClient) SetPolicy(ctx context.Context, in *SetPolicyReq return out, nil } +func (c *headscaleServiceClient) TestACL(ctx context.Context, in *TestACLRequest, opts ...grpc.CallOption) (*TestACLResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(TestACLResponse) + err := c.cc.Invoke(ctx, HeadscaleService_TestACL_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *headscaleServiceClient) Health(ctx context.Context, in *HealthRequest, opts ...grpc.CallOption) (*HealthResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(HealthResponse) @@ -374,6 +386,7 @@ type HeadscaleServiceServer interface { // --- Policy start --- GetPolicy(context.Context, *GetPolicyRequest) (*GetPolicyResponse, error) SetPolicy(context.Context, *SetPolicyRequest) (*SetPolicyResponse, error) + TestACL(context.Context, *TestACLRequest) (*TestACLResponse, error) // --- Health start --- Health(context.Context, *HealthRequest) (*HealthResponse, error) mustEmbedUnimplementedHeadscaleServiceServer() @@ -458,6 +471,9 @@ func (UnimplementedHeadscaleServiceServer) GetPolicy(context.Context, *GetPolicy func (UnimplementedHeadscaleServiceServer) SetPolicy(context.Context, *SetPolicyRequest) (*SetPolicyResponse, error) { return nil, status.Error(codes.Unimplemented, "method SetPolicy not implemented") } +func (UnimplementedHeadscaleServiceServer) TestACL(context.Context, *TestACLRequest) (*TestACLResponse, error) { + return nil, status.Error(codes.Unimplemented, "method TestACL not implemented") +} func (UnimplementedHeadscaleServiceServer) Health(context.Context, *HealthRequest) (*HealthResponse, error) { return nil, status.Error(codes.Unimplemented, "method Health not implemented") } @@ -914,6 +930,24 @@ func _HeadscaleService_SetPolicy_Handler(srv interface{}, ctx context.Context, d return interceptor(ctx, in, info, handler) } +func _HeadscaleService_TestACL_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(TestACLRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(HeadscaleServiceServer).TestACL(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: HeadscaleService_TestACL_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(HeadscaleServiceServer).TestACL(ctx, req.(*TestACLRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _HeadscaleService_Health_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(HealthRequest) if err := dec(in); err != nil { @@ -1035,6 +1069,10 @@ var HeadscaleService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetPolicy", Handler: _HeadscaleService_SetPolicy_Handler, }, + { + MethodName: "TestACL", + Handler: _HeadscaleService_TestACL_Handler, + }, { MethodName: "Health", Handler: _HeadscaleService_Health_Handler, diff --git a/gen/go/headscale/v1/policy.pb.go b/gen/go/headscale/v1/policy.pb.go index faa3fc40..b2dae98e 100644 --- a/gen/go/headscale/v1/policy.pb.go +++ b/gen/go/headscale/v1/policy.pb.go @@ -206,6 +206,286 @@ func (x *GetPolicyResponse) GetUpdatedAt() *timestamppb.Timestamp { return nil } +type ACLTest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Source alias (user, group, tag, host, or IP) to test from. + Src string `protobuf:"bytes,1,opt,name=src,proto3" json:"src,omitempty"` + // Protocol to test (tcp, udp, icmp). Defaults to TCP/UDP if empty. + Proto string `protobuf:"bytes,2,opt,name=proto,proto3" json:"proto,omitempty"` + // Destinations (in "host:port" format) that should be allowed. + Accept []string `protobuf:"bytes,3,rep,name=accept,proto3" json:"accept,omitempty"` + // Destinations (in "host:port" format) that should be denied. + Deny []string `protobuf:"bytes,4,rep,name=deny,proto3" json:"deny,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ACLTest) Reset() { + *x = ACLTest{} + mi := &file_headscale_v1_policy_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ACLTest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ACLTest) ProtoMessage() {} + +func (x *ACLTest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ACLTest.ProtoReflect.Descriptor instead. +func (*ACLTest) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{4} +} + +func (x *ACLTest) GetSrc() string { + if x != nil { + return x.Src + } + return "" +} + +func (x *ACLTest) GetProto() string { + if x != nil { + return x.Proto + } + return "" +} + +func (x *ACLTest) GetAccept() []string { + if x != nil { + return x.Accept + } + return nil +} + +func (x *ACLTest) GetDeny() []string { + if x != nil { + return x.Deny + } + return nil +} + +type ACLTestResult struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Source alias that was tested. + Src string `protobuf:"bytes,1,opt,name=src,proto3" json:"src,omitempty"` + // Whether the test passed (all assertions correct). + Passed bool `protobuf:"varint,2,opt,name=passed,proto3" json:"passed,omitempty"` + // Errors encountered during test execution. + Errors []string `protobuf:"bytes,3,rep,name=errors,proto3" json:"errors,omitempty"` + // Destinations that were correctly allowed. + AcceptOk []string `protobuf:"bytes,4,rep,name=accept_ok,json=acceptOk,proto3" json:"accept_ok,omitempty"` + // Destinations that should have been allowed but were denied. + AcceptFail []string `protobuf:"bytes,5,rep,name=accept_fail,json=acceptFail,proto3" json:"accept_fail,omitempty"` + // Destinations that were correctly denied. + DenyOk []string `protobuf:"bytes,6,rep,name=deny_ok,json=denyOk,proto3" json:"deny_ok,omitempty"` + // Destinations that should have been denied but were allowed. + DenyFail []string `protobuf:"bytes,7,rep,name=deny_fail,json=denyFail,proto3" json:"deny_fail,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ACLTestResult) Reset() { + *x = ACLTestResult{} + mi := &file_headscale_v1_policy_proto_msgTypes[5] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ACLTestResult) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ACLTestResult) ProtoMessage() {} + +func (x *ACLTestResult) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[5] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ACLTestResult.ProtoReflect.Descriptor instead. +func (*ACLTestResult) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{5} +} + +func (x *ACLTestResult) GetSrc() string { + if x != nil { + return x.Src + } + return "" +} + +func (x *ACLTestResult) GetPassed() bool { + if x != nil { + return x.Passed + } + return false +} + +func (x *ACLTestResult) GetErrors() []string { + if x != nil { + return x.Errors + } + return nil +} + +func (x *ACLTestResult) GetAcceptOk() []string { + if x != nil { + return x.AcceptOk + } + return nil +} + +func (x *ACLTestResult) GetAcceptFail() []string { + if x != nil { + return x.AcceptFail + } + return nil +} + +func (x *ACLTestResult) GetDenyOk() []string { + if x != nil { + return x.DenyOk + } + return nil +} + +func (x *ACLTestResult) GetDenyFail() []string { + if x != nil { + return x.DenyFail + } + return nil +} + +type TestACLRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Tests to run. + Tests []*ACLTest `protobuf:"bytes,1,rep,name=tests,proto3" json:"tests,omitempty"` + // Optional: policy content to test against a proposed policy. + // If empty, tests run against the current active policy. + Policy string `protobuf:"bytes,2,opt,name=policy,proto3" json:"policy,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TestACLRequest) Reset() { + *x = TestACLRequest{} + mi := &file_headscale_v1_policy_proto_msgTypes[6] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TestACLRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestACLRequest) ProtoMessage() {} + +func (x *TestACLRequest) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[6] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestACLRequest.ProtoReflect.Descriptor instead. +func (*TestACLRequest) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{6} +} + +func (x *TestACLRequest) GetTests() []*ACLTest { + if x != nil { + return x.Tests + } + return nil +} + +func (x *TestACLRequest) GetPolicy() string { + if x != nil { + return x.Policy + } + return "" +} + +type TestACLResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + // Whether all tests passed. + AllPassed bool `protobuf:"varint,1,opt,name=all_passed,json=allPassed,proto3" json:"all_passed,omitempty"` + // Individual test results. + Results []*ACLTestResult `protobuf:"bytes,2,rep,name=results,proto3" json:"results,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *TestACLResponse) Reset() { + *x = TestACLResponse{} + mi := &file_headscale_v1_policy_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *TestACLResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*TestACLResponse) ProtoMessage() {} + +func (x *TestACLResponse) ProtoReflect() protoreflect.Message { + mi := &file_headscale_v1_policy_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use TestACLResponse.ProtoReflect.Descriptor instead. +func (*TestACLResponse) Descriptor() ([]byte, []int) { + return file_headscale_v1_policy_proto_rawDescGZIP(), []int{7} +} + +func (x *TestACLResponse) GetAllPassed() bool { + if x != nil { + return x.AllPassed + } + return false +} + +func (x *TestACLResponse) GetResults() []*ACLTestResult { + if x != nil { + return x.Results + } + return nil +} + var File_headscale_v1_policy_proto protoreflect.FileDescriptor const file_headscale_v1_policy_proto_rawDesc = "" + @@ -221,7 +501,28 @@ const file_headscale_v1_policy_proto_rawDesc = "" + "\x11GetPolicyResponse\x12\x16\n" + "\x06policy\x18\x01 \x01(\tR\x06policy\x129\n" + "\n" + - "updated_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\tupdatedAtB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" + "updated_at\x18\x02 \x01(\v2\x1a.google.protobuf.TimestampR\tupdatedAt\"]\n" + + "\aACLTest\x12\x10\n" + + "\x03src\x18\x01 \x01(\tR\x03src\x12\x14\n" + + "\x05proto\x18\x02 \x01(\tR\x05proto\x12\x16\n" + + "\x06accept\x18\x03 \x03(\tR\x06accept\x12\x12\n" + + "\x04deny\x18\x04 \x03(\tR\x04deny\"\xc5\x01\n" + + "\rACLTestResult\x12\x10\n" + + "\x03src\x18\x01 \x01(\tR\x03src\x12\x16\n" + + "\x06passed\x18\x02 \x01(\bR\x06passed\x12\x16\n" + + "\x06errors\x18\x03 \x03(\tR\x06errors\x12\x1b\n" + + "\taccept_ok\x18\x04 \x03(\tR\bacceptOk\x12\x1f\n" + + "\vaccept_fail\x18\x05 \x03(\tR\n" + + "acceptFail\x12\x17\n" + + "\adeny_ok\x18\x06 \x03(\tR\x06denyOk\x12\x1b\n" + + "\tdeny_fail\x18\a \x03(\tR\bdenyFail\"U\n" + + "\x0eTestACLRequest\x12+\n" + + "\x05tests\x18\x01 \x03(\v2\x15.headscale.v1.ACLTestR\x05tests\x12\x16\n" + + "\x06policy\x18\x02 \x01(\tR\x06policy\"g\n" + + "\x0fTestACLResponse\x12\x1d\n" + + "\n" + + "all_passed\x18\x01 \x01(\bR\tallPassed\x125\n" + + "\aresults\x18\x02 \x03(\v2\x1b.headscale.v1.ACLTestResultR\aresultsB)Z'github.com/juanfont/headscale/gen/go/v1b\x06proto3" var ( file_headscale_v1_policy_proto_rawDescOnce sync.Once @@ -235,22 +536,28 @@ func file_headscale_v1_policy_proto_rawDescGZIP() []byte { return file_headscale_v1_policy_proto_rawDescData } -var file_headscale_v1_policy_proto_msgTypes = make([]protoimpl.MessageInfo, 4) +var file_headscale_v1_policy_proto_msgTypes = make([]protoimpl.MessageInfo, 8) var file_headscale_v1_policy_proto_goTypes = []any{ (*SetPolicyRequest)(nil), // 0: headscale.v1.SetPolicyRequest (*SetPolicyResponse)(nil), // 1: headscale.v1.SetPolicyResponse (*GetPolicyRequest)(nil), // 2: headscale.v1.GetPolicyRequest (*GetPolicyResponse)(nil), // 3: headscale.v1.GetPolicyResponse - (*timestamppb.Timestamp)(nil), // 4: google.protobuf.Timestamp + (*ACLTest)(nil), // 4: headscale.v1.ACLTest + (*ACLTestResult)(nil), // 5: headscale.v1.ACLTestResult + (*TestACLRequest)(nil), // 6: headscale.v1.TestACLRequest + (*TestACLResponse)(nil), // 7: headscale.v1.TestACLResponse + (*timestamppb.Timestamp)(nil), // 8: google.protobuf.Timestamp } var file_headscale_v1_policy_proto_depIdxs = []int32{ - 4, // 0: headscale.v1.SetPolicyResponse.updated_at:type_name -> google.protobuf.Timestamp - 4, // 1: headscale.v1.GetPolicyResponse.updated_at:type_name -> google.protobuf.Timestamp - 2, // [2:2] is the sub-list for method output_type - 2, // [2:2] is the sub-list for method input_type - 2, // [2:2] is the sub-list for extension type_name - 2, // [2:2] is the sub-list for extension extendee - 0, // [0:2] is the sub-list for field type_name + 8, // 0: headscale.v1.SetPolicyResponse.updated_at:type_name -> google.protobuf.Timestamp + 8, // 1: headscale.v1.GetPolicyResponse.updated_at:type_name -> google.protobuf.Timestamp + 4, // 2: headscale.v1.TestACLRequest.tests:type_name -> headscale.v1.ACLTest + 5, // 3: headscale.v1.TestACLResponse.results:type_name -> headscale.v1.ACLTestResult + 4, // [4:4] is the sub-list for method output_type + 4, // [4:4] is the sub-list for method input_type + 4, // [4:4] is the sub-list for extension type_name + 4, // [4:4] is the sub-list for extension extendee + 0, // [0:4] is the sub-list for field type_name } func init() { file_headscale_v1_policy_proto_init() } @@ -264,7 +571,7 @@ func file_headscale_v1_policy_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_headscale_v1_policy_proto_rawDesc), len(file_headscale_v1_policy_proto_rawDesc)), NumEnums: 0, - NumMessages: 4, + NumMessages: 8, NumExtensions: 0, NumServices: 0, }, diff --git a/gen/openapiv2/headscale/v1/headscale.swagger.json b/gen/openapiv2/headscale/v1/headscale.swagger.json index 0791e8ca..66bcdce4 100644 --- a/gen/openapiv2/headscale/v1/headscale.swagger.json +++ b/gen/openapiv2/headscale/v1/headscale.swagger.json @@ -549,6 +549,38 @@ ] } }, + "/api/v1/policy/test": { + "post": { + "operationId": "HeadscaleService_TestACL", + "responses": { + "200": { + "description": "A successful response.", + "schema": { + "$ref": "#/definitions/v1TestACLResponse" + } + }, + "default": { + "description": "An unexpected error response.", + "schema": { + "$ref": "#/definitions/rpcStatus" + } + } + }, + "parameters": [ + { + "name": "body", + "in": "body", + "required": true, + "schema": { + "$ref": "#/definitions/v1TestACLRequest" + } + } + ], + "tags": [ + "HeadscaleService" + ] + } + }, "/api/v1/preauthkey": { "get": { "operationId": "HeadscaleService_ListPreAuthKeys", @@ -872,6 +904,81 @@ } } }, + "v1ACLTest": { + "type": "object", + "properties": { + "src": { + "type": "string", + "description": "Source alias (user, group, tag, host, or IP) to test from." + }, + "proto": { + "type": "string", + "description": "Protocol to test (tcp, udp, icmp). Defaults to TCP/UDP if empty." + }, + "accept": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Destinations (in \"host:port\" format) that should be allowed." + }, + "deny": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Destinations (in \"host:port\" format) that should be denied." + } + } + }, + "v1ACLTestResult": { + "type": "object", + "properties": { + "src": { + "type": "string", + "description": "Source alias that was tested." + }, + "passed": { + "type": "boolean", + "description": "Whether the test passed (all assertions correct)." + }, + "errors": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Errors encountered during test execution." + }, + "acceptOk": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Destinations that were correctly allowed." + }, + "acceptFail": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Destinations that should have been allowed but were denied." + }, + "denyOk": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Destinations that were correctly denied." + }, + "denyFail": { + "type": "array", + "items": { + "type": "string" + }, + "description": "Destinations that should have been denied but were allowed." + } + } + }, "v1ApiKey": { "type": "object", "properties": { @@ -1330,6 +1437,40 @@ } } }, + "v1TestACLRequest": { + "type": "object", + "properties": { + "tests": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/definitions/v1ACLTest" + }, + "description": "Tests to run." + }, + "policy": { + "type": "string", + "description": "Optional: policy content to test against a proposed policy.\nIf empty, tests run against the current active policy." + } + } + }, + "v1TestACLResponse": { + "type": "object", + "properties": { + "allPassed": { + "type": "boolean", + "description": "Whether all tests passed." + }, + "results": { + "type": "array", + "items": { + "type": "object", + "$ref": "#/definitions/v1ACLTestResult" + }, + "description": "Individual test results." + } + } + }, "v1User": { "type": "object", "properties": { diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 9573d1ea..beaba863 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -26,6 +26,7 @@ import ( "tailscale.com/types/views" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/state" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" @@ -726,6 +727,30 @@ func (api headscaleV1APIServer) SetPolicy( } } + // Run embedded ACL tests if present + // Failed tests block the policy update (like Tailscale) + users, err := api.h.state.ListAllUsers() + if err != nil { + return nil, fmt.Errorf("loading users for test validation: %w", err) + } + + testPM, err := policyv2.NewPolicyManager([]byte(p), users, nodes) + if err != nil { + return nil, fmt.Errorf("parsing policy for tests: %w", err) + } + + pol := testPM.Policy() + if pol != nil && len(pol.Tests) > 0 { + results := testPM.RunTests(pol.Tests) + if !results.AllPassed { + return nil, status.Errorf(codes.InvalidArgument, + "ACL tests failed: %s", results.Errors()) + } + log.Info(). + Int("tests_passed", len(results.Results)). + Msg("All embedded ACL tests passed") + } + updated, err := api.h.state.SetPolicyInDB(p) if err != nil { return nil, err @@ -835,4 +860,64 @@ func (api headscaleV1APIServer) Health( return response, healthErr } +func (api headscaleV1APIServer) TestACL( + ctx context.Context, + request *v1.TestACLRequest, +) (*v1.TestACLResponse, error) { + if len(request.GetTests()) == 0 { + return nil, status.Error(codes.InvalidArgument, "at least one test is required") + } + + // Convert proto tests to internal ACLTest structs + tests := make([]policyv2.ACLTest, len(request.GetTests())) + for i, t := range request.GetTests() { + tests[i] = policyv2.ACLTest{ + Src: t.GetSrc(), + Proto: policyv2.Protocol(t.GetProto()), + Accept: t.GetAccept(), + Deny: t.GetDeny(), + } + } + + var results policyv2.ACLTestResults + + if request.GetPolicy() != "" { + // Test against a proposed policy + users, err := api.h.state.ListAllUsers() + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list users: %v", err) + } + nodes := api.h.state.ListNodes() + + pm, err := policyv2.NewPolicyManager([]byte(request.GetPolicy()), users, nodes) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid policy: %v", err) + } + + results = pm.RunTests(tests) + } else { + // Test against current active policy + results = api.h.state.RunACLTests(tests) + } + + // Convert results to proto response + protoResults := make([]*v1.ACLTestResult, len(results.Results)) + for i, r := range results.Results { + protoResults[i] = &v1.ACLTestResult{ + Src: r.Src, + Passed: r.Passed, + Errors: r.Errors, + AcceptOk: r.AcceptOK, + AcceptFail: r.AcceptFail, + DenyOk: r.DenyOK, + DenyFail: r.DenyFail, + } + } + + return &v1.TestACLResponse{ + AllPassed: results.AllPassed, + Results: protoResults, + }, nil +} + func (api headscaleV1APIServer) mustEmbedUnimplementedHeadscaleServiceServer() {} diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index f4db88a4..ee9128cc 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -32,6 +32,12 @@ type PolicyManager interface { // NodeCanApproveRoute reports whether the given node can approve the given route. NodeCanApproveRoute(types.NodeView, netip.Prefix) bool + // RunTests runs multiple ACL tests and returns aggregated results. + RunTests(tests []policyv2.ACLTest) policyv2.ACLTestResults + + // RunTest evaluates a single ACL test against the current policy. + RunTest(test policyv2.ACLTest) policyv2.ACLTestResult + Version() int DebugString() string } diff --git a/hscontrol/policy/v2/policy.go b/hscontrol/policy/v2/policy.go index c5d87722..a2e86f29 100644 --- a/hscontrol/policy/v2/policy.go +++ b/hscontrol/policy/v2/policy.go @@ -721,6 +721,16 @@ func (pm *PolicyManager) Version() int { return 2 } +// Policy returns the underlying Policy struct. +// This can be used to access embedded tests or other policy data. +func (pm *PolicyManager) Policy() *Policy { + if pm == nil { + return nil + } + + return pm.pol +} + func (pm *PolicyManager) DebugString() string { if pm == nil { return "PolicyManager is not setup" diff --git a/hscontrol/policy/v2/test.go b/hscontrol/policy/v2/test.go new file mode 100644 index 00000000..9309f595 --- /dev/null +++ b/hscontrol/policy/v2/test.go @@ -0,0 +1,385 @@ +package v2 + +import ( + "errors" + "fmt" + "net/netip" + "slices" + "strings" + + "github.com/juanfont/headscale/hscontrol/types" + "go4.org/netipx" + "tailscale.com/tailcfg" + "tailscale.com/types/views" +) + +var errDestinationNoIPs = errors.New("destination resolved to no IP addresses") + +// ACLTest represents a single ACL test case. +// It defines a source and lists of destinations that should be allowed or denied. +type ACLTest struct { + // Src is the source alias (user, group, tag, host, or IP) to test from. + Src string `json:"src"` + + // Proto is the protocol to test. If empty, defaults to TCP/UDP. + Proto Protocol `json:"proto,omitempty"` + + // Accept is a list of destinations (in "host:port" format) that should be allowed. + Accept []string `json:"accept,omitempty"` + + // Deny is a list of destinations (in "host:port" format) that should be denied. + Deny []string `json:"deny,omitempty"` +} + +// ACLTestResult represents the result of running a single ACL test. +type ACLTestResult struct { + // Src is the source alias that was tested. + Src string `json:"src"` + + // Proto is the protocol that was tested. Empty means TCP/UDP (default). + Proto Protocol `json:"proto,omitempty"` + + // Passed indicates whether the test passed (all assertions correct). + Passed bool `json:"passed"` + + // Errors contains any errors encountered during test execution. + Errors []string `json:"errors,omitempty"` + + // AcceptOK lists destinations that were correctly allowed. + AcceptOK []string `json:"accept_ok,omitempty"` + + // AcceptFail lists destinations that should have been allowed but were denied. + AcceptFail []string `json:"accept_fail,omitempty"` + + // DenyOK lists destinations that were correctly denied. + DenyOK []string `json:"deny_ok,omitempty"` + + // DenyFail lists destinations that should have been denied but were allowed. + DenyFail []string `json:"deny_fail,omitempty"` +} + +// ACLTestResults represents the aggregated results of running multiple ACL tests. +type ACLTestResults struct { + // AllPassed indicates whether all tests passed. + AllPassed bool `json:"all_passed"` + + // Results contains the individual test results. + Results []ACLTestResult `json:"results"` +} + +// Errors returns a combined error message from all failed tests. +// Each error is on a separate line for readability. +func (r ACLTestResults) Errors() string { + var errs []string + + for _, result := range r.Results { + if !result.Passed { + // Build protocol suffix for error messages + protoSuffix := "" + if result.Proto != "" { + protoSuffix = fmt.Sprintf(" (%s)", result.Proto) + } + + for _, e := range result.Errors { + errs = append(errs, fmt.Sprintf("%s%s: %s", result.Src, protoSuffix, e)) + } + + for _, dest := range result.AcceptFail { + errs = append(errs, fmt.Sprintf("%s -> %s%s: expected ALLOWED, got DENIED", result.Src, dest, protoSuffix)) + } + + for _, dest := range result.DenyFail { + errs = append(errs, fmt.Sprintf("%s -> %s%s: expected DENIED, got ALLOWED", result.Src, dest, protoSuffix)) + } + } + } + + return strings.Join(errs, "\n") +} + +// RunTests runs multiple ACL tests and returns aggregated results. +func (pm *PolicyManager) RunTests(tests []ACLTest) ACLTestResults { + results := ACLTestResults{ + AllPassed: true, + Results: make([]ACLTestResult, 0, len(tests)), + } + + for _, test := range tests { + result := pm.RunTest(test) + + results.Results = append(results.Results, result) + if !result.Passed { + results.AllPassed = false + } + } + + return results +} + +// RunTest evaluates a single ACL test against the current policy. +// It resolves the source alias to IPs, then checks each accept/deny destination. +func (pm *PolicyManager) RunTest(test ACLTest) ACLTestResult { + result := ACLTestResult{ + Src: test.Src, + Proto: test.Proto, + Passed: true, + } + + if pm == nil || pm.pol == nil { + result.Passed = false + result.Errors = append(result.Errors, "no policy configured") + + return result + } + + pm.mu.Lock() + defer pm.mu.Unlock() + + // Resolve the source alias to an IP set + srcIPs, srcErr := pm.resolveTestAlias(test.Src) + if srcErr != nil { + result.Passed = false + result.Errors = append(result.Errors, fmt.Sprintf("failed to resolve source %q: %v", test.Src, srcErr)) + + return result + } + + if srcIPs == nil || len(srcIPs.Prefixes()) == 0 { + result.Passed = false + result.Errors = append(result.Errors, fmt.Sprintf("source %q resolved to no IP addresses", test.Src)) + + return result + } + + // Test each destination in Accept list + for _, dest := range test.Accept { + allowed, err := pm.testAccess(srcIPs, dest, test.Proto) + if err != nil { + result.Passed = false + result.Errors = append(result.Errors, fmt.Sprintf("error testing %q: %v", dest, err)) + + continue + } + + if allowed { + result.AcceptOK = append(result.AcceptOK, dest) + } else { + result.Passed = false + result.AcceptFail = append(result.AcceptFail, dest) + } + } + + // Test each destination in Deny list + for _, dest := range test.Deny { + allowed, err := pm.testAccess(srcIPs, dest, test.Proto) + if err != nil { + result.Passed = false + result.Errors = append(result.Errors, fmt.Sprintf("error testing %q: %v", dest, err)) + + continue + } + + if !allowed { + result.DenyOK = append(result.DenyOK, dest) + } else { + result.Passed = false + result.DenyFail = append(result.DenyFail, dest) + } + } + + return result +} + +// resolveTestAlias resolves a test alias string to an IP set. +// It supports all standard alias types: user, group, tag, host, prefix, and autogroup. +func (pm *PolicyManager) resolveTestAlias(aliasStr string) (*netipx.IPSet, error) { + alias, err := parseAlias(aliasStr) + if err != nil { + return nil, fmt.Errorf("invalid alias: %w", err) + } + + ipSet, err := alias.Resolve(pm.pol, pm.users, pm.nodes) + if err != nil { + return nil, fmt.Errorf("failed to resolve alias: %w", err) + } + + return ipSet, nil +} + +// testAccess checks if traffic from srcIPs to the destination is allowed. +// The destination is in "host:port" format (e.g., "server:22" or "10.0.0.1:80"). +func (pm *PolicyManager) testAccess(srcIPs *netipx.IPSet, dest string, proto Protocol) (bool, error) { + // Parse the destination as AliasWithPorts + destWithPorts, err := pm.parseDestination(dest) + if err != nil { + return false, fmt.Errorf("invalid destination %q: %w", dest, err) + } + + // Resolve destination alias to IPs + destIPs, err := destWithPorts.Resolve(pm.pol, pm.users, pm.nodes) + if err != nil { + return false, fmt.Errorf("failed to resolve destination: %w", err) + } + + if destIPs == nil || len(destIPs.Prefixes()) == 0 { + return false, errDestinationNoIPs + } + + // Check access using the matchers + // We need to check if any rule allows srcIPs to reach destIPs + return pm.checkMatcherAccess(srcIPs, destIPs, destWithPorts.Ports, proto), nil +} + +// parseDestination parses a destination string in "host:port" format. +func (pm *PolicyManager) parseDestination(dest string) (*AliasWithPorts, error) { + var awp AliasWithPorts + + // Use the existing AliasWithPorts unmarshaling logic + err := awp.UnmarshalJSON([]byte(`"` + dest + `"`)) + if err != nil { + return nil, err + } + + return &awp, nil +} + +// checkMatcherAccess checks if access is allowed from srcIPs to destIPs for the given ports. +// It uses the compiled filter rules (not just matchers) to properly check port restrictions. +// The proto parameter filters by protocol - rules with IPProto set will only match if the +// requested protocol is in the rule's protocol list. +func (pm *PolicyManager) checkMatcherAccess(srcIPs, destIPs *netipx.IPSet, ports []tailcfg.PortRange, proto Protocol) bool { + // Get source prefixes + srcPrefixes := srcIPs.Prefixes() + if len(srcPrefixes) == 0 { + return false + } + + // ALL source prefixes must have access to the destination + // If any source prefix cannot reach the destination, return false + for _, srcPrefix := range srcPrefixes { + if !pm.checkSingleSourceAccess(srcPrefix, destIPs, ports, proto) { + return false + } + } + + return true +} + +// checkSingleSourceAccess checks if a single source prefix has access to the destination. +func (pm *PolicyManager) checkSingleSourceAccess(srcPrefix netip.Prefix, destIPs *netipx.IPSet, ports []tailcfg.PortRange, proto Protocol) bool { + // Check against filter rules (which include port information) + for _, rule := range pm.filter { + // Check if this source prefix matches the rule's source IPs + srcMatches := false + + for _, ruleSrcIP := range rule.SrcIPs { + // Parse the rule's source IP as a prefix + rulePrefix, err := netip.ParsePrefix(ruleSrcIP) + if err != nil { + // Try parsing as single IP + ruleAddr, err := netip.ParseAddr(ruleSrcIP) + if err != nil { + continue + } + + rulePrefix = netip.PrefixFrom(ruleAddr, ruleAddr.BitLen()) + } + + // Check if the source prefix overlaps with the rule's source + if srcPrefix.Overlaps(rulePrefix) { + srcMatches = true + + break + } + } + + if !srcMatches { + continue + } + + // Check if protocol matches + // If the rule has IPProto set, only match if the requested protocol is in the list + if len(rule.IPProto) > 0 { + requestedProtos, _ := proto.parseProtocol() + protoMatches := false + + for _, ruleProto := range rule.IPProto { + if slices.Contains(requestedProtos, ruleProto) { + protoMatches = true + + break + } + } + + if !protoMatches { + continue + } + } + + // Check if any destination port range matches + for _, dstPort := range rule.DstPorts { + // Parse the rule's destination IP + dstPrefix, err := netip.ParsePrefix(dstPort.IP) + if err != nil { + dstAddr, err := netip.ParseAddr(dstPort.IP) + if err != nil { + continue + } + + dstPrefix = netip.PrefixFrom(dstAddr, dstAddr.BitLen()) + } + + // Check if destination IPs overlap + dstMatches := false + + for _, prefix := range destIPs.Prefixes() { + if prefix.Overlaps(dstPrefix) { + dstMatches = true + + break + } + } + + if !dstMatches { + continue + } + + // Check if ports match + if portsMatch(ports, dstPort.Ports) { + return true + } + } + } + + return false +} + +// portsMatch checks if the requested ports are allowed by the rule's port range. +func portsMatch(requestedPorts []tailcfg.PortRange, allowedPorts tailcfg.PortRange) bool { + // If no specific ports requested, check if any port is allowed + if len(requestedPorts) == 0 { + return true + } + + // Check if any requested port is within the allowed range + for _, requested := range requestedPorts { + // Check if the requested port range is within the allowed range + if requested.First >= allowedPorts.First && requested.Last <= allowedPorts.Last { + return true + } + } + + return false +} + +// RunTestsWithPolicy creates a temporary PolicyManager from the given policy bytes +// and runs the provided tests against it. This is useful for testing a proposed +// policy before saving it. +func RunTestsWithPolicy(policyBytes []byte, users types.Users, nodes views.Slice[types.NodeView], tests []ACLTest) (ACLTestResults, error) { + pm, err := NewPolicyManager(policyBytes, users, nodes) + if err != nil { + return ACLTestResults{}, fmt.Errorf("failed to parse policy: %w", err) + } + + return pm.RunTests(tests), nil +} diff --git a/hscontrol/policy/v2/test_test.go b/hscontrol/policy/v2/test_test.go new file mode 100644 index 00000000..cd1cfa16 --- /dev/null +++ b/hscontrol/policy/v2/test_test.go @@ -0,0 +1,1199 @@ +package v2 + +import ( + "testing" + + "github.com/juanfont/headscale/hscontrol/types" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "gorm.io/gorm" + "tailscale.com/types/ptr" +) + +// testPolicyAliceToBob22 is a basic test policy allowing alice to access bob on port 22. +const testPolicyAliceToBob22 = `{ + "acls": [ + { + "action": "accept", + "src": ["alice@example.com"], + "dst": ["bob@example.com:22"] + } + ] +}` + +func TestRunTest_BasicAccept(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + pm, err := NewPolicyManager([]byte(testPolicyAliceToBob22), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_access_bob_port_22", + test: ACLTest{ + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_access_bob_port_80", + test: ACLTest{ + Src: "alice@example.com", + Deny: []string{"bob@example.com:80"}, + }, + wantPassed: true, + }, + { + name: "bob_cannot_access_alice", + test: ACLTest{ + Src: "bob@example.com", + Deny: []string{"alice@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "accept_fails_when_access_denied", + test: ACLTest{ + Src: "bob@example.com", + Accept: []string{"alice@example.com:22"}, + }, + wantPassed: false, + }, + { + name: "deny_fails_when_access_allowed", + test: ACLTest{ + Src: "alice@example.com", + Deny: []string{"bob@example.com:22"}, + }, + wantPassed: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s", tt.name) + }) + } +} + +func TestRunTest_WithTags(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "web-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + Tags: []string{"tag:webserver"}, + }, + { + ID: 3, + Hostname: "db-server", + IPv4: ap("100.64.0.3"), + IPv6: ap("fd7a:115c:a1e0::3"), + Tags: []string{"tag:database"}, + }, + } + + policy := `{ + "tagOwners": { + "tag:webserver": ["alice@example.com"], + "tag:database": ["alice@example.com"] + }, + "acls": [ + { + "action": "accept", + "src": ["alice@example.com"], + "dst": ["tag:webserver:80", "tag:webserver:443"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_access_webserver", + test: ACLTest{ + Src: "alice@example.com", + Accept: []string{"tag:webserver:80"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_access_database", + test: ACLTest{ + Src: "alice@example.com", + Deny: []string{"tag:database:5432"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s", tt.name) + }) + } +} + +func TestRunTest_WithGroups(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-laptop", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + { + ID: 3, + Hostname: "server", + IPv4: ap("100.64.0.3"), + IPv6: ap("fd7a:115c:a1e0::3"), + Tags: []string{"tag:server"}, + }, + } + + policy := `{ + "groups": { + "group:admins": ["alice@example.com"] + }, + "tagOwners": { + "tag:server": ["group:admins"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["tag:server:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "admin_group_can_access_server", + test: ACLTest{ + Src: "group:admins", + Accept: []string{"tag:server:22"}, + }, + wantPassed: true, + }, + { + name: "non_admin_cannot_access_server", + test: ACLTest{ + Src: "bob@example.com", + Deny: []string{"tag:server:22"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s", tt.name) + }) + } +} + +func TestRunTest_InvalidSource(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + } + + policy := `{ + "acls": [ + { + "action": "accept", + "src": ["alice@example.com"], + "dst": ["alice@example.com:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + // Test with non-existent user + result := pm.RunTest(ACLTest{ + Src: "nonexistent@example.com", + Accept: []string{"alice@example.com:22"}, + }) + + assert.False(t, result.Passed, "test should fail for non-existent source") + assert.NotEmpty(t, result.Errors, "should have error messages") +} + +func TestRunTests_Multiple(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + pm, err := NewPolicyManager([]byte(testPolicyAliceToBob22), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []ACLTest{ + { + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + }, + { + Src: "bob@example.com", + Deny: []string{"alice@example.com:22"}, + }, + } + + results := pm.RunTests(tests) + + assert.True(t, results.AllPassed, "all tests should pass") + assert.Len(t, results.Results, 2, "should have 2 results") +} + +func TestRunTests_SomeFail(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + pm, err := NewPolicyManager([]byte(testPolicyAliceToBob22), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []ACLTest{ + { + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + }, + { + // This test should fail - bob cannot access alice + Src: "bob@example.com", + Accept: []string{"alice@example.com:22"}, + }, + } + + results := pm.RunTests(tests) + + assert.False(t, results.AllPassed, "not all tests should pass") + assert.Len(t, results.Results, 2, "should have 2 results") + assert.True(t, results.Results[0].Passed, "first test should pass") + assert.False(t, results.Results[1].Passed, "second test should fail") + assert.NotEmpty(t, results.Errors(), "should have error description") +} + +func TestPolicyWithEmbeddedTests(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + // Policy with embedded tests + policy := `{ + "acls": [ + { + "action": "accept", + "src": ["alice@example.com"], + "dst": ["bob@example.com:22"] + } + ], + "tests": [ + { + "src": "alice@example.com", + "accept": ["bob@example.com:22"] + }, + { + "src": "bob@example.com", + "deny": ["alice@example.com:22"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + // Verify the tests were parsed + require.NotNil(t, pm.pol.Tests) + require.Len(t, pm.pol.Tests, 2) + + // Run the embedded tests + results := pm.RunTests(pm.pol.Tests) + assert.True(t, results.AllPassed, "embedded tests should pass") +} + +func TestRunTestsWithPolicy(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + tests := []ACLTest{ + { + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + }, + } + + results, err := RunTestsWithPolicy([]byte(testPolicyAliceToBob22), users, nodes.ViewSlice(), tests) + require.NoError(t, err) + assert.True(t, results.AllPassed, "tests should pass") +} + +func TestRunTest_ProtocolFiltering(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + // Policy with ICMP-only rule (wildcard IPs) and TCP rule for specific access + policy := `{ + "acls": [ + { + "action": "accept", + "proto": "icmp", + "src": ["*"], + "dst": ["*:*"] + }, + { + "action": "accept", + "src": ["alice@example.com"], + "dst": ["bob@example.com:22"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_access_bob_tcp_22", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_can_ping_bob_icmp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "icmp", + Accept: []string{"bob@example.com:*"}, + }, + wantPassed: true, + }, + { + name: "bob_can_ping_alice_icmp", + test: ACLTest{ + Src: "bob@example.com", + Proto: "icmp", + Accept: []string{"alice@example.com:*"}, + }, + wantPassed: true, + }, + { + name: "bob_cannot_access_alice_tcp_22_only_icmp_allowed", + test: ACLTest{ + Src: "bob@example.com", + Proto: "tcp", + Deny: []string{"alice@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "bob_cannot_access_alice_default_proto_only_icmp_allowed", + test: ACLTest{ + Src: "bob@example.com", + Deny: []string{"alice@example.com:22"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s: %v", tt.name, result) + }) + } +} + +func TestRunTest_TCPOnlyRule(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + // Policy with TCP-only rule + policy := `{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["alice@example.com"], + "dst": ["bob@example.com:22,80,443"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_ssh_bob_tcp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_can_http_bob_tcp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Accept: []string{"bob@example.com:80"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_access_bob_udp_same_ports", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Deny: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_access_bob_udp_dns", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Deny: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_ping_bob_icmp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "icmp", + Deny: []string{"bob@example.com:*"}, + }, + wantPassed: true, + }, + { + name: "default_proto_matches_tcp_rule", + test: ACLTest{ + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s: %v", tt.name, result) + }) + } +} + +func TestRunTest_UDPOnlyRule(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "dns-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + // Policy with UDP-only rule for DNS + policy := `{ + "acls": [ + { + "action": "accept", + "proto": "udp", + "src": ["alice@example.com"], + "dst": ["bob@example.com:53,123"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_dns_bob_udp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Accept: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + { + name: "alice_can_ntp_bob_udp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Accept: []string{"bob@example.com:123"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_ssh_bob_tcp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Deny: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_dns_bob_tcp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Deny: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + { + name: "default_proto_matches_udp_rule", + test: ACLTest{ + Src: "alice@example.com", + Accept: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s: %v", tt.name, result) + }) + } +} + +func TestRunTest_MixedProtocolRules(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + // Policy with separate TCP, UDP, and ICMP rules + policy := `{ + "acls": [ + { + "action": "accept", + "proto": "tcp", + "src": ["alice@example.com"], + "dst": ["bob@example.com:22,80,443"] + }, + { + "action": "accept", + "proto": "udp", + "src": ["alice@example.com"], + "dst": ["bob@example.com:53,123"] + }, + { + "action": "accept", + "proto": "icmp", + "src": ["alice@example.com"], + "dst": ["bob@example.com:*"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_ssh_tcp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_can_dns_udp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Accept: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + { + name: "alice_can_ping_icmp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "icmp", + Accept: []string{"bob@example.com:*"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_ssh_udp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Deny: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_dns_tcp", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Deny: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + { + name: "bob_cannot_access_alice_any_proto", + test: ACLTest{ + Src: "bob@example.com", + Deny: []string{"alice@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "bob_cannot_ping_alice", + test: ACLTest{ + Src: "bob@example.com", + Proto: "icmp", + Deny: []string{"alice@example.com:*"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s: %v", tt.name, result) + }) + } +} + +func TestRunTest_NoProtoDefaultsTCPUDP(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + // Policy WITHOUT proto field - defaults to TCP+UDP + policy := `{ + "acls": [ + { + "action": "accept", + "src": ["alice@example.com"], + "dst": ["bob@example.com:22,53,80"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(policy), users, nodes.ViewSlice()) + require.NoError(t, err) + + tests := []struct { + name string + test ACLTest + wantPassed bool + }{ + { + name: "alice_can_ssh_tcp_no_proto_rule", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + { + name: "alice_can_dns_udp_no_proto_rule", + test: ACLTest{ + Src: "alice@example.com", + Proto: "udp", + Accept: []string{"bob@example.com:53"}, + }, + wantPassed: true, + }, + { + name: "alice_can_http_tcp_no_proto_rule", + test: ACLTest{ + Src: "alice@example.com", + Proto: "tcp", + Accept: []string{"bob@example.com:80"}, + }, + wantPassed: true, + }, + { + name: "alice_cannot_ping_icmp_no_proto_rule", + test: ACLTest{ + Src: "alice@example.com", + Proto: "icmp", + Deny: []string{"bob@example.com:*"}, + }, + wantPassed: true, + }, + { + name: "default_test_proto_matches_no_proto_rule", + test: ACLTest{ + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + }, + wantPassed: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := pm.RunTest(tt.test) + assert.Equal(t, tt.wantPassed, result.Passed, "test result mismatch for %s: %v", tt.name, result) + }) + } +} + +func TestACLTestResult_Fields(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "alice", Email: "alice@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "bob", Email: "bob@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "alice-laptop", + IPv4: ap("100.64.0.1"), + IPv6: ap("fd7a:115c:a1e0::1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "bob-server", + IPv4: ap("100.64.0.2"), + IPv6: ap("fd7a:115c:a1e0::2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + } + + pm, err := NewPolicyManager([]byte(testPolicyAliceToBob22), users, nodes.ViewSlice()) + require.NoError(t, err) + + // Test that correctly populates AcceptOK and DenyOK + result := pm.RunTest(ACLTest{ + Src: "alice@example.com", + Accept: []string{"bob@example.com:22"}, + Deny: []string{"bob@example.com:80"}, + }) + + assert.True(t, result.Passed) + assert.Contains(t, result.AcceptOK, "bob@example.com:22") + assert.Contains(t, result.DenyOK, "bob@example.com:80") + assert.Empty(t, result.AcceptFail) + assert.Empty(t, result.DenyFail) + + // Test that correctly populates AcceptFail and DenyFail + result = pm.RunTest(ACLTest{ + Src: "bob@example.com", + Accept: []string{"alice@example.com:22"}, // Should fail - not allowed + Deny: []string{"bob@example.com:22"}, // Should fail - alice can access bob:22, not bob accessing bob:22 + }) + + assert.False(t, result.Passed) + assert.Contains(t, result.AcceptFail, "alice@example.com:22") +} + +// TestRunTest_AllSemantics tests that group access uses "ALL" semantics - +// ALL members of a group must have access for the test to pass, not just some. +// This prevents false positives when a user is in multiple groups with different privileges. +func TestRunTest_AllSemantics(t *testing.T) { + users := types.Users{ + {Model: gorm.Model{ID: 1}, Name: "admin", Email: "admin@example.com"}, + {Model: gorm.Model{ID: 2}, Name: "user1", Email: "user1@example.com"}, + {Model: gorm.Model{ID: 3}, Name: "user2", Email: "user2@example.com"}, + } + + nodes := types.Nodes{ + { + ID: 1, + Hostname: "admin-pc", + IPv4: ap("100.64.0.1"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + { + ID: 2, + Hostname: "user1-pc", + IPv4: ap("100.64.0.2"), + User: ptr.To(users[1]), + UserID: ptr.To(users[1].ID), + }, + { + ID: 3, + Hostname: "user2-pc", + IPv4: ap("100.64.0.3"), + User: ptr.To(users[2]), + UserID: ptr.To(users[2].ID), + }, + { + ID: 4, + Hostname: "server", + IPv4: ap("100.64.0.100"), + User: ptr.To(users[0]), + UserID: ptr.To(users[0].ID), + }, + } + + // Policy where: + // - group:admins (admin@) has full access + // - group:users (admin@, user1@, user2@) has limited access (only port 80) + // Admin is in BOTH groups + const mixedGroupPolicy = `{ + "groups": { + "group:admins": ["admin@example.com"], + "group:users": ["admin@example.com", "user1@example.com", "user2@example.com"] + }, + "acls": [ + { + "action": "accept", + "src": ["group:admins"], + "dst": ["100.64.0.100:*"] + }, + { + "action": "accept", + "src": ["group:users"], + "dst": ["100.64.0.100:80"] + } + ] + }` + + pm, err := NewPolicyManager([]byte(mixedGroupPolicy), users, nodes.ViewSlice()) + require.NoError(t, err) + + t.Run("admin_can_access_server_all_ports", func(t *testing.T) { + // Admin alone can access all ports + result := pm.RunTest(ACLTest{ + Src: "admin@example.com", + Accept: []string{"100.64.0.100:22", "100.64.0.100:80", "100.64.0.100:443"}, + }) + assert.True(t, result.Passed, "admin should have full access") + }) + + t.Run("group_users_can_only_access_port_80", func(t *testing.T) { + // group:users can only access port 80 + result := pm.RunTest(ACLTest{ + Src: "group:users", + Accept: []string{"100.64.0.100:80"}, + }) + assert.True(t, result.Passed, "group:users should access port 80") + }) + + t.Run("group_users_cannot_access_port_22_ALL_semantics", func(t *testing.T) { + // With "ALL" semantics: group:users -> :22 should FAIL + // because user1@ and user2@ don't have access to port 22 + // (even though admin@ does via group:admins) + result := pm.RunTest(ACLTest{ + Src: "group:users", + Accept: []string{"100.64.0.100:22"}, + }) + assert.False(t, result.Passed, + "group:users should NOT have access to port 22 - user1 and user2 don't have access") + }) + + t.Run("group_admins_can_access_port_22", func(t *testing.T) { + // group:admins -> :22 should pass (only admin@ is in this group) + result := pm.RunTest(ACLTest{ + Src: "group:admins", + Accept: []string{"100.64.0.100:22"}, + }) + assert.True(t, result.Passed, "group:admins should have access to port 22") + }) + + t.Run("individual_user_without_access_fails", func(t *testing.T) { + // user2@ alone should fail to access port 22 + result := pm.RunTest(ACLTest{ + Src: "user2@example.com", + Accept: []string{"100.64.0.100:22"}, + }) + assert.False(t, result.Passed, "user2 should not have access to port 22") + }) +} diff --git a/hscontrol/policy/v2/types.go b/hscontrol/policy/v2/types.go index 75b16bc1..a8e46dc7 100644 --- a/hscontrol/policy/v2/types.go +++ b/hscontrol/policy/v2/types.go @@ -1510,6 +1510,7 @@ type Policy struct { ACLs []ACL `json:"acls,omitempty"` AutoApprovers AutoApproverPolicy `json:"autoApprovers"` SSHs []SSH `json:"ssh,omitempty"` + Tests []ACLTest `json:"tests,omitempty"` } // MarshalJSON is deliberately not implemented for Policy. diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index b365269c..7dd16d3f 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -20,6 +20,7 @@ import ( hsdb "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/policy/matcher" + policyv2 "github.com/juanfont/headscale/hscontrol/policy/v2" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types/change" @@ -879,6 +880,16 @@ func (s *State) SetPolicy(pol []byte) (bool, error) { return s.polMan.SetPolicy(pol) } +// RunACLTests runs multiple ACL tests against the current policy and returns aggregated results. +func (s *State) RunACLTests(tests []policyv2.ACLTest) policyv2.ACLTestResults { + return s.polMan.RunTests(tests) +} + +// RunACLTest evaluates a single ACL test against the current policy. +func (s *State) RunACLTest(test policyv2.ACLTest) policyv2.ACLTestResult { + return s.polMan.RunTest(test) +} + // AutoApproveRoutes checks if a node's routes should be auto-approved. // AutoApproveRoutes checks if any routes should be auto-approved for a node and updates them. func (s *State) AutoApproveRoutes(nv types.NodeView) (change.Change, error) { diff --git a/proto/headscale/v1/headscale.proto b/proto/headscale/v1/headscale.proto index 5e556255..7bcb7b26 100644 --- a/proto/headscale/v1/headscale.proto +++ b/proto/headscale/v1/headscale.proto @@ -180,6 +180,13 @@ service HeadscaleService { body : "*" }; } + + rpc TestACL(TestACLRequest) returns (TestACLResponse) { + option (google.api.http) = { + post : "/api/v1/policy/test" + body : "*" + }; + } // --- Policy end --- // --- Health start --- diff --git a/proto/headscale/v1/policy.proto b/proto/headscale/v1/policy.proto index 6c52c01f..0342b377 100644 --- a/proto/headscale/v1/policy.proto +++ b/proto/headscale/v1/policy.proto @@ -17,3 +17,48 @@ message GetPolicyResponse { string policy = 1; google.protobuf.Timestamp updated_at = 2; } + +// ACL Testing messages + +message ACLTest { + // Source alias (user, group, tag, host, or IP) to test from. + string src = 1; + // Protocol to test (tcp, udp, icmp). Defaults to TCP/UDP if empty. + string proto = 2; + // Destinations (in "host:port" format) that should be allowed. + repeated string accept = 3; + // Destinations (in "host:port" format) that should be denied. + repeated string deny = 4; +} + +message ACLTestResult { + // Source alias that was tested. + string src = 1; + // Whether the test passed (all assertions correct). + bool passed = 2; + // Errors encountered during test execution. + repeated string errors = 3; + // Destinations that were correctly allowed. + repeated string accept_ok = 4; + // Destinations that should have been allowed but were denied. + repeated string accept_fail = 5; + // Destinations that were correctly denied. + repeated string deny_ok = 6; + // Destinations that should have been denied but were allowed. + repeated string deny_fail = 7; +} + +message TestACLRequest { + // Tests to run. + repeated ACLTest tests = 1; + // Optional: policy content to test against a proposed policy. + // If empty, tests run against the current active policy. + string policy = 2; +} + +message TestACLResponse { + // Whether all tests passed. + bool all_passed = 1; + // Individual test results. + repeated ACLTestResult results = 2; +}