make tags first class node owner (#2885)
Some checks failed
Build / build-nix (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=amd64 GOOS=linux) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=darwin) (push) Has been cancelled
Build / build-cross (GOARCH=arm64 GOOS=linux) (push) Has been cancelled
Check Generated Files / check-generated (push) Has been cancelled
NixOS Module Tests / nix-module-check (push) Has been cancelled
Tests / test (push) Has been cancelled

This PR changes tags to be something that exists on nodes in addition to users, to being its own thing. It is part of moving our tags support towards the correct tailscale compatible implementation.

There are probably rough edges in this PR, but the intention is to get it in, and then start fixing bugs from 0.28.0 milestone (long standing tags issue) to discover what works and what doesnt.

Updates #2417
Closes #2619
This commit is contained in:
Kristoffer Dalby 2025-12-02 12:01:25 +01:00 committed by GitHub
parent 705b239677
commit eb788cd007
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
49 changed files with 3102 additions and 757 deletions

View file

@ -55,11 +55,12 @@ jobs:
- TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand
- TestNodeTagCommand
- TestTaggedNodeRegistration
- TestTagPersistenceAcrossRestart
- TestNodeAdvertiseTagCommand
- TestNodeCommand
- TestNodeExpireCommand
- TestNodeRenameCommand
- TestNodeMoveCommand
- TestPolicyCommand
- TestPolicyBrokenConfigCommand
- TestDERPVerifyEndpoint

336
AGENTS.md
View file

