From 4eeea90c5a018edc591a87f640683e6c805bd2cf Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Wed, 21 Jan 2026 11:40:50 +0000 Subject: [PATCH] db: consolidate PolicyBytes into single function Move the duplicated policy loading logic from state/state.go and db/db.go into a single PolicyBytes function in db/policy.go. This standalone function can be used in contexts where HSDatabase is not fully initialized (such as during migrations) by accepting a raw *gorm.DB transaction instead of requiring *HSDatabase. Updates #3006 --- hscontrol/db/db.go | 40 +++------------------------------- hscontrol/db/policy.go | 40 ++++++++++++++++++++++++++++++++++ hscontrol/state/debug.go | 3 ++- hscontrol/state/state.go | 47 ++-------------------------------------- 4 files changed, 47 insertions(+), 83 deletions(-) diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index 84520e65..a1429aa6 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -7,7 +7,6 @@ import ( "errors" "fmt" "net/netip" - "os" "path/filepath" "slices" "strconv" @@ -52,40 +51,6 @@ type HSDatabase struct { regCache *zcache.Cache[types.RegistrationID, types.RegisterNode] } -// loadPolicyBytes loads policy from file or database based on configuration. -// This is used during migrations when HSDatabase is not yet fully initialized. -func loadPolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) { - switch cfg.Policy.Mode { - case types.PolicyModeFile: - if cfg.Policy.Path == "" { - return nil, nil - } - - absPath := util.AbsolutePathFromConfigPath(cfg.Policy.Path) - - return os.ReadFile(absPath) - - case types.PolicyModeDB: - p, err := GetPolicy(tx) - if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil, nil - } - - return nil, err - } - - if p.Data == "" { - return nil, nil - } - - return []byte(p.Data), nil - - default: - return nil, nil - } -} - // NewHeadscaleDatabase creates a new database connection and runs migrations. // It accepts the full configuration to allow migrations access to policy settings. func NewHeadscaleDatabase( @@ -628,7 +593,7 @@ AND auth_key_id NOT IN ( ID: "202601121700-migrate-hostinfo-request-tags", Migrate: func(tx *gorm.DB) error { // 1. Load policy from file or database based on configuration - policyData, err := loadPolicyBytes(tx, cfg) + policyData, err := PolicyBytes(tx, cfg) if err != nil { log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)") return nil @@ -705,7 +670,8 @@ AND auth_key_id NOT IN ( return fmt.Errorf("serializing merged tags for node %d: %w", node.ID, err) } - if err := tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error; err != nil { + err = tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error + if err != nil { return fmt.Errorf("updating tags for node %d: %w", node.ID, err) } diff --git a/hscontrol/db/policy.go b/hscontrol/db/policy.go index a874b602..bdc8af41 100644 --- a/hscontrol/db/policy.go +++ b/hscontrol/db/policy.go @@ -2,8 +2,10 @@ package db import ( "errors" + "os" "github.com/juanfont/headscale/hscontrol/types" + "github.com/juanfont/headscale/hscontrol/util" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -49,3 +51,41 @@ func GetPolicy(tx *gorm.DB) (*types.Policy, error) { return &p, nil } + +// PolicyBytes loads policy configuration from file or database based on the configured mode. +// Returns nil if no policy is configured, which is valid. +// This standalone function can be used in contexts where HSDatabase is not available, +// such as during migrations. +func PolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) { + switch cfg.Policy.Mode { + case types.PolicyModeFile: + path := cfg.Policy.Path + + // It is fine to start headscale without a policy file. + if len(path) == 0 { + return nil, nil + } + + absPath := util.AbsolutePathFromConfigPath(path) + + return os.ReadFile(absPath) + + case types.PolicyModeDB: + p, err := GetPolicy(tx) + if err != nil { + if errors.Is(err, types.ErrPolicyNotFound) { + return nil, nil + } + + return nil, err + } + + if p.Data == "" { + return nil, nil + } + + return []byte(p.Data), nil + } + + return nil, nil +} diff --git a/hscontrol/state/debug.go b/hscontrol/state/debug.go index 02d674d5..3ed1d79f 100644 --- a/hscontrol/state/debug.go +++ b/hscontrol/state/debug.go @@ -5,6 +5,7 @@ import ( "strings" "time" + hsdb "github.com/juanfont/headscale/hscontrol/db" "github.com/juanfont/headscale/hscontrol/routes" "github.com/juanfont/headscale/hscontrol/types" "tailscale.com/tailcfg" @@ -228,7 +229,7 @@ func (s *State) DebugPolicy() (string, error) { return p.Data, nil case types.PolicyModeFile: - pol, err := policyBytes(s.db, s.cfg) + pol, err := hsdb.PolicyBytes(s.db.DB, s.cfg) if err != nil { return "", err } diff --git a/hscontrol/state/state.go b/hscontrol/state/state.go index 114968a2..d1401ef0 100644 --- a/hscontrol/state/state.go +++ b/hscontrol/state/state.go @@ -8,9 +8,7 @@ import ( "context" "errors" "fmt" - "io" "net/netip" - "os" "slices" "strings" "sync" @@ -142,7 +140,7 @@ func NewState(cfg *types.Config) (*State, error) { return nil, fmt.Errorf("loading users: %w", err) } - pol, err := policyBytes(db, cfg) + pol, err := hsdb.PolicyBytes(db.DB, cfg) if err != nil { return nil, fmt.Errorf("loading policy: %w", err) } @@ -198,47 +196,6 @@ func (s *State) Close() error { return nil } -// policyBytes loads policy configuration from file or database based on the configured mode. -// Returns nil if no policy is configured, which is valid. -func policyBytes(db *hsdb.HSDatabase, cfg *types.Config) ([]byte, error) { - switch cfg.Policy.Mode { - case types.PolicyModeFile: - path := cfg.Policy.Path - - // It is fine to start headscale without a policy file. - if len(path) == 0 { - return nil, nil - } - - absPath := util.AbsolutePathFromConfigPath(path) - policyFile, err := os.Open(absPath) - if err != nil { - return nil, err - } - defer policyFile.Close() - - return io.ReadAll(policyFile) - - case types.PolicyModeDB: - p, err := db.GetPolicy() - if err != nil { - if errors.Is(err, types.ErrPolicyNotFound) { - return nil, nil - } - - return nil, err - } - - if p.Data == "" { - return nil, nil - } - - return []byte(p.Data), err - } - - return nil, fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, cfg.Policy.Mode) -} - // SetDERPMap updates the DERP relay configuration. func (s *State) SetDERPMap(dm *tailcfg.DERPMap) { s.derpMap.Store(dm) @@ -252,7 +209,7 @@ func (s *State) DERPMap() tailcfg.DERPMapView { // ReloadPolicy reloads the access control policy and triggers auto-approval if changed. // Returns true if the policy changed. func (s *State) ReloadPolicy() ([]change.Change, error) { - pol, err := policyBytes(s.db, s.cfg) + pol, err := hsdb.PolicyBytes(s.db.DB, s.cfg) if err != nil { return nil, fmt.Errorf("loading policy: %w", err) }