mirror of
https://github.com/photoprism/photoprism.git
synced 2026-01-23 02:24:24 +00:00
AI: Improve conflict resolution when merging face clusters #5167
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
3502251f7e
commit
68e1ddcc89
8 changed files with 267 additions and 10 deletions
|
|
@ -37,6 +37,10 @@ var FacesCommands = &cli.Command{
|
|||
Aliases: []string{"f"},
|
||||
Usage: "fix discovered issues",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "subject",
|
||||
Usage: "limit audit to the specific subject UID",
|
||||
},
|
||||
},
|
||||
Action: facesAuditAction,
|
||||
},
|
||||
|
|
@ -75,8 +79,14 @@ var FacesCommands = &cli.Command{
|
|||
Action: facesUpdateAction,
|
||||
},
|
||||
{
|
||||
Name: "optimize",
|
||||
Usage: "Optimizes face clusters",
|
||||
Name: "optimize",
|
||||
Usage: "Optimizes face clusters",
|
||||
Flags: []cli.Flag{
|
||||
&cli.BoolFlag{
|
||||
Name: "retry",
|
||||
Usage: "reset merge retry counters before optimizing",
|
||||
},
|
||||
},
|
||||
Action: facesOptimizeAction,
|
||||
},
|
||||
},
|
||||
|
|
@ -130,7 +140,9 @@ func facesAuditAction(ctx *cli.Context) error {
|
|||
|
||||
w := get.Faces()
|
||||
|
||||
if err := w.Audit(ctx.Bool("fix")); err != nil {
|
||||
subject := strings.TrimSpace(ctx.String("subject"))
|
||||
|
||||
if err := w.Audit(ctx.Bool("fix"), subject); err != nil {
|
||||
return err
|
||||
} else {
|
||||
elapsed := time.Since(start)
|
||||
|
|
@ -348,6 +360,14 @@ func facesOptimizeAction(ctx *cli.Context) error {
|
|||
|
||||
w := get.Faces()
|
||||
|
||||
if ctx.Bool("retry") {
|
||||
if reset, err := query.ResetFaceMergeRetry(""); err != nil {
|
||||
return err
|
||||
} else if reset > 0 {
|
||||
log.Infof("faces: reset merge retry counters for %s", english.Plural(reset, "cluster", "clusters"))
|
||||
}
|
||||
}
|
||||
|
||||
if res, err := w.Optimize(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -28,6 +28,8 @@ type Face struct {
|
|||
SampleRadius float64 `json:"SampleRadius" yaml:"SampleRadius,omitempty"`
|
||||
Collisions int `json:"Collisions" yaml:"Collisions,omitempty"`
|
||||
CollisionRadius float64 `json:"CollisionRadius" yaml:"CollisionRadius,omitempty"`
|
||||
MergeRetry uint8 `gorm:"type:TINYINT(3);default:0" json:"-" yaml:"-"`
|
||||
MergeNotes string `gorm:"type:VARCHAR(255);default:'';" json:"-" yaml:"-"`
|
||||
EmbeddingJSON json.RawMessage `gorm:"type:MEDIUMBLOB;" json:"-" yaml:"EmbeddingJSON,omitempty"`
|
||||
embedding face.Embedding `gorm:"-" yaml:"-"`
|
||||
MatchedAt *time.Time `json:"MatchedAt" yaml:"MatchedAt,omitempty"`
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ package query
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jinzhu/gorm"
|
||||
|
||||
|
|
@ -23,6 +26,21 @@ type FaceMap map[string]entity.Face
|
|||
// because markers still reference them. Callers may treat this as a non-fatal warning.
|
||||
var ErrRetainedManualClusters = errors.New("faces: retained manual clusters after merge")
|
||||
|
||||
// MergeMaxRetry limits how often the optimiser retries stubborn manual clusters (0 = unlimited).
|
||||
var MergeMaxRetry = 1
|
||||
|
||||
func init() {
|
||||
if v := os.Getenv("PHOTOPRISM_FACE_MERGE_MAX_RETRY"); v != "" {
|
||||
if n, err := strconv.Atoi(v); err == nil {
|
||||
if n < 0 {
|
||||
n = 0
|
||||
}
|
||||
|
||||
MergeMaxRetry = n
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// FacesByID retrieves faces from the database and returns a map with the Face ID as key.
|
||||
func FacesByID(knownOnly, unmatchedOnly, hidden, ignored bool) (FaceMap, IDs, error) {
|
||||
faces, err := Faces(knownOnly, unmatchedOnly, hidden, ignored)
|
||||
|
|
@ -79,6 +97,10 @@ func ManuallyAddedFaces(hidden, ignored bool, subj_uid string) (result entity.Fa
|
|||
stmt = stmt.Where("subj_uid <> ''")
|
||||
}
|
||||
|
||||
if MergeMaxRetry > 0 {
|
||||
stmt = stmt.Where("merge_retry < ?", MergeMaxRetry)
|
||||
}
|
||||
|
||||
if !ignored {
|
||||
stmt = stmt.Where("face_kind <= 1")
|
||||
}
|
||||
|
|
@ -229,12 +251,45 @@ func MergeFaces(merge entity.Faces, ignored bool) (merged *entity.Face, err erro
|
|||
} else if removed > 0 {
|
||||
log.Debugf("faces: removed %d orphans of %d candidate for subject %s", removed, len(merge), clean.Log(subjUID))
|
||||
} else {
|
||||
note := fmt.Sprintf("retained markers after merge attempt on %s", time.Now().UTC().Format(time.RFC3339))
|
||||
|
||||
for _, candidate := range merge {
|
||||
updates := entity.Values{
|
||||
"MergeRetry": gorm.Expr("merge_retry + 1"),
|
||||
"MergeNotes": note,
|
||||
}
|
||||
|
||||
if err := Db().Model(&entity.Face{}).Where("id = ?", candidate.ID).Updates(updates).Error; err != nil {
|
||||
log.Warnf("faces: failed updating merge retry for %s (%s)", candidate.ID, err)
|
||||
} else {
|
||||
candidate.MergeRetry++
|
||||
candidate.MergeNotes = note
|
||||
}
|
||||
}
|
||||
|
||||
return merged, fmt.Errorf("%w: kept %d candidate cluster(s) [%s] for subject %s because markers still reference them", ErrRetainedManualClusters, len(merge), clean.Log(strings.Join(merge.IDs(), ", ")), clean.Log(subjUID))
|
||||
}
|
||||
|
||||
return merged, err
|
||||
}
|
||||
|
||||
// ResetFaceMergeRetry clears merge retry metadata for all (or subject-specific) clusters.
|
||||
func ResetFaceMergeRetry(subjUID string) (int, error) {
|
||||
stmt := Db().Model(&entity.Face{}).Where("merge_retry > 0")
|
||||
|
||||
if subjUID != "" {
|
||||
stmt = stmt.Where("subj_uid = ?", subjUID)
|
||||
}
|
||||
|
||||
res := stmt.UpdateColumns(entity.Values{"MergeRetry": 0, "MergeNotes": ""})
|
||||
|
||||
if res.Error != nil {
|
||||
return 0, res.Error
|
||||
}
|
||||
|
||||
return int(res.RowsAffected), nil
|
||||
}
|
||||
|
||||
// ResolveFaceCollisions resolves collisions of different subject's faces.
|
||||
func ResolveFaceCollisions() (conflicts, resolved int, err error) {
|
||||
faces, ids, err := FacesByID(true, false, false, false)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,15 @@
|
|||
package query
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/rnd"
|
||||
)
|
||||
|
||||
func TestFaces(t *testing.T) {
|
||||
|
|
@ -229,6 +232,61 @@ func TestMergeFaces(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestMergeFacesRetainedClusters(t *testing.T) {
|
||||
subjUID := rnd.GenerateUID('j')
|
||||
|
||||
embeddingA := face.RandomEmbeddings(1, face.RegularFace)
|
||||
embeddingB := face.RandomEmbeddings(1, face.RegularFace)
|
||||
|
||||
faceA := entity.NewFace(subjUID, entity.SrcManual, embeddingA)
|
||||
require.NoError(t, faceA.Create())
|
||||
|
||||
faceB := entity.NewFace(subjUID, entity.SrcManual, embeddingB)
|
||||
require.NoError(t, faceB.Create())
|
||||
|
||||
// Create markers that deliberately fail to match the merged embedding.
|
||||
neutralEmbedding := face.Embeddings{face.NullEmbedding}
|
||||
neutralJSON := neutralEmbedding.JSON()
|
||||
|
||||
markers := []*entity.Marker{
|
||||
{
|
||||
FileUID: rnd.GenerateUID('f'),
|
||||
MarkerType: entity.MarkerFace,
|
||||
MarkerSrc: entity.SrcManual,
|
||||
FaceID: faceA.ID,
|
||||
EmbeddingsJSON: neutralJSON,
|
||||
},
|
||||
{
|
||||
FileUID: rnd.GenerateUID('f'),
|
||||
MarkerType: entity.MarkerFace,
|
||||
MarkerSrc: entity.SrcManual,
|
||||
FaceID: faceB.ID,
|
||||
EmbeddingsJSON: neutralJSON,
|
||||
},
|
||||
}
|
||||
|
||||
for _, marker := range markers {
|
||||
require.NoError(t, entity.Db().Create(marker).Error)
|
||||
}
|
||||
|
||||
_, err := MergeFaces(entity.Faces{*faceA, *faceB}, false)
|
||||
require.Error(t, err)
|
||||
require.True(t, errors.Is(err, ErrRetainedManualClusters))
|
||||
|
||||
var updated entity.Face
|
||||
require.NoError(t, entity.Db().Where("id = ?", faceA.ID).First(&updated).Error)
|
||||
require.NotZero(t, updated.MergeRetry)
|
||||
require.NotEmpty(t, updated.MergeNotes)
|
||||
|
||||
resetCount, err := ResetFaceMergeRetry(subjUID)
|
||||
require.NoError(t, err)
|
||||
require.GreaterOrEqual(t, resetCount, 1)
|
||||
|
||||
require.NoError(t, entity.Db().Where("id = ?", faceA.ID).First(&updated).Error)
|
||||
require.Zero(t, updated.MergeRetry)
|
||||
require.Empty(t, updated.MergeNotes)
|
||||
}
|
||||
|
||||
func TestResolveFaceCollisions(t *testing.T) {
|
||||
c, r, err := ResolveFaceCollisions()
|
||||
|
||||
|
|
|
|||
|
|
@ -122,6 +122,39 @@ func Embeddings(single, unclustered bool, size, score int) (result face.Embeddin
|
|||
return result, nil
|
||||
}
|
||||
|
||||
// MarkerCountsByFaceIDs returns a map of marker counts for the provided face IDs.
|
||||
func MarkerCountsByFaceIDs(faceIDs []string) (map[string]int, error) {
|
||||
counts := make(map[string]int, len(faceIDs))
|
||||
|
||||
if len(faceIDs) == 0 {
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
type row struct {
|
||||
FaceID string
|
||||
Count int
|
||||
}
|
||||
|
||||
var rows []row
|
||||
|
||||
if err := Db().
|
||||
Model(&entity.Marker{}).
|
||||
Select("face_id, COUNT(*) AS count").
|
||||
Where("marker_invalid = 0").
|
||||
Where("marker_type = ?", entity.MarkerFace).
|
||||
Where("face_id IN (?)", faceIDs).
|
||||
Group("face_id").
|
||||
Scan(&rows).Error; err != nil {
|
||||
return counts, err
|
||||
}
|
||||
|
||||
for _, r := range rows {
|
||||
counts[r.FaceID] = r.Count
|
||||
}
|
||||
|
||||
return counts, nil
|
||||
}
|
||||
|
||||
// RemoveInvalidMarkerReferences removes face and subject references from invalid markers.
|
||||
func RemoveInvalidMarkerReferences() (removed int64, err error) {
|
||||
result := Db().
|
||||
|
|
|
|||
|
|
@ -151,6 +151,36 @@ func TestEmbeddings(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestMarkerCountsByFaceIDs(t *testing.T) {
|
||||
counts, err := MarkerCountsByFaceIDs(nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
assert.Empty(t, counts)
|
||||
|
||||
faces, err := Faces(false, false, false, false)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(faces) == 0 {
|
||||
t.Skip("no faces available in test dataset")
|
||||
}
|
||||
|
||||
ids := []string{faces[0].ID}
|
||||
|
||||
counts, err = MarkerCountsByFaceIDs(ids)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(counts) == 0 {
|
||||
t.Skip("no markers found for sampled face")
|
||||
}
|
||||
|
||||
assert.GreaterOrEqual(t, counts[faces[0].ID], 0)
|
||||
}
|
||||
|
||||
func TestRemoveInvalidMarkerReferences(t *testing.T) {
|
||||
affected, err := RemoveInvalidMarkerReferences()
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ import (
|
|||
)
|
||||
|
||||
// Audit face clusters and subjects.
|
||||
func (w *Faces) Audit(fix bool) (err error) {
|
||||
func (w *Faces) Audit(fix bool, subjUID string) (err error) {
|
||||
invalidFaces, invalidSubj, err := query.MarkersWithNonExistentReferences()
|
||||
|
||||
if err != nil {
|
||||
|
|
@ -29,10 +29,14 @@ func (w *Faces) Audit(fix bool) (err error) {
|
|||
log.Errorf("faces: %s (find subjects)", err)
|
||||
}
|
||||
|
||||
if n := len(subj); n == 0 {
|
||||
log.Infof("faces: found no subjects")
|
||||
if subjUID == "" {
|
||||
if n := len(subj); n == 0 {
|
||||
log.Infof("faces: found no subjects")
|
||||
} else {
|
||||
log.Infof("faces: found %s", english.Plural(n, "subject", "subjects"))
|
||||
}
|
||||
} else {
|
||||
log.Infof("faces: found %s", english.Plural(n, "subject", "subjects"))
|
||||
log.Infof("faces: auditing subject %s (%s)", entity.SubjNames.Log(subjUID), clean.Log(subjUID))
|
||||
}
|
||||
|
||||
// Fix non-existent marker subjects references?
|
||||
|
|
@ -71,6 +75,54 @@ func (w *Faces) Audit(fix bool) (err error) {
|
|||
return err
|
||||
}
|
||||
|
||||
if subjUID != "" {
|
||||
filtered := make(query.FaceMap, len(faces))
|
||||
filteredIDs := make(query.IDs, 0, len(ids))
|
||||
|
||||
for _, id := range ids {
|
||||
faceEntry := faces[id]
|
||||
if faceEntry.SubjUID != subjUID {
|
||||
continue
|
||||
}
|
||||
|
||||
filtered[id] = faceEntry
|
||||
filteredIDs = append(filteredIDs, id)
|
||||
}
|
||||
|
||||
faces = filtered
|
||||
ids = filteredIDs
|
||||
|
||||
if len(ids) == 0 {
|
||||
log.Infof("faces: found no clusters for subject %s", entity.SubjNames.Log(subjUID))
|
||||
}
|
||||
}
|
||||
|
||||
stubborn := make([]entity.Face, 0)
|
||||
stubbornIDs := make([]string, 0)
|
||||
|
||||
for _, id := range ids {
|
||||
if entry, ok := faces[id]; ok {
|
||||
if entry.MergeRetry > 0 {
|
||||
stubborn = append(stubborn, entry)
|
||||
stubbornIDs = append(stubbornIDs, entry.ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(stubborn) > 0 {
|
||||
counts, countErr := query.MarkerCountsByFaceIDs(stubbornIDs)
|
||||
if countErr != nil {
|
||||
logErr("faces", "marker counts", countErr)
|
||||
} else if subjUID != "" {
|
||||
log.Warnf("faces: %s awaiting merge for subject %s", english.Plural(len(stubborn), "manual cluster", "manual clusters"), entity.SubjNames.Log(subjUID))
|
||||
for _, entry := range stubborn {
|
||||
log.Warnf("faces: cluster %s retry=%d markers=%d notes=%s", entry.ID, entry.MergeRetry, counts[entry.ID], clean.Log(entry.MergeNotes))
|
||||
}
|
||||
} else {
|
||||
log.Warnf("faces: %s pending manual cluster merge – use 'photoprism faces audit --subject=<uid>' for details", english.Plural(len(stubborn), "manual cluster", "manual clusters"))
|
||||
}
|
||||
}
|
||||
|
||||
// Remembers matched combinations.
|
||||
done := make(map[string]bool, len(ids)*len(ids))
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ func TestFaces_Audit(t *testing.T) {
|
|||
|
||||
m := NewFaces(c)
|
||||
|
||||
err := m.Audit(true)
|
||||
err := m.Audit(true, "")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
|
@ -30,12 +30,19 @@ func TestFaces_Audit(t *testing.T) {
|
|||
|
||||
m := NewFaces(c)
|
||||
|
||||
err := m.Audit(false)
|
||||
err := m.Audit(false, "")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("SubjectFilter", func(t *testing.T) {
|
||||
c := config.TestConfig()
|
||||
|
||||
m := NewFaces(c)
|
||||
|
||||
require.NoError(t, m.Audit(false, "jr0ncy131y7igds8"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestFaces_AuditNormalizesEmbeddings(t *testing.T) {
|
||||
|
|
@ -83,7 +90,7 @@ func TestFaces_AuditNormalizesEmbeddings(t *testing.T) {
|
|||
hashNorm := sha1.Sum(normalizeEmbeddingCopy(raw).JSON())
|
||||
expectedID := base32.StdEncoding.EncodeToString(hashNorm[:])
|
||||
|
||||
require.NoError(t, m.Audit(true))
|
||||
require.NoError(t, m.Audit(true, ""))
|
||||
|
||||
var updated entity.Face
|
||||
require.NoError(t, entity.Db().Where("id = ?", expectedID).First(&updated).Error)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue