Auth: Improve JWKS Fetch Concurrency & Timeouts #5230

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2025-09-25 18:46:24 +02:00
parent bae8ceb3a7
commit 633d4222ab
8 changed files with 211 additions and 38 deletions

View file

@ -260,6 +260,7 @@ If anything in this file conflicts with the `Makefile` or the Developer Guide, t
- Respect precedence: `options.yml` overrides CLI/env values, which override defaults. When adding a new option, update `internal/config/options.go` (yaml/flag tags), register it in `internal/config/flags.go`, expose a getter, surface it in `*config.Report()`, and write generated values back to `options.yml` by setting `c.options.OptionsYaml` before persisting. Use `CliTestContext` in `internal/config/test.go` to exercise new flags.
- When touching configuration in Go code, use the public accessors on `*config.Config` (e.g. `Config.JWKSUrl()`, `Config.SetJWKSUrl()`, `Config.ClusterUUID()`) instead of mutating `Config.Options()` directly; reserve raw option tweaks for test fixtures only.
- Logging: use the shared logger (`event.Log`) via the package-level `log` variable (see `internal/auth/jwt/logger.go`) instead of direct `fmt.Print*` or ad-hoc loggers.
- Favor explicit CLI flags: check `c.cliCtx.IsSet("<flag>")` before overriding user-supplied values, and follow the `ClusterUUID` pattern (`options.yml` → CLI/env → generated UUIDv4 persisted).
- Database helpers: reuse `conf.Db()` / `conf.Database*()`, avoid GORM `WithContext`, quote MySQL identifiers, and reject unsupported drivers early.
- Handler conventions: reuse limiter stacks (`limiter.Auth`, `limiter.Login`) and `limiter.AbortJSON` for 429s, lean on `api.ClientIP`, `header.BearerToken`, and `Abort*` helpers, compare secrets with constant time checks, set `Cache-Control: no-store` on sensitive responses, and register routes in `internal/server/routes.go`. For new list endpoints default `count=100` (max 1000) and `offset≥0`, document parameters explicitly, and set portal mode via `PHOTOPRISM_NODE_ROLE=portal` plus `PHOTOPRISM_JOIN_TOKEN` when needed.

View file