@ -237,6 +237,21 @@ headscale/
- `policy.go`: Policy storage and retrieval
- Schema migrations in `schema.sql` with extensive test data coverage
**CRITICAL DATABASE MIGRATION RULES**:
1. **NEVER reorder existing migrations** - Migration order is immutable once committed
2. **ONLY add new migrations to the END** of the migrations array
3. **NEVER disable foreign keys** in new migrations - no new migrations should be added to `migrationsRequiringFKDisabled`
4. **Migration ID format**: `YYYYMMDDHHSS-short-description` (timestamp + descriptive suffix)
- Example: `202511131500-add-user-roles`
- The timestamp must be chronologically ordered
5. **New migrations go after the comment** "As of 2025-07-02, no new IDs should be added here"
6. If you need to rename a column that other migrations depend on:
- Accept that the old column name will exist in intermediate migration states
- Update code to work with the new column name
- Let AutoMigrate create the new column if needed
- Do NOT try to rename columns that later migrations reference
**Policy Engine (`hscontrol/policy/`)**
- `policy.go`: Core ACL evaluation logic, HuJSON parsing
@ -687,6 +702,326 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) {
}, 10*time.Second, 500*time.Millisecond, "mixed operations")
```
## Tags-as-Identity Architecture
### Overview
Headscale implements a **tags-as-identity** model where tags and user ownership are mutually exclusive ways to identify nodes. This is a fundamental architectural principle that affects node registration, ownership, ACL evaluation, and API behavior.
### Core Principle: Tags XOR User Ownership
Every node in Headscale is **either** tagged **or** user-owned, never both:
- **Tagged Nodes**: Ownership is defined by tags (e.g., `tag:server`, `tag:database`)
- Tags are set during registration via tagged PreAuthKey
- Tags are immutable after registration (cannot be changed via API)
- May have `UserID` set for "created by" tracking, but ownership is via tags
- Identified by: `node.IsTagged()` returns `true`
- **User-Owned Nodes**: Ownership is defined by user assignment
- Registered via OIDC, web auth, or untagged PreAuthKey
- Node belongs to a specific user's namespace
- No tags (empty tags array)
- Identified by: `node.UserID().Valid() && !node.IsTagged()`
### Critical Implementation Details
#### Node Identification Methods
```go
// Primary methods for determining node ownership
node.IsTagged() // Returns true if node has tags OR AuthKey.Tags
node.HasTag(tag) // Returns true if node has specific tag
node.IsUserOwned() // Returns true if UserID set AND not tagged
// IMPORTANT: UserID can be set on tagged nodes for tracking!
// Always use IsTagged() to determine actual ownership, not just UserID.Valid()
```
#### UserID Field Semantics
**Critical distinction**: `UserID` has different meanings depending on node type:
- **Tagged nodes**: `UserID` is optional "created by" tracking
- Indicates which user created the tagged PreAuthKey
- Does NOT define ownership (tags define ownership)
- Example: User "alice" creates tagged PreAuthKey with `tag:server`, node gets `UserID=alice.ID` + `Tags=["tag:server"]`
- **User-owned nodes**: `UserID` defines ownership
- Required field for non-tagged nodes
- Defines which user namespace the node belongs to
- Example: User "bob" registers via OIDC, node gets `UserID=bob.ID` + `Tags=[]`
#### Mapper Behavior (mapper/tail.go)
The mapper converts internal nodes to Tailscale protocol format, handling the TaggedDevices special user:
```go
// From mapper/tail.go:102-116
User: func() tailcfg.UserID {
// IMPORTANT: Tags-as-identity model
// Tagged nodes ALWAYS use TaggedDevices user, even if UserID is set
if node.IsTagged() {
return tailcfg.UserID(int64(types.TaggedDevices.ID))
}
// User-owned nodes: use the actual user ID
return tailcfg.UserID(int64(node.UserID().Get()))
}()
```
**TaggedDevices constant** (`types.TaggedDevices.ID = 2147455555`): Special user ID for all tagged nodes in MapResponse protocol.
#### Registration Flow
**Tagged Node Registration** (via tagged PreAuthKey):
1. User creates PreAuthKey with tags: `pak.Tags = ["tag:server"]`
2. Node registers with PreAuthKey
3. Node gets: `Tags = ["tag:server"]`, `UserID = user.ID` (optional tracking), `AuthKeyID = pak.ID`
4. `IsTagged()` returns `true` (ownership via tags)
5. MapResponse sends `User = TaggedDevices.ID`
**User-Owned Node Registration** (via OIDC/web/untagged PreAuthKey):
1. User authenticates or uses untagged PreAuthKey
2. Node registers
3. Node gets: `Tags = []`, `UserID = user.ID` (required)
4. `IsTagged()` returns `false` (ownership via user)
5. MapResponse sends `User = user.ID`
#### API Validation (SetTags)
The SetTags gRPC API enforces tags-as-identity rules:
```go
// From grpcv1.go:340-347
// User-owned nodes are nodes with UserID that are NOT tagged
isUserOwned := nodeView.UserID().Valid() && !nodeView.IsTagged()
if isUserOwned && len(request.GetTags()) > 0 {
return error("cannot set tags on user-owned nodes")
}
```
**Key validation rules**:
- ✅ Can call SetTags on tagged nodes (tags already define ownership)
- ❌ Cannot set tags on user-owned nodes (would violate XOR rule)
- ❌ Cannot remove all tags from tagged nodes (would orphan the node)
#### Database Layer (db/node.go)
**Tag storage**: Tags are stored in PostgreSQL ARRAY column and SQLite JSON column:
```sql
-- From schema.sql
tags TEXT[] DEFAULT '{}' NOT NULL, -- PostgreSQL
tags TEXT DEFAULT '[]' NOT NULL, -- SQLite (JSON array)
```
**Validation** (`state/tags.go`):
- `validateNodeOwnership()`: Enforces tags XOR user rule
- `validateAndNormalizeTags()`: Validates tag format (`tag:name`) and uniqueness
#### Policy Layer
**Tag Ownership** (policy/v2/policy.go):
```go
func NodeCanHaveTag(node types.NodeView, tag string) bool {
// Checks if node's IP is in the tagOwnerMap IP set
// This is IP-based authorization, not UserID-based
if ips, ok := pm.tagOwnerMap[Tag(tag)]; ok {
if slices.ContainsFunc(node.IPs(), ips.Contains) {
return true
}
}
return false
}
```
**Important**: Tag authorization is based on IP ranges in ACL, not UserID. Tags define identity, ACL authorizes that identity.
### Testing Tags-as-Identity
**Unit Tests** (`hscontrol/types/node_tags_test.go`):
- `TestNodeIsTagged`: Validates IsTagged() for various scenarios
- `TestNodeOwnershipModel`: Tests tags XOR user ownership
- `TestUserTypedID`: Helper method validation
**API Tests** (`hscontrol/grpcv1_test.go`):
- `TestSetTags_UserXORTags`: Validates rejection of setting tags on user-owned nodes
- `TestSetTags_TaggedNode`: Validates that tagged nodes (even with UserID) are not rejected
**Auth Tests** (`hscontrol/auth_test.go:890-928`):
- Tests node registration with tagged PreAuthKey
- Validates tags are applied during registration
### Common Pitfalls
1. **Don't check only `UserID.Valid()` to determine user ownership**
- ❌ Wrong: `if node.UserID().Valid() { /* user-owned */ }`
- ✅ Correct: `if node.UserID().Valid() && !node.IsTagged() { /* user-owned */ }`
2. **Don't assume tagged nodes never have UserID set**
- Tagged nodes MAY have UserID for "created by" tracking
- Always use `IsTagged()` to determine ownership type
3. **Don't allow setting tags on user-owned nodes**
- This violates the tags XOR user principle
- Use API validation to prevent this
4. **Don't forget TaggedDevices in mapper**
- All tagged nodes MUST use `TaggedDevices.ID` in MapResponse
- User ID is only for actual user-owned nodes
### Migration Considerations
When nodes transition between ownership models:
- **No automatic migration**: Tags-as-identity is set at registration and immutable
- **Re-registration required**: To change from user-owned to tagged (or vice versa), node must be deleted and re-registered
- **UserID persistence**: UserID on tagged nodes is informational and not cleared
### Architecture Benefits
The tags-as-identity model provides:
1. **Clear ownership semantics**: No ambiguity about who/what owns a node
2. **ACL simplicity**: Tag-based access control without user conflicts
3. **API safety**: Validation prevents invalid ownership states
4. **Protocol compatibility**: TaggedDevices special user aligns with Tailscale's model
## Logging Patterns
### Incremental Log Event Building
When building log statements with multiple fields, especially with conditional fields, use the **incremental log event pattern** instead of long single-line chains. This improves readability and allows conditional field addition.
**Pattern:**
```go
// GOOD: Incremental building with conditional fields
logEvent := log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString())
if node.User != nil {
logEvent = logEvent.Str("user", node.User.Username())
} else if node.UserID != nil {
logEvent = logEvent.Uint("user_id", *node.UserID)
} else {
logEvent = logEvent.Str("user", "none")
}
logEvent.Msg("Registering node")
```
**Key rules:**
1. **Assign chained calls back to the variable**: `logEvent = logEvent.Str(...)` - zerolog methods return a new event, so you must capture the return value
2. **Use for conditional fields**: When fields depend on runtime conditions, build incrementally
3. **Use for long log lines**: When a log line exceeds ~100 characters, split it for readability
4. **Call `.Msg()` at the end**: The final `.Msg()` or `.Msgf()` sends the log event
**Anti-pattern to avoid:**
```go
// BAD: Long single-line chains are hard to read and can't have conditional fields
log.Debug().Caller().Str("node", node.Hostname).Str("machine_key", node.MachineKey.ShortString()).Str("node_key", node.NodeKey.ShortString()).Str("user", node.User.Username()).Msg("Registering node")
// BAD: Forgetting to assign the return value (field is lost!)
logEvent := log.Debug().Str("node", node.Hostname)
logEvent.Str("user", username) // This field is LOST - not assigned back
logEvent.Msg("message") // Only has "node" field
```
**When to use this pattern:**
- Log statements with 4+ fields
- Any log with conditional fields
- Complex logging in loops or error handling
- When you need to add context incrementally
**Example from codebase** (`hscontrol/db/node.go`):
```go
logEvent := log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString())
if node.User != nil {
logEvent = logEvent.Str("user", node.User.Username())
} else if node.UserID != nil {
logEvent = logEvent.Uint("user_id", *node.UserID)
} else {
logEvent = logEvent.Str("user", "none")
}
logEvent.Msg("Registering test node")
```
### Avoiding Log Helper Functions
Prefer the incremental log event pattern over creating helper functions that return multiple logging closures. Helper functions like `logPollFunc` create unnecessary indirection and allocate closures.
**Instead of:**
```go
// AVOID: Helper function returning closures
func logPollFunc(req tailcfg.MapRequest, node *types.Node) (
func(string, ...any), // warnf
func(string, ...any), // infof
func(string, ...any), // tracef
func(error, string, ...any), // errf
) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("omitPeers", req.OmitPeers).
Bool("stream", req.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
// ... more closures
}
```
**Prefer:**
```go
// BETTER: Build log events inline with shared context
func (m *mapSession) logTrace(msg string) {
log.Trace().
Caller().
Bool("omitPeers", m.req.OmitPeers).
Bool("stream", m.req.Stream).
Uint64("node.id", m.node.ID.Uint64()).
Str("node.name", m.node.Hostname).
Msg(msg)
}
// Or use incremental building for complex cases
logEvent := log.Trace().
Caller().
Bool("omitPeers", m.req.OmitPeers).
Bool("stream", m.req.Stream).
Uint64("node.id", m.node.ID.Uint64()).
Str("node.name", m.node.Hostname)
if additionalContext {
logEvent = logEvent.Str("extra", value)
}
logEvent.Msg("Operation completed")
```
## Important Notes
- **Dependencies**: Use `nix develop` for consistent toolchain (Go, buf, protobuf tools, linting)
@ -697,3 +1032,4 @@ assert.EventuallyWithT(t, func(c *assert.CollectT) {
- **Integration Tests**: Require Docker and can consume significant disk space - use headscale-integration-tester agent
- **Performance**: NodeStore optimizations are critical for scale - be careful with changes to state management
- **Quality Assurance**: Always use appropriate specialized agents for testing and validation tasks
- **Tags-as-Identity**: Tags and user ownership are mutually exclusive - always use `IsTagged()` to determine ownership

View file

@ -21,6 +21,10 @@ at creation time. When listing keys, only the prefix is shown (e.g.,
`hskey-auth-{prefix}-{secret}`. Legacy plaintext keys continue to work for
backwards compatibility.
### Tags
Tags are now implemented following the Tailscale model where tags and user ownership are mutually exclusive. Devices can be either user-owned (authenticated via web/OIDC) or tagged (authenticated via tagged PreAuthKeys). Tagged devices receive their identity from tags rather than users, making them suitable for servers and infrastructure. Applying a tag to a device removes user-based authentication. See the [Tailscale tags documentation](https://tailscale.com/kb/1068/tags) for details on how tags work.
### Database migration support removed for pre-0.25.0 databases
Headscale no longer supports direct upgrades from databases created before
@ -30,6 +34,8 @@ release.
### BREAKING
- **Tags**: The gRPC `SetTags` endpoint now allows converting user-owned nodes to tagged nodes by setting tags. Once a node is tagged, it cannot be converted back to a user-owned node.
- Database migration support removed for pre-0.25.0 databases [#2883](https://github.com/juanfont/headscale/pull/2883)
- If you are running a version older than 0.25.0, you must upgrade to 0.25.1 first, then upgrade to this release
- See the [upgrade path documentation](https://headscale.net/stable/about/faq/#what-is-the-recommended-update-path-can-i-skip-multiple-versions-while-updating) for detailed guidance

View file

@ -233,11 +233,7 @@ func isAuthKey(req tailcfg.RegisterRequest) bool {
}
func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse {
return &tailcfg.RegisterResponse{
// TODO(kradalby): Only send for user-owned nodes
// and not tagged nodes when tags is working.
User: node.UserView().TailscaleUser(),
Login: node.UserView().TailscaleLogin(),
resp := &tailcfg.RegisterResponse{
NodeKeyExpired: node.IsExpired(),
// Headscale does not implement the concept of machine authorization
@ -245,6 +241,18 @@ func nodeToRegisterResponse(node types.NodeView) *tailcfg.RegisterResponse {
// Revisit this if #2176 gets implemented.
MachineAuthorized: true,
}
// For tagged nodes, use the TaggedDevices special user
// For user-owned nodes, include User and Login information from the actual user
if node.IsTagged() {
resp.User = types.TaggedDevices.View().TailscaleUser()
resp.Login = types.TaggedDevices.View().TailscaleLogin()
} else if node.UserView().Valid() {
resp.User = node.UserView().TailscaleUser()
resp.Login = node.UserView().TailscaleLogin()
}
return resp
}
func (h *Headscale) waitForFollowup(

535
hscontrol/auth_tags_test.go Normal file
View file

@ -0,0 +1,535 @@
package hscontrol
import (
"testing"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
// TestTaggedPreAuthKeyCreatesTaggedNode tests that a PreAuthKey with tags creates
// a tagged node with:
// - Tags from the PreAuthKey
// - UserID tracking who created the key (informational "created by")
// - IsTagged() returns true.
func TestTaggedPreAuthKeyCreatesTaggedNode(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server", "tag:prod"}
// Create a tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
require.NotEmpty(t, pak.Tags, "PreAuthKey should have tags")
require.ElementsMatch(t, tags, pak.Tags, "PreAuthKey should have specified tags")
// Register a node using the tagged key
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify the node was created with tags
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
// Critical assertions for tags-as-identity model
assert.True(t, node.IsTagged(), "Node should be tagged")
assert.ElementsMatch(t, tags, node.Tags().AsSlice(), "Node should have tags from PreAuthKey")
assert.True(t, node.UserID().Valid(), "Node should have UserID tracking creator")
assert.Equal(t, user.ID, node.UserID().Get(), "UserID should track PreAuthKey creator")
// Verify node is identified correctly
assert.True(t, node.IsTagged(), "Tagged node is not user-owned")
assert.True(t, node.HasTag("tag:server"), "Node should have tag:server")
assert.True(t, node.HasTag("tag:prod"), "Node should have tag:prod")
assert.False(t, node.HasTag("tag:other"), "Node should not have tag:other")
}
// TestReAuthDoesNotReapplyTags tests that when a node re-authenticates using the
// same PreAuthKey, the tags are NOT re-applied. Tags are only set during initial
// authentication. This is critical for the container restart scenario (#2830).
//
// NOTE: This test verifies that re-authentication preserves the node's current tags
// without testing tag modification via SetNodeTags (which requires ACL policy setup).
func TestReAuthDoesNotReapplyTags(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
initialTags := []string{"tag:server", "tag:dev"}
// Create a tagged PreAuthKey with reusable=true for re-auth
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, initialTags)
require.NoError(t, err)
// Initial registration
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "reauth-test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify initial tags
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
require.True(t, node.IsTagged())
require.ElementsMatch(t, initialTags, node.Tags().AsSlice())
// Re-authenticate with the SAME PreAuthKey (container restart scenario)
// Key behavior: Tags should NOT be re-applied during re-auth
reAuthReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same key
},
NodeKey: nodeKey.Public(), // Same node key
Hostinfo: &tailcfg.Hostinfo{
Hostname: "reauth-test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
reAuthResp, err := app.handleRegisterWithAuthKey(reAuthReq, machineKey.Public())
require.NoError(t, err)
require.True(t, reAuthResp.MachineAuthorized)
// CRITICAL: Tags should remain unchanged after re-auth
// They should match the original tags, proving they weren't re-applied
nodeAfterReauth, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, nodeAfterReauth.IsTagged(), "Node should still be tagged")
assert.ElementsMatch(t, initialTags, nodeAfterReauth.Tags().AsSlice(), "Tags should remain unchanged on re-auth")
// Verify only one node was created (no duplicates)
nodes := app.state.ListNodesByUser(types.UserID(user.ID))
assert.Equal(t, 1, nodes.Len(), "Should have exactly one node")
}
// NOTE: TestSetTagsOnUserOwnedNode functionality is covered by gRPC tests in grpcv1_test.go
// which properly handle ACL policy setup. The test verifies that SetTags can convert
// user-owned nodes to tagged nodes while preserving UserID.
// TestCannotRemoveAllTags tests that attempting to remove all tags from a
// tagged node fails with ErrCannotRemoveAllTags. Once a node is tagged,
// it must always have at least one tag (Tailscale requirement).
func TestCannotRemoveAllTags(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a tagged node
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify node is tagged
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
require.True(t, node.IsTagged())
// Attempt to remove all tags by setting empty array
_, _, err = app.state.SetNodeTags(node.ID(), []string{})
require.Error(t, err, "Should not be able to remove all tags")
require.ErrorIs(t, err, types.ErrCannotRemoveAllTags, "Error should be ErrCannotRemoveAllTags")
// Verify node still has original tags
nodeAfter, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, nodeAfter.IsTagged(), "Node should still be tagged")
assert.ElementsMatch(t, tags, nodeAfter.Tags().AsSlice(), "Tags should be unchanged")
}
// TestUserOwnedNodeCreatedWithUntaggedPreAuthKey tests that using a PreAuthKey
// without tags creates a user-owned node (no tags, UserID is the owner).
func TestUserOwnedNodeCreatedWithUntaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("node-owner")
// Create an untagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
require.Empty(t, pak.Tags, "PreAuthKey should not be tagged")
require.Empty(t, pak.Tags, "PreAuthKey should have no tags")
// Register a node
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "user-owned-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp, err := app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.NoError(t, err)
require.True(t, resp.MachineAuthorized)
// Verify node is user-owned
node, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
// Critical assertions for user-owned node
assert.False(t, node.IsTagged(), "Node should not be tagged")
assert.False(t, node.IsTagged(), "Node should be user-owned (not tagged)")
assert.Empty(t, node.Tags().AsSlice(), "Node should have no tags")
assert.True(t, node.UserID().Valid(), "Node should have UserID")
assert.Equal(t, user.ID, node.UserID().Get(), "UserID should be the PreAuthKey owner")
}
// TestMultipleNodesWithSameReusableTaggedPreAuthKey tests that a reusable
// PreAuthKey with tags can be used to register multiple nodes, and all nodes
// receive the same tags from the key.
func TestMultipleNodesWithSameReusableTaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server", "tag:prod"}
// Create a REUSABLE tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
require.ElementsMatch(t, tags, pak.Tags)
// Register first node
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-1",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
// Register second node with SAME PreAuthKey
machineKey2 := key.NewMachine()
nodeKey2 := key.NewNode()
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same key
},
NodeKey: nodeKey2.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-2",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.NoError(t, err)
require.True(t, resp2.MachineAuthorized)
// Verify both nodes exist and have the same tags
node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public())
require.True(t, found)
// Both nodes should be tagged with the same tags
assert.True(t, node1.IsTagged(), "First node should be tagged")
assert.True(t, node2.IsTagged(), "Second node should be tagged")
assert.ElementsMatch(t, tags, node1.Tags().AsSlice(), "First node should have PreAuthKey tags")
assert.ElementsMatch(t, tags, node2.Tags().AsSlice(), "Second node should have PreAuthKey tags")
// Both nodes should track the same creator
assert.Equal(t, user.ID, node1.UserID().Get(), "First node should track creator")
assert.Equal(t, user.ID, node2.UserID().Get(), "Second node should track creator")
// Verify we have exactly 2 nodes
nodes := app.state.ListNodesByUser(types.UserID(user.ID))
assert.Equal(t, 2, nodes.Len(), "Should have exactly two nodes")
}
// TestNonReusableTaggedPreAuthKey tests that a non-reusable PreAuthKey with tags
// can only be used once. The second attempt should fail.
func TestNonReusableTaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a NON-REUSABLE tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, tags)
require.NoError(t, err)
require.ElementsMatch(t, tags, pak.Tags)
// Register first node - should succeed
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-1",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
// Verify first node was created with tags
node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
assert.True(t, node1.IsTagged())
assert.ElementsMatch(t, tags, node1.Tags().AsSlice())
// Attempt to register second node with SAME non-reusable key - should fail
machineKey2 := key.NewMachine()
nodeKey2 := key.NewNode()
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key, // Same non-reusable key
},
NodeKey: nodeKey2.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node-2",
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.Error(t, err, "Should not be able to reuse non-reusable PreAuthKey")
// Verify only one node was created
nodes := app.state.ListNodesByUser(types.UserID(user.ID))
assert.Equal(t, 1, nodes.Len(), "Should have exactly one node")
}
// TestExpiredTaggedPreAuthKey tests that an expired PreAuthKey with tags
// cannot be used to register a node.
func TestExpiredTaggedPreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a PreAuthKey that expires immediately
expiration := time.Now().Add(-1 * time.Hour) // Already expired
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, &expiration, tags)
require.NoError(t, err)
require.ElementsMatch(t, tags, pak.Tags)
// Attempt to register with expired key
machineKey := key.NewMachine()
nodeKey := key.NewNode()
regReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
_, err = app.handleRegisterWithAuthKey(regReq, machineKey.Public())
require.Error(t, err, "Should not be able to use expired PreAuthKey")
// Verify no node was created
_, found := app.state.GetNodeByNodeKey(nodeKey.Public())
assert.False(t, found, "No node should be created with expired key")
}
// TestSingleVsMultipleTags tests that PreAuthKeys work correctly with both
// a single tag and multiple tags.
func TestSingleVsMultipleTags(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
// Test with single tag
singleTag := []string{"tag:server"}
pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, singleTag)
require.NoError(t, err)
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak1.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "single-tag-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
node1, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
assert.True(t, node1.IsTagged())
assert.ElementsMatch(t, singleTag, node1.Tags().AsSlice())
// Test with multiple tags
multipleTags := []string{"tag:server", "tag:prod", "tag:database"}
pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, multipleTags)
require.NoError(t, err)
machineKey2 := key.NewMachine()
nodeKey2 := key.NewNode()
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak2.Key,
},
NodeKey: nodeKey2.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "multi-tag-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.NoError(t, err)
require.True(t, resp2.MachineAuthorized)
node2, found := app.state.GetNodeByNodeKey(nodeKey2.Public())
require.True(t, found)
assert.True(t, node2.IsTagged())
assert.ElementsMatch(t, multipleTags, node2.Tags().AsSlice())
// Verify HasTag works for all tags
assert.True(t, node2.HasTag("tag:server"))
assert.True(t, node2.HasTag("tag:prod"))
assert.True(t, node2.HasTag("tag:database"))
assert.False(t, node2.HasTag("tag:other"))
}
// TestReAuthWithDifferentMachineKey tests the edge case where a node attempts
// to re-authenticate with the same NodeKey but a DIFFERENT MachineKey.
// This scenario should be handled gracefully (currently creates a new node).
func TestReAuthWithDifferentMachineKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("tag-creator")
tags := []string{"tag:server"}
// Create a reusable tagged PreAuthKey
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
require.NoError(t, err)
// Initial registration
machineKey1 := key.NewMachine()
nodeKey := key.NewNode() // Same NodeKey for both attempts
regReq1 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp1, err := app.handleRegisterWithAuthKey(regReq1, machineKey1.Public())
require.NoError(t, err)
require.True(t, resp1.MachineAuthorized)
// Verify initial node
node1, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, node1.IsTagged())
// Re-authenticate with DIFFERENT MachineKey but SAME NodeKey
machineKey2 := key.NewMachine() // Different machine key
regReq2 := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(), // Same NodeKey
Hostinfo: &tailcfg.Hostinfo{
Hostname: "test-node",
},
Expiry: time.Now().Add(24 * time.Hour),
}
resp2, err := app.handleRegisterWithAuthKey(regReq2, machineKey2.Public())
require.NoError(t, err)
require.True(t, resp2.MachineAuthorized)
// Verify the node still exists and has tags
// Note: Depending on implementation, this might be the same node or a new node
node2, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, node2.IsTagged())
assert.ElementsMatch(t, tags, node2.Tags().AsSlice())
}

View file

@ -70,7 +70,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_valid_new_node",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("preauth-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -111,7 +112,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_reusable_multiple_nodes",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("reusable-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -177,7 +179,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_single_use_exhausted",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("single-use-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
if err != nil {
return "", err
}
@ -264,7 +267,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_ephemeral_node",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("ephemeral-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
if err != nil {
return "", err
}
@ -370,7 +374,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_logout",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("logout-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -429,7 +434,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_machine_key_mismatch",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("mismatch-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -477,7 +483,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_key_extension_not_allowed",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("extend-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -525,7 +532,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_expired_forces_reauth",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("reauth-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -585,7 +593,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "ephemeral_node_logout_deletion",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("ephemeral-logout-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
if err != nil {
return "", err
}
@ -767,7 +776,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "empty_hostname",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("empty-hostname-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -805,7 +815,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "nil_hostinfo",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("nil-hostinfo-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -848,7 +859,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("expired-pak-user")
expiry := time.Now().Add(-1 * time.Hour) // Expired 1 hour ago
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil)
if err != nil {
return "", err
}
@ -880,7 +892,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("tagged-pak-user")
tags := []string{"tag:server", "tag:database"}
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, tags)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, tags)
if err != nil {
return "", err
}
@ -926,7 +939,7 @@ func TestAuthenticationFlows(t *testing.T) {
user := app.state.CreateUserForTest("reauth-user")
// First, register with initial auth key
pak1, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -953,7 +966,7 @@ func TestAuthenticationFlows(t *testing.T) {
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
// Create new auth key for re-authentication
pak2, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak2, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -992,7 +1005,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "existing_node_reauth_interactive_flow",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("interactive-reauth-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1053,7 +1067,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "node_key_rotation_same_machine",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("rotation-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1081,7 +1096,7 @@ func TestAuthenticationFlows(t *testing.T) {
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
// Create new auth key for rotation
pakRotation, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pakRotation, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1129,7 +1144,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "malformed_expiry_zero_time",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("zero-expiry-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1167,7 +1183,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "malformed_hostinfo_invalid_data",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("invalid-hostinfo-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1353,7 +1370,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "preauth_key_usage_count_tracking",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("usage-count-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // Single use
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // Single use
if err != nil {
return "", err
}
@ -1432,7 +1450,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "concurrent_registration_same_node_key",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("concurrent-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1473,7 +1492,8 @@ func TestAuthenticationFlows(t *testing.T) {
user := app.state.CreateUserForTest("future-expiry-user")
// Auth key expires in the future
expiry := time.Now().Add(48 * time.Hour)
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil)
if err != nil {
return "", err
}
@ -1517,7 +1537,7 @@ func TestAuthenticationFlows(t *testing.T) {
user2 := app.state.CreateUserForTest("user2-context")
// Register node with user1's auth key
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1544,7 +1564,7 @@ func TestAuthenticationFlows(t *testing.T) {
}, 1*time.Second, 50*time.Millisecond, "waiting for node to be available in NodeStore")
// Return user2's auth key for re-authentication
pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil)
pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1571,15 +1591,15 @@ func TestAuthenticationFlows(t *testing.T) {
// Verify NEW node was created for user2
node2, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2))
require.True(t, found, "new node should exist for user2")
assert.Equal(t, uint(2), node2.UserID(), "new node should belong to user2")
assert.Equal(t, uint(2), node2.UserID().Get(), "new node should belong to user2")
user := node2.User()
assert.Equal(t, "user2-context", user.Username(), "new node should show user2 username")
assert.Equal(t, "user2-context", user.Name(), "new node should show user2 username")
// Verify original node still exists for user1
node1, found := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1))
require.True(t, found, "original node should still exist for user1")
assert.Equal(t, uint(1), node1.UserID(), "original node should still belong to user1")
assert.Equal(t, uint(1), node1.UserID().Get(), "original node should still belong to user1")
// Verify they are different nodes (different IDs)
assert.NotEqual(t, node1.ID(), node2.ID(), "should be different node IDs")
@ -1595,7 +1615,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
// Create user1 and register a node with auth key
user1 := app.state.CreateUserForTest("interactive-user-1")
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1645,16 +1666,16 @@ func TestAuthenticationFlows(t *testing.T) {
// User1's original node should STILL exist (not transferred)
node1, found1 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(1))
require.True(t, found1, "user1's original node should still exist")
assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1")
assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1")
assert.Equal(t, nodeKey1.Public(), node1.NodeKey(), "user1's node should have original node key")
// User2 should have a NEW node created
node2, found2 := app.state.GetNodeByMachineKey(machineKey1.Public(), types.UserID(2))
require.True(t, found2, "user2 should have new node created")
assert.Equal(t, uint(2), node2.UserID(), "user2's node should belong to user2")
assert.Equal(t, uint(2), node2.UserID().Get(), "user2's node should belong to user2")
user := node2.User()
assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should show correct username")
assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should show correct username")
// Both nodes should have the same machine key but different IDs
assert.NotEqual(t, node1.ID(), node2.ID(), "should be different nodes (different IDs)")
@ -1720,7 +1741,8 @@ func TestAuthenticationFlows(t *testing.T) {
name: "logout_with_exactly_now_expiry",
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
user := app.state.CreateUserForTest("exact-now-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1813,7 +1835,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
// First create a node under user1
user1 := app.state.CreateUserForTest("existing-user-1")
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -1863,7 +1886,7 @@ func TestAuthenticationFlows(t *testing.T) {
// User1's original node with nodeKey1 should STILL exist
node1, found1 := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found1, "user1's original node with nodeKey1 should still exist")
assert.Equal(t, uint(1), node1.UserID(), "user1's node should still belong to user1")
assert.Equal(t, uint(1), node1.UserID().Get(), "user1's node should still belong to user1")
assert.Equal(t, uint64(1), node1.ID().Uint64(), "user1's node should be ID=1")
// User2 should have a NEW node with nodeKey2
@ -1872,7 +1895,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Equal(t, "existing-node-user2", node2.Hostname(), "hostname should be from new registration")
user := node2.User()
assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2")
assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2")
assert.Equal(t, machineKey1.Public(), node2.MachineKey(), "machine key should be the same")
// Verify it's a NEW node, not transferred
@ -2022,7 +2045,8 @@ func TestAuthenticationFlows(t *testing.T) {
setupFunc: func(t *testing.T, app *Headscale) (string, error) {
// Register initial node
user := app.state.CreateUserForTest("rotation-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
if err != nil {
return "", err
}
@ -2072,7 +2096,7 @@ func TestAuthenticationFlows(t *testing.T) {
// User1's original node with nodeKey1 should STILL exist
oldNode, foundOld := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, foundOld, "user1's original node with nodeKey1 should still exist")
assert.Equal(t, uint(1), oldNode.UserID(), "user1's node should still belong to user1")
assert.Equal(t, uint(1), oldNode.UserID().Get(), "user1's node should still belong to user1")
assert.Equal(t, uint64(1), oldNode.ID().Uint64(), "user1's node should be ID=1")
// User2 should have a NEW node with nodeKey2
@ -2082,7 +2106,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.Equal(t, machineKey1.Public(), newNode.MachineKey())
user := newNode.User()
assert.Equal(t, "interactive-test-user", user.Username(), "user2's node should belong to user2")
assert.Equal(t, "interactive-test-user", user.Name(), "user2's node should belong to user2")
// Verify it's a NEW node, not transferred
assert.NotEqual(t, uint64(1), newNode.ID().Uint64(), "should be a NEW node (different ID)")
@ -2333,7 +2357,7 @@ func TestAuthenticationFlows(t *testing.T) {
assert.True(t, found, "node should be registered")
if found {
assert.Equal(t, "pending-node-2", node.Hostname())
assert.Equal(t, "second-registration-user", node.User().Name)
assert.Equal(t, "second-registration-user", node.User().Name())
}
// First registration should still be in cache (not completed)
@ -2593,7 +2617,7 @@ func TestNodeStoreLookup(t *testing.T) {
nodeKey := key.NewNode()
user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
// Register a node
@ -2642,9 +2666,9 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
user2 := app.state.CreateUserForTest("user2")
// Create pre-auth keys for both users
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
require.NoError(t, err)
pak2, err := app.state.CreatePreAuthKey(types.UserID(user2.ID), true, false, nil, nil)
pak2, err := app.state.CreatePreAuthKey(user2.TypedID(), true, false, nil, nil)
require.NoError(t, err)
// Create machine and node keys for 4 nodes (2 per user)
@ -2720,7 +2744,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
t.Logf("All nodes logged out")
// Create a new pre-auth key for user1 (reusable for all nodes)
newPak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
newPak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
require.NoError(t, err)
// Re-login all nodes using user1's new pre-auth key
@ -2765,7 +2789,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// User1's original nodes should still be owned by user1
registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID))
require.True(t, found, "User1's original node %s should still exist", node.hostname)
require.Equal(t, user1.ID, registeredNode.UserID(), "Node %s should still belong to user1", node.hostname)
require.Equal(t, user1.ID, registeredNode.UserID().Get(), "Node %s should still belong to user1", node.hostname)
t.Logf("✓ User1's original node %s (ID=%d) still owned by user1", node.hostname, registeredNode.ID().Uint64())
}
@ -2774,7 +2798,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// User2's original nodes should still be owned by user2
registeredNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user2.ID))
require.True(t, found, "User2's original node %s should still exist", node.hostname)
require.Equal(t, user2.ID, registeredNode.UserID(), "Node %s should still belong to user2", node.hostname)
require.Equal(t, user2.ID, registeredNode.UserID().Get(), "Node %s should still belong to user2", node.hostname)
t.Logf("✓ User2's original node %s (ID=%d) still owned by user2", node.hostname, registeredNode.ID().Uint64())
}
@ -2785,7 +2809,7 @@ func TestPreAuthKeyLogoutAndReloginDifferentUser(t *testing.T) {
// Should be able to find a node with user1 and this machine key (the new one)
newNode, found := app.state.GetNodeByMachineKey(node.machineKey.Public(), types.UserID(user1.ID))
require.True(t, found, "Should have created new node for user1 with machine key from %s", node.hostname)
require.Equal(t, user1.ID, newNode.UserID(), "New node should belong to user1")
require.Equal(t, user1.ID, newNode.UserID().Get(), "New node should belong to user1")
t.Logf("✓ New node created for user1 with machine key from %s (ID=%d)", node.hostname, newNode.ID().Uint64())
}
}
@ -2813,7 +2837,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
// Step 1: Register node for user1 via pre-auth key (simulating initial web flow registration)
user1 := app.state.CreateUserForTest("user1")
pak1, err := app.state.CreatePreAuthKey(types.UserID(user1.ID), true, false, nil, nil)
pak1, err := app.state.CreatePreAuthKey(user1.TypedID(), true, false, nil, nil)
require.NoError(t, err)
regReq1 := tailcfg.RegisterRequest{
@ -2834,7 +2858,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
// Verify node exists for user1
user1Node, found := app.state.GetNodeByMachineKey(machineKey.Public(), types.UserID(user1.ID))
require.True(t, found, "Node should exist for user1")
require.Equal(t, user1.ID, user1Node.UserID(), "Node should belong to user1")
require.Equal(t, user1.ID, user1Node.UserID().Get(), "Node should belong to user1")
user1NodeID := user1Node.ID()
t.Logf("✓ User1 node created with ID: %d", user1NodeID)
@ -2896,7 +2920,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
t.Fatal("User1's node was transferred or deleted - this breaks the integration test!")
}
assert.Equal(t, user1.ID, user1NodeAfter.UserID(), "User1's node should still belong to user1")
assert.Equal(t, user1.ID, user1NodeAfter.UserID().Get(), "User1's node should still belong to user1")
assert.Equal(t, user1NodeID, user1NodeAfter.ID(), "Should be the same node (same ID)")
assert.True(t, user1NodeAfter.IsExpired(), "User1's node should still be expired")
t.Logf("✓ User1's original node still exists (ID: %d, expired: %v)", user1NodeAfter.ID(), user1NodeAfter.IsExpired())
@ -2911,7 +2935,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
t.Fatal("User2 doesn't have a node - registration failed!")
}
assert.Equal(t, user2.ID, user2Node.UserID(), "User2's node should belong to user2")
assert.Equal(t, user2.ID, user2Node.UserID().Get(), "User2's node should belong to user2")
assert.NotEqual(t, user1NodeID, user2Node.ID(), "Should be a NEW node (different ID), not transfer!")
assert.Equal(t, machineKey.Public(), user2Node.MachineKey(), "Should have same machine key")
assert.Equal(t, nodeKey2.Public(), user2Node.NodeKey(), "Should have new node key")
@ -2921,7 +2945,7 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
t.Run("returned_node_is_user2_new_node", func(t *testing.T) {
// The node returned from HandleNodeFromAuthPath should be user2's NEW node
assert.Equal(t, user2.ID, node.UserID(), "Returned node should belong to user2")
assert.Equal(t, user2.ID, node.UserID().Get(), "Returned node should belong to user2")
assert.NotEqual(t, user1NodeID, node.ID(), "Returned node should be NEW, not transferred from user1")
t.Logf("✓ HandleNodeFromAuthPath returned user2's new node (ID: %d)", node.ID())
})
@ -2949,10 +2973,11 @@ func TestWebFlowReauthDifferentUser(t *testing.T) {
user2Nodes := 0
for i := 0; i < allNodesSlice.Len(); i++ {
n := allNodesSlice.At(i)
if n.UserID() == user1.ID {
if n.UserID().Get() == user1.ID {
user1Nodes++
}
if n.UserID() == user2.ID {
if n.UserID().Get() == user2.ID {
user2Nodes++
}
}
@ -3026,7 +3051,7 @@ func TestGitHubIssue2830_NodeRestartWithUsedPreAuthKey(t *testing.T) {
// Create user and single-use pre-auth key
user := app.state.CreateUserForTest("test-user")
pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil) // reusable=false
pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil) // reusable=false
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable field
@ -3117,7 +3142,7 @@ func TestNodeReregistrationWithReusablePreAuthKey(t *testing.T) {
app := createTestApp(t)
user := app.state.CreateUserForTest("test-user")
pakNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil) // reusable=true
pakNew, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, nil, nil) // reusable=true
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable field
@ -3173,7 +3198,7 @@ func TestNodeReregistrationWithExpiredPreAuthKey(t *testing.T) {
user := app.state.CreateUserForTest("test-user")
expiry := time.Now().Add(-1 * time.Hour) // Already expired
pak, err := app.state.CreatePreAuthKey(types.UserID(user.ID), true, false, &expiry, nil)
pak, err := app.state.CreatePreAuthKey(user.TypedID(), true, false, &expiry, nil)
require.NoError(t, err)
machineKey := key.NewMachine()
@ -3306,7 +3331,7 @@ func TestGitHubIssue2830_ExistingNodeCanReregisterWithUsedPreAuthKey(t *testing.
// Create a SINGLE-USE pre-auth key (reusable=false)
// This is the type of key that triggers the bug in issue #2830
preAuthKeyNew, err := app.state.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
preAuthKeyNew, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
// Fetch the full pre-auth key to check Reusable and Used fields

View file

@ -577,6 +577,21 @@ AND auth_key_id NOT IN (
},
Rollback: func(db *gorm.DB) error { return nil },
},
{
// Rename forced_tags column to tags in nodes table.
// This must run after migration 202505141324 which creates tables with forced_tags.
ID: "202511131445-node-forced-tags-to-tags",
Migrate: func(tx *gorm.DB) error {
// Rename the column from forced_tags to tags
err := tx.Migrator().RenameColumn(&types.Node{}, "forced_tags", "tags")
if err != nil {
return fmt.Errorf("renaming forced_tags to tags: %w", err)
}
return nil
},
Rollback: func(db *gorm.DB) error { return nil },
},
},
)

View file

@ -231,8 +231,7 @@ func TestPostgresMigrationAndDataValidation(t *testing.T) {
name string
dbPath string
wantFunc func(*testing.T, *HSDatabase)
}{
}
}{}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View file

@ -95,7 +95,7 @@ func TestIPAllocatorSequential(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
@ -123,7 +123,7 @@ func TestIPAllocatorSequential(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.2"),
IPv6: nap("fd7a:115c:a1e0::2"),
})
@ -309,7 +309,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
})
@ -334,7 +334,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv6: nap("fd7a:115c:a1e0::1"),
})
@ -359,7 +359,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
@ -383,7 +383,7 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
IPv6: nap("fd7a:115c:a1e0::1"),
})
@ -407,19 +407,19 @@ func TestBackfillIPAddresses(t *testing.T) {
db.DB.Save(&user)
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.1"),
})
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.2"),
})
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.3"),
})
db.DB.Save(&types.Node{
User: user,
User: &user,
IPv4: nap("100.64.0.4"),
})

View file

@ -196,8 +196,9 @@ func SetTags(
tags []string,
) error {
if len(tags) == 0 {
// if no tags are provided, we remove all forced tags
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", "[]").Error; err != nil {
// if no tags are provided, we remove all tags
err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", "[]").Error
if err != nil {
return fmt.Errorf("removing tags: %w", err)
}
@ -211,7 +212,8 @@ func SetTags(
return err
}
if err := tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("forced_tags", string(b)).Error; err != nil {
err = tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("tags", string(b)).Error
if err != nil {
return fmt.Errorf("updating tags: %w", err)
}
@ -349,12 +351,20 @@ func RegisterNodeForTest(tx *gorm.DB, node types.Node, ipv4 *netip.Addr, ipv6 *n
panic("RegisterNodeForTest can only be called during tests")
}
log.Debug().
logEvent := log.Debug().
Str("node", node.Hostname).
Str("machine_key", node.MachineKey.ShortString()).
Str("node_key", node.NodeKey.ShortString()).
Str("user", node.User.Username()).
Msg("Registering test node")
Str("node_key", node.NodeKey.ShortString())
if node.User != nil {
logEvent = logEvent.Str("user", node.User.Username())
} else if node.UserID != nil {
logEvent = logEvent.Uint("user_id", *node.UserID)
} else {
logEvent = logEvent.Str("user", "none")
}
logEvent.Msg("Registering test node")
// If the a new node is registered with the same machine key, to the same user,
// update the existing node.
@ -642,7 +652,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
}
// Create a preauth key for the node
pak, err := hsdb.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := hsdb.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
if err != nil {
panic(fmt.Sprintf("failed to create preauth key for test node: %v", err))
}
@ -656,7 +666,7 @@ func (hsdb *HSDatabase) CreateNodeForTest(user *types.User, hostname ...string)
NodeKey: nodeKey.Public(),
DiscoKey: discoKey.Public(),
Hostname: nodeName,
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}

View file

@ -83,7 +83,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode")
@ -97,7 +97,7 @@ func (s *Suite) TestExpireNode(c *check.C) {
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
Expiry: &time.Time{},
@ -124,7 +124,7 @@ func (s *Suite) TestSetTags(c *check.C) {
user, err := db.CreateUser(types.User{Name: "test"})
c.Assert(err, check.IsNil)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
c.Assert(err, check.IsNil)
_, err = db.getNode(types.UserID(user.ID), "testnode")
@ -138,7 +138,7 @@ func (s *Suite) TestSetTags(c *check.C) {
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
@ -152,7 +152,7 @@ func (s *Suite) TestSetTags(c *check.C) {
c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, sTags)
c.Assert(node.Tags, check.DeepEquals, sTags)
// assign duplicate tags, expect no errors but no doubles in DB
eTags := []string{"tag:bar", "tag:test", "tag:unknown", "tag:test"}
@ -161,17 +161,10 @@ func (s *Suite) TestSetTags(c *check.C) {
node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(
node.ForcedTags,
node.Tags,
check.DeepEquals,
[]string{"tag:bar", "tag:test", "tag:unknown"},
)
// test removing tags
err = db.SetTags(node.ID, []string{})
c.Assert(err, check.IsNil)
node, err = db.getNode(types.UserID(user.ID), "testnode")
c.Assert(err, check.IsNil)
c.Assert(node.ForcedTags, check.DeepEquals, []string{})
}
func TestHeadscale_generateGivenName(t *testing.T) {
@ -430,7 +423,7 @@ func TestAutoApproveRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.routes,
@ -446,12 +439,12 @@ func TestAutoApproveRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "taggednode",
UserID: taggedUser.ID,
UserID: &taggedUser.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.routes,
},
ForcedTags: []string{"tag:exit"},
Tags: []string{"tag:exit"},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
@ -593,10 +586,10 @@ func TestListEphemeralNodes(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
pakEph, err := db.CreatePreAuthKey(types.UserID(user.ID), false, true, nil, nil)
pakEph, err := db.CreatePreAuthKey(user.TypedID(), false, true, nil, nil)
require.NoError(t, err)
node := types.Node{
@ -604,7 +597,7 @@ func TestListEphemeralNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}
@ -614,7 +607,7 @@ func TestListEphemeralNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "ephemeral",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pakEph.ID),
}
@ -657,7 +650,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@ -667,7 +660,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@ -680,7 +673,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "我的电脑",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
@ -688,7 +681,7 @@ func TestNodeNaming(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "a",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
}
@ -808,7 +801,7 @@ func TestRenameNodeComprehensive(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@ -931,7 +924,7 @@ func TestListPeers(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test1",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@ -941,7 +934,7 @@ func TestListPeers(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test2",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@ -1016,7 +1009,7 @@ func TestListNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test1",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}
@ -1026,7 +1019,7 @@ func TestListNodes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test2",
UserID: user2.ID,
UserID: &user2.ID,
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{},
}

View file

@ -15,15 +15,15 @@ import (
)
var (
ErrPreAuthKeyNotFound = errors.New("AuthKey not found")
ErrPreAuthKeyExpired = errors.New("AuthKey expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("AuthKey has already been used")
ErrPreAuthKeyNotFound = errors.New("auth-key not found")
ErrPreAuthKeyExpired = errors.New("auth-key expired")
ErrSingleUseAuthKeyHasBeenUsed = errors.New("auth-key has already been used")
ErrUserMismatch = errors.New("user mismatch")
ErrPreAuthKeyACLTagInvalid = errors.New("AuthKey tag is invalid")
ErrPreAuthKeyACLTagInvalid = errors.New("auth-key tag is invalid")
)
func (hsdb *HSDatabase) CreatePreAuthKey(
uid types.UserID,
uid *types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
@ -41,17 +41,40 @@ const (
)
// CreatePreAuthKey creates a new PreAuthKey in a user, and returns it.
// The uid parameter can be nil for system-created tagged keys.
// For tagged keys, uid tracks "created by" (who created the key).
// For user-owned keys, uid tracks the node owner.
func CreatePreAuthKey(
tx *gorm.DB,
uid types.UserID,
uid *types.UserID,
reusable bool,
ephemeral bool,
expiration *time.Time,
aclTags []string,
) (*types.PreAuthKeyNew, error) {
user, err := GetUserByID(tx, uid)
if err != nil {
return nil, err
// Validate: must be tagged OR user-owned, not neither
if uid == nil && len(aclTags) == 0 {
return nil, ErrPreAuthKeyNotTaggedOrOwned
}
// If uid != nil && len(aclTags) > 0:
// Both are allowed: UserID tracks "created by", tags define node ownership
// This is valid per the new model
var (
user *types.User
userID *uint
)
if uid != nil {
var err error
user, err = GetUserByID(tx, *uid)
if err != nil {
return nil, err
}
userID = &user.ID
}
// Remove duplicates and sort for consistency
@ -108,15 +131,15 @@ func CreatePreAuthKey(
}
key := types.PreAuthKey{
UserID: user.ID,
User: *user,
UserID: userID, // nil for system-created keys, or "created by" for tagged keys
User: user, // nil for system-created keys
Reusable: reusable,
Ephemeral: ephemeral,
CreatedAt: &now,
Expiration: expiration,
Tags: aclTags,
Prefix: prefix, // Store prefix
Hash: hash, // Store hash
Tags: aclTags, // empty for user-owned keys
Prefix: prefix, // Store prefix
Hash: hash, // Store hash
}
if err := tx.Save(&key).Error; err != nil {
@ -149,14 +172,19 @@ func ListPreAuthKeysByUser(tx *gorm.DB, uid types.UserID) ([]types.PreAuthKey, e
}
keys := []types.PreAuthKey{}
if err := tx.Preload("User").Where(&types.PreAuthKey{UserID: user.ID}).Find(&keys).Error; err != nil {
err = tx.Preload("User").Where(&types.PreAuthKey{UserID: &user.ID}).Find(&keys).Error
if err != nil {
return nil, err
}
return keys, nil
}
var ErrPreAuthKeyFailedToParse = errors.New("failed to parse AuthKey")
var (
ErrPreAuthKeyFailedToParse = errors.New("failed to parse auth-key")
ErrPreAuthKeyNotTaggedOrOwned = errors.New("auth-key must be either tagged or owned by user")
)
func findAuthKey(tx *gorm.DB, keyStr string) (*types.PreAuthKey, error) {
var pak types.PreAuthKey

View file

@ -24,7 +24,7 @@ func TestCreatePreAuthKey(t *testing.T) {
test: func(t *testing.T, db *HSDatabase) {
t.Helper()
_, err := db.CreatePreAuthKey(12345, true, false, nil, nil)
_, err := db.CreatePreAuthKey(ptr.To(types.UserID(12345)), true, false, nil, nil)
assert.Error(t, err)
},
},
@ -36,7 +36,7 @@ func TestCreatePreAuthKey(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
key, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
key, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key.Key)
@ -83,7 +83,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test-tags-1"})
require.NoError(t, err)
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"badtag"})
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"badtag"})
assert.Error(t, err)
},
},
@ -98,7 +98,7 @@ func TestPreAuthKeyACLTags(t *testing.T) {
expectedTags := []string{"tag:test1", "tag:test2"}
tagsWithDuplicate := []string{"tag:test1", "tag:test2", "tag:test2"}
_, err = db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, tagsWithDuplicate)
_, err = db.CreatePreAuthKey(user.TypedID(), false, false, nil, tagsWithDuplicate)
require.NoError(t, err)
listedPaks, err := db.ListPreAuthKeys(types.UserID(user.ID))
@ -128,13 +128,13 @@ func TestCannotDeleteAssignedPreAuthKey(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test8"})
require.NoError(t, err)
key, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, []string{"tag:good"})
key, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:good"})
require.NoError(t, err)
node := types.Node{
ID: 0,
Hostname: "testest",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(key.ID),
}
@ -180,7 +180,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
validateResult: func(t *testing.T, pak *types.PreAuthKey) {
t.Helper()
assert.Equal(t, user.ID, pak.UserID)
assert.Equal(t, user.ID, *pak.UserID)
assert.NotEmpty(t, pak.Key) // Legacy keys have Key populated
assert.Empty(t, pak.Prefix) // Legacy keys have empty Prefix
assert.Nil(t, pak.Hash) // Legacy keys have nil Hash
@ -191,7 +191,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
setupKey: func() string {
// Create new key via API
keyStr, err := db.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
true, false, nil, []string{"tag:test"},
)
require.NoError(t, err)
@ -203,7 +203,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
validateResult: func(t *testing.T, pak *types.PreAuthKey) {
t.Helper()
assert.Equal(t, user.ID, pak.UserID)
assert.Equal(t, user.ID, *pak.UserID)
assert.Empty(t, pak.Key) // New keys have empty Key
assert.NotEmpty(t, pak.Prefix) // New keys have Prefix
assert.NotNil(t, pak.Hash) // New keys have Hash
@ -214,7 +214,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
name: "new_key_format_validation",
setupKey: func() string {
keyStr, err := db.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
true, false, nil, nil,
)
require.NoError(t, err)
@ -244,7 +244,7 @@ func TestPreAuthKeyAuthentication(t *testing.T) {
setupKey: func() string {
// Create valid key
key, err := db.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
true, false, nil, nil,
)
require.NoError(t, err)
@ -415,11 +415,11 @@ func TestMultipleLegacyKeysAllowed(t *testing.T) {
assert.Len(t, legacyKeys, 5, "should have created 5 legacy keys")
// Now create new bcrypt-based keys - these should have unique prefixes
key1, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
key1, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key1.Key)
key2, err := db.CreatePreAuthKey(types.UserID(user.ID), true, false, nil, nil)
key2, err := db.CreatePreAuthKey(user.TypedID(), true, false, nil, nil)
require.NoError(t, err)
assert.NotEmpty(t, key2.Key)

View file

@ -81,7 +81,7 @@ CREATE TABLE nodes(
given_name varchar(63),
user_id integer,
register_method text,
forced_tags text,
tags text,
auth_key_id integer,
last_seen datetime,
expiry datetime,

View file

@ -189,7 +189,11 @@ func (hsdb *HSDatabase) GetUserByName(name string) (*types.User, error) {
// ListNodesByUser gets all the nodes in a given user.
func ListNodesByUser(tx *gorm.DB, uid types.UserID) (types.Nodes, error) {
nodes := types.Nodes{}
if err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: uint(uid)}).Find(&nodes).Error; err != nil {
uidPtr := uint(uid)
err := tx.Preload("AuthKey").Preload("AuthKey.User").Preload("User").Where(&types.Node{UserID: &uidPtr}).Find(&nodes).Error
if err != nil {
return nil, err
}

View file

@ -50,7 +50,7 @@ func TestDestroyUserErrors(t *testing.T) {
user := db.CreateUserForTest("test")
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
err = db.DestroyUser(types.UserID(user.ID))
@ -71,13 +71,13 @@ func TestDestroyUserErrors(t *testing.T) {
user, err := db.CreateUser(types.User{Name: "test"})
require.NoError(t, err)
pak, err := db.CreatePreAuthKey(types.UserID(user.ID), false, false, nil, nil)
pak, err := db.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
node := types.Node{
ID: 0,
Hostname: "testnode",
UserID: user.ID,
UserID: &user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: ptr.To(pak.ID),
}

View file

@ -172,7 +172,7 @@ func (api headscaleV1APIServer) CreatePreAuthKey(
}
preAuthKey, err := api.h.state.CreatePreAuthKey(
types.UserID(user.ID),
user.TypedID(),
request.GetReusable(),
request.GetEphemeral(),
&expiration,
@ -341,6 +341,17 @@ func (api headscaleV1APIServer) SetTags(
ctx context.Context,
request *v1.SetTagsRequest,
) (*v1.SetTagsResponse, error) {
// Validate tags not empty - tagged nodes must have at least one tag
if len(request.GetTags()) == 0 {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(
codes.InvalidArgument,
"cannot remove all tags from a node - tagged nodes must have at least one tag",
)
}
// Validate tag format
for _, tag := range request.GetTags() {
err := validateTag(tag)
if err != nil {
@ -348,6 +359,16 @@ func (api headscaleV1APIServer) SetTags(
}
}
// User XOR Tags: nodes are either tagged or user-owned, never both.
// Setting tags on a user-owned node converts it to a tagged node.
// Once tagged, a node cannot be converted back to user-owned.
_, found := api.h.state.GetNodeByID(types.NodeID(request.GetNodeId()))
if !found {
return &v1.SetTagsResponse{
Node: nil,
}, status.Error(codes.NotFound, "node not found")
}
node, nodeChange, err := api.h.state.SetNodeTags(types.NodeID(request.GetNodeId()), request.GetTags())
if err != nil {
return &v1.SetTagsResponse{
@ -529,13 +550,19 @@ func nodesToProto(state *state.State, nodes views.Slice[types.NodeView]) []*v1.N
for index, node := range nodes.All() {
resp := node.Proto()
// Tags-as-identity: tagged nodes show as TaggedDevices user in API responses
// (UserID may be set internally for "created by" tracking)
if node.IsTagged() {
resp.User = types.TaggedDevices.Proto()
}
var tags []string
for _, tag := range node.RequestTags() {
if state.NodeCanHaveTag(node, tag) {
tags = append(tags, tag)
}
}
resp.ValidTags = lo.Uniq(append(tags, node.ForcedTags().AsSlice()...))
resp.ValidTags = lo.Uniq(append(tags, node.Tags().AsSlice()...))
resp.SubnetRoutes = util.PrefixesToString(append(state.GetNodePrimaryRoutes(node.ID()), node.ExitRoutes()...))
response[index] = resp
@ -780,7 +807,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
NodeKey: key.NewNode().Public(),
MachineKey: key.NewMachine().Public(),
Hostname: request.GetName(),
User: *user,
User: user,
Expiry: &time.Time{},
LastSeen: &time.Time{},

View file

@ -1,6 +1,17 @@
package hscontrol
import "testing"
import (
"context"
"testing"
v1 "github.com/juanfont/headscale/gen/go/headscale/v1"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
func Test_validateTag(t *testing.T) {
type args struct {
@ -40,3 +51,212 @@ func Test_validateTag(t *testing.T) {
})
}
}
// TestSetTags_Conversion tests the conversion of user-owned nodes to tagged nodes.
// The tags-as-identity model allows one-way conversion from user-owned to tagged.
// Tag authorization is checked via the policy manager - unauthorized tags are rejected.
func TestSetTags_Conversion(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create test user and nodes
user := app.state.CreateUserForTest("test-user")
// Create a pre-auth key WITHOUT tags for user-owned node
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, nil)
require.NoError(t, err)
machineKey1 := key.NewMachine()
nodeKey1 := key.NewNode()
// Register a user-owned node (via untagged PreAuthKey)
userOwnedReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey1.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "user-owned-node",
},
}
_, err = app.handleRegisterWithAuthKey(userOwnedReq, machineKey1.Public())
require.NoError(t, err)
// Get the created node
userOwnedNode, found := app.state.GetNodeByNodeKey(nodeKey1.Public())
require.True(t, found)
// Create API server instance
apiServer := newHeadscaleV1APIServer(app)
tests := []struct {
name string
nodeID uint64
tags []string
wantErr bool
wantCode codes.Code
wantErrMessage string
}{
{
// Conversion is allowed, but tag authorization fails without tagOwners
name: "reject unauthorized tags on user-owned node",
nodeID: uint64(userOwnedNode.ID()),
tags: []string{"tag:server"},
wantErr: true,
wantCode: codes.InvalidArgument,
wantErrMessage: "invalid or unauthorized tags",
},
{
// Conversion is allowed, but tag authorization fails without tagOwners
name: "reject multiple unauthorized tags",
nodeID: uint64(userOwnedNode.ID()),
tags: []string{"tag:server", "tag:database"},
wantErr: true,
wantCode: codes.InvalidArgument,
wantErrMessage: "invalid or unauthorized tags",
},
{
name: "reject non-existent node",
nodeID: 99999,
tags: []string{"tag:server"},
wantErr: true,
wantCode: codes.NotFound,
wantErrMessage: "node not found",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{
NodeId: tt.nodeID,
Tags: tt.tags,
})
if tt.wantErr {
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok, "error should be a gRPC status error")
assert.Equal(t, tt.wantCode, st.Code())
assert.Contains(t, st.Message(), tt.wantErrMessage)
assert.Nil(t, resp.GetNode())
} else {
require.NoError(t, err)
assert.NotNil(t, resp)
assert.NotNil(t, resp.GetNode())
}
})
}
}
// TestSetTags_TaggedNode tests that SetTags correctly identifies tagged nodes
// and doesn't reject them with the "user-owned nodes" error.
// Note: This test doesn't validate ACL tag authorization - that's tested elsewhere.
func TestSetTags_TaggedNode(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create test user and tagged pre-auth key
user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:initial"})
require.NoError(t, err)
machineKey := key.NewMachine()
nodeKey := key.NewNode()
// Register a tagged node (via tagged PreAuthKey)
taggedReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
}
_, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public())
require.NoError(t, err)
// Get the created node
taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, taggedNode.IsTagged(), "Node should be tagged")
assert.True(t, taggedNode.UserID().Valid(), "Tagged node should have UserID for tracking")
// Create API server instance
apiServer := newHeadscaleV1APIServer(app)
// Test: SetTags should NOT reject tagged nodes with "user-owned" error
// (Even though they have UserID set, IsTagged() identifies them correctly)
resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{
NodeId: uint64(taggedNode.ID()),
Tags: []string{"tag:initial"}, // Keep existing tag to avoid ACL validation issues
})
// The call should NOT fail with "cannot set tags on user-owned nodes"
if err != nil {
st, ok := status.FromError(err)
require.True(t, ok)
// If error is about unauthorized tags, that's fine - ACL validation is working
// If error is about user-owned nodes, that's the bug we're testing for
assert.NotContains(t, st.Message(), "user-owned nodes", "Should not reject tagged nodes as user-owned")
} else {
// Success is also fine
assert.NotNil(t, resp)
}
}
// TestSetTags_CannotRemoveAllTags tests that SetTags rejects attempts to remove
// all tags from a tagged node, enforcing Tailscale's requirement that tagged
// nodes must have at least one tag.
func TestSetTags_CannotRemoveAllTags(t *testing.T) {
t.Parallel()
app := createTestApp(t)
// Create test user and tagged pre-auth key
user := app.state.CreateUserForTest("test-user")
pak, err := app.state.CreatePreAuthKey(user.TypedID(), false, false, nil, []string{"tag:server"})
require.NoError(t, err)
machineKey := key.NewMachine()
nodeKey := key.NewNode()
// Register a tagged node
taggedReq := tailcfg.RegisterRequest{
Auth: &tailcfg.RegisterResponseAuth{
AuthKey: pak.Key,
},
NodeKey: nodeKey.Public(),
Hostinfo: &tailcfg.Hostinfo{
Hostname: "tagged-node",
},
}
_, err = app.handleRegisterWithAuthKey(taggedReq, machineKey.Public())
require.NoError(t, err)
// Get the created node
taggedNode, found := app.state.GetNodeByNodeKey(nodeKey.Public())
require.True(t, found)
assert.True(t, taggedNode.IsTagged())
// Create API server instance
apiServer := newHeadscaleV1APIServer(app)
// Attempt to remove all tags (empty array)
resp, err := apiServer.SetTags(context.Background(), &v1.SetTagsRequest{
NodeId: uint64(taggedNode.ID()),
Tags: []string{}, // Empty - attempting to remove all tags
})
// Should fail with InvalidArgument error
require.Error(t, err)
st, ok := status.FromError(err)
require.True(t, ok, "error should be a gRPC status error")
assert.Equal(t, codes.InvalidArgument, st.Code())
assert.Contains(t, st.Message(), "cannot remove all tags")
assert.Nil(t, resp.GetNode())
}

View file

@ -73,15 +73,17 @@ func generateUserProfiles(
node types.NodeView,
peers views.Slice[types.NodeView],
) []tailcfg.UserProfile {
userMap := make(map[uint]*types.User)
userMap := make(map[uint]*types.UserView)
ids := make([]uint, 0, len(userMap))
user := node.User()
userMap[user.ID] = &user
ids = append(ids, user.ID)
userID := user.Model().ID
userMap[userID] = &user
ids = append(ids, userID)
for _, peer := range peers.All() {
peerUser := peer.User()
userMap[peerUser.ID] = &peerUser
ids = append(ids, peerUser.ID)
peerUserID := peerUser.Model().ID
userMap[peerUserID] = &peerUser
ids = append(ids, peerUserID)
}
slices.Sort(ids)

View file

@ -14,6 +14,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"tailscale.com/tailcfg"
"tailscale.com/types/dnstype"
"tailscale.com/types/ptr"
)
var iap = func(ipStr string) *netip.Addr {
@ -50,8 +51,8 @@ func TestDNSConfigMapResponse(t *testing.T) {
mach := func(hostname, username string, userid uint) *types.Node {
return &types.Node{
Hostname: hostname,
UserID: userid,
User: types.User{
UserID: ptr.To(userid),
User: &types.User{
Name: username,
},
}

View file

@ -83,7 +83,8 @@ func tailNode(
tags = append(tags, tag)
}
}
for _, tag := range node.ForcedTags().All() {
for _, tag := range node.Tags().All() {
tags = append(tags, tag)
}
tags = lo.Uniq(tags)
@ -99,7 +100,7 @@ func tailNode(
Name: hostname,
Cap: capVer,
User: tailcfg.UserID(node.UserID()),
User: node.TailscaleUserID(),
Key: node.NodeKey(),
KeyExpiry: keyExpiry.UTC(),

View file

@ -15,6 +15,7 @@ import (
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestTailNode(t *testing.T) {
@ -97,14 +98,14 @@ func TestTailNode(t *testing.T) {
IPv4: iap("100.64.0.1"),
Hostname: "mini",
GivenName: "mini",
UserID: 0,
User: types.User{
UserID: ptr.To(uint(0)),
User: &types.User{
Name: "mini",
},
ForcedTags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Tags: []string{},
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
tsaddr.AllIPv4(),

View file

@ -1,13 +1,10 @@
package hscontrol
import (
"os"
"path/filepath"
"testing"
"github.com/juanfont/headscale/hscontrol/templates"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestOIDCCallbackTemplate(t *testing.T) {
@ -49,15 +46,6 @@ func TestOIDCCallbackTemplate(t *testing.T) {
assert.Contains(t, html, "<svg")
assert.Contains(t, html, "class=\"headscale-logo\"")
assert.Contains(t, html, "id=\"checkbox\"")
// Save the output for manual inspection
testDataDir := filepath.Join("testdata", "oidc_templates")
err := os.MkdirAll(testDataDir, 0o755)
require.NoError(t, err)
outputFile := filepath.Join(testDataDir, tt.name+".html")
err = os.WriteFile(outputFile, []byte(html), 0o600)
require.NoError(t, err)
})
}
}

View file

@ -32,11 +32,11 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "test-node",
UserID: user1.ID,
User: user1,
UserID: ptr.To(user1.ID),
User: ptr.To(user1),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ForcedTags: []string{"tag:test"},
Tags: []string{"tag:test"},
}
node2 := &types.Node{
@ -44,8 +44,8 @@ func TestApproveRoutesWithPolicy_NeverRemovesApprovedRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "other-node",
UserID: user2.ID,
User: user2,
UserID: ptr.To(user2.ID),
User: ptr.To(user2),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.2")),
}
@ -304,8 +304,8 @@ func TestApproveRoutesWithPolicy_NilAndEmptyCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,

View file

@ -168,15 +168,15 @@ func TestApproveRoutesWithPolicy_NeverRemovesRoutes(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: tt.nodeHostname,
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
},
IPv4: ptr.To(netip.MustParseAddr("100.64.0.1")),
ApprovedRoutes: tt.currentApproved,
ForcedTags: tt.nodeTags,
Tags: tt.nodeTags,
}
nodes := types.Nodes{&node}
@ -294,8 +294,8 @@ func TestApproveRoutesWithPolicy_EdgeCases(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tt.announcedRoutes,
@ -343,8 +343,8 @@ func TestApproveRoutesWithPolicy_NilPolicyManagerCase(t *testing.T) {
MachineKey: key.NewMachine().Public(),
NodeKey: key.NewNode().Public(),
Hostname: "testnode",
UserID: user.ID,
User: user,
UserID: ptr.To(user.ID),
User: ptr.To(user),
RegisterMethod: util.RegisterMethodAuthKey,
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: announcedRoutes,

View file

@ -14,6 +14,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
var ap = func(ipStr string) *netip.Addr {
@ -44,17 +45,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{
@ -68,19 +69,19 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
},
want: types.Nodes{
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@ -91,17 +92,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -115,14 +116,14 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
},
want: types.Nodes{
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
},
@ -133,17 +134,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -157,14 +158,14 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: types.Nodes{
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@ -175,17 +176,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -199,14 +200,14 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
},
want: types.Nodes{
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
},
@ -217,17 +218,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -241,19 +242,19 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: types.Nodes{
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@ -264,17 +265,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -288,19 +289,19 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: types.Nodes{
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
},
@ -311,17 +312,17 @@ func TestReduceNodes(t *testing.T) {
&types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "joe"},
User: &types.User{Name: "joe"},
},
&types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
&types.Node{
ID: 3,
IPv4: ap("100.64.0.3"),
User: types.User{Name: "mickael"},
User: &types.User{Name: "mickael"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -329,7 +330,7 @@ func TestReduceNodes(t *testing.T) {
node: &types.Node{ // current nodes
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "marc"},
User: &types.User{Name: "marc"},
},
},
want: nil,
@ -347,28 +348,28 @@ func TestReduceNodes(t *testing.T) {
Hostname: "ts-head-upcrmb",
IPv4: ap("100.64.0.3"),
IPv6: ap("fd7a:115c:a1e0::3"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
&types.Node{
ID: 2,
Hostname: "ts-unstable-rlwpvr",
IPv4: ap("100.64.0.4"),
IPv6: ap("fd7a:115c:a1e0::4"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
&types.Node{
ID: 3,
Hostname: "ts-head-8w6paa",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
&types.Node{
ID: 4,
Hostname: "ts-unstable-lys2ib",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
rules: []tailcfg.FilterRule{ // list of all ACLRules registered
@ -390,7 +391,7 @@ func TestReduceNodes(t *testing.T) {
Hostname: "ts-head-8w6paa",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
want: types.Nodes{
@ -399,14 +400,14 @@ func TestReduceNodes(t *testing.T) {
Hostname: "ts-head-upcrmb",
IPv4: ap("100.64.0.3"),
IPv6: ap("fd7a:115c:a1e0::3"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
&types.Node{
ID: 2,
Hostname: "ts-unstable-rlwpvr",
IPv4: ap("100.64.0.4"),
IPv6: ap("fd7a:115c:a1e0::4"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
},
},
@ -418,13 +419,13 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "peer1",
User: types.User{Name: "mini"},
User: &types.User{Name: "mini"},
},
{
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "peer2",
User: types.User{Name: "peer2"},
User: &types.User{Name: "peer2"},
},
},
rules: []tailcfg.FilterRule{
@ -440,7 +441,7 @@ func TestReduceNodes(t *testing.T) {
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "mini",
User: types.User{Name: "mini"},
User: &types.User{Name: "mini"},
},
},
want: []*types.Node{
@ -448,7 +449,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "peer2",
User: types.User{Name: "peer2"},
User: &types.User{Name: "peer2"},
},
},
},
@ -460,19 +461,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "user1-1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
rules: []tailcfg.FilterRule{
@ -509,7 +510,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "user-2-1",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
want: []*types.Node{
@ -517,19 +518,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "user1-1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
},
@ -541,19 +542,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "user-2-1",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
rules: []tailcfg.FilterRule{
@ -590,7 +591,7 @@ func TestReduceNodes(t *testing.T) {
ID: 0,
IPv4: ap("100.64.0.1"),
Hostname: "user1-1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
},
want: []*types.Node{
@ -598,19 +599,19 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.2"),
Hostname: "user1-2",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 2,
IPv4: ap("100.64.0.3"),
Hostname: "user-2-1",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
{
ID: 3,
IPv4: ap("100.64.0.4"),
Hostname: "user2-2",
User: types.User{Name: "user2"},
User: &types.User{Name: "user2"},
},
},
},
@ -622,13 +623,13 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "user1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
{
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
@ -649,7 +650,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "user1",
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
},
want: []*types.Node{
@ -657,7 +658,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.33.0.0/16")},
},
@ -673,7 +674,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@ -683,7 +684,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
rules: []tailcfg.FilterRule{
@ -700,7 +701,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@ -712,7 +713,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
},
@ -724,7 +725,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@ -734,7 +735,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
rules: []tailcfg.FilterRule{
@ -751,7 +752,7 @@ func TestReduceNodes(t *testing.T) {
ID: 2,
IPv4: ap("100.64.0.2"),
Hostname: "node",
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
},
want: []*types.Node{
@ -759,7 +760,7 @@ func TestReduceNodes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
Hostname: "router",
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("10.99.0.0/16")},
},
@ -804,7 +805,7 @@ func TestReduceNodesFromPolicy(t *testing.T) {
ID: id,
IPv4: ap(ip),
Hostname: hostname,
User: types.User{Name: username},
User: &types.User{Name: username},
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: routes,
},
@ -812,8 +813,6 @@ func TestReduceNodesFromPolicy(t *testing.T) {
}
}
type args struct {
}
tests := []struct {
name string
nodes types.Nodes
@ -1075,22 +1074,22 @@ func TestSSHPolicyRules(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: ap("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(uint(1)),
User: ptr.To(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: ap("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
}
taggedClient := types.Node{
Hostname: "tagged-client",
IPv4: ap("100.64.0.4"),
UserID: 2,
User: users[1],
ForcedTags: []string{"tag:client"},
Hostname: "tagged-client",
IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
Tags: []string{"tag:client"},
}
tests := []struct {
@ -1447,7 +1446,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@ -1475,7 +1474,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@ -1501,7 +1500,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@ -1529,7 +1528,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@ -1556,7 +1555,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@ -1581,7 +1580,7 @@ func TestReduceRoutes(t *testing.T) {
ID: 1,
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: types.User{Name: "user1"},
User: &types.User{Name: "user1"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/24"),
@ -1614,7 +1613,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // Node IP
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1646,7 +1645,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1673,7 +1672,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1701,7 +1700,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"),
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1739,7 +1738,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // node with IP 100.64.0.2
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1774,7 +1773,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.64.0.1"), // router with IP 100.64.0.1
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1816,7 +1815,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // node
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1850,7 +1849,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.64.0.2"), // node
User: types.User{Name: "node"},
User: &types.User{Name: "node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("10.10.10.0/24"),
@ -1887,7 +1886,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.123.45.89"), // Node B - regular node
User: types.User{Name: "node-b"},
User: &types.User{Name: "node-b"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet connected to Node A
@ -1917,7 +1916,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 1,
IPv4: ap("100.123.45.67"), // Node A - router node
User: types.User{Name: "router"},
User: &types.User{Name: "router"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet connected to this router
@ -1946,7 +1945,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.123.45.89"), // Node B - regular node that should be reachable
User: types.User{Name: "node-b"},
User: &types.User{Name: "node-b"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet behind router
@ -1984,7 +1983,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 3,
IPv4: ap("100.123.45.99"), // Node C - isolated node
User: types.User{Name: "isolated-node"},
User: &types.User{Name: "isolated-node"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/24"), // Subnet behind router
@ -2027,7 +2026,7 @@ func TestReduceRoutes(t *testing.T) {
node: &types.Node{
ID: 2,
IPv4: ap("100.123.45.89"), // Node B - regular node
User: types.User{Name: "node-b"},
User: &types.User{Name: "node-b"},
},
routes: []netip.Prefix{
netip.MustParsePrefix("192.168.1.0/14"), // Network 192.168.1.0/14 as mentioned in original issue

View file

@ -16,6 +16,7 @@ import (
"gorm.io/gorm"
"tailscale.com/net/tsaddr"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
"tailscale.com/util/must"
)
@ -143,13 +144,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2221"),
User: users[0],
User: ptr.To(users[0]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0:ab12:4843:2222:6273:2222"),
User: users[0],
User: ptr.To(users[0]),
},
},
want: []tailcfg.FilterRule{},
@ -190,7 +191,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{
netip.MustParsePrefix("10.33.0.0/16"),
@ -201,7 +202,7 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -282,19 +283,19 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
// "internal" exit node
&types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -343,7 +344,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -352,12 +353,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -452,7 +453,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: tsaddr.ExitRoutes(),
},
@ -461,12 +462,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -564,7 +565,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/16"), netip.MustParsePrefix("16.0.0.0/16")},
},
@ -573,12 +574,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -654,7 +655,7 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("8.0.0.0/8"), netip.MustParsePrefix("16.0.0.0/8")},
},
@ -663,12 +664,12 @@ func TestReduceFilterRules(t *testing.T) {
&types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[2],
User: ptr.To(users[2]),
},
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -736,17 +737,17 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.100"),
IPv6: ap("fd7a:115c:a1e0::100"),
User: users[3],
User: ptr.To(users[3]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/24")},
},
ForcedTags: []string{"tag:access-servers"},
Tags: []string{"tag:access-servers"},
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
},
},
want: []tailcfg.FilterRule{
@ -803,13 +804,13 @@ func TestReduceFilterRules(t *testing.T) {
node: &types.Node{
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[3],
User: ptr.To(users[3]),
},
peers: types.Nodes{
&types.Node{
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[1],
User: ptr.To(users[1]),
Hostinfo: &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{p("172.16.0.0/24"), p("10.10.11.0/24"), p("10.10.12.0/24")},
},

View file

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
func TestNodeCanApproveRoute(t *testing.T) {
@ -24,34 +25,34 @@ func TestNodeCanApproveRoute(t *testing.T) {
ID: 1,
Hostname: "user1-device",
IPv4: ap("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(uint(1)),
User: ptr.To(users[0]),
}
exitNode := types.Node{
ID: 2,
Hostname: "user2-device",
IPv4: ap("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
}
taggedNode := types.Node{
ID: 3,
Hostname: "tagged-server",
IPv4: ap("100.64.0.3"),
UserID: 3,
User: users[2],
ForcedTags: []string{"tag:router"},
ID: 3,
Hostname: "tagged-server",
IPv4: ap("100.64.0.3"),
UserID: ptr.To(uint(3)),
User: ptr.To(users[2]),
Tags: []string{"tag:router"},
}
multiTagNode := types.Node{
ID: 4,
Hostname: "multi-tag-node",
IPv4: ap("100.64.0.4"),
UserID: 2,
User: users[1],
ForcedTags: []string{"tag:router", "tag:server"},
ID: 4,
Hostname: "multi-tag-node",
IPv4: ap("100.64.0.4"),
UserID: ptr.To(uint(2)),
User: ptr.To(users[1]),
Tags: []string{"tag:router", "tag:server"},
}
tests := []struct {

View file

@ -168,7 +168,7 @@ func (pol *Policy) compileACLWithAutogroupSelf(
// Pre-filter to same-user untagged devices once - reuse for both sources and destinations
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID == node.User().ID && !n.IsTagged() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
}
}
@ -349,7 +349,7 @@ func (pol *Policy) compileSSHPolicy(
// Build destination set for autogroup:self (same-user untagged devices only)
var dest netipx.IPSetBuilder
for _, n := range nodes.All() {
if n.User().ID == node.User().ID && !n.IsTagged() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
n.AppendToIPSet(&dest)
}
}
@ -365,7 +365,7 @@ func (pol *Policy) compileSSHPolicy(
// Pre-filter to same-user untagged devices for efficiency
sameUserNodes := make([]types.NodeView, 0)
for _, n := range nodes.All() {
if n.User().ID == node.User().ID && !n.IsTagged() {
if n.User().ID() == node.User().ID() && !n.IsTagged() {
sameUserNodes = append(sameUserNodes, n)
}
}

View file

@ -15,6 +15,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
// aliasWithPorts creates an AliasWithPorts structure from an alias and ports.
@ -381,7 +382,7 @@ func TestParsing(t *testing.T) {
},
&types.Node{
IPv4: ap("200.200.200.200"),
User: users[0],
User: &users[0],
Hostinfo: &tailcfg.Hostinfo{},
},
}.ViewSlice())
@ -409,14 +410,14 @@ func TestCompileSSHPolicy_UserMapping(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
}
nodes := types.Nodes{&nodeUser1, &nodeUser2}
@ -621,14 +622,14 @@ func TestCompileSSHPolicy_CheckAction(t *testing.T) {
nodeUser1 := types.Node{
Hostname: "user1-device",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
}
nodeUser2 := types.Node{
Hostname: "user2-device",
IPv4: createAddr("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
}
nodes := types.Nodes{&nodeUser1, &nodeUser2}
@ -682,15 +683,15 @@ func TestSSHIntegrationReproduction(t *testing.T) {
node1 := &types.Node{
Hostname: "user1-node",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: ptr.To(users[0].ID),
User: ptr.To(users[0]),
}
node2 := &types.Node{
Hostname: "user2-node",
IPv4: createAddr("100.64.0.2"),
UserID: 2,
User: users[1],
UserID: ptr.To(users[1].ID),
User: ptr.To(users[1]),
}
nodes := types.Nodes{node1, node2}
@ -741,11 +742,12 @@ func TestSSHJSONSerialization(t *testing.T) {
{Name: "user1", Model: gorm.Model{ID: 1}},
}
uid := uint(1)
node := &types.Node{
Hostname: "test-node",
IPv4: createAddr("100.64.0.1"),
UserID: 1,
User: users[0],
UserID: &uid,
User: &users[0],
}
nodes := types.Nodes{node}
@ -804,32 +806,32 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
nodes := types.Nodes{
{
User: users[0],
User: ptr.To(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: users[0],
User: ptr.To(users[0]),
IPv4: ap("100.64.0.2"),
},
{
User: users[1],
User: ptr.To(users[1]),
IPv4: ap("100.64.0.3"),
},
{
User: users[1],
User: ptr.To(users[1]),
IPv4: ap("100.64.0.4"),
},
// Tagged device for user1
{
User: users[0],
IPv4: ap("100.64.0.5"),
ForcedTags: []string{"tag:test"},
User: &users[0],
IPv4: ap("100.64.0.5"),
Tags: []string{"tag:test"},
},
// Tagged device for user2
{
User: users[1],
IPv4: ap("100.64.0.6"),
ForcedTags: []string{"tag:test"},
User: &users[1],
IPv4: ap("100.64.0.6"),
Tags: []string{"tag:test"},
},
}
@ -925,6 +927,251 @@ func TestCompileFilterRulesForNodeWithAutogroupSelf(t *testing.T) {
}
}
// TestTagUserMutualExclusivity tests that user-owned nodes and tagged nodes
// are treated as separate identity classes and cannot inadvertently access each other.
func TestTagUserMutualExclusivity(t *testing.T) {
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
}
nodes := types.Nodes{
// User-owned nodes
{
User: ptr.To(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[1]),
IPv4: ap("100.64.0.2"),
},
// Tagged nodes
{
User: &users[0], // "created by" tracking
IPv4: ap("100.64.0.10"),
Tags: []string{"tag:server"},
},
{
User: &users[1], // "created by" tracking
IPv4: ap("100.64.0.11"),
Tags: []string{"tag:database"},
},
}
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
},
ACLs: []ACL{
// Rule 1: user1 (user-owned) should NOT be able to reach tagged nodes
{
Action: "accept",
Sources: []Alias{up("user1@")},
Destinations: []AliasWithPorts{
aliasWithPorts(tp("tag:server"), tailcfg.PortRangeAny),
},
},
// Rule 2: tag:server should be able to reach tag:database
{
Action: "accept",
Sources: []Alias{tp("tag:server")},
Destinations: []AliasWithPorts{
aliasWithPorts(tp("tag:database"), tailcfg.PortRangeAny),
},
},
},
}
err := policy.validate()
if err != nil {
t.Fatalf("policy validation failed: %v", err)
}
// Test user1's user-owned node (100.64.0.1)
userNode := nodes[0].View()
userRules, err := policy.compileFilterRulesForNode(users, userNode, nodes.ViewSlice())
if err != nil {
t.Fatalf("unexpected error for user node: %v", err)
}
// User1's user-owned node should NOT reach tag:server (100.64.0.10)
// because user1@ as a source only matches user1's user-owned devices, NOT tagged devices
for _, rule := range userRules {
for _, dst := range rule.DstPorts {
if dst.IP == "100.64.0.10" {
t.Errorf("SECURITY: user-owned node should NOT reach tagged node (got dest %s in rule)", dst.IP)
}
}
}
// Test tag:server node (100.64.0.10)
// compileFilterRulesForNode returns rules for what the node can ACCESS (as source)
taggedNode := nodes[2].View()
taggedRules, err := policy.compileFilterRulesForNode(users, taggedNode, nodes.ViewSlice())
if err != nil {
t.Fatalf("unexpected error for tagged node: %v", err)
}
// Tag:server (as source) should be able to reach tag:database (100.64.0.11)
// Check destinations in the rules for this node
foundDatabaseDest := false
for _, rule := range taggedRules {
// Check if this rule applies to tag:server as source
if !slices.Contains(rule.SrcIPs, "100.64.0.10/32") {
continue
}
// Check if tag:database is in destinations
for _, dst := range rule.DstPorts {
if dst.IP == "100.64.0.11/32" {
foundDatabaseDest = true
break
}
}
if foundDatabaseDest {
break
}
}
if !foundDatabaseDest {
t.Errorf("tag:server should reach tag:database but didn't find 100.64.0.11 in destinations")
}
}
// TestAutogroupTagged tests that autogroup:tagged correctly selects all devices
// with tag-based identity (IsTagged() == true or has requested tags in tagOwners).
func TestAutogroupTagged(t *testing.T) {
t.Parallel()
users := types.Users{
{Model: gorm.Model{ID: 1}, Name: "user1"},
{Model: gorm.Model{ID: 2}, Name: "user2"},
}
nodes := types.Nodes{
// User-owned nodes (not tagged)
{
User: ptr.To(users[0]),
IPv4: ap("100.64.0.1"),
},
{
User: ptr.To(users[1]),
IPv4: ap("100.64.0.2"),
},
// Tagged nodes
{
User: &users[0], // "created by" tracking
IPv4: ap("100.64.0.10"),
Tags: []string{"tag:server"},
},
{
User: &users[1], // "created by" tracking
IPv4: ap("100.64.0.11"),
Tags: []string{"tag:database"},
},
{
User: &users[0],
IPv4: ap("100.64.0.12"),
Tags: []string{"tag:web", "tag:prod"},
},
}
policy := &Policy{
TagOwners: TagOwners{
Tag("tag:server"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
},
ACLs: []ACL{
// Rule: autogroup:tagged can reach user-owned nodes
{
Action: "accept",
Sources: []Alias{agp("autogroup:tagged")},
Destinations: []AliasWithPorts{
aliasWithPorts(up("user1@"), tailcfg.PortRangeAny),
aliasWithPorts(up("user2@"), tailcfg.PortRangeAny),
},
},
},
}
err := policy.validate()
require.NoError(t, err)
// Verify autogroup:tagged includes all tagged nodes
taggedIPs, err := AutoGroupTagged.Resolve(policy, users, nodes.ViewSlice())
require.NoError(t, err)
require.NotNil(t, taggedIPs)
// Should contain all tagged nodes
assert.True(t, taggedIPs.Contains(*ap("100.64.0.10")), "should include tag:server")
assert.True(t, taggedIPs.Contains(*ap("100.64.0.11")), "should include tag:database")
assert.True(t, taggedIPs.Contains(*ap("100.64.0.12")), "should include tag:web,tag:prod")
// Should NOT contain user-owned nodes
assert.False(t, taggedIPs.Contains(*ap("100.64.0.1")), "should not include user1 node")
assert.False(t, taggedIPs.Contains(*ap("100.64.0.2")), "should not include user2 node")
// Test ACL filtering: all tagged nodes should be able to reach user nodes
tests := []struct {
name string
sourceNode types.NodeView
shouldReach []string // IP strings for comparison
}{
{
name: "tag:server can reach user-owned nodes",
sourceNode: nodes[2].View(),
shouldReach: []string{"100.64.0.1", "100.64.0.2"},
},
{
name: "tag:database can reach user-owned nodes",
sourceNode: nodes[3].View(),
shouldReach: []string{"100.64.0.1", "100.64.0.2"},
},
{
name: "tag:web,tag:prod can reach user-owned nodes",
sourceNode: nodes[4].View(),
shouldReach: []string{"100.64.0.1", "100.64.0.2"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
rules, err := policy.compileFilterRulesForNode(users, tt.sourceNode, nodes.ViewSlice())
require.NoError(t, err)
// Verify all expected destinations are reachable
for _, expectedDest := range tt.shouldReach {
found := false
for _, rule := range rules {
for _, dstPort := range rule.DstPorts {
// DstPort.IP is CIDR notation like "100.64.0.1/32"
if strings.HasPrefix(dstPort.IP, expectedDest+"/") || dstPort.IP == expectedDest {
found = true
break
}
}
if found {
break
}
}
assert.True(t, found, "Expected to find destination %s in rules", expectedDest)
}
})
}
}
func TestAutogroupSelfInSourceIsRejected(t *testing.T) {
// Test that autogroup:self cannot be used in sources (per Tailscale spec)
policy := &Policy{
@ -959,10 +1206,10 @@ func TestAutogroupSelfWithSpecificUserSource(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
}
policy := &Policy{
@ -1026,11 +1273,11 @@ func TestAutogroupSelfWithGroupSource(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: users[2], IPv4: ap("100.64.0.5")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
}
policy := &Policy{
@ -1095,13 +1342,13 @@ func TestSSHWithAutogroupSelfInDestination(t *testing.T) {
nodes := types.Nodes{
// User1's nodes
{User: users[0], IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
{User: users[0], IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-node1"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-node2"},
// User2's nodes
{User: users[1], IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
{User: users[1], IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-node1"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-node2"},
// Tagged node for user1 (should be excluded)
{User: users[0], IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", ForcedTags: []string{"tag:server"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.5"), Hostname: "user1-tagged", Tags: []string{"tag:server"}},
}
policy := &Policy{
@ -1173,10 +1420,10 @@ func TestSSHWithAutogroupSelfAndSpecificUser(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
}
policy := &Policy{
@ -1227,11 +1474,11 @@ func TestSSHWithAutogroupSelfAndGroup(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1")},
{User: users[0], IPv4: ap("100.64.0.2")},
{User: users[1], IPv4: ap("100.64.0.3")},
{User: users[1], IPv4: ap("100.64.0.4")},
{User: users[2], IPv4: ap("100.64.0.5")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1")},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3")},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4")},
{User: ptr.To(users[2]), IPv4: ap("100.64.0.5")},
}
policy := &Policy{
@ -1284,10 +1531,10 @@ func TestSSHWithAutogroupSelfExcludesTaggedDevices(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
{User: users[0], IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
{User: users[0], IPv4: ap("100.64.0.3"), Hostname: "tagged1", ForcedTags: []string{"tag:server"}},
{User: users[0], IPv4: ap("100.64.0.4"), Hostname: "tagged2", ForcedTags: []string{"tag:web"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "untagged1"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "untagged2"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.3"), Hostname: "tagged1", Tags: []string{"tag:server"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.4"), Hostname: "tagged2", Tags: []string{"tag:web"}},
}
policy := &Policy{
@ -1344,10 +1591,10 @@ func TestSSHWithAutogroupSelfAndMixedDestinations(t *testing.T) {
}
nodes := types.Nodes{
{User: users[0], IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
{User: users[0], IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
{User: users[1], IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
{User: users[1], IPv4: ap("100.64.0.4"), Hostname: "user2-router", ForcedTags: []string{"tag:router"}},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.1"), Hostname: "user1-device"},
{User: ptr.To(users[0]), IPv4: ap("100.64.0.2"), Hostname: "user1-device2"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.3"), Hostname: "user2-device"},
{User: ptr.To(users[1]), IPv4: ap("100.64.0.4"), Hostname: "user2-router", Tags: []string{"tag:router"}},
}
policy := &Policy{

View file

@ -697,14 +697,14 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// Check for removed nodes
for nodeID, oldNode := range oldNodeMap {
if _, exists := newNodeMap[nodeID]; !exists {
affectedUsers[oldNode.User().ID] = struct{}{}
affectedUsers[oldNode.User().ID()] = struct{}{}
}
}
// Check for added nodes
for nodeID, newNode := range newNodeMap {
if _, exists := oldNodeMap[nodeID]; !exists {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
}
}
@ -712,26 +712,26 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
for nodeID, newNode := range newNodeMap {
if oldNode, exists := oldNodeMap[nodeID]; exists {
// Check if user changed
if oldNode.User().ID != newNode.User().ID {
affectedUsers[oldNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID] = struct{}{}
if oldNode.User().ID() != newNode.User().ID() {
affectedUsers[oldNode.User().ID()] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
}
// Check if tag status changed
if oldNode.IsTagged() != newNode.IsTagged() {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
}
// Check if IPs changed (simple check - could be more sophisticated)
oldIPs := oldNode.IPs()
newIPs := newNode.IPs()
if len(oldIPs) != len(newIPs) {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
} else {
// Check if any IPs are different
for i, oldIP := range oldIPs {
if i >= len(newIPs) || oldIP != newIPs[i] {
affectedUsers[newNode.User().ID] = struct{}{}
affectedUsers[newNode.User().ID()] = struct{}{}
break
}
}
@ -750,7 +750,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
// Check in new nodes first
for _, node := range newNodes.All() {
if node.ID() == nodeID {
nodeUserID = node.User().ID
nodeUserID = node.User().ID()
found = true
break
}
@ -760,7 +760,7 @@ func (pm *PolicyManager) invalidateAutogroupSelfCache(oldNodes, newNodes views.S
if !found {
for _, node := range oldNodes.All() {
if node.ID() == nodeID {
nodeUserID = node.User().ID
nodeUserID = node.User().ID()
found = true
break
}

View file

@ -11,6 +11,7 @@ import (
"github.com/stretchr/testify/require"
"gorm.io/gorm"
"tailscale.com/tailcfg"
"tailscale.com/types/ptr"
)
func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo) *types.Node {
@ -19,8 +20,8 @@ func node(name, ipv4, ipv6 string, user types.User, hostinfo *tailcfg.Hostinfo)
Hostname: name,
IPv4: ap(ipv4),
IPv6: ap(ipv6),
User: user,
UserID: user.ID,
User: ptr.To(user),
UserID: ptr.To(user.ID),
Hostinfo: hostinfo,
}
}
@ -456,8 +457,8 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
Hostname: "test-1-device",
IPv4: ap("100.64.0.1"),
IPv6: ap("fd7a:115c:a1e0::1"),
User: users[0],
UserID: users[0].ID,
User: ptr.To(users[0]),
UserID: ptr.To(users[0].ID),
Hostinfo: &tailcfg.Hostinfo{},
}
@ -467,9 +468,9 @@ func TestAutogroupSelfWithOtherRules(t *testing.T) {
Hostname: "test-2-router",
IPv4: ap("100.64.0.2"),
IPv6: ap("fd7a:115c:a1e0::2"),
User: users[1],
UserID: users[1].ID,
ForcedTags: []string{"tag:node-router"},
User: ptr.To(users[1]),
UserID: ptr.To(users[1].ID),
Tags: []string{"tag:node-router"},
Hostinfo: &tailcfg.Hostinfo{},
}

View file

@ -206,7 +206,12 @@ func (u Username) Resolve(_ *Policy, users types.Users, nodes views.Slice[types.
continue
}
if node.User().ID == user.ID {
// Skip nodes without a user (defensive check for tests)
if !node.User().Valid() {
continue
}
if node.User().ID() == user.ID {
node.AppendToIPSet(&ips)
}
}
@ -311,8 +316,8 @@ func (t Tag) Resolve(p *Policy, users types.Users, nodes views.Slice[types.NodeV
}
for _, node := range nodes.All() {
// Check if node has this tag in all tags (ForcedTags + AuthKey.Tags)
if slices.Contains(node.Tags(), string(t)) {
// Check if node has this tag
if node.HasTag(string(t)) {
node.AppendToIPSet(&ips)
}

View file

@ -1549,7 +1549,17 @@ func TestResolvePolicy(t *testing.T) {
"groupuser1": {Model: gorm.Model{ID: 3}, Name: "groupuser1"},
"groupuser2": {Model: gorm.Model{ID: 4}, Name: "groupuser2"},
"notme": {Model: gorm.Model{ID: 5}, Name: "notme"},
"testuser2": {Model: gorm.Model{ID: 6}, Name: "testuser2"},
}
// Extract users to variables so we can take their addresses
testuser := users["testuser"]
groupuser := users["groupuser"]
groupuser1 := users["groupuser1"]
groupuser2 := users["groupuser2"]
notme := users["notme"]
testuser2 := users["testuser2"]
tests := []struct {
name string
nodes types.Nodes
@ -1579,29 +1589,27 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: users["notme"],
User: ptr.To(notme),
IPv4: ap("100.100.101.1"),
},
// Not matching forced tags
{
User: users["testuser"],
ForcedTags: []string{"tag:anything"},
User: ptr.To(testuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.2"),
},
// not matching pak tag
// not matching because it's tagged (tags copied from AuthKey)
{
User: users["testuser"],
AuthKey: &types.PreAuthKey{
Tags: []string{"alsotagged"},
},
User: ptr.To(testuser),
Tags: []string{"alsotagged"},
IPv4: ap("100.100.101.3"),
},
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.103"),
},
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.104"),
},
},
@ -1613,29 +1621,27 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: users["notme"],
User: ptr.To(notme),
IPv4: ap("100.100.101.4"),
},
// Not matching forced tags
{
User: users["groupuser"],
ForcedTags: []string{"tag:anything"},
User: ptr.To(groupuser),
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.5"),
},
// not matching pak tag
// not matching because it's tagged (tags copied from AuthKey)
{
User: users["groupuser"],
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:alsotagged"},
},
User: ptr.To(groupuser),
Tags: []string{"tag:alsotagged"},
IPv4: ap("100.100.101.6"),
},
{
User: users["groupuser"],
User: ptr.To(groupuser),
IPv4: ap("100.100.101.203"),
},
{
User: users["groupuser"],
User: ptr.To(groupuser),
IPv4: ap("100.100.101.204"),
},
},
@ -1653,12 +1659,12 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Not matching other user
{
User: users["notme"],
User: ptr.To(notme),
IPv4: ap("100.100.101.9"),
},
// Not matching forced tags
{
ForcedTags: []string{"tag:anything"},
Tags: []string{"tag:anything"},
IPv4: ap("100.100.101.10"),
},
// not matching pak tag
@ -1670,14 +1676,12 @@ func TestResolvePolicy(t *testing.T) {
},
// Not matching forced tags
{
ForcedTags: []string{"tag:test"},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
},
// not matching pak tag
// matching tag (tags copied from AuthKey during registration)
{
AuthKey: &types.PreAuthKey{
Tags: []string{"tag:test"},
},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.239"),
},
},
@ -1706,11 +1710,11 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(Group("group:testgroup")),
nodes: types.Nodes{
{
User: users["groupuser1"],
User: ptr.To(groupuser1),
IPv4: ap("100.100.101.203"),
},
{
User: users["groupuser2"],
User: ptr.To(groupuser2),
IPv4: ap("100.100.101.204"),
},
},
@ -1731,7 +1735,7 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(Username("invaliduser@")),
nodes: types.Nodes{
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.103"),
},
},
@ -1742,7 +1746,7 @@ func TestResolvePolicy(t *testing.T) {
toResolve: tp("tag:invalid"),
nodes: types.Nodes{
{
ForcedTags: []string{"tag:test"},
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.234"),
},
},
@ -1763,18 +1767,18 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Node with no tags (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
// Node with forced tags (should be excluded)
{
User: users["testuser"],
ForcedTags: []string{"tag:test"},
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@ -1782,7 +1786,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with non-allowed requested tag (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
@ -1790,7 +1794,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, one allowed (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
@ -1798,7 +1802,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, none allowed (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
@ -1822,18 +1826,18 @@ func TestResolvePolicy(t *testing.T) {
nodes: types.Nodes{
// Node with no tags (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
// Node with forced tag (should be included)
{
User: users["testuser"],
ForcedTags: []string{"tag:test"},
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.2"),
},
// Node with allowed requested tag (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@ -1841,7 +1845,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with non-allowed requested tag (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed"},
},
@ -1849,7 +1853,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, one allowed (should be included)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test", "tag:notallowed"},
},
@ -1857,7 +1861,7 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple requested tags, none allowed (should be excluded)
{
User: users["testuser"],
User: ptr.To(testuser),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:notallowed1", "tag:notallowed2"},
},
@ -1865,8 +1869,8 @@ func TestResolvePolicy(t *testing.T) {
},
// Node with multiple forced tags (should be included)
{
User: users["testuser"],
ForcedTags: []string{"tag:test", "tag:other"},
User: ptr.To(testuser),
Tags: []string{"tag:test", "tag:other"},
IPv4: ap("100.100.101.7"),
},
},
@ -1886,20 +1890,20 @@ func TestResolvePolicy(t *testing.T) {
toResolve: ptr.To(AutoGroupSelf),
nodes: types.Nodes{
{
User: users["testuser"],
User: ptr.To(testuser),
IPv4: ap("100.100.101.1"),
},
{
User: users["testuser2"],
User: ptr.To(testuser2),
IPv4: ap("100.100.101.2"),
},
{
User: users["testuser"],
ForcedTags: []string{"tag:test"},
User: ptr.To(testuser),
Tags: []string{"tag:test"},
IPv4: ap("100.100.101.3"),
},
{
User: users["testuser2"],
User: ptr.To(testuser2),
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:test"},
},
@ -1961,23 +1965,23 @@ func TestResolveAutoApprovers(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
{
IPv4: ap("100.64.0.4"),
ForcedTags: []string{"tag:testtag"},
Tags: []string{"tag:testtag"},
},
{
IPv4: ap("100.64.0.5"),
ForcedTags: []string{"tag:exittest"},
Tags: []string{"tag:exittest"},
},
}
@ -2280,15 +2284,15 @@ func TestNodeCanApproveRoute(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
}
@ -2413,15 +2417,15 @@ func TestResolveTagOwners(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
}
@ -2498,15 +2502,15 @@ func TestNodeCanHaveTag(t *testing.T) {
nodes := types.Nodes{
{
IPv4: ap("100.64.0.1"),
User: users[0],
User: &users[0],
},
{
IPv4: ap("100.64.0.2"),
User: users[1],
User: &users[1],
},
{
IPv4: ap("100.64.0.3"),
User: users[2],
User: &users[2],
},
}
@ -2580,6 +2584,49 @@ func TestNodeCanHaveTag(t *testing.T) {
tag: "tag:test",
want: false,
},
{
name: "node-with-unauthorized-tag-different-user",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
},
},
node: nodes[2], // user3's node
tag: "tag:prod",
want: false,
},
{
name: "node-with-multiple-tags-one-unauthorized",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:web"): Owners{ptr.To(Username("user1@"))},
Tag("tag:database"): Owners{ptr.To(Username("user2@"))},
},
},
node: nodes[0], // user1's node
tag: "tag:database",
want: false, // user1 cannot have tag:database (owned by user2)
},
{
name: "empty-tagowners-map",
policy: &Policy{
TagOwners: TagOwners{},
},
node: nodes[0],
tag: "tag:test",
want: false, // No one can have tags if tagOwners is empty
},
{
name: "tag-not-in-tagowners",
policy: &Policy{
TagOwners: TagOwners{
Tag("tag:prod"): Owners{ptr.To(Username("user1@"))},
},
},
node: nodes[0],
tag: "tag:dev", // This tag is not defined in tagOwners
want: false,
},
}
for _, tt := range tests {

View file

@ -11,6 +11,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"github.com/sasha-s/go-deadlock"
"tailscale.com/tailcfg"
@ -42,11 +43,6 @@ type mapSession struct {
node *types.Node
w http.ResponseWriter
warnf func(string, ...any)
infof func(string, ...any)
tracef func(string, ...any)
errf func(error, string, ...any)
}
func (h *Headscale) newMapSession(
@ -55,8 +51,6 @@ func (h *Headscale) newMapSession(
w http.ResponseWriter,
node *types.Node,
) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node)
ka := keepAliveInterval + (time.Duration(rand.IntN(9000)) * time.Millisecond)
return &mapSession{
@ -73,12 +67,6 @@ func (h *Headscale) newMapSession(
keepAlive: ka,
keepAliveTicker: nil,
// Loggers
warnf: warnf,
infof: infof,
tracef: tracef,
errf: errf,
}
}
@ -295,6 +283,7 @@ func (m *mapSession) writeMap(msg *tailcfg.MapResponse) error {
}
data := make([]byte, reservedResponseHeaderSize)
//nolint:gosec // G115: JSON response size will not exceed uint32 max
binary.LittleEndian.PutUint32(data, uint32(len(jsonBody)))
data = append(data, jsonBody...)
@ -365,45 +354,22 @@ func logTracePeerChange(hostname string, hostinfoChange bool, peerChange *tailcf
trace.Time("last_seen", *peerChange.LastSeen).Msg("PeerChange received")
}
func logPollFunc(
mapRequest tailcfg.MapRequest,
node *types.Node,
) (func(string, ...any), func(string, ...any), func(string, ...any), func(error, string, ...any)) {
return func(msg string, a ...any) {
log.Warn().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Info().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(msg string, a ...any) {
log.Trace().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Msgf(msg, a...)
},
func(err error, msg string, a ...any) {
log.Error().
Caller().
Bool("omitPeers", mapRequest.OmitPeers).
Bool("stream", mapRequest.Stream).
Uint64("node.id", node.ID.Uint64()).
Str("node.name", node.Hostname).
Err(err).
Msgf(msg, a...)
}
// logf adds common mapSession context to a zerolog event.
func (m *mapSession) logf(event *zerolog.Event) *zerolog.Event {
return event.
Bool("omitPeers", m.req.OmitPeers).
Bool("stream", m.req.Stream).
Uint64("node.id", m.node.ID.Uint64()).
Str("node.name", m.node.Hostname)
}
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
func (m *mapSession) infof(msg string, a ...any) { m.logf(log.Info().Caller()).Msgf(msg, a...) }
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
func (m *mapSession) tracef(msg string, a ...any) { m.logf(log.Trace().Caller()).Msgf(msg, a...) }
//nolint:zerologlint // logf returns *zerolog.Event which is properly terminated with Msgf
func (m *mapSession) errf(err error, msg string, a ...any) {
m.logf(log.Error().Caller()).Err(err).Msgf(msg, a...)
}

View file

@ -78,7 +78,7 @@ func (s *State) DebugOverview() string {
now := time.Now()
for _, node := range allNodes.All() {
if node.Valid() {
userName := node.User().Name
userName := node.User().Name()
userNodeCounts[userName]++
if node.IsOnline().Valid() && node.IsOnline().Get() {
@ -281,7 +281,7 @@ func (s *State) DebugOverviewJSON() DebugOverviewInfo {
for _, node := range allNodes.All() {
if node.Valid() {
userName := node.User().Name
userName := node.User().Name()
info.Users[userName]++
if node.IsOnline().Valid() && node.IsOnline().Get() {

View file

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestNetInfoFromMapRequest(t *testing.T) {
@ -148,8 +149,8 @@ func createTestNodeSimple(id types.NodeID) *types.Node {
node := &types.Node{
ID: id,
Hostname: "test-node",
UserID: uint(id),
User: user,
UserID: ptr.To(uint(id)),
User: &user,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
IPv4: &netip.Addr{},

View file

@ -408,7 +408,7 @@ func snapshotFromNodes(nodes map[types.NodeID]types.Node, peersFunc PeersFunc) S
// Build nodesByUser, nodesByNodeKey, and nodesByMachineKey maps
for _, n := range nodes {
nodeView := n.View()
userID := types.UserID(n.UserID)
userID := n.TypedUserID()
newSnap.nodesByUser[userID] = append(newSnap.nodesByUser[userID], nodeView)
newSnap.nodesByNodeKey[n.NodeKey] = nodeView
@ -515,7 +515,7 @@ func (s *NodeStore) DebugString() string {
if len(nodes) > 0 {
userName := "unknown"
if len(nodes) > 0 && nodes[0].Valid() {
userName = nodes[0].User().Name
userName = nodes[0].User().Name()
}
sb.WriteString(fmt.Sprintf(" - User %d (%s): %d nodes\n", userID, userName, len(nodes)))
}

View file

@ -13,6 +13,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"tailscale.com/types/key"
"tailscale.com/types/ptr"
)
func TestSnapshotFromNodes(t *testing.T) {
@ -173,8 +174,8 @@ func createTestNode(nodeID types.NodeID, userID uint, username, hostname string)
DiscoKey: discoKey.Public(),
Hostname: hostname,
GivenName: hostname,
UserID: userID,
User: types.User{
UserID: ptr.To(userID),
User: &types.User{
Name: username,
DisplayName: username,
},
@ -627,7 +628,7 @@ func TestNodeStoreOperations(t *testing.T) {
go func() {
resultNode3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.ForcedTags = []string{"tag1", "tag2"}
n.Tags = []string{"tag1", "tag2"}
})
close(done3)
}()
@ -648,24 +649,24 @@ func TestNodeStoreOperations(t *testing.T) {
// resultNode1 (from hostname update) should also have the givenname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode1.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode1.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.ForcedTags().AsSlice())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode1.Tags().AsSlice())
// resultNode2 (from givenname update) should also have the hostname and tags changes
assert.Equal(t, "multi-update-hostname", resultNode2.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode2.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.ForcedTags().AsSlice())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode2.Tags().AsSlice())
// resultNode3 (from tags update) should also have the hostname and givenname changes
assert.Equal(t, "multi-update-hostname", resultNode3.Hostname())
assert.Equal(t, "multi-update-givenname", resultNode3.GivenName())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.ForcedTags().AsSlice())
assert.Equal(t, []string{"tag1", "tag2"}, resultNode3.Tags().AsSlice())
// Verify the snapshot also has all changes
snapshot := store.data.Load()
finalNode := snapshot.nodesByID[1]
assert.Equal(t, "multi-update-hostname", finalNode.Hostname)
assert.Equal(t, "multi-update-givenname", finalNode.GivenName)
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.ForcedTags)
assert.Equal(t, []string{"tag1", "tag2"}, finalNode.Tags)
},
},
},
@ -687,7 +688,7 @@ func TestNodeStoreOperations(t *testing.T) {
resultNode, ok := store.UpdateNode(1, func(n *types.Node) {
n.Hostname = "db-save-hostname"
n.GivenName = "db-save-given"
n.ForcedTags = []string{"db-tag1", "db-tag2"}
n.Tags = []string{"db-tag1", "db-tag2"}
})
assert.True(t, ok, "UpdateNode should succeed")
@ -696,21 +697,21 @@ func TestNodeStoreOperations(t *testing.T) {
// Verify the returned node has all expected values
assert.Equal(t, "db-save-hostname", resultNode.Hostname())
assert.Equal(t, "db-save-given", resultNode.GivenName())
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.ForcedTags().AsSlice())
assert.Equal(t, []string{"db-tag1", "db-tag2"}, resultNode.Tags().AsSlice())
// Convert to struct as would be done for database save
nodePtr := resultNode.AsStruct()
assert.NotNil(t, nodePtr)
assert.Equal(t, "db-save-hostname", nodePtr.Hostname)
assert.Equal(t, "db-save-given", nodePtr.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.ForcedTags)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, nodePtr.Tags)
// Verify the snapshot also reflects the same state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, "db-save-hostname", storedNode.Hostname)
assert.Equal(t, "db-save-given", storedNode.GivenName)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.ForcedTags)
assert.Equal(t, []string{"db-tag1", "db-tag2"}, storedNode.Tags)
},
},
{
@ -742,7 +743,7 @@ func TestNodeStoreOperations(t *testing.T) {
go func() {
result3, ok3 = store.UpdateNode(1, func(n *types.Node) {
n.ForcedTags = []string{"concurrent-tag"}
n.Tags = []string{"concurrent-tag"}
})
close(done3)
}()
@ -767,22 +768,22 @@ func TestNodeStoreOperations(t *testing.T) {
// All should have the complete final state
assert.Equal(t, "concurrent-db-hostname", nodePtr1.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr1.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.ForcedTags)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr1.Tags)
assert.Equal(t, "concurrent-db-hostname", nodePtr2.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr2.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.ForcedTags)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr2.Tags)
assert.Equal(t, "concurrent-db-hostname", nodePtr3.Hostname)
assert.Equal(t, "concurrent-db-given", nodePtr3.GivenName)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.ForcedTags)
assert.Equal(t, []string{"concurrent-tag"}, nodePtr3.Tags)
// Verify consistency with stored state
snapshot := store.data.Load()
storedNode := snapshot.nodesByID[1]
assert.Equal(t, nodePtr1.Hostname, storedNode.Hostname)
assert.Equal(t, nodePtr1.GivenName, storedNode.GivenName)
assert.Equal(t, nodePtr1.ForcedTags, storedNode.ForcedTags)
assert.Equal(t, nodePtr1.Tags, storedNode.Tags)
},
},
{
@ -855,8 +856,8 @@ func createConcurrentTestNode(id types.NodeID, hostname string) types.Node {
Hostname: hostname,
MachineKey: machineKey.Public(),
NodeKey: nodeKey.Public(),
UserID: 1,
User: types.User{
UserID: ptr.To(uint(1)),
User: &types.User{
Name: "concurrent-test-user",
},
}

View file

@ -53,6 +53,9 @@ const (
// ErrUnsupportedPolicyMode is returned for invalid policy modes. Valid modes are "file" and "db".
var ErrUnsupportedPolicyMode = errors.New("unsupported policy mode")
// ErrNodeNotFound is returned when a node cannot be found by its ID.
var ErrNodeNotFound = errors.New("node not found")
// State manages Headscale's core state, coordinating between database, policy management,
// IP allocation, and DERP routing. All methods are thread-safe.
type State struct {
@ -651,13 +654,36 @@ func (s *State) SetNodeExpiry(nodeID types.NodeID, expiry time.Time) (types.Node
return s.persistNodeToDB(n)
}
// SetNodeTags assigns tags to a node for use in access control policies.
// SetNodeTags assigns tags to a node, making it a "tagged node".
// Once a node is tagged, it cannot be un-tagged (only tags can be changed).
// The UserID is preserved as "created by" information.
func (s *State) SetNodeTags(nodeID types.NodeID, tags []string) (types.NodeView, change.ChangeSet, error) {
// CANNOT REMOVE ALL TAGS
if len(tags) == 0 {
return types.NodeView{}, change.EmptySet, types.ErrCannotRemoveAllTags
}
// Get node for validation
existingNode, exists := s.nodeStore.GetNode(nodeID)
if !exists {
return types.NodeView{}, change.EmptySet, fmt.Errorf("%w: %d", ErrNodeNotFound, nodeID)
}
// Validate tags against policy
validatedTags, err := s.validateAndNormalizeTags(existingNode.AsStruct(), tags)
if err != nil {
return types.NodeView{}, change.EmptySet, err
}
// Log the operation
logTagOperation(existingNode, validatedTags)
// Update NodeStore before database to ensure consistency. The NodeStore update is
// blocking and will be the source of truth for the batcher. The database update must
// make the exact same change.
n, ok := s.nodeStore.UpdateNode(nodeID, func(node *types.Node) {
node.ForcedTags = tags
node.Tags = validatedTags
// UserID is preserved as "created by" - do NOT set to nil
})
if !ok {
@ -927,7 +953,8 @@ func (s *State) DestroyAPIKey(key types.APIKey) error {
}
// CreatePreAuthKey generates a new pre-authentication key for a user.
func (s *State) CreatePreAuthKey(userID types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) {
// The userID parameter is now optional (can be nil) for system-created tagged keys.
func (s *State) CreatePreAuthKey(userID *types.UserID, reusable bool, ephemeral bool, expiration *time.Time, aclTags []string) (*types.PreAuthKeyNew, error) {
return s.db.CreatePreAuthKey(userID, reusable, ephemeral, expiration, aclTags)
}
@ -1063,8 +1090,6 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
// Prepare the node for registration
nodeToRegister := types.Node{
Hostname: params.Hostname,
UserID: params.User.ID,
User: params.User,
MachineKey: params.MachineKey,
NodeKey: params.NodeKey,
DiscoKey: params.DiscoKey,
@ -1075,11 +1100,38 @@ func (s *State) createAndSaveNewNode(params newNodeParams) (types.NodeView, erro
Expiry: params.Expiry,
}
// Pre-auth key specific fields
// Assign ownership based on PreAuthKey
if params.PreAuthKey != nil {
nodeToRegister.ForcedTags = params.PreAuthKey.Proto().GetAclTags()
if params.PreAuthKey.IsTagged() {
// TAGGED NODE
// Tags from PreAuthKey are assigned ONLY during initial authentication
nodeToRegister.Tags = params.PreAuthKey.Proto().GetAclTags()
// Set UserID to track "created by" (who created the PreAuthKey)
if params.PreAuthKey.UserID != nil {
nodeToRegister.UserID = params.PreAuthKey.UserID
nodeToRegister.User = params.PreAuthKey.User
}
// If PreAuthKey.UserID is nil, the node is "orphaned" (system-created)
} else {
// USER-OWNED NODE
nodeToRegister.UserID = &params.PreAuthKey.User.ID
nodeToRegister.User = params.PreAuthKey.User
nodeToRegister.Tags = nil
}
nodeToRegister.AuthKey = params.PreAuthKey
nodeToRegister.AuthKeyID = &params.PreAuthKey.ID
} else {
// Non-PreAuthKey registration (OIDC, CLI) - always user-owned
nodeToRegister.UserID = &params.User.ID
nodeToRegister.User = &params.User
nodeToRegister.Tags = nil
}
// Validate before saving
err := validateNodeOwnership(&nodeToRegister)
if err != nil {
return types.NodeView{}, err
}
// Allocate new IPs
@ -1156,7 +1208,7 @@ func (s *State) HandleNodeFromAuthPath(
logHostinfoValidation(
regEntry.Node.MachineKey.ShortString(),
regEntry.Node.NodeKey.String(),
user.Username(),
user.Name,
hostname,
regEntry.Node.Hostinfo,
)
@ -1171,7 +1223,7 @@ func (s *State) HandleNodeFromAuthPath(
log.Debug().
Caller().
Str("registration_id", registrationID.String()).
Str("user.name", user.Username()).
Str("user.name", user.Name).
Str("registrationMethod", registrationMethod).
Str("node.name", existingNodeSameUser.Hostname()).
Uint64("node.id", existingNodeSameUser.ID().Uint64()).
@ -1233,7 +1285,7 @@ func (s *State) HandleNodeFromAuthPath(
// Check if node exists with this machine key for a different user (for netinfo preservation)
existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(regEntry.Node.MachineKey)
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != user.ID {
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != user.ID {
// Node exists but belongs to a different user
// Create a NEW node for the new user (do not transfer)
// This allows the same machine to have separate node identities per user
@ -1243,8 +1295,8 @@ func (s *State) HandleNodeFromAuthPath(
Str("existing.node.name", existingNodeAnyUser.Hostname()).
Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()).
Str("machine.key", regEntry.Node.MachineKey.ShortString()).
Str("old.user", oldUser.Username()).
Str("new.user", user.Username()).
Str("old.user", oldUser.Name()).
Str("new.user", user.Name).
Str("method", registrationMethod).
Msg("Creating new node for different user (same machine key exists for another user)")
}
@ -1253,7 +1305,7 @@ func (s *State) HandleNodeFromAuthPath(
log.Debug().
Caller().
Str("registration_id", registrationID.String()).
Str("user.name", user.Username()).
Str("user.name", user.Name).
Str("registrationMethod", registrationMethod).
Str("expiresAt", fmt.Sprintf("%v", expiry)).
Msg("Registering new node from auth callback")
@ -1416,8 +1468,11 @@ func (s *State) HandleNodeFromPreAuthKey(
node.RegisterMethod = util.RegisterMethodAuthKey
// TODO(kradalby): This might need a rework as part of #2417
node.ForcedTags = pak.Proto().GetAclTags()
// CRITICAL: Tags from PreAuthKey are ONLY applied during initial authentication
// On re-registration, we MUST NOT change tags or node ownership
// The node keeps whatever tags/user ownership it already has
//
// Only update AuthKey reference
node.AuthKey = pak
node.AuthKeyID = &pak.ID
node.IsOnline = ptr.To(false)
@ -1467,7 +1522,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Check if node exists with this machine key for a different user
existingNodeAnyUser, existsAnyUser := s.nodeStore.GetNodeByMachineKeyAnyUser(machineKey)
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID() != pak.User.ID {
if existsAnyUser && existingNodeAnyUser.Valid() && existingNodeAnyUser.UserID().Get() != pak.User.ID {
// Node exists but belongs to a different user
// Create a NEW node for the new user (do not transfer)
// This allows the same machine to have separate node identities per user
@ -1477,7 +1532,7 @@ func (s *State) HandleNodeFromPreAuthKey(
Str("existing.node.name", existingNodeAnyUser.Hostname()).
Uint64("existing.node.id", existingNodeAnyUser.ID().Uint64()).
Str("machine.key", machineKey.ShortString()).
Str("old.user", oldUser.Username()).
Str("old.user", oldUser.Name()).
Str("new.user", pak.User.Username()).
Msg("Creating new node for different user (same machine key exists for another user)")
}
@ -1488,7 +1543,7 @@ func (s *State) HandleNodeFromPreAuthKey(
// Create and save new node
var err error
finalNode, err = s.createAndSaveNewNode(newNodeParams{
User: pak.User,
User: *pak.User,
MachineKey: machineKey,
NodeKey: regReq.NodeKey,
DiscoKey: key.DiscoPublic{}, // DiscoKey not available in RegisterRequest

107
hscontrol/state/tags.go Normal file
View file

@ -0,0 +1,107 @@
package state
import (
"errors"
"fmt"
"slices"
"strings"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/rs/zerolog/log"
)
var (
// ErrNodeMarkedTaggedButHasNoTags is returned when a node is marked as tagged but has no tags.
ErrNodeMarkedTaggedButHasNoTags = errors.New("node marked as tagged but has no tags")
// ErrNodeHasNeitherUserNorTags is returned when a node has neither a user nor tags.
ErrNodeHasNeitherUserNorTags = errors.New("node has neither user nor tags - must be owned by user or tagged")
// ErrInvalidOrUnauthorizedTags is returned when tags are invalid or unauthorized.
ErrInvalidOrUnauthorizedTags = errors.New("invalid or unauthorized tags")
)
// validateNodeOwnership ensures proper node ownership model.
// A node must be EITHER user-owned OR tagged (mutually exclusive by behavior).
// Tagged nodes CAN have a UserID for "created by" tracking, but the tag is the owner.
func validateNodeOwnership(node *types.Node) error {
isTagged := node.IsTagged()
// Tagged nodes: Must have tags, UserID is optional (just "created by")
if isTagged {
if len(node.Tags) == 0 {
return fmt.Errorf("%w: %q", ErrNodeMarkedTaggedButHasNoTags, node.Hostname)
}
// UserID can be set (created by) or nil (orphaned), both valid for tagged nodes
return nil
}
// User-owned nodes: Must have UserID, must NOT have tags
if node.UserID == nil {
return fmt.Errorf("%w: %q", ErrNodeHasNeitherUserNorTags, node.Hostname)
}
return nil
}
// validateAndNormalizeTags validates tags against policy and normalizes them.
// Returns validated and normalized tags, or an error if validation fails.
func (s *State) validateAndNormalizeTags(node *types.Node, requestedTags []string) ([]string, error) {
if len(requestedTags) == 0 {
return nil, nil
}
var (
validTags []string
invalidTags []string
)
for _, tag := range requestedTags {
// Validate format
if !strings.HasPrefix(tag, "tag:") {
invalidTags = append(invalidTags, tag)
continue
}
// Validate against policy
nodeView := node.View()
if s.polMan.NodeCanHaveTag(nodeView, tag) {
validTags = append(validTags, tag)
} else {
invalidTags = append(invalidTags, tag)
}
}
if len(invalidTags) > 0 {
return nil, fmt.Errorf("%w: %v", ErrInvalidOrUnauthorizedTags, invalidTags)
}
// Normalize: sort and deduplicate
slices.Sort(validTags)
return slices.Compact(validTags), nil
}
// logTagOperation logs tag assignment operations for audit purposes.
func logTagOperation(existingNode types.NodeView, newTags []string) {
if existingNode.IsTagged() {
log.Info().
Uint64("node.id", existingNode.ID().Uint64()).
Str("node.name", existingNode.Hostname()).
Strs("old.tags", existingNode.Tags().AsSlice()).
Strs("new.tags", newTags).
Msg("Updating tags on already-tagged node")
} else {
var userID uint
if existingNode.UserID().Valid() {
userID = existingNode.UserID().Get()
}
log.Info().
Uint64("node.id", existingNode.ID().Uint64()).
Str("node.name", existingNode.Hostname()).
Uint("created.by.user", userID).
Strs("new.tags", newTags).
Msg("Converting user-owned node to tagged node (irreversible)")
}
}

View file

@ -6,7 +6,6 @@ import (
"net/netip"
"regexp"
"slices"
"sort"
"strconv"
"strings"
"time"
@ -28,6 +27,7 @@ var (
ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars")
ErrNodeHasNoGivenName = errors.New("node has no given name")
ErrNodeUserHasNoName = errors.New("node user has no name")
ErrCannotRemoveAllTags = errors.New("cannot remove all tags from node")
invalidDNSRegex = regexp.MustCompile("[^a-z0-9-.]+")
)
@ -97,16 +97,21 @@ type Node struct {
// GivenName is the name used in all DNS related
// parts of headscale.
GivenName string `gorm:"type:varchar(63);unique_index"`
UserID uint
User User `gorm:"constraint:OnDelete:CASCADE;"`
// UserID is set for ALL nodes (tagged and user-owned) to track "created by".
// For tagged nodes, this is informational only - the tag is the owner.
// For user-owned nodes, this identifies the owner.
// Only nil for orphaned nodes (should not happen in normal operation).
UserID *uint
User *User `gorm:"constraint:OnDelete:CASCADE;"`
RegisterMethod string
// ForcedTags are tags set by CLI/API. It is not considered
// the source of truth, but is one of the sources from
// which a tag might originate.
// ForcedTags are _always_ applied to the node.
ForcedTags []string `gorm:"column:forced_tags;serializer:json"`
// Tags is the definitive owner for tagged nodes.
// When non-empty, the node is "tagged" and tags define its identity.
// Empty for user-owned nodes.
// Tags cannot be removed once set (one-way transition).
Tags []string `gorm:"column:tags;serializer:json"`
// When a node has been created with a PreAuthKey, we need to
// prevent the preauthkey from being deleted before the node.
@ -196,55 +201,32 @@ func (node *Node) HasIP(i netip.Addr) bool {
return false
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys).
// IsTagged reports if a device is tagged and therefore should not be treated
// as a user-owned device.
// When a node has tags, the tags define its identity (not the user).
func (node *Node) IsTagged() bool {
if len(node.ForcedTags) > 0 {
return true
}
return len(node.Tags) > 0
}
if node.AuthKey != nil && len(node.AuthKey.Tags) > 0 {
return true
}
if node.Hostinfo == nil {
return false
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
return false
// IsUserOwned returns true if node is owned by a user (not tagged).
// Tagged nodes may have a UserID for "created by" tracking, but the tag is the owner.
func (node *Node) IsUserOwned() bool {
return !node.IsTagged()
}
// HasTag reports if a node has a given tag.
// Currently, this function only handles tags set
// via CLI ("forced tags" and preauthkeys).
func (node *Node) HasTag(tag string) bool {
return slices.Contains(node.Tags(), tag)
return slices.Contains(node.Tags, tag)
}
func (node *Node) Tags() []string {
var tags []string
if node.AuthKey != nil {
tags = append(tags, node.AuthKey.Tags...)
// TypedUserID returns the UserID as a typed UserID type.
// Returns 0 if UserID is nil.
func (node *Node) TypedUserID() UserID {
if node.UserID == nil {
return 0
}
// TODO(kradalby): Figure out how tagging should work
// and hostinfo.requestedtags.
// Do this in other work.
// #2417
tags = append(tags, node.ForcedTags...)
sort.Strings(tags)
tags = slices.Compact(tags)
return tags
return UserID(*node.UserID)
}
func (node *Node) RequestTags() []string {
@ -389,8 +371,8 @@ func (node *Node) Proto() *v1.Node {
IpAddresses: node.IPsAsString(),
Name: node.Hostname,
GivenName: node.GivenName,
User: node.User.Proto(),
ForcedTags: node.ForcedTags,
User: nil, // Will be set below based on node type
ForcedTags: node.Tags,
Online: node.IsOnline != nil && *node.IsOnline,
// Only ApprovedRoutes and AvailableRoutes is set here. SubnetRoutes has
@ -404,6 +386,13 @@ func (node *Node) Proto() *v1.Node {
CreatedAt: timestamppb.New(node.CreatedAt),
}
// Set User field based on node ownership
// Note: User will be set to TaggedDevices in the gRPC layer (grpcv1.go)
// for proper MapResponse formatting
if node.User != nil {
nodeProto.User = node.User.Proto()
}
if node.AuthKey != nil {
nodeProto.PreAuthKey = node.AuthKey.Proto()
}
@ -701,8 +690,20 @@ func (nodes Nodes) DebugString() string {
func (node Node) DebugString() string {
var sb strings.Builder
fmt.Fprintf(&sb, "%s(%s):\n", node.Hostname, node.ID)
fmt.Fprintf(&sb, "\tUser: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
fmt.Fprintf(&sb, "\tTags: %v\n", node.Tags())
// Show ownership status
if node.IsTagged() {
fmt.Fprintf(&sb, "\tTagged: %v\n", node.Tags)
if node.User != nil {
fmt.Fprintf(&sb, "\tCreated by: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
}
} else if node.User != nil {
fmt.Fprintf(&sb, "\tUser-owned: %s (%d, %q)\n", node.User.Display(), node.User.ID, node.User.Username())
} else {
fmt.Fprintf(&sb, "\tOrphaned: no user or tags\n")
}
fmt.Fprintf(&sb, "\tIPs: %v\n", node.IPs())
fmt.Fprintf(&sb, "\tApprovedRoutes: %v\n", node.ApprovedRoutes)
fmt.Fprintf(&sb, "\tAnnouncedRoutes: %v\n", node.AnnouncedRoutes())
@ -714,8 +715,7 @@ func (node Node) DebugString() string {
}
func (v NodeView) UserView() UserView {
u := v.User()
return u.View()
return v.User()
}
func (v NodeView) IPs() []netip.Addr {
@ -790,13 +790,6 @@ func (v NodeView) RequestTagsSlice() views.Slice[string] {
return v.Hostinfo().RequestTags()
}
func (v NodeView) Tags() []string {
if !v.Valid() {
return nil
}
return v.ж.Tags()
}
// IsTagged reports if a device is tagged
// and therefore should not be treated as a
// user owned device.
@ -893,6 +886,32 @@ func (v NodeView) HasTag(tag string) bool {
return v.ж.HasTag(tag)
}
// TypedUserID returns the UserID as a typed UserID type.
// Returns 0 if UserID is nil or node is invalid.
func (v NodeView) TypedUserID() UserID {
if !v.Valid() {
return 0
}
return v.ж.TypedUserID()
}
// TailscaleUserID returns the user ID to use in Tailscale protocol.
// Tagged nodes always return TaggedDevices.ID, user-owned nodes return their actual UserID.
func (v NodeView) TailscaleUserID() tailcfg.UserID {
if !v.Valid() {
return 0
}
if v.IsTagged() {
//nolint:gosec // G115: TaggedDevices.ID is a constant that fits in int64
return tailcfg.UserID(int64(TaggedDevices.ID))
}
//nolint:gosec // G115: UserID values are within int64 range
return tailcfg.UserID(int64(v.UserID().Get()))
}
// Prefixes returns the node IPs as netip.Prefix.
func (v NodeView) Prefixes() []netip.Prefix {
if !v.Valid() {

View file

@ -0,0 +1,295 @@
package types
import (
"testing"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/stretchr/testify/assert"
"gorm.io/gorm"
"tailscale.com/types/ptr"
)
// TestNodeIsTagged tests the IsTagged() method for determining if a node is tagged.
func TestNodeIsTagged(t *testing.T) {
tests := []struct {
name string
node Node
want bool
}{
{
name: "node with tags - is tagged",
node: Node{
Tags: []string{"tag:server", "tag:prod"},
},
want: true,
},
{
name: "node with single tag - is tagged",
node: Node{
Tags: []string{"tag:web"},
},
want: true,
},
{
name: "node with no tags - not tagged",
node: Node{
Tags: []string{},
},
want: false,
},
{
name: "node with nil tags - not tagged",
node: Node{
Tags: nil,
},
want: false,
},
{
// Tags should be copied from AuthKey during registration, so a node
// with only AuthKey.Tags and no Tags would be invalid in practice.
// IsTagged() only checks node.Tags, not AuthKey.Tags.
name: "node registered with tagged authkey only - not tagged (tags should be copied)",
node: Node{
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
want: false,
},
{
name: "node with both tags and authkey tags - is tagged",
node: Node{
Tags: []string{"tag:server"},
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
want: true,
},
{
name: "node with user and no tags - not tagged",
node: Node{
UserID: ptr.To(uint(42)),
Tags: []string{},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.node.IsTagged()
assert.Equal(t, tt.want, got, "IsTagged() returned unexpected value")
})
}
}
// TestNodeViewIsTagged tests the IsTagged() method on NodeView.
func TestNodeViewIsTagged(t *testing.T) {
tests := []struct {
name string
node Node
want bool
}{
{
name: "tagged node via Tags field",
node: Node{
Tags: []string{"tag:server"},
},
want: true,
},
{
// Tags should be copied from AuthKey during registration, so a node
// with only AuthKey.Tags and no Tags would be invalid in practice.
name: "node with only AuthKey tags - not tagged (tags should be copied)",
node: Node{
AuthKey: &PreAuthKey{
Tags: []string{"tag:web"},
},
},
want: false, // IsTagged() only checks node.Tags
},
{
name: "user-owned node",
node: Node{
UserID: ptr.To(uint(1)),
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
view := tt.node.View()
got := view.IsTagged()
assert.Equal(t, tt.want, got, "NodeView.IsTagged() returned unexpected value")
})
}
}
// TestNodeHasTag tests the HasTag() method for checking specific tag membership.
func TestNodeHasTag(t *testing.T) {
tests := []struct {
name string
node Node
tag string
want bool
}{
{
name: "node has the tag",
node: Node{
Tags: []string{"tag:server", "tag:prod"},
},
tag: "tag:server",
want: true,
},
{
name: "node does not have the tag",
node: Node{
Tags: []string{"tag:server", "tag:prod"},
},
tag: "tag:web",
want: false,
},
{
// Tags should be copied from AuthKey during registration
// HasTag() only checks node.Tags, not AuthKey.Tags
name: "node has tag only in authkey - returns false",
node: Node{
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
tag: "tag:database",
want: false,
},
{
// node.Tags is what matters, not AuthKey.Tags
name: "node has tag in Tags but not in AuthKey",
node: Node{
Tags: []string{"tag:server"},
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
tag: "tag:server",
want: true,
},
{
name: "invalid tag format still returns false",
node: Node{
Tags: []string{"tag:server"},
},
tag: "invalid-tag",
want: false,
},
{
name: "empty tag returns false",
node: Node{
Tags: []string{"tag:server"},
},
tag: "",
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.node.HasTag(tt.tag)
assert.Equal(t, tt.want, got, "HasTag() returned unexpected value")
})
}
}
// TestNodeTagsImmutableAfterRegistration tests that tags can only be set during registration.
func TestNodeTagsImmutableAfterRegistration(t *testing.T) {
// Test that a node registered with tags keeps them
taggedNode := Node{
ID: 1,
Tags: []string{"tag:server"},
AuthKey: &PreAuthKey{
Tags: []string{"tag:server"},
},
RegisterMethod: util.RegisterMethodAuthKey,
}
// Node should be tagged
assert.True(t, taggedNode.IsTagged(), "Node registered with tags should be tagged")
// Node should have the tag
has := taggedNode.HasTag("tag:server")
assert.True(t, has, "Node should have the tag it was registered with")
// Test that a user-owned node is not tagged
userNode := Node{
ID: 2,
UserID: ptr.To(uint(42)),
Tags: []string{},
RegisterMethod: util.RegisterMethodOIDC,
}
assert.False(t, userNode.IsTagged(), "User-owned node should not be tagged")
}
// TestNodeOwnershipModel tests the tags-as-identity model.
func TestNodeOwnershipModel(t *testing.T) {
tests := []struct {
name string
node Node
wantIsTagged bool
description string
}{
{
name: "tagged node has tags, UserID is informational",
node: Node{
ID: 1,
UserID: ptr.To(uint(5)), // "created by" user 5
Tags: []string{"tag:server"},
},
wantIsTagged: true,
description: "Tagged nodes may have UserID set for tracking, but ownership is defined by tags",
},
{
name: "user-owned node has no tags",
node: Node{
ID: 2,
UserID: ptr.To(uint(5)),
Tags: []string{},
},
wantIsTagged: false,
description: "User-owned nodes are owned by the user, not by tags",
},
{
// Tags should be copied from AuthKey to Node during registration
// IsTagged() only checks node.Tags, not AuthKey.Tags
name: "node with only authkey tags - not tagged (tags should be copied)",
node: Node{
ID: 3,
UserID: ptr.To(uint(5)), // "created by" user 5
AuthKey: &PreAuthKey{
Tags: []string{"tag:database"},
},
},
wantIsTagged: false,
description: "IsTagged() only checks node.Tags; AuthKey.Tags should be copied during registration",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.node.IsTagged()
assert.Equal(t, tt.wantIsTagged, got, tt.description)
})
}
}
// TestUserTypedID tests the TypedID() helper method.
func TestUserTypedID(t *testing.T) {
user := User{
Model: gorm.Model{ID: 42},
}
typedID := user.TypedID()
assert.NotNil(t, typedID, "TypedID() should return non-nil pointer")
assert.Equal(t, UserID(42), *typedID, "TypedID() should return correct UserID value")
}

View file

@ -139,7 +139,7 @@ func TestNodeFQDN(t *testing.T) {
name: "no-dnsconfig-with-username",
node: Node{
GivenName: "test",
User: User{
User: &User{
Name: "user",
},
},
@ -150,7 +150,7 @@ func TestNodeFQDN(t *testing.T) {
name: "all-set",
node: Node{
GivenName: "test",
User: User{
User: &User{
Name: "user",
},
},
@ -160,7 +160,7 @@ func TestNodeFQDN(t *testing.T) {
{
name: "no-given-name",
node: Node{
User: User{
User: &User{
Name: "user",
},
},
@ -179,7 +179,7 @@ func TestNodeFQDN(t *testing.T) {
name: "no-dnsconfig",
node: Node{
GivenName: "test",
User: User{
User: &User{
Name: "user",
},
},

View file

@ -23,16 +23,19 @@ type PreAuthKey struct {
Prefix string
Hash []byte // bcrypt
UserID uint
User User `gorm:"constraint:OnDelete:SET NULL;"`
// For tagged keys: UserID tracks who created the key (informational)
// For user-owned keys: UserID tracks the node owner
// Can be nil for system-created tagged keys
UserID *uint
User *User `gorm:"constraint:OnDelete:SET NULL;"`
Reusable bool
Ephemeral bool `gorm:"default:false"`
Used bool `gorm:"default:false"`
// Tags are always applied to the node and is one of
// the sources of tags a node might have. They are copied
// from the PreAuthKey when the node logs in the first time,
// and ignored after.
// Tags to assign to nodes registered with this key.
// Tags are copied to the node during registration.
// If non-empty, this creates tagged nodes (not user-owned).
Tags []string `gorm:"serializer:json"`
CreatedAt *time.Time
@ -48,19 +51,23 @@ type PreAuthKeyNew struct {
Tags []string
Expiration *time.Time
CreatedAt *time.Time
User User
User *User // Can be nil for system-created tagged keys
}
func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
Id: key.ID,
Key: key.Key,
User: key.User.Proto(),
User: nil, // Will be set below if not nil
Reusable: key.Reusable,
Ephemeral: key.Ephemeral,
AclTags: key.Tags,
}
if key.User != nil {
protoKey.User = key.User.Proto()
}
if key.Expiration != nil {
protoKey.Expiration = timestamppb.New(*key.Expiration)
}
@ -74,7 +81,7 @@ func (key *PreAuthKeyNew) Proto() *v1.PreAuthKey {
func (key *PreAuthKey) Proto() *v1.PreAuthKey {
protoKey := v1.PreAuthKey{
User: key.User.Proto(),
User: nil, // Will be set below if not nil
Id: key.ID,
Ephemeral: key.Ephemeral,
Reusable: key.Reusable,
@ -82,6 +89,10 @@ func (key *PreAuthKey) Proto() *v1.PreAuthKey {
AclTags: key.Tags,
}
if key.User != nil {
protoKey.User = key.User.Proto()
}
// For new keys (with prefix/hash), show the prefix so users can identify the key
// For legacy keys (with plaintext key), show the full key for backwards compatibility
if key.Prefix != "" {
@ -139,3 +150,9 @@ func (pak *PreAuthKey) Validate() error {
return nil
}
// IsTagged returns true if this PreAuthKey creates tagged nodes.
// When a PreAuthKey has tags, nodes registered with it will be tagged nodes.
func (pak *PreAuthKey) IsTagged() bool {
return len(pak.Tags) > 0
}

View file

@ -54,7 +54,13 @@ func (src *Node) Clone() *Node {
if dst.IPv6 != nil {
dst.IPv6 = ptr.To(*src.IPv6)
}
dst.ForcedTags = append(src.ForcedTags[:0:0], src.ForcedTags...)
if dst.UserID != nil {
dst.UserID = ptr.To(*src.UserID)
}
if dst.User != nil {
dst.User = ptr.To(*src.User)
}
dst.Tags = append(src.Tags[:0:0], src.Tags...)
if dst.AuthKeyID != nil {
dst.AuthKeyID = ptr.To(*src.AuthKeyID)
}
@ -87,10 +93,10 @@ var _NodeCloneNeedsRegeneration = Node(struct {
IPv6 *netip.Addr
Hostname string
GivenName string
UserID uint
User User
UserID *uint
User *User
RegisterMethod string
ForcedTags []string
Tags []string
AuthKeyID *uint64
AuthKey *PreAuthKey
Expiry *time.Time
@ -111,6 +117,12 @@ func (src *PreAuthKey) Clone() *PreAuthKey {
dst := new(PreAuthKey)
*dst = *src
dst.Hash = append(src.Hash[:0:0], src.Hash...)
if dst.UserID != nil {
dst.UserID = ptr.To(*src.UserID)
}
if dst.User != nil {
dst.User = ptr.To(*src.User)
}
dst.Tags = append(src.Tags[:0:0], src.Tags...)
if dst.CreatedAt != nil {
dst.CreatedAt = ptr.To(*src.CreatedAt)
@ -127,8 +139,8 @@ var _PreAuthKeyCloneNeedsRegeneration = PreAuthKey(struct {
Key string
Prefix string
Hash []byte
UserID uint
User User
UserID *uint
User *User
Reusable bool
Ephemeral bool
Used bool

View file

@ -139,12 +139,13 @@ func (v NodeView) IPv4() views.ValuePointer[netip.Addr] { return views.ValuePo
func (v NodeView) IPv6() views.ValuePointer[netip.Addr] { return views.ValuePointerOf(v.ж.IPv6) }
func (v NodeView) Hostname() string { return v.ж.Hostname }
func (v NodeView) GivenName() string { return v.ж.GivenName }
func (v NodeView) UserID() uint { return v.ж.UserID }
func (v NodeView) User() User { return v.ж.User }
func (v NodeView) Hostname() string { return v.ж.Hostname }
func (v NodeView) GivenName() string { return v.ж.GivenName }
func (v NodeView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) }
func (v NodeView) User() UserView { return v.ж.User.View() }
func (v NodeView) RegisterMethod() string { return v.ж.RegisterMethod }
func (v NodeView) ForcedTags() views.Slice[string] { return views.SliceOf(v.ж.ForcedTags) }
func (v NodeView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v NodeView) AuthKeyID() views.ValuePointer[uint64] { return views.ValuePointerOf(v.ж.AuthKeyID) }
func (v NodeView) AuthKey() PreAuthKeyView { return v.ж.AuthKey.View() }
@ -179,10 +180,10 @@ var _NodeViewNeedsRegeneration = Node(struct {
IPv6 *netip.Addr
Hostname string
GivenName string
UserID uint
User User
UserID *uint
User *User
RegisterMethod string
ForcedTags []string
Tags []string
AuthKeyID *uint64
AuthKey *PreAuthKey
Expiry *time.Time
@ -239,16 +240,17 @@ func (v *PreAuthKeyView) UnmarshalJSON(b []byte) error {
return nil
}
func (v PreAuthKeyView) ID() uint64 { return v.ж.ID }
func (v PreAuthKeyView) Key() string { return v.ж.Key }
func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix }
func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) }
func (v PreAuthKeyView) UserID() uint { return v.ж.UserID }
func (v PreAuthKeyView) User() User { return v.ж.User }
func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable }
func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral }
func (v PreAuthKeyView) Used() bool { return v.ж.Used }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v PreAuthKeyView) ID() uint64 { return v.ж.ID }
func (v PreAuthKeyView) Key() string { return v.ж.Key }
func (v PreAuthKeyView) Prefix() string { return v.ж.Prefix }
func (v PreAuthKeyView) Hash() views.ByteSlice[[]byte] { return views.ByteSliceOf(v.ж.Hash) }
func (v PreAuthKeyView) UserID() views.ValuePointer[uint] { return views.ValuePointerOf(v.ж.UserID) }
func (v PreAuthKeyView) User() UserView { return v.ж.User.View() }
func (v PreAuthKeyView) Reusable() bool { return v.ж.Reusable }
func (v PreAuthKeyView) Ephemeral() bool { return v.ж.Ephemeral }
func (v PreAuthKeyView) Used() bool { return v.ж.Used }
func (v PreAuthKeyView) Tags() views.Slice[string] { return views.SliceOf(v.ж.Tags) }
func (v PreAuthKeyView) CreatedAt() views.ValuePointer[time.Time] {
return views.ValuePointerOf(v.ж.CreatedAt)
}
@ -263,8 +265,8 @@ var _PreAuthKeyViewNeedsRegeneration = PreAuthKey(struct {
Key string
Prefix string
Hash []byte
UserID uint
User User
UserID *uint
User *User
Reusable bool
Ephemeral bool
Used bool

View file

@ -22,6 +22,21 @@ type UserID uint64
type Users []User
const (
// TaggedDevicesUserID is the special user ID for tagged devices.
// This ID is used when rendering tagged nodes in the Tailscale protocol.
TaggedDevicesUserID = 2147455555
)
// TaggedDevices is a special user used in MapResponse for tagged nodes.
// Tagged nodes don't belong to a real user - the tag is their identity.
// This special user ID is used when rendering tagged nodes in the Tailscale protocol.
var TaggedDevices = User{
Model: gorm.Model{ID: TaggedDevicesUserID},
Name: "tagged-devices",
DisplayName: "Tagged Devices",
}
func (u Users) String() string {
var sb strings.Builder
sb.WriteString("[ ")
@ -77,6 +92,13 @@ func (u *User) StringID() string {
return strconv.FormatUint(uint64(u.ID), 10)
}
// TypedID returns a pointer to the user's ID as a UserID type.
// This is a convenience method to avoid ugly casting like ptr.To(types.UserID(user.ID)).
func (u *User) TypedID() *UserID {
uid := UserID(u.ID)
return &uid
}
// Username is the main way to get the username of a user,
// it will return the email if it exists, the name if it exists,
// the OIDCIdentifier if it exists, and the ID if nothing else exists.
@ -117,6 +139,13 @@ func (u UserView) TailscaleUser() tailcfg.User {
return u.ж.TailscaleUser()
}
// ID returns the user's ID.
// This is a custom accessor because gorm.Model.ID is embedded
// and the viewer generator doesn't always produce it.
func (u UserView) ID() uint {
return u.ж.ID
}
func (u *User) TailscaleLogin() tailcfg.Login {
return tailcfg.Login{
ID: tailcfg.LoginID(u.ID),

View file

@ -4,6 +4,7 @@ import (
"cmp"
"encoding/json"
"fmt"
"slices"
"strconv"
"strings"
"testing"
@ -18,7 +19,6 @@ import (
"github.com/juanfont/headscale/integration/tsic"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"tailscale.com/tailcfg"
)
@ -643,7 +643,9 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
status, err := client.Status()
assert.NoError(ct, err)
assert.Equal(ct, "Running", status.BackendState, "Expected node to be logged in, backend state: %s", status.BackendState)
assert.Equal(ct, "userid:2", status.Self.UserID.String(), "Expected node to be logged in as userid:2")
// With tags-as-identity model, tagged nodes show as TaggedDevices user (2147455555)
// The PreAuthKey was created with tags, so the node is tagged
assert.Equal(ct, "userid:2147455555", status.Self.UserID.String(), "Expected node to be logged in as tagged-devices user")
}, 30*time.Second, 2*time.Second)
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
@ -652,7 +654,8 @@ func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
assert.NoError(ct, err)
assert.Len(ct, listNodes, 2, "Should have 2 nodes after re-login")
assert.Equal(ct, user1, listNodes[0].GetUser().GetName(), "First node should belong to user1")
assert.Equal(ct, user2, listNodes[1].GetUser().GetName(), "Second node should belong to user2")
// Second node is tagged (created with tagged PreAuthKey), so it shows as "tagged-devices"
assert.Equal(ct, "tagged-devices", listNodes[1].GetUser().GetName(), "Second node should be tagged-devices")
}, 20*time.Second, 1*time.Second)
}
@ -847,118 +850,455 @@ func TestNodeTagCommand(t *testing.T) {
headscale, err := scenario.Headscale()
require.NoError(t, err)
regIDs := []string{
types.MustRegistrationID().String(),
types.MustRegistrationID().String(),
}
nodes := make([]*v1.Node, len(regIDs))
// Test 1: Verify that tags require authorization via ACL policy
// The tags-as-identity model allows conversion from user-owned to tagged, but only
// if the tag is authorized via tagOwners in the ACL policy.
regID := types.MustRegistrationID().String()
_, err = headscale.Execute(
[]string{
"headscale",
"debug",
"create-node",
"--name",
"user-owned-node",
"--user",
"user1",
"--key",
regID,
"--output",
"json",
},
)
assert.NoError(t, err)
for index, regID := range regIDs {
_, err := headscale.Execute(
var userOwnedNode v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"debug",
"create-node",
"--name",
fmt.Sprintf("node-%d", index+1),
"nodes",
"--user",
"user1",
"register",
"--key",
regID,
"--output",
"json",
},
)
assert.NoError(t, err)
var node v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"--user",
"user1",
"register",
"--key",
regID,
"--output",
"json",
},
&node,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Waiting for node registration")
nodes[index] = &node
}
assert.EventuallyWithT(t, func(ct *assert.CollectT) {
assert.Len(ct, nodes, len(regIDs), "Should have correct number of nodes after CLI operations")
}, 15*time.Second, 1*time.Second)
var node v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"tag",
"-i", "1",
"-t", "tag:test",
"--output", "json",
},
&node,
&userOwnedNode,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Waiting for node tag command")
}, 10*time.Second, 200*time.Millisecond, "Waiting for user-owned node registration")
assert.Equal(t, []string{"tag:test"}, node.GetForcedTags())
// Verify node is user-owned (no tags)
assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should not have tags")
assert.Empty(t, userOwnedNode.GetForcedTags(), "User-owned node should not have forced tags")
// Attempt to set tags on user-owned node should FAIL because there's no ACL policy
// authorizing the tag. The tags-as-identity model allows conversion from user-owned
// to tagged, but only if the tag is authorized via tagOwners in the ACL policy.
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"tag",
"-i", "2",
"-t", "wrong-tag",
"-i", strconv.FormatUint(userOwnedNode.GetId(), 10),
"-t", "tag:test",
"--output", "json",
},
)
assert.ErrorContains(t, err, "tag must start with the string 'tag:'")
require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail")
// Test 2: Verify tag format validation
// Create a PreAuthKey with tags to create a tagged node
// Get the user ID from the node
userID := userOwnedNode.GetUser().GetId()
var preAuthKey v1.PreAuthKey
// Test list all nodes after added seconds
resultMachines := make([]*v1.Node, len(regIDs))
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"preauthkeys",
"--user", strconv.FormatUint(userID, 10),
"create",
"--reusable",
"--tags", "tag:integration-test",
"--output", "json",
},
&resultMachines,
&preAuthKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Waiting for nodes list after tagging")
found := false
for _, node := range resultMachines {
if node.GetForcedTags() != nil {
for _, tag := range node.GetForcedTags() {
if tag == "tag:test" {
found = true
}
}
}, 10*time.Second, 200*time.Millisecond, "Creating PreAuthKey with tags")
// Verify PreAuthKey has tags
assert.Contains(t, preAuthKey.GetAclTags(), "tag:integration-test", "PreAuthKey should have tags")
// Test 3: Verify invalid tag format is rejected
_, err = headscale.Execute(
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(userID, 10),
"create",
"--tags", "wrong-tag", // Missing "tag:" prefix
"--output", "json",
},
)
assert.ErrorContains(t, err, "tag must start with the string 'tag:'", "Invalid tag format should be rejected")
}
func TestTaggedNodeRegistration(t *testing.T) {
IntegrationSkip(t)
// ACL policy that authorizes the tags used in tagged PreAuthKeys
// user1 and user2 can assign these tags when creating PreAuthKeys
policy := &policyv2.Policy{
TagOwners: policyv2.TagOwners{
"tag:server": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
"tag:prod": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
"tag:forbidden": policyv2.Owners{usernameOwner("user1@"), usernameOwner("user2@")},
},
ACLs: []policyv2.ACL{
{
Action: "accept",
Sources: []policyv2.Alias{policyv2.Wildcard},
Destinations: []policyv2.AliasWithPorts{{Alias: policyv2.Wildcard, Ports: []tailcfg.PortRange{tailcfg.PortRangeAny}}},
},
},
}
spec := ScenarioSpec{
Users: []string{"user1", "user2"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv(
[]tsic.Option{},
hsic.WithACLPolicy(policy),
hsic.WithTestName("tagged-reg"),
)
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Get users (they were already created by ScenarioSpec)
users, err := headscale.ListUsers()
require.NoError(t, err)
require.Len(t, users, 2, "Should have 2 users")
var user1, user2 *v1.User
for _, u := range users {
if u.GetName() == "user1" {
user1 = u
} else if u.GetName() == "user2" {
user2 = u
}
}
assert.True(
t,
found,
"should find a node with the tag 'tag:test' in the list of nodes",
require.NotNil(t, user1, "Should find user1")
require.NotNil(t, user2, "Should find user2")
// Test 1: Create a PreAuthKey with tags
var taggedKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(user1.GetId(), 10),
"create",
"--reusable",
"--tags", "tag:server,tag:prod",
"--output", "json",
},
&taggedKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating tagged PreAuthKey")
// Verify PreAuthKey has both tags
assert.Contains(t, taggedKey.GetAclTags(), "tag:server", "PreAuthKey should have tag:server")
assert.Contains(t, taggedKey.GetAclTags(), "tag:prod", "PreAuthKey should have tag:prod")
assert.Len(t, taggedKey.GetAclTags(), 2, "PreAuthKey should have exactly 2 tags")
// Test 2: Register a node using the tagged PreAuthKey
err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0]))
require.NoError(t, err)
err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
// Wait for the node to be registered
var registeredNode *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node")
// Find the tagged node - it will have user "tagged-devices" per tags-as-identity model
for _, node := range nodes {
if node.GetUser().GetName() == "tagged-devices" && len(node.GetValidTags()) > 0 {
registeredNode = node
break
}
}
assert.NotNil(c, registeredNode, "Should find a tagged node")
}, 30*time.Second, 500*time.Millisecond, "Waiting for tagged node registration")
// Test 3: Verify the registered node has the tags from the PreAuthKey
assert.Contains(t, registeredNode.GetValidTags(), "tag:server", "Node should have tag:server")
assert.Contains(t, registeredNode.GetValidTags(), "tag:prod", "Node should have tag:prod")
assert.Len(t, registeredNode.GetValidTags(), 2, "Node should have exactly 2 tags")
// Test 4: Verify the node shows as TaggedDevices user (tags-as-identity model)
// Tagged nodes always show as "tagged-devices" in API responses, even though
// internally UserID may be set for "created by" tracking
assert.Equal(t, "tagged-devices", registeredNode.GetUser().GetName(), "Tagged node should show as tagged-devices user")
// Test 5: Verify the node is identified as tagged
assert.NotEmpty(t, registeredNode.GetValidTags(), "Tagged node should have tags")
// Test 6: Verify tag modification on tagged nodes
// NOTE: Changing tags requires complex ACL authorization where the node's IP
// must be authorized for the new tags via tagOwners. For simplicity, we skip
// this test and instead verify that tags cannot be arbitrarily changed without
// proper ACL authorization.
//
// This is expected behavior - tag changes must be authorized by ACL policy.
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"tag",
"-i", strconv.FormatUint(registeredNode.GetId(), 10),
"-t", "tag:unauthorized",
"--output", "json",
},
)
// This SHOULD fail because tag:unauthorized is not in our ACL policy
require.ErrorContains(t, err, "invalid or unauthorized tags", "Unauthorized tag should be rejected")
// Test 7: Create a user-owned node for comparison
var userOwnedKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(user2.GetId(), 10),
"create",
"--reusable",
"--output", "json",
},
&userOwnedKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating user-owned PreAuthKey")
// Verify this PreAuthKey has NO tags
assert.Empty(t, userOwnedKey.GetAclTags(), "User-owned PreAuthKey should have no tags")
err = scenario.CreateTailscaleNodesInUser("user2", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0]))
require.NoError(t, err)
err = scenario.RunTailscaleUp("user2", headscale.GetEndpoint(), userOwnedKey.GetKey())
require.NoError(t, err)
// Wait for the user-owned node to be registered
var userOwnedNode *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.GreaterOrEqual(c, len(nodes), 2, "Should have at least 2 nodes")
// Find the node registered with user2
for _, node := range nodes {
if node.GetUser().GetName() == "user2" {
userOwnedNode = node
break
}
}
assert.NotNil(c, userOwnedNode, "Should find a node for user2")
}, 30*time.Second, 500*time.Millisecond, "Waiting for user-owned node registration")
// Test 8: Verify user-owned node has NO tags
assert.Empty(t, userOwnedNode.GetValidTags(), "User-owned node should have no tags")
assert.NotZero(t, userOwnedNode.GetUser().GetId(), "User-owned node should have UserID")
// Test 9: Verify attempting to set UNAUTHORIZED tags on user-owned node fails
// Note: Under tags-as-identity model, user-owned nodes CAN be converted to tagged nodes
// if the tags are authorized. We use an unauthorized tag to test rejection.
_, err = headscale.Execute(
[]string{
"headscale",
"nodes",
"tag",
"-i", strconv.FormatUint(userOwnedNode.GetId(), 10),
"-t", "tag:not-in-policy",
"--output", "json",
},
)
require.ErrorContains(t, err, "invalid or unauthorized tags", "Setting unauthorized tags should fail")
// Test 10: Verify basic connectivity - wait for sync
err = scenario.WaitForTailscaleSync()
require.NoError(t, err, "Clients should be able to sync")
}
// TestTagPersistenceAcrossRestart validates that tags persist across container
// restarts and that re-authentication doesn't re-apply tags from PreAuthKey.
// This is a regression test for issue #2830.
func TestTagPersistenceAcrossRestart(t *testing.T) {
IntegrationSkip(t)
spec := ScenarioSpec{
Users: []string{"user1"},
}
scenario, err := NewScenario(spec)
require.NoError(t, err)
defer scenario.ShutdownAssertNoPanics(t)
err = scenario.CreateHeadscaleEnv([]tsic.Option{}, hsic.WithTestName("tag-persist"))
require.NoError(t, err)
headscale, err := scenario.Headscale()
require.NoError(t, err)
// Get user
users, err := headscale.ListUsers()
require.NoError(t, err)
require.Len(t, users, 1)
user1 := users[0]
// Create a reusable PreAuthKey with tags
var taggedKey v1.PreAuthKey
assert.EventuallyWithT(t, func(c *assert.CollectT) {
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user", strconv.FormatUint(user1.GetId(), 10),
"create",
"--reusable", // Critical: key must be reusable for container restart
"--tags", "tag:server,tag:prod",
"--output", "json",
},
&taggedKey,
)
assert.NoError(c, err)
}, 10*time.Second, 200*time.Millisecond, "Creating reusable tagged PreAuthKey")
require.True(t, taggedKey.GetReusable(), "PreAuthKey must be reusable for restart scenario")
require.Contains(t, taggedKey.GetAclTags(), "tag:server")
require.Contains(t, taggedKey.GetAclTags(), "tag:prod")
// Register initial node with tagged PreAuthKey
err = scenario.CreateTailscaleNodesInUser("user1", "unstable", 1, tsic.WithNetwork(scenario.Networks()[0]))
require.NoError(t, err)
err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
// Wait for node registration and get initial node state
var initialNode *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
assert.GreaterOrEqual(c, len(nodes), 1, "Should have at least 1 node")
for _, node := range nodes {
if node.GetUser().GetId() == user1.GetId() || node.GetUser().GetName() == "tagged-devices" {
initialNode = node
break
}
}
assert.NotNil(c, initialNode, "Should find the registered node")
}, 30*time.Second, 500*time.Millisecond, "Waiting for initial node registration")
// Verify initial tags
require.Contains(t, initialNode.GetValidTags(), "tag:server", "Initial node should have tag:server")
require.Contains(t, initialNode.GetValidTags(), "tag:prod", "Initial node should have tag:prod")
require.Len(t, initialNode.GetValidTags(), 2, "Initial node should have exactly 2 tags")
initialNodeID := initialNode.GetId()
t.Logf("Initial node registered with ID %d and tags %v", initialNodeID, initialNode.GetValidTags())
// Simulate container restart by shutting down and restarting Tailscale client
allClients, err := scenario.ListTailscaleClients()
require.NoError(t, err)
require.Len(t, allClients, 1, "Should have exactly 1 client")
client := allClients[0]
// Stop the client (simulates container stop)
err = client.Down()
require.NoError(t, err)
// Wait a bit to ensure the client is fully stopped
time.Sleep(2 * time.Second)
// Restart the client with the SAME PreAuthKey (container restart scenario)
// This simulates what happens when a Docker container restarts with a reusable PreAuthKey
err = scenario.RunTailscaleUp("user1", headscale.GetEndpoint(), taggedKey.GetKey())
require.NoError(t, err)
// Wait for re-authentication
var nodeAfterRestart *v1.Node
assert.EventuallyWithT(t, func(c *assert.CollectT) {
nodes, err := headscale.ListNodes()
assert.NoError(c, err)
for _, node := range nodes {
if node.GetId() == initialNodeID {
nodeAfterRestart = node
break
}
}
assert.NotNil(c, nodeAfterRestart, "Should find the same node after restart")
}, 30*time.Second, 500*time.Millisecond, "Waiting for node re-authentication")
// CRITICAL ASSERTION: Tags should NOT be re-applied from PreAuthKey
// Tags are only applied during INITIAL authentication, not re-authentication
// The node should keep its existing tags (which happen to be the same in this case)
assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:server", "Node should still have tag:server after restart")
assert.Contains(t, nodeAfterRestart.GetValidTags(), "tag:prod", "Node should still have tag:prod after restart")
assert.Len(t, nodeAfterRestart.GetValidTags(), 2, "Node should still have exactly 2 tags after restart")
// Verify it's the SAME node (same ID), not a new registration
assert.Equal(t, initialNodeID, nodeAfterRestart.GetId(), "Should be the same node, not a new registration")
// Verify node count hasn't increased (no duplicate nodes)
finalNodes, err := headscale.ListNodes()
require.NoError(t, err)
assert.Len(t, finalNodes, 1, "Should still have exactly 1 node (no duplicates from restart)")
t.Logf("Container restart validation complete - node %d maintained tags across restart", initialNodeID)
}
func TestNodeAdvertiseTagCommand(t *testing.T) {