From 3bf01d4cee572b5387ab0caba4a88f2057ce7b1c Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 21 Jan 2026 10:51:56 +0000 Subject: [PATCH] db: use PolicyManager for RequestTags migration Refactor the RequestTags migration (202601121700-migrate-hostinfo-request-tags) to use PolicyManager.NodeCanHaveTag() instead of reimplementing tag validation. Changes: - NewHeadscaleDatabase now accepts *types.Config to allow migrations access to policy configuration - Add loadPolicyBytes helper to load policy from file or DB based on config - Add standalone GetPolicy(tx *gorm.DB) for use during migrations - Replace custom tag validation logic with PolicyManager Benefits: - Full HuJSON parsing support (not just JSON) - Proper group expansion via PolicyManager - Support for nested tags and autogroups - Works with both file and database policy modes - Single source of truth for tag validation Updates #3006 --- cmd/headscale/cli/policy.go | 6 +- hscontrol/db/db.go | 214 +++++++++++++------------------ hscontrol/db/db_test.go | 28 ++-- hscontrol/db/policy.go | 12 +- hscontrol/db/suite_test.go | 38 +++--- hscontrol/mapper/batcher_test.go | 3 +- hscontrol/state/state.go | 3 +- 7 files changed, 144 insertions(+), 160 deletions(-) diff --git a/cmd/headscale/cli/policy.go b/cmd/headscale/cli/policy.go index f99d5390..2aaebcfa 100644 --- a/cmd/headscale/cli/policy.go +++ b/cmd/headscale/cli/policy.go @@ -69,8 +69,7 @@ var getPolicy = &cobra.Command{ } d, err := db.NewHeadscaleDatabase( - cfg.Database, - cfg.BaseDomain, + cfg, nil, ) if err != nil { @@ -145,8 +144,7 @@ var setPolicy = &cobra.Command{ } d, err := db.NewHeadscaleDatabase( - cfg.Database, - cfg.BaseDomain, + cfg, nil, ) if err != nil { diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 1e456011..68f32e74 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -7,15 +7,16 @@ import ( "errors" "fmt" "net/netip" + "os" "path/filepath" "slices" "strconv" - "strings" "time" "github.com/glebarez/sqlite" "github.com/go-gormigrate/gormigrate/v2" "github.com/juanfont/headscale/hscontrol/db/sqliteconfig" + "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/rs/zerolog/log" @@ -54,20 +55,51 @@ type KV struct { type HSDatabase struct { DB *gorm.DB - cfg *types.DatabaseConfig + cfg *types.Config regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] - - baseDomain string } -// TODO(kradalby): assemble this struct from toptions or something typed -// rather than arguments. +// loadPolicyBytes loads policy from file or database based on configuration. +// This is used during migrations when HSDatabase is not yet fully initialized. +func loadPolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) { + switch cfg.Policy.Mode { + case types.PolicyModeFile: + if cfg.Policy.Path == "" { + return nil, nil + } + + absPath := util.AbsolutePathFromConfigPath(cfg.Policy.Path) + + return os.ReadFile(absPath) + + case types.PolicyModeDB: + p, err := GetPolicy(tx) + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err + } + + if p.Data == "" { + return nil, nil + } + + return []byte(p.Data), nil + + default: + return nil, nil + } +} + +// NewHeadscaleDatabase creates a new database connection and runs migrations. +// It accepts the full configuration to allow migrations access to policy settings. func NewHeadscaleDatabase( - cfg types.DatabaseConfig, - baseDomain string, + cfg *types.Config, regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], ) (*HSDatabase, error) { - dbConn, err := openDB(cfg) + dbConn, err := openDB(cfg.Database) if err != nil { return nil, err } @@ -254,7 +286,7 @@ AND auth_key_id NOT IN ( ID: "202507021200", Migrate: func(tx *gorm.DB) error { // Only run on SQLite - if cfg.Type != types.DatabaseSqlite { + if cfg.Database.Type != types.DatabaseSqlite { log.Info().Msg("Skipping schema migration on non-SQLite database") return nil } @@ -602,119 +634,55 @@ AND auth_key_id NOT IN ( // Fixes: https://github.com/juanfont/headscale/issues/3006 ID: "202601121700-migrate-hostinfo-request-tags", Migrate: func(tx *gorm.DB) error { - // 1. Load policy from database - var policyData string - err := tx.Raw("SELECT data FROM policies ORDER BY id DESC LIMIT 1").Scan(&policyData).Error - if err != nil || policyData == "" { - log.Info().Msg("No policy found in database, skipping RequestTags migration (tags will be validated on node reconnect)") - return nil - } - - // 2. Parse tagOwners and groups from policy - type migrationPolicy struct { - TagOwners map[string][]string `json:"tagOwners"` - Groups map[string][]string `json:"groups"` - } - var pol migrationPolicy - if err := json.Unmarshal([]byte(policyData), &pol); err != nil { - log.Warn().Err(err).Msg("Failed to parse policy JSON, skipping RequestTags migration (tags will be validated on node reconnect)") - return nil - } - - if len(pol.TagOwners) == 0 { - log.Info().Msg("No tagOwners defined in policy, skipping RequestTags migration") - return nil - } - - // Helper function to check if a user can have a tag - canUserHaveTag := func(username string, tag string) bool { - owners, exists := pol.TagOwners[tag] - if !exists { - return false // Tag not defined in policy - } - - for _, owner := range owners { - // Direct username match - if owner == username { - return true - } - - // Group expansion - if strings.HasPrefix(owner, "group:") { - if groupMembers, ok := pol.Groups[owner]; ok { - if slices.Contains(groupMembers, username) { - return true - } - } - } - } - return false - } - - // 3. Query nodes with user info - type nodeRow struct { - ID uint64 - HostInfo string - Tags string - UserID *uint64 - Username *string - } - var nodes []nodeRow - err = tx.Raw(` - SELECT n.id, n.host_info, n.tags, n.user_id, u.name as username - FROM nodes n - LEFT JOIN users u ON n.user_id = u.id - WHERE n.host_info IS NOT NULL AND n.host_info != '' AND n.host_info != '{}' - `).Scan(&nodes).Error + // 1. Load policy from file or database based on configuration + policyData, err := loadPolicyBytes(tx, cfg) if err != nil { - return fmt.Errorf("querying nodes for RequestTags migration: %w", err) + log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)") + return nil + } + + if len(policyData) == 0 { + log.Info().Msg("No policy found, skipping RequestTags migration (tags will be validated on node reconnect)") + return nil + } + + // 2. Load users and nodes to create PolicyManager + users, err := ListUsers(tx) + if err != nil { + return fmt.Errorf("loading users for RequestTags migration: %w", err) + } + + nodes, err := ListNodes(tx) + if err != nil { + return fmt.Errorf("loading nodes for RequestTags migration: %w", err) + } + + // 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.) + polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice()) + if err != nil { + log.Warn().Err(err).Msg("Failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)") + return nil } // 4. Process each node for _, node := range nodes { - // Parse host_info JSON to extract RequestTags - var hostInfo struct { - RequestTags []string `json:"RequestTags"` - } - if err := json.Unmarshal([]byte(node.HostInfo), &hostInfo); err != nil { - log.Trace(). - Uint64("node.id", node.ID). - Err(err). - Msg("Skipping node with invalid host_info JSON during RequestTags migration") + if node.Hostinfo == nil { continue } - // Skip if no RequestTags in host_info - if len(hostInfo.RequestTags) == 0 { + requestTags := node.Hostinfo.RequestTags + if len(requestTags) == 0 { continue } - // Skip if no username (can't validate) - if node.Username == nil || *node.Username == "" { - log.Debug(). - Uint64("node.id", node.ID). - Strs("request_tags", hostInfo.RequestTags). - Msg("Skipping node without username during RequestTags migration") - continue - } + existingTags := node.Tags - // Parse existing tags from the tags column - var existingTags []string - if node.Tags != "" && node.Tags != "null" { - if err := json.Unmarshal([]byte(node.Tags), &existingTags); err != nil { - log.Trace(). - Uint64("node.id", node.ID). - Err(err). - Msg("Skipping node with invalid tags JSON during RequestTags migration") - continue - } - } + var validatedTags, rejectedTags []string - // Validate and merge RequestTags - var validatedTags []string - var rejectedTags []string - for _, tag := range hostInfo.RequestTags { - if canUserHaveTag(*node.Username, tag) { + nodeView := node.View() + + for _, tag := range requestTags { + if polMan.NodeCanHaveTag(nodeView, tag) { if !slices.Contains(existingTags, tag) { validatedTags = append(validatedTags, tag) } @@ -723,37 +691,34 @@ AND auth_key_id NOT IN ( } } - // Skip if no validated tags to add if len(validatedTags) == 0 { if len(rejectedTags) > 0 { log.Debug(). - Uint64("node.id", node.ID). - Str("username", *node.Username). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). Strs("rejected_tags", rejectedTags). - Msg("RequestTags rejected during migration (user not authorized)") + Msg("RequestTags rejected during migration (not authorized)") } + continue } - // Merge validated tags with existing tags mergedTags := append(existingTags, validatedTags...) slices.Sort(mergedTags) mergedTags = slices.Compact(mergedTags) - // Serialize back to JSON tagsJSON, err := json.Marshal(mergedTags) if err != nil { return fmt.Errorf("serializing merged tags for node %d: %w", node.ID, err) } - // Update the tags column if err := tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error; err != nil { return fmt.Errorf("updating tags for node %d: %w", node.ID, err) } log.Info(). - Uint64("node.id", node.ID). - Str("username", *node.Username). + Uint64("node.id", uint64(node.ID)). + Str("node.name", node.Hostname). Strs("validated_tags", validatedTags). Strs("rejected_tags", rejectedTags). Strs("existing_tags", existingTags). @@ -821,7 +786,8 @@ AND auth_key_id NOT IN ( return nil }) - if err := runMigrations(cfg, dbConn, migrations); err != nil { + err = runMigrations(cfg.Database, dbConn, migrations) + if err != nil { return nil, fmt.Errorf("migration failed: %w", err) } @@ -829,7 +795,7 @@ AND auth_key_id NOT IN ( // This is currently only done on sqlite as squibble does not // support Postgres and we use our sqlite schema as our source of // truth. - if cfg.Type == types.DatabaseSqlite { + if cfg.Database.Type == types.DatabaseSqlite { sqlConn, err := dbConn.DB() if err != nil { return nil, fmt.Errorf("getting DB from gorm: %w", err) @@ -861,10 +827,8 @@ AND auth_key_id NOT IN ( db := HSDatabase{ DB: dbConn, - cfg: &cfg, + cfg: cfg, regCache: regCache, - - baseDomain: baseDomain, } return &db, err @@ -1107,7 +1071,7 @@ func (hsdb *HSDatabase) Close() error { return err } - if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog { + if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog { db.Exec("VACUUM") } diff --git a/hscontrol/db/db_test.go b/hscontrol/db/db_test.go index 9d534269..89cdcc6c 100644 --- a/hscontrol/db/db_test.go +++ b/hscontrol/db/db_test.go @@ -288,13 +288,17 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase { } db, err := NewHeadscaleDatabase( - types.DatabaseConfig{ - Type: "sqlite3", - Sqlite: types.SqliteConfig{ - Path: dbPath, + &types.Config{ + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, }, }, - "", emptyCache(), ) if err != nil { @@ -343,13 +347,17 @@ func TestSQLiteAllTestdataMigrations(t *testing.T) { require.NoError(t, err) _, err = NewHeadscaleDatabase( - types.DatabaseConfig{ - Type: "sqlite3", - Sqlite: types.SqliteConfig{ - Path: dbPath, + &types.Config{ + Database: types.DatabaseConfig{ + Type: "sqlite3", + Sqlite: types.SqliteConfig{ + Path: dbPath, + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, }, }, - "", emptyCache(), ) require.NoError(t, err) diff --git a/hscontrol/db/policy.go b/hscontrol/db/policy.go index 49b419b5..a874b602 100644 --- a/hscontrol/db/policy.go +++ b/hscontrol/db/policy.go @@ -24,14 +24,22 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) { // GetPolicy returns the latest policy in the database. func (hsdb *HSDatabase) GetPolicy() (*types.Policy, error) { + return GetPolicy(hsdb.DB) +} + +// GetPolicy returns the latest policy from the database. +// This standalone function can be used in contexts where HSDatabase is not available, +// such as during migrations. +func GetPolicy(tx *gorm.DB) (*types.Policy, error) { var p types.Policy // Query: // SELECT * FROM policies ORDER BY id DESC LIMIT 1; - if err := hsdb.DB. + err := tx. Order("id DESC"). Limit(1). - First(&p).Error; err != nil { + First(&p).Error + if err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return nil, types.ErrPolicyNotFound } diff --git a/hscontrol/db/suite_test.go b/hscontrol/db/suite_test.go index 4ebccbdd..15a85cf8 100644 --- a/hscontrol/db/suite_test.go +++ b/hscontrol/db/suite_test.go @@ -23,13 +23,17 @@ func newSQLiteTestDB() (*HSDatabase, error) { zerolog.SetGlobalLevel(zerolog.Disabled) db, err := NewHeadscaleDatabase( - types.DatabaseConfig{ - Type: types.DatabaseSqlite, - Sqlite: types.SqliteConfig{ - Path: tmpDir + "/headscale_test.db", + &types.Config{ + Database: types.DatabaseConfig{ + Type: types.DatabaseSqlite, + Sqlite: types.SqliteConfig{ + Path: tmpDir + "/headscale_test.db", + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, }, }, - "", emptyCache(), ) if err != nil { @@ -72,18 +76,22 @@ func newHeadscaleDBFromPostgresURL(t *testing.T, pu *url.URL) *HSDatabase { port, _ := strconv.Atoi(pu.Port()) db, err := NewHeadscaleDatabase( - types.DatabaseConfig{ - Type: types.DatabasePostgres, - Postgres: types.PostgresConfig{ - Host: pu.Hostname(), - User: pu.User.Username(), - Name: strings.TrimLeft(pu.Path, "/"), - Pass: pass, - Port: port, - Ssl: "disable", + &types.Config{ + Database: types.DatabaseConfig{ + Type: types.DatabasePostgres, + Postgres: types.PostgresConfig{ + Host: pu.Hostname(), + User: pu.User.Username(), + Name: strings.TrimLeft(pu.Path, "/"), + Pass: pass, + Port: port, + Ssl: "disable", + }, + }, + Policy: types.PolicyConfig{ + Mode: types.PolicyModeDB, }, }, - "", emptyCache(), ) if err != nil { diff --git a/hscontrol/mapper/batcher_test.go b/hscontrol/mapper/batcher_test.go index 3cbd4e2d..70d5e377 100644 --- a/hscontrol/mapper/batcher_test.go +++ b/hscontrol/mapper/batcher_test.go @@ -213,8 +213,7 @@ func setupBatcherWithTestData( // Create database and populate it with test data database, err := db.NewHeadscaleDatabase( - cfg.Database, - "", + cfg, emptyCache(), ) if err != nil { diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 61dbe7b5..114968a2 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -115,8 +115,7 @@ func NewState(cfg *types.Config) (*State, error) { ) db, err := hsdb.NewHeadscaleDatabase( - cfg.Database, - cfg.BaseDomain, + cfg, registrationCache, ) if err != nil {