AI: Improve conflict resolution when merging face clusters #5167

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2025-10-07 18:58:21 +02:00
parent 3502251f7e
commit 68e1ddcc89
8 changed files with 267 additions and 10 deletions

View file

@ -37,6 +37,10 @@ var FacesCommands = &cli.Command{
Aliases: []string{"f"},
Usage: "fix discovered issues",
},
&cli.StringFlag{
Name: "subject",
Usage: "limit audit to the specific subject UID",
},
},
Action: facesAuditAction,
},
@ -75,8 +79,14 @@ var FacesCommands = &cli.Command{
Action: facesUpdateAction,
},
{
Name: "optimize",
Usage: "Optimizes face clusters",
Name: "optimize",
Usage: "Optimizes face clusters",
Flags: []cli.Flag{
&cli.BoolFlag{
Name: "retry",
Usage: "reset merge retry counters before optimizing",
},
},
Action: facesOptimizeAction,
},
},
@ -130,7 +140,9 @@ func facesAuditAction(ctx *cli.Context) error {
w := get.Faces()
if err := w.Audit(ctx.Bool("fix")); err != nil {
subject := strings.TrimSpace(ctx.String("subject"))
if err := w.Audit(ctx.Bool("fix"), subject); err != nil {
return err
} else {
elapsed := time.Since(start)
@ -348,6 +360,14 @@ func facesOptimizeAction(ctx *cli.Context) error {
w := get.Faces()
if ctx.Bool("retry") {
if reset, err := query.ResetFaceMergeRetry(""); err != nil {
return err
} else if reset > 0 {
log.Infof("faces: reset merge retry counters for %s", english.Plural(reset, "cluster", "clusters"))
}
}
if res, err := w.Optimize(); err != nil {
return err
} else {

View file

@ -28,6 +28,8 @@ type Face struct {
SampleRadius float64 `json:"SampleRadius" yaml:"SampleRadius,omitempty"`
Collisions int `json:"Collisions" yaml:"Collisions,omitempty"`
CollisionRadius float64 `json:"CollisionRadius" yaml:"CollisionRadius,omitempty"`
MergeRetry uint8 `gorm:"type:TINYINT(3);default:0" json:"-" yaml:"-"`
MergeNotes string `gorm:"type:VARCHAR(255);default:'';" json:"-" yaml:"-"`
EmbeddingJSON json.RawMessage `gorm:"type:MEDIUMBLOB;" json:"-" yaml:"EmbeddingJSON,omitempty"`
embedding face.Embedding `gorm:"-" yaml:"-"`
MatchedAt *time.Time `json:"MatchedAt" yaml:"MatchedAt,omitempty"`

View file

@ -3,7 +3,10 @@ package query
import (
"errors"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/jinzhu/gorm"
@ -23,6 +26,21 @@ type FaceMap map[string]entity.Face
// because markers still reference them. Callers may treat this as a non-fatal warning.
var ErrRetainedManualClusters = errors.New("faces: retained manual clusters after merge")
// MergeMaxRetry limits how often the optimiser retries stubborn manual clusters (0 = unlimited).
var MergeMaxRetry = 1
func init() {
if v := os.Getenv("PHOTOPRISM_FACE_MERGE_MAX_RETRY"); v != "" {
if n, err := strconv.Atoi(v); err == nil {
if n < 0 {
n = 0
}
MergeMaxRetry = n
}
}
}
// FacesByID retrieves faces from the database and returns a map with the Face ID as key.
func FacesByID(knownOnly, unmatchedOnly, hidden, ignored bool) (FaceMap, IDs, error) {
faces, err := Faces(knownOnly, unmatchedOnly, hidden, ignored)
@ -79,6 +97,10 @@ func ManuallyAddedFaces(hidden, ignored bool, subj_uid string) (result entity.Fa
stmt = stmt.Where("subj_uid <> ''")
}
if MergeMaxRetry > 0 {
stmt = stmt.Where("merge_retry < ?", MergeMaxRetry)
}
if !ignored {
stmt = stmt.Where("face_kind <= 1")
}
@ -229,12 +251,45 @@ func MergeFaces(merge entity.Faces, ignored bool) (merged *entity.Face, err erro
} else if removed > 0 {
log.Debugf("faces: removed %d orphans of %d candidate for subject %s", removed, len(merge), clean.Log(subjUID))
} else {
note := fmt.Sprintf("retained markers after merge attempt on %s", time.Now().UTC().Format(time.RFC3339))
for _, candidate := range merge {
updates := entity.Values{
"MergeRetry": gorm.Expr("merge_retry + 1"),
"MergeNotes": note,
}
if err := Db().Model(&entity.Face{}).Where("id = ?", candidate.ID).Updates(updates).Error; err != nil {
log.Warnf("faces: failed updating merge retry for %s (%s)", candidate.ID, err)
} else {
candidate.MergeRetry++
candidate.MergeNotes = note
}
}
return merged, fmt.Errorf("%w: kept %d candidate cluster(s) [%s] for subject %s because markers still reference them", ErrRetainedManualClusters, len(merge), clean.Log(strings.Join(merge.IDs(), ", ")), clean.Log(subjUID))
}
return merged, err
}
// ResetFaceMergeRetry clears merge retry metadata for all (or subject-specific) clusters.
func ResetFaceMergeRetry(subjUID string) (int, error) {
stmt := Db().Model(&entity.Face{}).Where("merge_retry > 0")
if subjUID != "" {
stmt = stmt.Where("subj_uid = ?", subjUID)
}
res := stmt.UpdateColumns(entity.Values{"MergeRetry": 0, "MergeNotes": ""})
if res.Error != nil {
return 0, res.Error
}
return int(res.RowsAffected), nil
}
// ResolveFaceCollisions resolves collisions of different subject's faces.
func ResolveFaceCollisions() (conflicts, resolved int, err error) {
faces, ids, err := FacesByID(true, false, false, false)

View file

@ -1,12 +1,15 @@
package query
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/photoprism/photoprism/internal/ai/face"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/rnd"
)
func TestFaces(t *testing.T) {
@ -229,6 +232,61 @@ func TestMergeFaces(t *testing.T) {
})
}
func TestMergeFacesRetainedClusters(t *testing.T) {
subjUID := rnd.GenerateUID('j')
embeddingA := face.RandomEmbeddings(1, face.RegularFace)
embeddingB := face.RandomEmbeddings(1, face.RegularFace)
faceA := entity.NewFace(subjUID, entity.SrcManual, embeddingA)
require.NoError(t, faceA.Create())
faceB := entity.NewFace(subjUID, entity.SrcManual, embeddingB)
require.NoError(t, faceB.Create())
// Create markers that deliberately fail to match the merged embedding.
neutralEmbedding := face.Embeddings{face.NullEmbedding}
neutralJSON := neutralEmbedding.JSON()
markers := []*entity.Marker{
{
FileUID: rnd.GenerateUID('f'),
MarkerType: entity.MarkerFace,
MarkerSrc: entity.SrcManual,
FaceID: faceA.ID,
EmbeddingsJSON: neutralJSON,
},
{
FileUID: rnd.GenerateUID('f'),
MarkerType: entity.MarkerFace,
MarkerSrc: entity.SrcManual,
FaceID: faceB.ID,
EmbeddingsJSON: neutralJSON,
},
}
for _, marker := range markers {
require.NoError(t, entity.Db().Create(marker).Error)
}
_, err := MergeFaces(entity.Faces{*faceA, *faceB}, false)
require.Error(t, err)
require.True(t, errors.Is(err, ErrRetainedManualClusters))
var updated entity.Face
require.NoError(t, entity.Db().Where("id = ?", faceA.ID).First(&updated).Error)
require.NotZero(t, updated.MergeRetry)
require.NotEmpty(t, updated.MergeNotes)
resetCount, err := ResetFaceMergeRetry(subjUID)
require.NoError(t, err)
require.GreaterOrEqual(t, resetCount, 1)
require.NoError(t, entity.Db().Where("id = ?", faceA.ID).First(&updated).Error)
require.Zero(t, updated.MergeRetry)
require.Empty(t, updated.MergeNotes)
}
func TestResolveFaceCollisions(t *testing.T) {
c, r, err := ResolveFaceCollisions()

View file

@ -122,6 +122,39 @@ func Embeddings(single, unclustered bool, size, score int) (result face.Embeddin
return result, nil
}
// MarkerCountsByFaceIDs returns a map of marker counts for the provided face IDs.
func MarkerCountsByFaceIDs(faceIDs []string) (map[string]int, error) {
counts := make(map[string]int, len(faceIDs))
if len(faceIDs) == 0 {
return counts, nil
}
type row struct {
FaceID string
Count int
}
var rows []row
if err := Db().
Model(&entity.Marker{}).
Select("face_id, COUNT(*) AS count").
Where("marker_invalid = 0").
Where("marker_type = ?", entity.MarkerFace).
Where("face_id IN (?)", faceIDs).
Group("face_id").
Scan(&rows).Error; err != nil {
return counts, err
}
for _, r := range rows {
counts[r.FaceID] = r.Count
}
return counts, nil
}
// RemoveInvalidMarkerReferences removes face and subject references from invalid markers.
func RemoveInvalidMarkerReferences() (removed int64, err error) {
result := Db().

View file

@ -151,6 +151,36 @@ func TestEmbeddings(t *testing.T) {
})
}
func TestMarkerCountsByFaceIDs(t *testing.T) {
counts, err := MarkerCountsByFaceIDs(nil)
if err != nil {
t.Fatal(err)
}
assert.Empty(t, counts)
faces, err := Faces(false, false, false, false)
if err != nil {
t.Fatal(err)
}
if len(faces) == 0 {
t.Skip("no faces available in test dataset")
}
ids := []string{faces[0].ID}
counts, err = MarkerCountsByFaceIDs(ids)
if err != nil {
t.Fatal(err)
}
if len(counts) == 0 {
t.Skip("no markers found for sampled face")
}
assert.GreaterOrEqual(t, counts[faces[0].ID], 0)
}
func TestRemoveInvalidMarkerReferences(t *testing.T) {
affected, err := RemoveInvalidMarkerReferences()

View file

@ -16,7 +16,7 @@ import (
)
// Audit face clusters and subjects.
func (w *Faces) Audit(fix bool) (err error) {
func (w *Faces) Audit(fix bool, subjUID string) (err error) {
invalidFaces, invalidSubj, err := query.MarkersWithNonExistentReferences()
if err != nil {
@ -29,10 +29,14 @@ func (w *Faces) Audit(fix bool) (err error) {
log.Errorf("faces: %s (find subjects)", err)
}
if n := len(subj); n == 0 {
log.Infof("faces: found no subjects")
if subjUID == "" {
if n := len(subj); n == 0 {
log.Infof("faces: found no subjects")
} else {
log.Infof("faces: found %s", english.Plural(n, "subject", "subjects"))
}
} else {
log.Infof("faces: found %s", english.Plural(n, "subject", "subjects"))
log.Infof("faces: auditing subject %s (%s)", entity.SubjNames.Log(subjUID), clean.Log(subjUID))
}
// Fix non-existent marker subjects references?
@ -71,6 +75,54 @@ func (w *Faces) Audit(fix bool) (err error) {
return err
}
if subjUID != "" {
filtered := make(query.FaceMap, len(faces))
filteredIDs := make(query.IDs, 0, len(ids))
for _, id := range ids {
faceEntry := faces[id]
if faceEntry.SubjUID != subjUID {
continue
}
filtered[id] = faceEntry
filteredIDs = append(filteredIDs, id)
}
faces = filtered
ids = filteredIDs
if len(ids) == 0 {
log.Infof("faces: found no clusters for subject %s", entity.SubjNames.Log(subjUID))
}
}
stubborn := make([]entity.Face, 0)
stubbornIDs := make([]string, 0)
for _, id := range ids {
if entry, ok := faces[id]; ok {
if entry.MergeRetry > 0 {
stubborn = append(stubborn, entry)
stubbornIDs = append(stubbornIDs, entry.ID)
}
}
}
if len(stubborn) > 0 {
counts, countErr := query.MarkerCountsByFaceIDs(stubbornIDs)
if countErr != nil {
logErr("faces", "marker counts", countErr)
} else if subjUID != "" {
log.Warnf("faces: %s awaiting merge for subject %s", english.Plural(len(stubborn), "manual cluster", "manual clusters"), entity.SubjNames.Log(subjUID))
for _, entry := range stubborn {
log.Warnf("faces: cluster %s retry=%d markers=%d notes=%s", entry.ID, entry.MergeRetry, counts[entry.ID], clean.Log(entry.MergeNotes))
}
} else {
log.Warnf("faces: %s pending manual cluster merge use 'photoprism faces audit --subject=<uid>' for details", english.Plural(len(stubborn), "manual cluster", "manual clusters"))
}
}
// Remembers matched combinations.
done := make(map[string]bool, len(ids)*len(ids))

View file

@ -19,7 +19,7 @@ func TestFaces_Audit(t *testing.T) {
m := NewFaces(c)
err := m.Audit(true)
err := m.Audit(true, "")
if err != nil {
t.Fatal(err)
@ -30,12 +30,19 @@ func TestFaces_Audit(t *testing.T) {
m := NewFaces(c)
err := m.Audit(false)
err := m.Audit(false, "")
if err != nil {
t.Fatal(err)
}
})
t.Run("SubjectFilter", func(t *testing.T) {
c := config.TestConfig()
m := NewFaces(c)
require.NoError(t, m.Audit(false, "jr0ncy131y7igds8"))
})
}
func TestFaces_AuditNormalizesEmbeddings(t *testing.T) {
@ -83,7 +90,7 @@ func TestFaces_AuditNormalizesEmbeddings(t *testing.T) {
hashNorm := sha1.Sum(normalizeEmbeddingCopy(raw).JSON())
expectedID := base32.StdEncoding.EncodeToString(hashNorm[:])
require.NoError(t, m.Audit(true))
require.NoError(t, m.Audit(true, ""))
var updated entity.Face
require.NoError(t, entity.Db().Where("id = ?", expectedID).First(&updated).Error)