mirror of
https://github.com/photoprism/photoprism.git
synced 2026-01-23 02:24:24 +00:00
Auth: Improve JWKS Fetch Concurrency & Timeouts #5230
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
bae8ceb3a7
commit
633d4222ab
8 changed files with 211 additions and 38 deletions
|
|
@ -260,6 +260,7 @@ If anything in this file conflicts with the `Makefile` or the Developer Guide, t
|
|||
|
||||
- Respect precedence: `options.yml` overrides CLI/env values, which override defaults. When adding a new option, update `internal/config/options.go` (yaml/flag tags), register it in `internal/config/flags.go`, expose a getter, surface it in `*config.Report()`, and write generated values back to `options.yml` by setting `c.options.OptionsYaml` before persisting. Use `CliTestContext` in `internal/config/test.go` to exercise new flags.
|
||||
- When touching configuration in Go code, use the public accessors on `*config.Config` (e.g. `Config.JWKSUrl()`, `Config.SetJWKSUrl()`, `Config.ClusterUUID()`) instead of mutating `Config.Options()` directly; reserve raw option tweaks for test fixtures only.
|
||||
- Logging: use the shared logger (`event.Log`) via the package-level `log` variable (see `internal/auth/jwt/logger.go`) instead of direct `fmt.Print*` or ad-hoc loggers.
|
||||
- Favor explicit CLI flags: check `c.cliCtx.IsSet("<flag>")` before overriding user-supplied values, and follow the `ClusterUUID` pattern (`options.yml` → CLI/env → generated UUIDv4 persisted).
|
||||
- Database helpers: reuse `conf.Db()` / `conf.Database*()`, avoid GORM `WithContext`, quote MySQL identifiers, and reject unsupported drivers early.
|
||||
- Handler conventions: reuse limiter stacks (`limiter.Auth`, `limiter.Login`) and `limiter.AbortJSON` for 429s, lean on `api.ClientIP`, `header.BearerToken`, and `Abort*` helpers, compare secrets with constant time checks, set `Cache-Control: no-store` on sensitive responses, and register routes in `internal/server/routes.go`. For new list endpoints default `count=100` (max 1000) and `offset≥0`, document parameters explicitly, and set portal mode via `PHOTOPRISM_NODE_ROLE=portal` plus `PHOTOPRISM_JOIN_TOKEN` when needed.
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ type ClaimsSpec struct {
|
|||
TTL time.Duration
|
||||
}
|
||||
|
||||
// validate performs sanity checks on the claim specification before issuing a token.
|
||||
func (s ClaimsSpec) validate() error {
|
||||
if strings.TrimSpace(s.Issuer) == "" {
|
||||
return errors.New("jwt: issuer required")
|
||||
|
|
|
|||
|
|
@ -4,10 +4,23 @@ import (
|
|||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/event"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
// Init test logger.
|
||||
log = logrus.StandardLogger()
|
||||
log.SetLevel(logrus.TraceLevel)
|
||||
event.AuditLog = log
|
||||
|
||||
c := config.TestConfig()
|
||||
defer c.CloseDb()
|
||||
|
||||
// Run unit tests.
|
||||
code := m.Run()
|
||||
|
||||
// Remove temporary SQLite files after running the tests.
|
||||
|
|
|
|||
6
internal/auth/jwt/logger.go
Normal file
6
internal/auth/jwt/logger.go
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
package jwt
|
||||
|
||||
import "github.com/photoprism/photoprism/internal/event"
|
||||
|
||||
// log provides package-wide logging using the shared event logger.
|
||||
var log = event.Log
|
||||
|
|
@ -134,6 +134,7 @@ func (m *Manager) AllKeys() []*Key {
|
|||
return out
|
||||
}
|
||||
|
||||
// loadKeys reads existing key records from disk into memory.
|
||||
func (m *Manager) loadKeys() error {
|
||||
dir := m.keyDir()
|
||||
|
||||
|
|
@ -208,6 +209,7 @@ func (m *Manager) loadKeys() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// generateKey creates a fresh Ed25519 key pair, persists it, and returns a clone.
|
||||
func (m *Manager) generateKey() (*Key, error) {
|
||||
seed := make([]byte, ed25519.SeedSize)
|
||||
if _, err := rand.Read(seed); err != nil {
|
||||
|
|
@ -243,6 +245,7 @@ func (m *Manager) generateKey() (*Key, error) {
|
|||
return k.clone(), nil
|
||||
}
|
||||
|
||||
// persistKey writes the private and public key records to disk using secure permissions.
|
||||
func (m *Manager) persistKey(k *Key) error {
|
||||
dir := m.keyDir()
|
||||
if err := fs.MkdirAll(dir); err != nil {
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
)
|
||||
|
||||
func TestManagerEnsureActiveKey(t *testing.T) {
|
||||
c := cfg.NewTestConfig("jwt-manager-active")
|
||||
c := cfg.TestConfig()
|
||||
m, err := NewManager(c)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, m)
|
||||
|
|
@ -53,7 +53,7 @@ func TestManagerEnsureActiveKey(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestManagerGenerateSecondKey(t *testing.T) {
|
||||
c := cfg.NewTestConfig("jwt-manager-rotate")
|
||||
c := cfg.TestConfig()
|
||||
m, err := NewManager(c)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
|
@ -24,6 +25,19 @@ var (
|
|||
errKeyNotFound = errors.New("jwt: key not found")
|
||||
)
|
||||
|
||||
const (
|
||||
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
|
||||
jwksFetchMaxRetries = 3
|
||||
// jwksFetchBaseDelay is the initial retry delay (with jitter) applied after the first failure.
|
||||
jwksFetchBaseDelay = 200 * time.Millisecond
|
||||
// jwksFetchMaxDelay is the upper bound for retry delays to prevent unbounded backoff.
|
||||
jwksFetchMaxDelay = 2 * time.Second
|
||||
)
|
||||
|
||||
// randInt63n is defined for deterministic testing of jitter (overridable in tests).
|
||||
var randInt63n = rand.Int63n
|
||||
|
||||
// cacheEntry stores the JWKS material cached on disk and in memory.
|
||||
type cacheEntry struct {
|
||||
URL string `json:"url"`
|
||||
ETag string `json:"etag,omitempty"`
|
||||
|
|
@ -149,6 +163,7 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
|
|||
return claims, nil
|
||||
}
|
||||
|
||||
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
|
||||
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
|
||||
keys, err := v.keysForURL(ctx, url, force)
|
||||
if err != nil {
|
||||
|
|
@ -172,71 +187,155 @@ func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force b
|
|||
return nil, errKeyNotFound
|
||||
}
|
||||
|
||||
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
|
||||
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
||||
ttl := 300 * time.Second
|
||||
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
|
||||
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
|
||||
}
|
||||
|
||||
if !force && v.cache.URL == url && len(v.cache.Keys) > 0 {
|
||||
age := v.now().Unix() - v.cache.FetchedAt
|
||||
if age >= 0 && time.Duration(age)*time.Second <= ttl {
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), nil
|
||||
}
|
||||
}
|
||||
attempts := 0
|
||||
|
||||
for {
|
||||
cached := v.snapshotCache()
|
||||
|
||||
if keys, ok := v.cachedKeys(url, ttl, cached, force); ok {
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
etag := ""
|
||||
if !force && cached.URL == url {
|
||||
etag = cached.ETag
|
||||
}
|
||||
|
||||
result, err := v.fetchJWKS(ctx, url, etag)
|
||||
if err != nil {
|
||||
if !force && cached.URL == url && len(cached.Keys) > 0 {
|
||||
return append([]PublicJWK(nil), cached.Keys...), nil
|
||||
}
|
||||
|
||||
attempts++
|
||||
if attempts >= jwksFetchMaxRetries {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
delay := backoffDuration(attempts)
|
||||
log.Debugf("jwt: jwks fetch retry %d for %s in %s (%s)", attempts, url, delay, err)
|
||||
|
||||
select {
|
||||
case <-time.After(delay):
|
||||
continue
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
if keys, ok := v.updateCache(url, result); ok {
|
||||
return keys, nil
|
||||
}
|
||||
// Cache changed by another goroutine between snapshot and update; retry.
|
||||
}
|
||||
}
|
||||
|
||||
// snapshotCache returns the current JWKS cache entry under lock for safe reading.
|
||||
func (v *Verifier) snapshotCache() cacheEntry {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
cache := v.cache
|
||||
return cache
|
||||
}
|
||||
|
||||
// cachedKeys returns cached JWKS keys if they are fresh enough and match the target URL.
|
||||
func (v *Verifier) cachedKeys(url string, ttl time.Duration, cache cacheEntry, force bool) ([]PublicJWK, bool) {
|
||||
if force || cache.URL != url || len(cache.Keys) == 0 {
|
||||
return nil, false
|
||||
}
|
||||
age := v.now().Unix() - cache.FetchedAt
|
||||
if age < 0 {
|
||||
return nil, false
|
||||
}
|
||||
if time.Duration(age)*time.Second > ttl {
|
||||
return nil, false
|
||||
}
|
||||
return append([]PublicJWK(nil), cache.Keys...), true
|
||||
}
|
||||
|
||||
type jwksFetchResult struct {
|
||||
keys []PublicJWK
|
||||
etag string
|
||||
fetchedAt int64
|
||||
notModified bool
|
||||
}
|
||||
|
||||
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
|
||||
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v.cache.URL == url && v.cache.ETag != "" {
|
||||
req.Header.Set("If-None-Match", v.cache.ETag)
|
||||
if etag != "" {
|
||||
req.Header.Set("If-None-Match", etag)
|
||||
}
|
||||
|
||||
resp, err := v.httpClient.Do(req)
|
||||
if err != nil {
|
||||
if v.cache.URL == url && len(v.cache.Keys) > 0 {
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode == http.StatusNotModified {
|
||||
v.cache.FetchedAt = v.now().Unix()
|
||||
_ = v.saveCacheLocked()
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), nil
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
if v.cache.URL == url && len(v.cache.Keys) > 0 {
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), nil
|
||||
switch resp.StatusCode {
|
||||
case http.StatusNotModified:
|
||||
return &jwksFetchResult{
|
||||
etag: etag,
|
||||
fetchedAt: v.now().Unix(),
|
||||
notModified: true,
|
||||
}, nil
|
||||
case http.StatusOK:
|
||||
var body JWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(body.Keys) == 0 {
|
||||
return nil, errors.New("jwt: jwks contains no keys")
|
||||
}
|
||||
return &jwksFetchResult{
|
||||
keys: append([]PublicJWK(nil), body.Keys...),
|
||||
etag: resp.Header.Get("ETag"),
|
||||
fetchedAt: v.now().Unix(),
|
||||
}, nil
|
||||
default:
|
||||
return nil, fmt.Errorf("jwt: jwks fetch failed: %s", resp.Status)
|
||||
}
|
||||
}
|
||||
|
||||
var body JWKS
|
||||
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(body.Keys) == 0 {
|
||||
return nil, errors.New("jwt: jwks contains no keys")
|
||||
// updateCache stores the JWKS fetch result on success and returns the fresh keys.
|
||||
func (v *Verifier) updateCache(url string, result *jwksFetchResult) ([]PublicJWK, bool) {
|
||||
v.mu.Lock()
|
||||
defer v.mu.Unlock()
|
||||
|
||||
if result.notModified {
|
||||
if v.cache.URL != url {
|
||||
return nil, false
|
||||
}
|
||||
v.cache.FetchedAt = result.fetchedAt
|
||||
if result.etag != "" {
|
||||
v.cache.ETag = result.etag
|
||||
}
|
||||
_ = v.saveCacheLocked()
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), true
|
||||
}
|
||||
|
||||
v.cache = cacheEntry{
|
||||
URL: url,
|
||||
ETag: resp.Header.Get("ETag"),
|
||||
Keys: append([]PublicJWK(nil), body.Keys...),
|
||||
FetchedAt: v.now().Unix(),
|
||||
ETag: result.etag,
|
||||
Keys: append([]PublicJWK(nil), result.keys...),
|
||||
FetchedAt: result.fetchedAt,
|
||||
}
|
||||
_ = v.saveCacheLocked()
|
||||
|
||||
return append([]PublicJWK(nil), body.Keys...), nil
|
||||
return append([]PublicJWK(nil), v.cache.Keys...), true
|
||||
}
|
||||
|
||||
// loadCache restores a previously persisted JWKS cache entry from disk.
|
||||
func (v *Verifier) loadCache() error {
|
||||
if v.cachePath == "" || !fs.FileExists(v.cachePath) {
|
||||
return nil
|
||||
|
|
@ -256,6 +355,7 @@ func (v *Verifier) loadCache() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
// saveCacheLocked persists the current cache entry to disk; caller must hold the mutex.
|
||||
func (v *Verifier) saveCacheLocked() error {
|
||||
if v.cachePath == "" {
|
||||
return nil
|
||||
|
|
@ -269,3 +369,22 @@ func (v *Verifier) saveCacheLocked() error {
|
|||
}
|
||||
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
|
||||
}
|
||||
|
||||
// backoffDuration returns the retry delay for the given fetch attempt, adding jitter.
|
||||
func backoffDuration(attempt int) time.Duration {
|
||||
if attempt < 1 {
|
||||
attempt = 1
|
||||
}
|
||||
|
||||
base := jwksFetchBaseDelay << (attempt - 1)
|
||||
if base > jwksFetchMaxDelay {
|
||||
base = jwksFetchMaxDelay
|
||||
}
|
||||
|
||||
jitterRange := base / 2
|
||||
if jitterRange > 0 {
|
||||
base += time.Duration(randInt63n(int64(jitterRange) + 1))
|
||||
}
|
||||
|
||||
return base
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import (
|
|||
)
|
||||
|
||||
func TestVerifierPrimeAndVerify(t *testing.T) {
|
||||
portalCfg := cfg.NewTestConfig("jwt-verifier-portal")
|
||||
portalCfg := cfg.TestConfig()
|
||||
clusterUUID := rnd.UUIDv7()
|
||||
portalCfg.Options().ClusterUUID = clusterUUID
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestIssuerClampTTL(t *testing.T) {
|
||||
portalCfg := cfg.NewTestConfig("jwt-issuer-ttl")
|
||||
portalCfg := cfg.TestConfig()
|
||||
mgr, err := NewManager(portalCfg)
|
||||
require.NoError(t, err)
|
||||
mgr.now = func() time.Time { return time.Unix(0, 0) }
|
||||
|
|
@ -136,3 +136,33 @@ func TestIssuerClampTTL(t *testing.T) {
|
|||
ttl := parsed.ExpiresAt.Time.Sub(parsed.IssuedAt.Time)
|
||||
require.Equal(t, MaxTokenTTL, ttl)
|
||||
}
|
||||
|
||||
func TestBackoffDuration(t *testing.T) {
|
||||
origRand := randInt63n
|
||||
randInt63n = func(n int64) int64 {
|
||||
if n <= 0 {
|
||||
return 0
|
||||
}
|
||||
return n - 1
|
||||
}
|
||||
t.Cleanup(func() { randInt63n = origRand })
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
attempt int
|
||||
expect time.Duration
|
||||
}{
|
||||
{"Attempt1", 1, 300 * time.Millisecond},
|
||||
{"Attempt2", 2, 600 * time.Millisecond},
|
||||
{"Attempt3", 3, 1200 * time.Millisecond},
|
||||
{"Attempt4", 4, 2400 * time.Millisecond},
|
||||
{"Attempt5", 5, 3 * time.Second},
|
||||
{"AttemptZero", 0, 300 * time.Millisecond},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if got := backoffDuration(tt.attempt); got != tt.expect {
|
||||
t.Errorf("%s: expected %s, got %s", tt.name, tt.expect, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue