db: use PolicyManager for RequestTags migration

Refactor the RequestTags migration (202601121700-migrate-hostinfo-request-tags)
to use PolicyManager.NodeCanHaveTag() instead of reimplementing tag validation.

Changes:
- NewHeadscaleDatabase now accepts *types.Config to allow migrations
  access to policy configuration
- Add loadPolicyBytes helper to load policy from file or DB based on config
- Add standalone GetPolicy(tx *gorm.DB) for use during migrations
- Replace custom tag validation logic with PolicyManager

Benefits:
- Full HuJSON parsing support (not just JSON)
- Proper group expansion via PolicyManager
- Support for nested tags and autogroups
- Works with both file and database policy modes
- Single source of truth for tag validation

Updates #3006
This commit is contained in:
Kristoffer Dalby 2026-01-21 10:51:56 +00:00
parent 740f650a4d
commit 3bf01d4cee
7 changed files with 144 additions and 160 deletions

View file

@ -69,8 +69,7 @@ var getPolicy = &cobra.Command{
} }
d, err := db.NewHeadscaleDatabase( d, err := db.NewHeadscaleDatabase(
cfg.Database, cfg,
cfg.BaseDomain,
nil, nil,
) )
if err != nil { if err != nil {
@ -145,8 +144,7 @@ var setPolicy = &cobra.Command{
} }
d, err := db.NewHeadscaleDatabase( d, err := db.NewHeadscaleDatabase(
cfg.Database, cfg,
cfg.BaseDomain,
nil, nil,
) )
if err != nil { if err != nil {

View file

@ -7,15 +7,16 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"os"
"path/filepath" "path/filepath"
"slices" "slices"
"strconv" "strconv"
"strings"
"time" "time"
"github.com/glebarez/sqlite" "github.com/glebarez/sqlite"
"github.com/go-gormigrate/gormigrate/v2" "github.com/go-gormigrate/gormigrate/v2"
"github.com/juanfont/headscale/hscontrol/db/sqliteconfig" "github.com/juanfont/headscale/hscontrol/db/sqliteconfig"
"github.com/juanfont/headscale/hscontrol/policy"
"github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util" "github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log" "github.com/rs/zerolog/log"
@ -54,20 +55,51 @@ type KV struct {
type HSDatabase struct { type HSDatabase struct {
DB *gorm.DB DB *gorm.DB
cfg *types.DatabaseConfig cfg *types.Config
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
baseDomain string
} }
// TODO(kradalby): assemble this struct from toptions or something typed // loadPolicyBytes loads policy from file or database based on configuration.
// rather than arguments. // This is used during migrations when HSDatabase is not yet fully initialized.
func loadPolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) {
switch cfg.Policy.Mode {
case types.PolicyModeFile:
if cfg.Policy.Path == "" {
return nil, nil
}
absPath := util.AbsolutePathFromConfigPath(cfg.Policy.Path)
return os.ReadFile(absPath)
case types.PolicyModeDB:
p, err := GetPolicy(tx)
if err != nil {
if errors.Is(err, types.ErrPolicyNotFound) {
return nil, nil
}
return nil, err
}
if p.Data == "" {
return nil, nil
}
return []byte(p.Data), nil
default:
return nil, nil
}
}
// NewHeadscaleDatabase creates a new database connection and runs migrations.
// It accepts the full configuration to allow migrations access to policy settings.
func NewHeadscaleDatabase( func NewHeadscaleDatabase(
cfg types.DatabaseConfig, cfg *types.Config,
baseDomain string,
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode], regCache *zcache.Cache[types.RegistrationID, types.RegisterNode],
) (*HSDatabase, error) { ) (*HSDatabase, error) {
dbConn, err := openDB(cfg) dbConn, err := openDB(cfg.Database)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -254,7 +286,7 @@ AND auth_key_id NOT IN (
ID: "202507021200", ID: "202507021200",
Migrate: func(tx *gorm.DB) error { Migrate: func(tx *gorm.DB) error {
// Only run on SQLite // Only run on SQLite
if cfg.Type != types.DatabaseSqlite { if cfg.Database.Type != types.DatabaseSqlite {
log.Info().Msg("Skipping schema migration on non-SQLite database") log.Info().Msg("Skipping schema migration on non-SQLite database")
return nil return nil
} }
@ -602,119 +634,55 @@ AND auth_key_id NOT IN (
// Fixes: https://github.com/juanfont/headscale/issues/3006 // Fixes: https://github.com/juanfont/headscale/issues/3006
ID: "202601121700-migrate-hostinfo-request-tags", ID: "202601121700-migrate-hostinfo-request-tags",
Migrate: func(tx *gorm.DB) error { Migrate: func(tx *gorm.DB) error {
// 1. Load policy from database // 1. Load policy from file or database based on configuration
var policyData string policyData, err := loadPolicyBytes(tx, cfg)
err := tx.Raw("SELECT data FROM policies ORDER BY id DESC LIMIT 1").Scan(&policyData).Error
if err != nil || policyData == "" {
log.Info().Msg("No policy found in database, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil
}
// 2. Parse tagOwners and groups from policy
type migrationPolicy struct {
TagOwners map[string][]string `json:"tagOwners"`
Groups map[string][]string `json:"groups"`
}
var pol migrationPolicy
if err := json.Unmarshal([]byte(policyData), &pol); err != nil {
log.Warn().Err(err).Msg("Failed to parse policy JSON, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil
}
if len(pol.TagOwners) == 0 {
log.Info().Msg("No tagOwners defined in policy, skipping RequestTags migration")
return nil
}
// Helper function to check if a user can have a tag
canUserHaveTag := func(username string, tag string) bool {
owners, exists := pol.TagOwners[tag]
if !exists {
return false // Tag not defined in policy
}
for _, owner := range owners {
// Direct username match
if owner == username {
return true
}
// Group expansion
if strings.HasPrefix(owner, "group:") {
if groupMembers, ok := pol.Groups[owner]; ok {
if slices.Contains(groupMembers, username) {
return true
}
}
}
}
return false
}
// 3. Query nodes with user info
type nodeRow struct {
ID uint64
HostInfo string
Tags string
UserID *uint64
Username *string
}
var nodes []nodeRow
err = tx.Raw(`
SELECT n.id, n.host_info, n.tags, n.user_id, u.name as username
FROM nodes n
LEFT JOIN users u ON n.user_id = u.id
WHERE n.host_info IS NOT NULL AND n.host_info != '' AND n.host_info != '{}'
`).Scan(&nodes).Error
if err != nil { if err != nil {
return fmt.Errorf("querying nodes for RequestTags migration: %w", err) log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil
}
if len(policyData) == 0 {
log.Info().Msg("No policy found, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil
}
// 2. Load users and nodes to create PolicyManager
users, err := ListUsers(tx)
if err != nil {
return fmt.Errorf("loading users for RequestTags migration: %w", err)
}
nodes, err := ListNodes(tx)
if err != nil {
return fmt.Errorf("loading nodes for RequestTags migration: %w", err)
}
// 3. Create PolicyManager (handles HuJSON parsing, groups, nested tags, etc.)
polMan, err := policy.NewPolicyManager(policyData, users, nodes.ViewSlice())
if err != nil {
log.Warn().Err(err).Msg("Failed to parse policy, skipping RequestTags migration (tags will be validated on node reconnect)")
return nil
} }
// 4. Process each node // 4. Process each node
for _, node := range nodes { for _, node := range nodes {
// Parse host_info JSON to extract RequestTags if node.Hostinfo == nil {
var hostInfo struct {
RequestTags []string `json:"RequestTags"`
}
if err := json.Unmarshal([]byte(node.HostInfo), &hostInfo); err != nil {
log.Trace().
Uint64("node.id", node.ID).
Err(err).
Msg("Skipping node with invalid host_info JSON during RequestTags migration")
continue continue
} }
// Skip if no RequestTags in host_info requestTags := node.Hostinfo.RequestTags
if len(hostInfo.RequestTags) == 0 { if len(requestTags) == 0 {
continue continue
} }
// Skip if no username (can't validate) existingTags := node.Tags
if node.Username == nil || *node.Username == "" {
log.Debug().
Uint64("node.id", node.ID).
Strs("request_tags", hostInfo.RequestTags).
Msg("Skipping node without username during RequestTags migration")
continue
}
// Parse existing tags from the tags column var validatedTags, rejectedTags []string
var existingTags []string
if node.Tags != "" && node.Tags != "null" {
if err := json.Unmarshal([]byte(node.Tags), &existingTags); err != nil {
log.Trace().
Uint64("node.id", node.ID).
Err(err).
Msg("Skipping node with invalid tags JSON during RequestTags migration")
continue
}
}
// Validate and merge RequestTags nodeView := node.View()
var validatedTags []string
var rejectedTags []string for _, tag := range requestTags {
for _, tag := range hostInfo.RequestTags { if polMan.NodeCanHaveTag(nodeView, tag) {
if canUserHaveTag(*node.Username, tag) {
if !slices.Contains(existingTags, tag) { if !slices.Contains(existingTags, tag) {
validatedTags = append(validatedTags, tag) validatedTags = append(validatedTags, tag)
} }
@ -723,37 +691,34 @@ AND auth_key_id NOT IN (
} }
} }
// Skip if no validated tags to add
if len(validatedTags) == 0 { if len(validatedTags) == 0 {
if len(rejectedTags) > 0 { if len(rejectedTags) > 0 {
log.Debug(). log.Debug().
Uint64("node.id", node.ID). Uint64("node.id", uint64(node.ID)).
Str("username", *node.Username). Str("node.name", node.Hostname).
Strs("rejected_tags", rejectedTags). Strs("rejected_tags", rejectedTags).
Msg("RequestTags rejected during migration (user not authorized)") Msg("RequestTags rejected during migration (not authorized)")
} }
continue continue
} }
// Merge validated tags with existing tags
mergedTags := append(existingTags, validatedTags...) mergedTags := append(existingTags, validatedTags...)
slices.Sort(mergedTags) slices.Sort(mergedTags)
mergedTags = slices.Compact(mergedTags) mergedTags = slices.Compact(mergedTags)
// Serialize back to JSON
tagsJSON, err := json.Marshal(mergedTags) tagsJSON, err := json.Marshal(mergedTags)
if err != nil { if err != nil {
return fmt.Errorf("serializing merged tags for node %d: %w", node.ID, err) return fmt.Errorf("serializing merged tags for node %d: %w", node.ID, err)
} }
// Update the tags column
if err := tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error; err != nil { if err := tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error; err != nil {
return fmt.Errorf("updating tags for node %d: %w", node.ID, err) return fmt.Errorf("updating tags for node %d: %w", node.ID, err)
} }
log.Info(). log.Info().
Uint64("node.id", node.ID). Uint64("node.id", uint64(node.ID)).
Str("username", *node.Username). Str("node.name", node.Hostname).
Strs("validated_tags", validatedTags). Strs("validated_tags", validatedTags).
Strs("rejected_tags", rejectedTags). Strs("rejected_tags", rejectedTags).
Strs("existing_tags", existingTags). Strs("existing_tags", existingTags).
@ -821,7 +786,8 @@ AND auth_key_id NOT IN (
return nil return nil
}) })
if err := runMigrations(cfg, dbConn, migrations); err != nil { err = runMigrations(cfg.Database, dbConn, migrations)
if err != nil {
return nil, fmt.Errorf("migration failed: %w", err) return nil, fmt.Errorf("migration failed: %w", err)
} }
@ -829,7 +795,7 @@ AND auth_key_id NOT IN (
// This is currently only done on sqlite as squibble does not // This is currently only done on sqlite as squibble does not
// support Postgres and we use our sqlite schema as our source of // support Postgres and we use our sqlite schema as our source of
// truth. // truth.
if cfg.Type == types.DatabaseSqlite { if cfg.Database.Type == types.DatabaseSqlite {
sqlConn, err := dbConn.DB() sqlConn, err := dbConn.DB()
if err != nil { if err != nil {
return nil, fmt.Errorf("getting DB from gorm: %w", err) return nil, fmt.Errorf("getting DB from gorm: %w", err)
@ -861,10 +827,8 @@ AND auth_key_id NOT IN (
db := HSDatabase{ db := HSDatabase{
DB: dbConn, DB: dbConn,
cfg: &cfg, cfg: cfg,
regCache: regCache, regCache: regCache,
baseDomain: baseDomain,
} }
return &db, err return &db, err
@ -1107,7 +1071,7 @@ func (hsdb *HSDatabase) Close() error {
return err return err
} }
if hsdb.cfg.Type == types.DatabaseSqlite && hsdb.cfg.Sqlite.WriteAheadLog { if hsdb.cfg.Database.Type == types.DatabaseSqlite && hsdb.cfg.Database.Sqlite.WriteAheadLog {
db.Exec("VACUUM") db.Exec("VACUUM")
} }

View file

@ -288,13 +288,17 @@ func dbForTestWithPath(t *testing.T, sqlFilePath string) *HSDatabase {
} }
db, err := NewHeadscaleDatabase( db, err := NewHeadscaleDatabase(
types.DatabaseConfig{ &types.Config{
Type: "sqlite3", Database: types.DatabaseConfig{
Sqlite: types.SqliteConfig{ Type: "sqlite3",
Path: dbPath, Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
}, },
}, },
"",
emptyCache(), emptyCache(),
) )
if err != nil { if err != nil {
@ -343,13 +347,17 @@ func TestSQLiteAllTestdataMigrations(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, err = NewHeadscaleDatabase( _, err = NewHeadscaleDatabase(
types.DatabaseConfig{ &types.Config{
Type: "sqlite3", Database: types.DatabaseConfig{
Sqlite: types.SqliteConfig{ Type: "sqlite3",
Path: dbPath, Sqlite: types.SqliteConfig{
Path: dbPath,
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
}, },
}, },
"",
emptyCache(), emptyCache(),
) )
require.NoError(t, err) require.NoError(t, err)

View file

@ -24,14 +24,22 @@ func (hsdb *HSDatabase) SetPolicy(policy string) (*types.Policy, error) {
// GetPolicy returns the latest policy in the database. // GetPolicy returns the latest policy in the database.
func (hsdb *HSDatabase) GetPolicy() (*types.Policy, error) { func (hsdb *HSDatabase) GetPolicy() (*types.Policy, error) {
return GetPolicy(hsdb.DB)
}
// GetPolicy returns the latest policy from the database.
// This standalone function can be used in contexts where HSDatabase is not available,
// such as during migrations.
func GetPolicy(tx *gorm.DB) (*types.Policy, error) {
var p types.Policy var p types.Policy
// Query: // Query:
// SELECT * FROM policies ORDER BY id DESC LIMIT 1; // SELECT * FROM policies ORDER BY id DESC LIMIT 1;
if err := hsdb.DB. err := tx.
Order("id DESC"). Order("id DESC").
Limit(1). Limit(1).
First(&p).Error; err != nil { First(&p).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) { if errors.Is(err, gorm.ErrRecordNotFound) {
return nil, types.ErrPolicyNotFound return nil, types.ErrPolicyNotFound
} }

View file

@ -23,13 +23,17 @@ func newSQLiteTestDB() (*HSDatabase, error) {
zerolog.SetGlobalLevel(zerolog.Disabled) zerolog.SetGlobalLevel(zerolog.Disabled)
db, err := NewHeadscaleDatabase( db, err := NewHeadscaleDatabase(
types.DatabaseConfig{ &types.Config{
Type: types.DatabaseSqlite, Database: types.DatabaseConfig{
Sqlite: types.SqliteConfig{ Type: types.DatabaseSqlite,
Path: tmpDir + "/headscale_test.db", Sqlite: types.SqliteConfig{
Path: tmpDir + "/headscale_test.db",
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
}, },
}, },
"",
emptyCache(), emptyCache(),
) )
if err != nil { if err != nil {
@ -72,18 +76,22 @@ func newHeadscaleDBFromPostgresURL(t *testing.T, pu *url.URL) *HSDatabase {
port, _ := strconv.Atoi(pu.Port()) port, _ := strconv.Atoi(pu.Port())
db, err := NewHeadscaleDatabase( db, err := NewHeadscaleDatabase(
types.DatabaseConfig{ &types.Config{
Type: types.DatabasePostgres, Database: types.DatabaseConfig{
Postgres: types.PostgresConfig{ Type: types.DatabasePostgres,
Host: pu.Hostname(), Postgres: types.PostgresConfig{
User: pu.User.Username(), Host: pu.Hostname(),
Name: strings.TrimLeft(pu.Path, "/"), User: pu.User.Username(),
Pass: pass, Name: strings.TrimLeft(pu.Path, "/"),
Port: port, Pass: pass,
Ssl: "disable", Port: port,
Ssl: "disable",
},
},
Policy: types.PolicyConfig{
Mode: types.PolicyModeDB,
}, },
}, },
"",
emptyCache(), emptyCache(),
) )
if err != nil { if err != nil {

View file

@ -213,8 +213,7 @@ func setupBatcherWithTestData(
// Create database and populate it with test data // Create database and populate it with test data
database, err := db.NewHeadscaleDatabase( database, err := db.NewHeadscaleDatabase(
cfg.Database, cfg,
"",
emptyCache(), emptyCache(),
) )
if err != nil { if err != nil {

View file

@ -115,8 +115,7 @@ func NewState(cfg *types.Config) (*State, error) {
) )
db, err := hsdb.NewHeadscaleDatabase( db, err := hsdb.NewHeadscaleDatabase(
cfg.Database, cfg,
cfg.BaseDomain,
registrationCache, registrationCache,
) )
if err != nil { if err != nil {