headscale/hscontrol/db/api_key.go
Kristoffer Dalby b36438bf90 all: disable thelper linter and clean up nolint comments
Disable the thelper linter which triggers on inline test closures in
table-driven tests. These closures are intentionally not standalone
helpers and don't benefit from t.Helper().

Also remove explanatory comments from nolint directives throughout the
codebase as they add noise without providing significant value.
2026-01-21 15:56:57 +00:00

297 lines
8.1 KiB
Go

package db
import (
"errors"
"fmt"
"strings"
"time"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"golang.org/x/crypto/bcrypt"
"gorm.io/gorm"
)
const (
apiKeyPrefix = "hskey-api-" //nolint:gosec
apiKeyPrefixLength = 12
apiKeyHashLength = 64
// Legacy format constants.
legacyAPIPrefixLength = 7
legacyAPIKeyLength = 32
)
var (
ErrAPIKeyFailedToParse = errors.New("failed to parse ApiKey")
ErrAPIKeyGenerationFailed = errors.New("failed to generate API key")
ErrAPIKeyInvalidGeneration = errors.New("generated API key failed validation")
)
// CreateAPIKey creates a new ApiKey in a user, and returns it.
func (hsdb *HSDatabase) CreateAPIKey(
expiration *time.Time,
) (string, *types.APIKey, error) {
// Generate public prefix (12 chars)
prefix, err := util.GenerateRandomStringURLSafe(apiKeyPrefixLength)
if err != nil {
return "", nil, err
}
// Validate prefix
if len(prefix) != apiKeyPrefixLength {
return "", nil, fmt.Errorf("%w: generated prefix has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyPrefixLength, len(prefix))
}
if !isValidBase64URLSafe(prefix) {
return "", nil, fmt.Errorf("%w: generated prefix contains invalid characters", ErrAPIKeyInvalidGeneration)
}
// Generate secret (64 chars)
secret, err := util.GenerateRandomStringURLSafe(apiKeyHashLength)
if err != nil {
return "", nil, err
}
// Validate secret
if len(secret) != apiKeyHashLength {
return "", nil, fmt.Errorf("%w: generated secret has invalid length: expected %d, got %d", ErrAPIKeyInvalidGeneration, apiKeyHashLength, len(secret))
}
if !isValidBase64URLSafe(secret) {
return "", nil, fmt.Errorf("%w: generated secret contains invalid characters", ErrAPIKeyInvalidGeneration)
}
// Full key string (shown ONCE to user)
keyStr := apiKeyPrefix + prefix + "-" + secret
// bcrypt hash of secret
hash, err := bcrypt.GenerateFromPassword([]byte(secret), bcrypt.DefaultCost)
if err != nil {
return "", nil, err
}
key := types.APIKey{
Prefix: prefix,
Hash: hash,
Expiration: expiration,
}
if err := hsdb.DB.Save(&key).Error; err != nil {
return "", nil, fmt.Errorf("failed to save API key to database: %w", err)
}
return keyStr, &key, nil
}
// ListAPIKeys returns the list of ApiKeys for a user.
func (hsdb *HSDatabase) ListAPIKeys() ([]types.APIKey, error) {
keys := []types.APIKey{}
if err := hsdb.DB.Find(&keys).Error; err != nil {
return nil, err
}
return keys, nil
}
// GetAPIKey returns a ApiKey for a given key.
func (hsdb *HSDatabase) GetAPIKey(prefix string) (*types.APIKey, error) {
key := types.APIKey{}
if result := hsdb.DB.First(&key, "prefix = ?", prefix); result.Error != nil {
return nil, result.Error
}
return &key, nil
}
// GetAPIKeyByID returns a ApiKey for a given id.
func (hsdb *HSDatabase) GetAPIKeyByID(id uint64) (*types.APIKey, error) {
key := types.APIKey{}
if result := hsdb.DB.Find(&types.APIKey{ID: id}).First(&key); result.Error != nil {
return nil, result.Error
}
return &key, nil
}
// DestroyAPIKey destroys a ApiKey. Returns error if the ApiKey
// does not exist.
func (hsdb *HSDatabase) DestroyAPIKey(key types.APIKey) error {
if result := hsdb.DB.Unscoped().Delete(key); result.Error != nil {
return result.Error
}
return nil
}
// ExpireAPIKey marks a ApiKey as expired.
func (hsdb *HSDatabase) ExpireAPIKey(key *types.APIKey) error {
if err := hsdb.DB.Model(&key).Update("Expiration", time.Now()).Error; err != nil {
return err
}
return nil
}
func (hsdb *HSDatabase) ValidateAPIKey(keyStr string) (bool, error) {
key, err := validateAPIKey(hsdb.DB, keyStr)
if err != nil {
return false, err
}
if key.Expiration != nil && key.Expiration.Before(time.Now()) {
return false, nil
}
return true, nil
}
// ParseAPIKeyPrefix extracts the database prefix from a display prefix.
// Handles formats: "hskey-api-{12chars}-***", "hskey-api-{12chars}", or just "{12chars}".
// Returns the 12-character prefix suitable for database lookup.
func ParseAPIKeyPrefix(displayPrefix string) (string, error) {
// If it's already just the 12-character prefix, return it
if len(displayPrefix) == apiKeyPrefixLength && isValidBase64URLSafe(displayPrefix) {
return displayPrefix, nil
}
// If it starts with the API key prefix, parse it
if strings.HasPrefix(displayPrefix, apiKeyPrefix) {
// Remove the "hskey-api-" prefix
_, remainder, found := strings.Cut(displayPrefix, apiKeyPrefix)
if !found {
return "", fmt.Errorf("%w: invalid display prefix format", ErrAPIKeyFailedToParse)
}
// Extract just the first 12 characters (the actual prefix)
if len(remainder) < apiKeyPrefixLength {
return "", fmt.Errorf("%w: prefix too short", ErrAPIKeyFailedToParse)
}
prefix := remainder[:apiKeyPrefixLength]
// Validate it's base64 URL-safe
if !isValidBase64URLSafe(prefix) {
return "", fmt.Errorf("%w: prefix contains invalid characters", ErrAPIKeyFailedToParse)
}
return prefix, nil
}
// For legacy 7-character prefixes or other formats, return as-is
return displayPrefix, nil
}
// validateAPIKey validates an API key and returns the key if valid.
// Handles both new (hskey-api-{prefix}-{secret}) and legacy (prefix.secret) formats.
func validateAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) {
// Validate input is not empty
if keyStr == "" {
return nil, ErrAPIKeyFailedToParse
}
// Check for new format: hskey-api-{prefix}-{secret}
_, prefixAndSecret, found := strings.Cut(keyStr, apiKeyPrefix)
if !found {
// Legacy format: prefix.secret
return validateLegacyAPIKey(db, keyStr)
}
// New format: parse and verify
const expectedMinLength = apiKeyPrefixLength + 1 + apiKeyHashLength
if len(prefixAndSecret) < expectedMinLength {
return nil, fmt.Errorf(
"%w: key too short, expected at least %d chars after prefix, got %d",
ErrAPIKeyFailedToParse,
expectedMinLength,
len(prefixAndSecret),
)
}
// Use fixed-length parsing
prefix := prefixAndSecret[:apiKeyPrefixLength]
// Validate separator at expected position
if prefixAndSecret[apiKeyPrefixLength] != '-' {
return nil, fmt.Errorf(
"%w: expected separator '-' at position %d, got '%c'",
ErrAPIKeyFailedToParse,
apiKeyPrefixLength,
prefixAndSecret[apiKeyPrefixLength],
)
}
secret := prefixAndSecret[apiKeyPrefixLength+1:]
// Validate secret length
if len(secret) != apiKeyHashLength {
return nil, fmt.Errorf(
"%w: secret length mismatch, expected %d chars, got %d",
ErrAPIKeyFailedToParse,
apiKeyHashLength,
len(secret),
)
}
// Validate prefix contains only base64 URL-safe characters
if !isValidBase64URLSafe(prefix) {
return nil, fmt.Errorf(
"%w: prefix contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrAPIKeyFailedToParse,
)
}
// Validate secret contains only base64 URL-safe characters
if !isValidBase64URLSafe(secret) {
return nil, fmt.Errorf(
"%w: secret contains invalid characters (expected base64 URL-safe: A-Za-z0-9_-)",
ErrAPIKeyFailedToParse,
)
}
// Look up by prefix (indexed)
var key types.APIKey
err := db.First(&key, "prefix = ?", prefix).Error
if err != nil {
return nil, fmt.Errorf("API key not found: %w", err)
}
// Verify bcrypt hash
err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret))
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return &key, nil
}
// validateLegacyAPIKey validates a legacy format API key (prefix.secret).
func validateLegacyAPIKey(db *gorm.DB, keyStr string) (*types.APIKey, error) {
// Legacy format uses "." as separator
prefix, secret, found := strings.Cut(keyStr, ".")
if !found {
return nil, ErrAPIKeyFailedToParse
}
// Legacy prefix is 7 chars
if len(prefix) != legacyAPIPrefixLength {
return nil, fmt.Errorf("%w: legacy prefix length mismatch", ErrAPIKeyFailedToParse)
}
var key types.APIKey
err := db.First(&key, "prefix = ?", prefix).Error
if err != nil {
return nil, fmt.Errorf("API key not found: %w", err)
}
// Verify bcrypt (key.Hash stores bcrypt of full secret)
err = bcrypt.CompareHashAndPassword(key.Hash, []byte(secret))
if err != nil {
return nil, fmt.Errorf("invalid API key: %w", err)
}
return &key, nil
}