@ -29,6 +29,7 @@ type ClaimsSpec struct {
TTL time.Duration
}
// validate performs sanity checks on the claim specification before issuing a token.
func (s ClaimsSpec) validate() error {
if strings.TrimSpace(s.Issuer) == "" {
return errors.New("jwt: issuer required")

View file

@ -4,10 +4,23 @@ import (
"os"
"testing"
"github.com/sirupsen/logrus"
"github.com/photoprism/photoprism/internal/config"
"github.com/photoprism/photoprism/internal/event"
"github.com/photoprism/photoprism/pkg/fs"
)
func TestMain(m *testing.M) {
// Init test logger.
log = logrus.StandardLogger()
log.SetLevel(logrus.TraceLevel)
event.AuditLog = log
c := config.TestConfig()
defer c.CloseDb()
// Run unit tests.
code := m.Run()
// Remove temporary SQLite files after running the tests.

View file

@ -0,0 +1,6 @@
package jwt
import "github.com/photoprism/photoprism/internal/event"
// log provides package-wide logging using the shared event logger.
var log = event.Log

View file

@ -134,6 +134,7 @@ func (m *Manager) AllKeys() []*Key {
return out
}
// loadKeys reads existing key records from disk into memory.
func (m *Manager) loadKeys() error {
dir := m.keyDir()
@ -208,6 +209,7 @@ func (m *Manager) loadKeys() error {
return nil
}
// generateKey creates a fresh Ed25519 key pair, persists it, and returns a clone.
func (m *Manager) generateKey() (*Key, error) {
seed := make([]byte, ed25519.SeedSize)
if _, err := rand.Read(seed); err != nil {
@ -243,6 +245,7 @@ func (m *Manager) generateKey() (*Key, error) {
return k.clone(), nil
}
// persistKey writes the private and public key records to disk using secure permissions.
func (m *Manager) persistKey(k *Key) error {
dir := m.keyDir()
if err := fs.MkdirAll(dir); err != nil {

View file

@ -14,7 +14,7 @@ import (
)
func TestManagerEnsureActiveKey(t *testing.T) {
c := cfg.NewTestConfig("jwt-manager-active")
c := cfg.TestConfig()
m, err := NewManager(c)
require.NoError(t, err)
require.NotNil(t, m)
@ -53,7 +53,7 @@ func TestManagerEnsureActiveKey(t *testing.T) {
}
func TestManagerGenerateSecondKey(t *testing.T) {
c := cfg.NewTestConfig("jwt-manager-rotate")
c := cfg.TestConfig()
m, err := NewManager(c)
require.NoError(t, err)

View file

@ -7,6 +7,7 @@ import (
"encoding/json"
"errors"
"fmt"
"math/rand"
"net/http"
"os"
"path/filepath"
@ -24,6 +25,19 @@ var (
errKeyNotFound = errors.New("jwt: key not found")
)
const (
// jwksFetchMaxRetries caps the number of immediate retry attempts after a fetch error.
jwksFetchMaxRetries = 3
// jwksFetchBaseDelay is the initial retry delay (with jitter) applied after the first failure.
jwksFetchBaseDelay = 200 * time.Millisecond
// jwksFetchMaxDelay is the upper bound for retry delays to prevent unbounded backoff.
jwksFetchMaxDelay = 2 * time.Second
)
// randInt63n is defined for deterministic testing of jitter (overridable in tests).
var randInt63n = rand.Int63n
// cacheEntry stores the JWKS material cached on disk and in memory.
type cacheEntry struct {
URL string `json:"url"`
ETag string `json:"etag,omitempty"`
@ -149,6 +163,7 @@ func (v *Verifier) VerifyToken(ctx context.Context, tokenString string, expected
return claims, nil
}
// publicKeyForKid resolves the public key for the given key ID, fetching JWKS data if needed.
func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force bool) (ed25519.PublicKey, error) {
keys, err := v.keysForURL(ctx, url, force)
if err != nil {
@ -172,71 +187,155 @@ func (v *Verifier) publicKeyForKid(ctx context.Context, url, kid string, force b
return nil, errKeyNotFound
}
// keysForURL returns JWKS keys for the specified endpoint, reusing cache when possible.
func (v *Verifier) keysForURL(ctx context.Context, url string, force bool) ([]PublicJWK, error) {
v.mu.Lock()
defer v.mu.Unlock()
ttl := 300 * time.Second
if v.conf != nil && v.conf.JWKSCacheTTL() > 0 {
ttl = time.Duration(v.conf.JWKSCacheTTL()) * time.Second
}
if !force && v.cache.URL == url && len(v.cache.Keys) > 0 {
age := v.now().Unix() - v.cache.FetchedAt
if age >= 0 && time.Duration(age)*time.Second <= ttl {
return append([]PublicJWK(nil), v.cache.Keys...), nil
}
}
attempts := 0
for {
cached := v.snapshotCache()
if keys, ok := v.cachedKeys(url, ttl, cached, force); ok {
return keys, nil
}
etag := ""
if !force && cached.URL == url {
etag = cached.ETag
}
result, err := v.fetchJWKS(ctx, url, etag)
if err != nil {
if !force && cached.URL == url && len(cached.Keys) > 0 {
return append([]PublicJWK(nil), cached.Keys...), nil
}
attempts++
if attempts >= jwksFetchMaxRetries {
return nil, err
}
delay := backoffDuration(attempts)
log.Debugf("jwt: jwks fetch retry %d for %s in %s (%s)", attempts, url, delay, err)
select {
case <-time.After(delay):
continue
case <-ctx.Done():
return nil, ctx.Err()
}
}
if keys, ok := v.updateCache(url, result); ok {
return keys, nil
}
// Cache changed by another goroutine between snapshot and update; retry.
}
}
// snapshotCache returns the current JWKS cache entry under lock for safe reading.
func (v *Verifier) snapshotCache() cacheEntry {
v.mu.Lock()
defer v.mu.Unlock()
cache := v.cache
return cache
}
// cachedKeys returns cached JWKS keys if they are fresh enough and match the target URL.
func (v *Verifier) cachedKeys(url string, ttl time.Duration, cache cacheEntry, force bool) ([]PublicJWK, bool) {
if force || cache.URL != url || len(cache.Keys) == 0 {
return nil, false
}
age := v.now().Unix() - cache.FetchedAt
if age < 0 {
return nil, false
}
if time.Duration(age)*time.Second > ttl {
return nil, false
}
return append([]PublicJWK(nil), cache.Keys...), true
}
type jwksFetchResult struct {
keys []PublicJWK
etag string
fetchedAt int64
notModified bool
}
// fetchJWKS downloads the JWKS document (respecting conditional requests) and returns the parsed keys.
func (v *Verifier) fetchJWKS(ctx context.Context, url, etag string) (*jwksFetchResult, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
if err != nil {
return nil, err
}
if v.cache.URL == url && v.cache.ETag != "" {
req.Header.Set("If-None-Match", v.cache.ETag)
if etag != "" {
req.Header.Set("If-None-Match", etag)
}
resp, err := v.httpClient.Do(req)
if err != nil {
if v.cache.URL == url && len(v.cache.Keys) > 0 {
return append([]PublicJWK(nil), v.cache.Keys...), nil
}
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode == http.StatusNotModified {
v.cache.FetchedAt = v.now().Unix()
_ = v.saveCacheLocked()
return append([]PublicJWK(nil), v.cache.Keys...), nil
}
if resp.StatusCode != http.StatusOK {
if v.cache.URL == url && len(v.cache.Keys) > 0 {
return append([]PublicJWK(nil), v.cache.Keys...), nil
switch resp.StatusCode {
case http.StatusNotModified:
return &jwksFetchResult{
etag: etag,
fetchedAt: v.now().Unix(),
notModified: true,
}, nil
case http.StatusOK:
var body JWKS
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, err
}
if len(body.Keys) == 0 {
return nil, errors.New("jwt: jwks contains no keys")
}
return &jwksFetchResult{
keys: append([]PublicJWK(nil), body.Keys...),
etag: resp.Header.Get("ETag"),
fetchedAt: v.now().Unix(),
}, nil
default:
return nil, fmt.Errorf("jwt: jwks fetch failed: %s", resp.Status)
}
}
var body JWKS
if err := json.NewDecoder(resp.Body).Decode(&body); err != nil {
return nil, err
}
if len(body.Keys) == 0 {
return nil, errors.New("jwt: jwks contains no keys")
// updateCache stores the JWKS fetch result on success and returns the fresh keys.
func (v *Verifier) updateCache(url string, result *jwksFetchResult) ([]PublicJWK, bool) {
v.mu.Lock()
defer v.mu.Unlock()
if result.notModified {
if v.cache.URL != url {
return nil, false
}
v.cache.FetchedAt = result.fetchedAt
if result.etag != "" {
v.cache.ETag = result.etag
}
_ = v.saveCacheLocked()
return append([]PublicJWK(nil), v.cache.Keys...), true
}
v.cache = cacheEntry{
URL: url,
ETag: resp.Header.Get("ETag"),
Keys: append([]PublicJWK(nil), body.Keys...),
FetchedAt: v.now().Unix(),
ETag: result.etag,
Keys: append([]PublicJWK(nil), result.keys...),
FetchedAt: result.fetchedAt,
}
_ = v.saveCacheLocked()
return append([]PublicJWK(nil), body.Keys...), nil
return append([]PublicJWK(nil), v.cache.Keys...), true
}
// loadCache restores a previously persisted JWKS cache entry from disk.
func (v *Verifier) loadCache() error {
if v.cachePath == "" || !fs.FileExists(v.cachePath) {
return nil
@ -256,6 +355,7 @@ func (v *Verifier) loadCache() error {
return nil
}
// saveCacheLocked persists the current cache entry to disk; caller must hold the mutex.
func (v *Verifier) saveCacheLocked() error {
if v.cachePath == "" {
return nil
@ -269,3 +369,22 @@ func (v *Verifier) saveCacheLocked() error {
}
return os.WriteFile(v.cachePath, data, fs.ModeSecretFile)
}
// backoffDuration returns the retry delay for the given fetch attempt, adding jitter.
func backoffDuration(attempt int) time.Duration {
if attempt < 1 {
attempt = 1
}
base := jwksFetchBaseDelay << (attempt - 1)
if base > jwksFetchMaxDelay {
base = jwksFetchMaxDelay
}
jitterRange := base / 2
if jitterRange > 0 {
base += time.Duration(randInt63n(int64(jitterRange) + 1))
}
return base
}

View file

@ -18,7 +18,7 @@ import (
)
func TestVerifierPrimeAndVerify(t *testing.T) {
portalCfg := cfg.NewTestConfig("jwt-verifier-portal")
portalCfg := cfg.TestConfig()
clusterUUID := rnd.UUIDv7()
portalCfg.Options().ClusterUUID = clusterUUID
@ -105,7 +105,7 @@ func TestVerifierPrimeAndVerify(t *testing.T) {
}
func TestIssuerClampTTL(t *testing.T) {
portalCfg := cfg.NewTestConfig("jwt-issuer-ttl")
portalCfg := cfg.TestConfig()
mgr, err := NewManager(portalCfg)
require.NoError(t, err)
mgr.now = func() time.Time { return time.Unix(0, 0) }
@ -136,3 +136,33 @@ func TestIssuerClampTTL(t *testing.T) {
ttl := parsed.ExpiresAt.Time.Sub(parsed.IssuedAt.Time)
require.Equal(t, MaxTokenTTL, ttl)
}
func TestBackoffDuration(t *testing.T) {
origRand := randInt63n
randInt63n = func(n int64) int64 {
if n <= 0 {
return 0
}
return n - 1
}
t.Cleanup(func() { randInt63n = origRand })
tests := []struct {
name string
attempt int
expect time.Duration
}{
{"Attempt1", 1, 300 * time.Millisecond},
{"Attempt2", 2, 600 * time.Millisecond},
{"Attempt3", 3, 1200 * time.Millisecond},
{"Attempt4", 4, 2400 * time.Millisecond},
{"Attempt5", 5, 3 * time.Second},
{"AttemptZero", 0, 300 * time.Millisecond},
}
for _, tt := range tests {
if got := backoffDuration(tt.attempt); got != tt.expect {
t.Errorf("%s: expected %s, got %s", tt.name, tt.expect, got)
}
}
}