Entity: Ensure that AuthID wrap/unwrap is used for auth_user and auth_sessions, and that auth_sessions wrap/unwrap on create/save/find as required

This commit is contained in:
Keith Martin 2025-11-07 23:52:49 +10:00
parent d90182833e
commit babcb59d22
5 changed files with 212 additions and 25 deletions

View file

@ -50,7 +50,7 @@ type Session struct {
AuthProvider string `gorm:"type:VARBINARY(128);default:'';" json:"AuthProvider" yaml:"AuthProvider,omitempty"`
AuthMethod string `gorm:"type:VARBINARY(128);default:'';" json:"AuthMethod" yaml:"AuthMethod,omitempty"`
AuthIssuer string `gorm:"type:VARBINARY(255);default:'';" json:"AuthIssuer,omitempty" yaml:"AuthIssuer,omitempty"`
AuthID string `gorm:"type:VARBINARY(255);index;default:'';" json:"AuthID" yaml:"AuthID,omitempty"`
AuthID string `gorm:"type:VARBINARY(264);index;default:'';" json:"AuthID" yaml:"AuthID,omitempty"` // Make sure that you wrap and unwrap if using auth_id in a query.
AuthScope string `gorm:"size:1024;default:'';" json:"AuthScope" yaml:"AuthScope,omitempty"`
GrantType string `gorm:"type:VARBINARY(64);default:'';" json:"GrantType" yaml:"GrantType,omitempty"`
LastActive int64 `json:"LastActive" yaml:"LastActive,omitempty"`
@ -276,6 +276,27 @@ func (m *Session) Updates(values interface{}) error {
return UnscopedDb().Model(m).Updates(values).Error
}
// Wraps a string value in pseudo XML to force type to string
func wrapString(s string) (r string) {
r = s
if s != "" && !strings.HasPrefix(s, "<pp>") && !strings.HasSuffix(s, "</pp>") {
r = fmt.Sprintf("<pp>%s</pp>", s)
}
return r
}
// Wraps the AuthID field so that SQLite will save it correctly
func (m *Session) wrapAuthID() {
m.AuthID = wrapString(m.AuthID)
}
// Unwraps the AuthID field so that PhotoPrism can use it correctly
func (m *Session) unwrapAuthID() {
if m.AuthID != "" && strings.HasPrefix(m.AuthID, "<pp>") && strings.HasSuffix(m.AuthID, "</pp>") {
m.AuthID = strings.TrimSuffix(strings.TrimPrefix(m.AuthID, "<pp>"), "</pp>")
}
}
// BeforeCreate creates a random UID if needed before inserting a new row to the database.
func (m *Session) BeforeCreate(scope *gorm.Scope) error {
if rnd.InvalidRefID(m.RefID) {
@ -283,6 +304,7 @@ func (m *Session) BeforeCreate(scope *gorm.Scope) error {
Log("session", "set ref id", scope.SetColumn("RefID", m.RefID))
}
m.wrapAuthID()
if rnd.IsSessionID(m.ID) {
return nil
}
@ -292,6 +314,36 @@ func (m *Session) BeforeCreate(scope *gorm.Scope) error {
return scope.SetColumn("ID", m.ID)
}
// BeforeSave ensures that the AuthID will save correctly on SQLite
func (m *Session) BeforeSave(scope *gorm.Scope) error {
m.wrapAuthID()
return nil
}
// BeforeUpdate ensures that the AuthID will save correctly on SQLite
func (m *Session) BeforeUpdate(scope *gorm.Scope) error {
m.wrapAuthID()
return nil
}
// AfterSave ensures that the AuthID will not have the prefix and suffix added so that it will save correctly on SQLite
func (m *Session) AfterSave(scope *gorm.Scope) error {
m.unwrapAuthID()
return nil
}
// AfterUpdate ensures that the AuthID will not have the prefix and suffix added so that it will save correctly on SQLite
func (m *Session) AfterUpdate(scope *gorm.Scope) error {
m.unwrapAuthID()
return nil
}
// AfterFind ensures that the AuthID will not have the prefix and suffix added so that it will save correctly on SQLite
func (m *Session) AfterFind(scope *gorm.Scope) error {
m.unwrapAuthID()
return nil
}
// SetClient sets the client of this session.
func (m *Session) SetClient(c *Client) *Session {
if c == nil {

View file

@ -49,7 +49,7 @@ func DeleteChildSessions(s *Session) (deleted int) {
found := Sessions{}
if err := Db().Where("auth_id = ? AND auth_method = ?", s.ID, authn.MethodSession.String()).Find(&found).Error; err != nil {
if err := Db().Where("auth_id = ? AND auth_method = ?", wrapString(s.ID), authn.MethodSession.String()).Find(&found).Error; err != nil {
event.AuditErr([]string{"failed to find child sessions", status.Error(err)})
return deleted
}

View file

@ -8,6 +8,7 @@ import (
"github.com/gin-gonic/gin"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/photoprism/photoprism/internal/auth/acl"
"github.com/photoprism/photoprism/pkg/authn"
@ -238,10 +239,11 @@ func TestSession_Create(t *testing.T) {
s.SetAuthToken("69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7xxx")
err := s.Create()
require.Nil(t, err)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
s.Delete()
})
m2 := FindSessionByRefID("sessxkkcxxxx")
assert.Equal(t, "charles", m2.UserName)
@ -264,18 +266,19 @@ func TestSession_Create(t *testing.T) {
s.SetAuthToken(authToken)
err := s.Create()
require.Nil(t, err)
if err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
s.Delete()
})
m2, _ := FindSession(id)
assert.NotEqual(t, "123", m2.RefID)
})
t.Run("IdAlreadyExists", func(t *testing.T) {
authToken := "69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7ac0"
m := FindSessionByRefID("sessxkkcxxxx")
assert.Empty(t, m)
s := &Session{
UserName: "charles",
SessExpires: unix.Day * 3,
@ -283,11 +286,54 @@ func TestSession_Create(t *testing.T) {
RefID: "sessxkkcxxxx",
}
s.SetAuthToken(authToken)
s.SetAuthToken("69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7xxx")
err := s.Create()
require.Nil(t, err)
t.Cleanup(func() {
s.Delete()
})
authToken := "69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7ac0"
s2 := &Session{
UserName: "charles",
SessExpires: unix.Day * 3,
SessTimeout: unix.Now() + unix.Week,
RefID: "sessxkkcxxxx",
}
s2.SetAuthToken(authToken)
err = s2.Create()
assert.Error(t, err)
})
t.Run("LongNumericAuthID", func(t *testing.T) {
refID := rnd.RefID("ts")
m := FindSessionByRefID(refID)
assert.Empty(t, m)
s := &Session{
UserName: "charles",
SessExpires: unix.Day * 3,
SessTimeout: unix.Now() + unix.Week,
RefID: refID,
AuthID: "012345678901234567890",
}
s.SetAuthToken("69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7xxs")
err := s.Create()
require.Nil(t, err)
t.Cleanup(func() {
s.Delete()
})
m2 := FindSessionByRefID(refID)
assert.Equal(t, "charles", m2.UserName)
assert.Equal(t, "012345678901234567890", m2.AuthID)
})
}
func TestSession_Save(t *testing.T) {
@ -304,14 +350,37 @@ func TestSession_Save(t *testing.T) {
s.SetAuthToken("69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7xxy")
err := s.Save()
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
m2 := FindSessionByRefID("sessxkkcxxxy")
assert.Equal(t, "chris", m2.UserName)
})
t.Run("LongNumericAuthID", func(t *testing.T) {
refID := rnd.RefID("ts")
m := FindSessionByRefID(refID)
assert.Empty(t, m)
s := &Session{
UserName: "chris",
SessExpires: unix.Day * 3,
SessTimeout: unix.Now() + unix.Week,
RefID: refID,
AuthID: "012345678901234567890",
}
s.SetAuthToken("69be27ac5ca305b394046a83f6fda18167ca3d3f2dbe7xxy")
err := s.Save()
require.Nil(t, err)
t.Cleanup(func() {
s.Delete()
})
m2 := FindSessionByRefID(refID)
assert.Equal(t, "chris", m2.UserName)
assert.Equal(t, "012345678901234567890", m2.AuthID)
})
}
func TestSession_Updates(t *testing.T) {

View file

@ -52,7 +52,7 @@ type User struct {
AuthProvider string `gorm:"type:VARBINARY(128);default:'';" json:"AuthProvider" yaml:"AuthProvider,omitempty"`
AuthMethod string `gorm:"type:VARBINARY(128);default:'';" json:"AuthMethod" yaml:"AuthMethod,omitempty"`
AuthIssuer string `gorm:"type:VARBINARY(255);default:'';" json:"AuthIssuer,omitempty" yaml:"AuthIssuer,omitempty"`
AuthID string `gorm:"type:VARBINARY(264);index;default:'';" json:"AuthID" yaml:"AuthID,omitempty"`
AuthID string `gorm:"type:VARBINARY(264);index;default:'';" json:"AuthID" yaml:"AuthID,omitempty"` // Make sure that you wrap and unwrap if using auth_id in a query. See FindUser below.
UserName string `gorm:"size:200;index;" json:"Name" yaml:"Name,omitempty"`
DisplayName string `gorm:"size:200;" json:"DisplayName" yaml:"DisplayName,omitempty"`
UserEmail string `gorm:"size:255;index;" json:"Email" yaml:"Email,omitempty"`
@ -148,18 +148,18 @@ func FindUser(find User) *User {
stmt = stmt.Where("user_uid = ?", find.UserUID)
} else if authn.ProviderOIDC.Equal(find.AuthProvider) && find.AuthID != "" {
if find.AuthIssuer == "" {
stmt = stmt.Where("auth_provider = ? AND auth_id = ?", find.AuthProvider, find.AuthID)
stmt = stmt.Where("auth_provider = ? AND auth_id = ?", find.AuthProvider, wrapString(find.AuthID))
} else {
stmt = stmt.Where("auth_provider = ? AND (auth_issuer = '' OR auth_issuer = ?) AND auth_id = ?", find.AuthProvider, find.AuthIssuer, find.AuthID)
stmt = stmt.Where("auth_provider = ? AND (auth_issuer = '' OR auth_issuer = ?) AND auth_id = ?", find.AuthProvider, find.AuthIssuer, wrapString(find.AuthID))
}
} else if find.AuthProvider != "" && find.AuthID != "" && find.UserName != "" {
stmt = stmt.Where("auth_provider = ? AND auth_id = ? OR user_name = ?", find.AuthProvider, find.AuthID, find.UserName)
stmt = stmt.Where("auth_provider = ? AND auth_id = ? OR user_name = ?", find.AuthProvider, wrapString(find.AuthID), find.UserName)
} else if find.UserName != "" {
stmt = stmt.Where("user_name = ?", find.UserName)
} else if find.UserEmail != "" {
stmt = stmt.Where("user_email = ?", find.UserEmail)
} else if find.AuthProvider != "" && find.AuthID != "" {
stmt = stmt.Where("auth_provider = ? AND auth_id = ?", find.AuthProvider, find.AuthID)
stmt = stmt.Where("auth_provider = ? AND auth_id = ?", find.AuthProvider, wrapString(find.AuthID))
} else {
return nil
}
@ -413,12 +413,12 @@ func (m *User) BeforeCreate(scope *gorm.Scope) error {
Log("user", "set ref id", scope.SetColumn("RefID", m.RefID))
}
m.wrapAuthID()
if rnd.IsUnique(m.UserUID, UserUID) {
return nil
}
m.wrapAuthID()
m.UserUID = rnd.GenerateUID(UserUID)
return scope.SetColumn("UserUID", m.UserUID)
}
@ -673,17 +673,17 @@ func (m *User) SetMethod(method authn.MethodType) *User {
// SetAuthID sets a custom authentication identifier.
func (m *User) SetAuthID(id, issuer string) *User {
// Update auth id if not empty.
if authId := clean.Auth(id); authId == "" {
if authID := clean.Auth(id); authID == "" {
return m
} else {
m.AuthID = authId
m.AuthID = authID
m.AuthIssuer = clean.Uri(issuer)
}
// Make sure other users do not use the same identifier.
if m.HasUID() && m.AuthProvider != "" {
if err := UnscopedDb().Model(&User{}).
Where("user_uid <> ? AND auth_provider = ? AND auth_id = ? AND super_admin = 0", m.UserUID, m.AuthProvider, m.AuthID).
Where("user_uid <> ? AND auth_provider = ? AND auth_id = ? AND super_admin = 0", m.UserUID, m.AuthProvider, wrapString(m.AuthID)).
Updates(Values{"auth_id": "", "auth_provider": authn.ProviderNone}).Error; err != nil {
event.AuditErr([]string{"user %s", "failed to resolve auth id conflicts", status.Error(err)}, m.RefID)
}

View file

@ -324,6 +324,32 @@ func TestUser_Create(t *testing.T) {
t.Fatal(err)
}
})
t.Run("LongNumericAuthID", func(t *testing.T) {
useruid := rnd.GenerateUID(UserUID)
var m = User{
UserUID: useruid,
UserName: "examplelong",
UserRole: string(acl.RoleGuest),
DisplayName: "Example Long",
SuperAdmin: false,
CanLogin: true,
AuthID: "012345678901234567890",
AuthProvider: string(authn.ProviderOIDC),
}
if err := m.Create(); err != nil {
t.Fatal(err)
}
t.Cleanup(func() {
m.Delete()
UnscopedDb().Delete(m)
})
assert.Equal(t, "examplelong", m.Username())
assert.Equal(t, "examplelong", m.UserName)
assert.Equal(t, "012345678901234567890", m.AuthID)
})
}
func TestUser_UpdateUsername(t *testing.T) {
@ -551,6 +577,14 @@ func TestFindUser(t *testing.T) {
assert.NotEmpty(t, m.UserUID)
assert.Equal(t, "jane.doe", m.UserName)
assert.Equal(t, "oidc", m.AuthProvider)
n := FindUser(User{AuthProvider: authn.ProviderOIDC.String(), AuthID: info.Subject})
require.NotNil(t, n)
assert.NotEmpty(t, n.UserUID)
assert.Equal(t, "jane.doe", n.UserName)
assert.Equal(t, "oidc", n.AuthProvider)
})
t.Run("UserName", func(t *testing.T) {
m := FindUser(User{UserName: "admin"})
@ -1822,6 +1856,38 @@ func TestUser_SetAuthID(t *testing.T) {
assert.Equal(t, uuid, m.AuthID)
assert.Equal(t, "", m.AuthIssuer)
})
t.Run("DupeAuthProviderAndID", func(t *testing.T) {
m := UserFixtures.Get("guest")
n := NewUser()
n.UserName = "guest2"
n.DisplayName = "Guest User2"
n.UserEmail = "guest2@example.com"
n.UserRole = acl.RoleGuest.String()
n.AuthProvider = authn.ProviderOIDC.String()
n.AuthMethod = authn.MethodDefault.String()
n.SuperAdmin = false
n.CanLogin = true
n.SetAuthID(uuid, issuer)
n.Save()
t.Cleanup(func() {
n.Delete()
UnscopedDb().Delete(n)
})
newUserUID := n.UserUID
m.SetAuthID(uuid, issuer)
assert.Equal(t, uuid, m.AuthID)
assert.Equal(t, issuer, m.AuthIssuer)
n = FindUserByUID(newUserUID)
require.NotNil(t, n)
assert.Equal(t, "guest2", n.UserName)
assert.Equal(t, "", n.AuthID)
assert.Equal(t, authn.ProviderNone.String(), n.AuthProvider)
})
}
func TestUser_UpdateAuthID(t *testing.T) {