mirror of
https://github.com/juanfont/headscale.git
synced 2026-01-23 02:24:10 +00:00
db: consolidate PolicyBytes into single function
Move the duplicated policy loading logic from state/state.go and db/db.go into a single PolicyBytes function in db/policy.go. This standalone function can be used in contexts where HSDatabase is not fully initialized (such as during migrations) by accepting a raw *gorm.DB transaction instead of requiring *HSDatabase. Updates #3006
This commit is contained in:
parent
ecb4b488ba
commit
4eeea90c5a
4 changed files with 47 additions and 83 deletions
|
|
@ -7,7 +7,6 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
|
@ -52,40 +51,6 @@ type HSDatabase struct {
|
|||
regCache *zcache.Cache[types.RegistrationID, types.RegisterNode]
|
||||
}
|
||||
|
||||
// loadPolicyBytes loads policy from file or database based on configuration.
|
||||
// This is used during migrations when HSDatabase is not yet fully initialized.
|
||||
func loadPolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) {
|
||||
switch cfg.Policy.Mode {
|
||||
case types.PolicyModeFile:
|
||||
if cfg.Policy.Path == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
absPath := util.AbsolutePathFromConfigPath(cfg.Policy.Path)
|
||||
|
||||
return os.ReadFile(absPath)
|
||||
|
||||
case types.PolicyModeDB:
|
||||
p, err := GetPolicy(tx)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrPolicyNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.Data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []byte(p.Data), nil
|
||||
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewHeadscaleDatabase creates a new database connection and runs migrations.
|
||||
// It accepts the full configuration to allow migrations access to policy settings.
|
||||
func NewHeadscaleDatabase(
|
||||
|
|
@ -628,7 +593,7 @@ AND auth_key_id NOT IN (
|
|||
ID: "202601121700-migrate-hostinfo-request-tags",
|
||||
Migrate: func(tx *gorm.DB) error {
|
||||
// 1. Load policy from file or database based on configuration
|
||||
policyData, err := loadPolicyBytes(tx, cfg)
|
||||
policyData, err := PolicyBytes(tx, cfg)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to load policy, skipping RequestTags migration (tags will be validated on node reconnect)")
|
||||
return nil
|
||||
|
|
@ -705,7 +670,8 @@ AND auth_key_id NOT IN (
|
|||
return fmt.Errorf("serializing merged tags for node %d: %w", node.ID, err)
|
||||
}
|
||||
|
||||
if err := tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error; err != nil {
|
||||
err = tx.Exec("UPDATE nodes SET tags = ? WHERE id = ?", string(tagsJSON), node.ID).Error
|
||||
if err != nil {
|
||||
return fmt.Errorf("updating tags for node %d: %w", node.ID, err)
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@ package db
|
|||
|
||||
import (
|
||||
"errors"
|
||||
"os"
|
||||
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"github.com/juanfont/headscale/hscontrol/util"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
)
|
||||
|
|
@ -49,3 +51,41 @@ func GetPolicy(tx *gorm.DB) (*types.Policy, error) {
|
|||
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// PolicyBytes loads policy configuration from file or database based on the configured mode.
|
||||
// Returns nil if no policy is configured, which is valid.
|
||||
// This standalone function can be used in contexts where HSDatabase is not available,
|
||||
// such as during migrations.
|
||||
func PolicyBytes(tx *gorm.DB, cfg *types.Config) ([]byte, error) {
|
||||
switch cfg.Policy.Mode {
|
||||
case types.PolicyModeFile:
|
||||
path := cfg.Policy.Path
|
||||
|
||||
// It is fine to start headscale without a policy file.
|
||||
if len(path) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
absPath := util.AbsolutePathFromConfigPath(path)
|
||||
|
||||
return os.ReadFile(absPath)
|
||||
|
||||
case types.PolicyModeDB:
|
||||
p, err := GetPolicy(tx)
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrPolicyNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.Data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []byte(p.Data), nil
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ import (
|
|||
"strings"
|
||||
"time"
|
||||
|
||||
hsdb "github.com/juanfont/headscale/hscontrol/db"
|
||||
"github.com/juanfont/headscale/hscontrol/routes"
|
||||
"github.com/juanfont/headscale/hscontrol/types"
|
||||
"tailscale.com/tailcfg"
|
||||
|
|
@ -228,7 +229,7 @@ func (s *State) DebugPolicy() (string, error) {
|
|||
|
||||
return p.Data, nil
|
||||
case types.PolicyModeFile:
|
||||
pol, err := policyBytes(s.db, s.cfg)
|
||||
pol, err := hsdb.PolicyBytes(s.db.DB, s.cfg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,9 +8,7 @@ import (
|
|||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/netip"
|
||||
"os"
|
||||
"slices"
|
||||
"strings"
|
||||
"sync"
|
||||
|
|
@ -142,7 +140,7 @@ func NewState(cfg *types.Config) (*State, error) {
|
|||
return nil, fmt.Errorf("loading users: %w", err)
|
||||
}
|
||||
|
||||
pol, err := policyBytes(db, cfg)
|
||||
pol, err := hsdb.PolicyBytes(db.DB, cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading policy: %w", err)
|
||||
}
|
||||
|
|
@ -198,47 +196,6 @@ func (s *State) Close() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// policyBytes loads policy configuration from file or database based on the configured mode.
|
||||
// Returns nil if no policy is configured, which is valid.
|
||||
func policyBytes(db *hsdb.HSDatabase, cfg *types.Config) ([]byte, error) {
|
||||
switch cfg.Policy.Mode {
|
||||
case types.PolicyModeFile:
|
||||
path := cfg.Policy.Path
|
||||
|
||||
// It is fine to start headscale without a policy file.
|
||||
if len(path) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
absPath := util.AbsolutePathFromConfigPath(path)
|
||||
policyFile, err := os.Open(absPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer policyFile.Close()
|
||||
|
||||
return io.ReadAll(policyFile)
|
||||
|
||||
case types.PolicyModeDB:
|
||||
p, err := db.GetPolicy()
|
||||
if err != nil {
|
||||
if errors.Is(err, types.ErrPolicyNotFound) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if p.Data == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return []byte(p.Data), err
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("%w: %s", ErrUnsupportedPolicyMode, cfg.Policy.Mode)
|
||||
}
|
||||
|
||||
// SetDERPMap updates the DERP relay configuration.
|
||||
func (s *State) SetDERPMap(dm *tailcfg.DERPMap) {
|
||||
s.derpMap.Store(dm)
|
||||
|
|
@ -252,7 +209,7 @@ func (s *State) DERPMap() tailcfg.DERPMapView {
|
|||
// ReloadPolicy reloads the access control policy and triggers auto-approval if changed.
|
||||
// Returns true if the policy changed.
|
||||
func (s *State) ReloadPolicy() ([]change.Change, error) {
|
||||
pol, err := policyBytes(s.db, s.cfg)
|
||||
pol, err := hsdb.PolicyBytes(s.db.DB, s.cfg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("loading policy: %w", err)
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue