mirror of
https://github.com/photoprism/photoprism.git
synced 2026-01-23 02:24:24 +00:00
AI: Improve Face Detection with an ONNX-based model #5167
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
677e190b6e
commit
94f8a5f35d
39 changed files with 1713 additions and 43 deletions
7
Makefile
7
Makefile
|
|
@ -65,7 +65,7 @@ endif
|
|||
|
||||
# Declare "make" targets.
|
||||
all: dep build-js
|
||||
dep: dep-tensorflow dep-js
|
||||
dep: dep-tensorflow dep-onnx dep-js
|
||||
biuld: build
|
||||
build: build-go
|
||||
watch: watch-js
|
||||
|
|
@ -177,6 +177,7 @@ install:
|
|||
@[ ! -d "$(DESTDIR)" ] || (echo "ERROR: Install path '$(DESTDIR)' already exists!"; exit 1)
|
||||
mkdir --mode=$(INSTALL_MODE) -p $(DESTDIR)
|
||||
env TMPDIR="$(BUILD_PATH)" ./scripts/dist/install-tensorflow.sh $(DESTDIR)
|
||||
env TMPDIR="$(BUILD_PATH)" ./scripts/dist/install-onnx.sh $(DESTDIR)
|
||||
rm -rf --preserve-root $(DESTDIR)/include
|
||||
(cd $(DESTDIR) && mkdir -p bin lib assets)
|
||||
./scripts/build.sh prod "$(DESTDIR)/bin/$(BINARY_NAME)"
|
||||
|
|
@ -192,6 +193,8 @@ install-go:
|
|||
go build -v ./...
|
||||
install-tensorflow:
|
||||
sudo scripts/dist/install-tensorflow.sh
|
||||
install-onnx:
|
||||
sudo scripts/dist/install-onnx.sh
|
||||
install-darktable:
|
||||
sudo scripts/dist/install-darktable.sh
|
||||
acceptance-sqlite-restart:
|
||||
|
|
@ -280,6 +283,8 @@ dep-tensorflow:
|
|||
scripts/download-facenet.sh
|
||||
scripts/download-nasnet.sh
|
||||
scripts/download-nsfw.sh
|
||||
dep-onnx:
|
||||
scripts/download-scrfs.sh
|
||||
dep-acceptance: storage/acceptance
|
||||
storage/acceptance:
|
||||
[ -f "./storage/acceptance/index.db" ] || (cd storage && rm -rf acceptance && wget -c https://dl.photoprism.app/qa/acceptance.tar.gz -O - | tar -xz)
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
&& \
|
||||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
echo "ALL ALL=(ALL) NOPASSWD:SETENV: ALL" >> /etc/sudoers.d/all && \
|
||||
mkdir -p /etc/skel/.config/go/telemetry && \
|
||||
|
|
|
|||
|
|
@ -113,6 +113,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-go-tools.sh && \
|
||||
echo "ALL ALL=(ALL) NOPASSWD:SETENV: ALL" >> /etc/sudoers.d/all && \
|
||||
|
|
|
|||
|
|
@ -109,6 +109,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
&& \
|
||||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-go-tools.sh && \
|
||||
echo "ALL ALL=(ALL) NOPASSWD:SETENV: ALL" >> /etc/sudoers.d/all && \
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
&& \
|
||||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-chrome.sh && \
|
||||
/scripts/install-go.sh && \
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-yt-dlp.sh && \
|
||||
/scripts/install-libheif.sh && \
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-libheif.sh && \
|
||||
/scripts/install-chrome.sh && \
|
||||
|
|
|
|||
|
|
@ -71,6 +71,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-libheif.sh && \
|
||||
/scripts/install-chrome.sh && \
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-libheif.sh && \
|
||||
/scripts/install-chrome.sh && \
|
||||
|
|
|
|||
|
|
@ -76,6 +76,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-libheif.sh && \
|
||||
/scripts/install-chrome.sh && \
|
||||
|
|
|
|||
|
|
@ -78,6 +78,7 @@ RUN echo 'APT::Acquire::Retries "3";' > /etc/apt/apt.conf.d/80retries && \
|
|||
/scripts/install-nodejs.sh && \
|
||||
/scripts/install-mariadb.sh mariadb-client && \
|
||||
/scripts/install-tensorflow.sh && \
|
||||
/scripts/install-onnx.sh && \
|
||||
/scripts/install-darktable.sh && \
|
||||
/scripts/install-yt-dlp.sh && \
|
||||
/scripts/install-libheif.sh && \
|
||||
|
|
|
|||
1
go.mod
1
go.mod
|
|
@ -89,6 +89,7 @@ require (
|
|||
github.com/ugjka/go-tz/v2 v2.2.6
|
||||
github.com/urfave/cli/v2 v2.27.7
|
||||
github.com/wamuir/graft v0.10.0
|
||||
github.com/yalue/onnxruntime_go v1.21.0
|
||||
github.com/zitadel/oidc/v3 v3.45.0
|
||||
golang.org/x/mod v0.28.0
|
||||
golang.org/x/sys v0.36.0
|
||||
|
|
|
|||
2
go.sum
2
go.sum
|
|
@ -446,6 +446,8 @@ github.com/wamuir/graft v0.10.0 h1:HSpBUvm7O+jwsRIuDQlw80xW4xMXRFkOiVLtWaZCU2s=
|
|||
github.com/wamuir/graft v0.10.0/go.mod h1:k6NJX3fCM/xzh5NtHky9USdgHTcz2vAvHp4c23I6UK4=
|
||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 h1:gEOO8jv9F4OT7lGCjxCBTO/36wtF6j2nSip77qHd4x4=
|
||||
github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1/go.mod h1:Ohn+xnUBiLI6FVj/9LpzZWtj1/D6lUovWYBkxHVV3aM=
|
||||
github.com/yalue/onnxruntime_go v1.21.0 h1:DdtvfY7OP5gR8mwPDqAOAQckf+KcI30hPNJL8hQaYWI=
|
||||
github.com/yalue/onnxruntime_go v1.21.0/go.mod h1:b4X26A8pekNb1ACJ58wAXgNKeUCGEAQ9dmACut9Sm/4=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/zitadel/logging v0.6.2 h1:MW2kDDR0ieQynPZ0KIZPrh9ote2WkxfBif5QoARDQcU=
|
||||
github.com/zitadel/logging v0.6.2/go.mod h1:z6VWLWUkJpnNVDSLzrPSQSQyttysKZ6bCRongw0ROK4=
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
## Face Detection and Embedding Guidelines
|
||||
|
||||
**Last Updated:** October 5, 2025
|
||||
**Last Updated:** October 6, 2025
|
||||
|
||||
### Overview
|
||||
|
||||
|
|
@ -15,6 +15,13 @@ Key changes:
|
|||
|
||||
### Detection Pipeline
|
||||
|
||||
PhotoPrism now supports two interchangeable detection engines:
|
||||
|
||||
- **Pigo** (default) — CPU-only cascade classifier, retains historical behaviour.
|
||||
- **ONNX SCRFD 0.5g** — optional ONNX Runtime-backed CNN that delivers higher recall on occluded or off-axis faces. The ONNX engine consumes 720 px thumbnails (model input 640 px), schedules work on the meta/vision workers, and defaults to half the available CPUs (minimum 1 thread). The engine is enabled automatically when `FACE_ENGINE=auto` and the bundled SCRFD model is present (the prebuilt runtime targets glibc ≥ 2.27 on x86_64/arm64). Operators can switch at runtime via `photoprism --face-engine=<auto|pigo|onnx>` or `photoprism faces reset --engine=<auto|pigo|onnx>` for a full re-index.
|
||||
|
||||
Runtime selection lives in `Config.FaceEngine()`; `auto` resolves to ONNX when the SCRFD assets are available, otherwise Pigo. `Config.FaceEngineRunType()` mirrors the vision run-type semantics: ONNX defaults to asynchronous `on-demand` mode when only a single inference thread is configured so the indexer remains responsive.
|
||||
|
||||
#### Angle Sweep
|
||||
|
||||
- The detector now evaluates the Pigo cascade at **-0.3, 0, and +0.3 radians**. These angles are exposed via the new `FACE_ANGLE` option.
|
||||
|
|
@ -75,11 +82,14 @@ This guarantees that Euclidean distance comparisons are equivalent to cosine com
|
|||
|
||||
### Configuration Summary
|
||||
|
||||
| Setting | Default | Description |
|
||||
|------------------------|------------------------------|--------------------------------------------------------------------------------------|
|
||||
| `FACE_ANGLE` | `-0.3,0,0.3` | Detection angles (radians) swept by Pigo. |
|
||||
| `FACE_SCORE` | `9.0` (with dynamic offsets) | Base quality threshold before scale adjustments. |
|
||||
| `FACE_OVERLAP` | `42` | Maximum allowed IoU when deduplicating markers. |
|
||||
| Setting | Default | Description |
|
||||
|--------------------------|------------------------------|-------------------------------------------------------------------------------------------------|
|
||||
| `FACE_ENGINE` | `auto` | Detection engine (`auto`, `pigo`, `onnx`). `auto` resolves to ONNX when the SCRFD model exists. |
|
||||
| `FACE_ENGINE_RUN` | `auto` | Run schedule (`auto`, `on-demand`, `on-index`, ...). `auto` is stored as an empty string. |
|
||||
| `FACE_ENGINE_THREADS` | `runtime.NumCPU()/2` (≥1) | ONNX inference threads; ignored by Pigo. |
|
||||
| `FACE_ANGLE` | `-0.3,0,0.3` | Detection angles (radians) swept by Pigo. |
|
||||
| `FACE_SCORE` | `9.0` (with dynamic offsets) | Base quality threshold before scale adjustments. |
|
||||
| `FACE_OVERLAP` | `42` | Maximum allowed IoU when deduplicating markers. |
|
||||
|
||||
### Benchmark Reference
|
||||
|
||||
|
|
|
|||
|
|
@ -64,8 +64,25 @@ var (
|
|||
mouthCascades = []string{"lp93", "lp84", "lp82", "lp81"}
|
||||
)
|
||||
|
||||
// Detector struct contains Pigo face detector general settings.
|
||||
type Detector struct {
|
||||
// pigoEngine implements DetectionEngine using the bundled Pigo cascades.
|
||||
type pigoEngine struct{}
|
||||
|
||||
// newPigoEngine constructs a Pigo-backed DetectionEngine instance.
|
||||
func newPigoEngine() *pigoEngine {
|
||||
return &pigoEngine{}
|
||||
}
|
||||
|
||||
func (p *pigoEngine) Name() string {
|
||||
return EnginePigo
|
||||
}
|
||||
|
||||
// Close releases resources held by the Pigo engine (none at the moment).
|
||||
func (p *pigoEngine) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// pigoDetector contains Pigo face detector general settings.
|
||||
type pigoDetector struct {
|
||||
minSize int
|
||||
shiftFactor float64
|
||||
scaleFactor float64
|
||||
|
|
@ -76,7 +93,7 @@ type Detector struct {
|
|||
}
|
||||
|
||||
// Detect runs the detection algorithm over the provided source image.
|
||||
func Detect(fileName string, findLandmarks bool, minSize int) (faces Faces, err error) {
|
||||
func (p *pigoEngine) Detect(fileName string, findLandmarks bool, minSize int) (faces Faces, err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
log.Errorf("faces: %s (panic)\nstack: %s", r, debug.Stack())
|
||||
|
|
@ -89,7 +106,7 @@ func Detect(fileName string, findLandmarks bool, minSize int) (faces Faces, err
|
|||
|
||||
angles := append([]float64(nil), DetectionAngles...)
|
||||
|
||||
d := &Detector{
|
||||
d := &pigoDetector{
|
||||
minSize: minSize,
|
||||
shiftFactor: 0.1,
|
||||
scaleFactor: 1.1,
|
||||
|
|
@ -123,7 +140,7 @@ func Detect(fileName string, findLandmarks bool, minSize int) (faces Faces, err
|
|||
}
|
||||
|
||||
// Detect runs the detection algorithm over the provided source image.
|
||||
func (d *Detector) Detect(fileName string) (faces []pigo.Detection, params pigo.CascadeParams, err error) {
|
||||
func (d *pigoDetector) Detect(fileName string) (faces []pigo.Detection, params pigo.CascadeParams, err error) {
|
||||
if len(d.angles) == 0 {
|
||||
// Fallback to defaults when the detector is constructed manually (e.g. tests).
|
||||
d.angles = append([]float64(nil), DetectionAngles...)
|
||||
|
|
@ -199,7 +216,7 @@ func (d *Detector) Detect(fileName string) (faces []pigo.Detection, params pigo.
|
|||
}
|
||||
|
||||
// Faces adds landmark coordinates to detected faces and returns the results.
|
||||
func (d *Detector) Faces(det []pigo.Detection, params pigo.CascadeParams, findLandmarks bool) (results Faces, err error) {
|
||||
func (d *pigoDetector) Faces(det []pigo.Detection, params pigo.CascadeParams, findLandmarks bool) (results Faces, err error) {
|
||||
// Sort results by size.
|
||||
sort.Slice(det, func(i, j int) bool {
|
||||
return det[i].Scale > det[j].Scale
|
||||
|
|
|
|||
|
|
@ -175,7 +175,7 @@ func TestDetectQualityFallback(t *testing.T) {
|
|||
func BenchmarkDetectorFacesLandmarks(b *testing.B) {
|
||||
const sample = "testdata/18.jpg"
|
||||
|
||||
d := &Detector{
|
||||
d := &pigoDetector{
|
||||
minSize: 20,
|
||||
shiftFactor: 0.1,
|
||||
scaleFactor: 1.1,
|
||||
|
|
|
|||
127
internal/ai/face/engine.go
Normal file
127
internal/ai/face/engine.go
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
package face
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type EngineName = string
|
||||
|
||||
const (
|
||||
EngineAuto EngineName = "auto"
|
||||
EnginePigo EngineName = "pigo"
|
||||
EngineONNX EngineName = "onnx"
|
||||
)
|
||||
|
||||
// ParseEngine normalizes user input and returns a supported engine name or EngineAuto when unknown.
|
||||
func ParseEngine(s string) EngineName {
|
||||
s = strings.ToLower(strings.TrimSpace(s))
|
||||
|
||||
switch s {
|
||||
case EnginePigo, EngineONNX:
|
||||
return s
|
||||
default:
|
||||
return EngineAuto
|
||||
}
|
||||
}
|
||||
|
||||
// DetectionEngine represents a strategy for locating faces in an image.
|
||||
type DetectionEngine interface {
|
||||
Name() string
|
||||
Detect(fileName string, findLandmarks bool, minSize int) (Faces, error)
|
||||
Close() error
|
||||
}
|
||||
|
||||
// EngineSettings capture configuration required to initialize a detection engine.
|
||||
type EngineSettings struct {
|
||||
Name EngineName
|
||||
ONNX ONNXOptions
|
||||
}
|
||||
|
||||
var (
|
||||
engineMu sync.RWMutex
|
||||
activeEngine DetectionEngine
|
||||
)
|
||||
|
||||
func init() {
|
||||
activeEngine = newPigoEngine()
|
||||
}
|
||||
|
||||
// UseEngine replaces the active detection engine and returns the previous instance.
|
||||
func UseEngine(engine DetectionEngine) (previous DetectionEngine) {
|
||||
engineMu.Lock()
|
||||
prev := activeEngine
|
||||
if engine == nil {
|
||||
activeEngine = newPigoEngine()
|
||||
} else {
|
||||
activeEngine = engine
|
||||
}
|
||||
engineMu.Unlock()
|
||||
return prev
|
||||
}
|
||||
|
||||
// ConfigureEngine selects and initializes the face detection engine based on the provided settings.
|
||||
func ConfigureEngine(settings EngineSettings) error {
|
||||
desired := ParseEngine(settings.Name)
|
||||
|
||||
if desired == EngineAuto {
|
||||
desired = EnginePigo
|
||||
}
|
||||
|
||||
var (
|
||||
newEngine DetectionEngine
|
||||
initErr error
|
||||
)
|
||||
|
||||
switch desired {
|
||||
case EngineONNX:
|
||||
if settings.ONNX.ModelPath == "" {
|
||||
initErr = fmt.Errorf("faces: ONNX model path is empty")
|
||||
newEngine = newPigoEngine()
|
||||
break
|
||||
}
|
||||
|
||||
if newEngine, initErr = NewONNXEngine(settings.ONNX); initErr != nil {
|
||||
newEngine = newPigoEngine()
|
||||
}
|
||||
case EnginePigo:
|
||||
fallthrough
|
||||
default:
|
||||
if desired != EnginePigo {
|
||||
log.Warnf("faces: unknown detection engine %q, falling back to pigo", desired)
|
||||
}
|
||||
newEngine = newPigoEngine()
|
||||
}
|
||||
|
||||
prev := UseEngine(newEngine)
|
||||
if prev != nil {
|
||||
_ = prev.Close()
|
||||
}
|
||||
|
||||
return initErr
|
||||
}
|
||||
|
||||
// ActiveEngine returns the currently configured detection engine.
|
||||
func ActiveEngine() DetectionEngine {
|
||||
engineMu.RLock()
|
||||
engine := activeEngine
|
||||
engineMu.RUnlock()
|
||||
return engine
|
||||
}
|
||||
|
||||
// Detect runs the active engine on the provided file and returns the detected faces.
|
||||
func Detect(fileName string, findLandmarks bool, minSize int) (Faces, error) {
|
||||
engine := ActiveEngine()
|
||||
if engine == nil {
|
||||
return Faces{}, fmt.Errorf("faces: detection engine not configured")
|
||||
}
|
||||
return engine.Detect(fileName, findLandmarks, minSize)
|
||||
}
|
||||
|
||||
// resetEngine restores the default Pigo engine.
|
||||
func resetEngine() {
|
||||
engineMu.Lock()
|
||||
activeEngine = newPigoEngine()
|
||||
engineMu.Unlock()
|
||||
}
|
||||
640
internal/ai/face/engine_onnx.go
Normal file
640
internal/ai/face/engine_onnx.go
Normal file
|
|
@ -0,0 +1,640 @@
|
|||
package face
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"image"
|
||||
_ "image/jpeg"
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/disintegration/imaging"
|
||||
onnxruntime "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// ONNXOptions configures how the ONNX runtime-backed detector is initialised.
|
||||
type ONNXOptions struct {
|
||||
ModelPath string
|
||||
LibraryPath string
|
||||
Threads int
|
||||
ScoreThreshold float32
|
||||
NMSThreshold float32
|
||||
}
|
||||
|
||||
const (
|
||||
DefaultONNXModelFilename = "scrfd.onnx"
|
||||
onnxDefaultScoreThreshold = 0.50
|
||||
onnxDefaultNMSThreshold = 0.40
|
||||
onnxDefaultInputSize = 640
|
||||
onnxInputMean = 127.5
|
||||
onnxInputStd = 128.0
|
||||
)
|
||||
|
||||
// anchorCacheKey uniquely identifies cached anchor center grids.
|
||||
type anchorCacheKey struct {
|
||||
height int
|
||||
width int
|
||||
stride int
|
||||
anchors int
|
||||
}
|
||||
|
||||
// onnxEngine runs face detection using an ONNX Runtime session and SCRFD model.
|
||||
type onnxEngine struct {
|
||||
session *onnxruntime.DynamicAdvancedSession
|
||||
inputName string
|
||||
outputNames []string
|
||||
inputWidth int
|
||||
inputHeight int
|
||||
featStrides []int
|
||||
numAnchors int
|
||||
useKps bool
|
||||
batched bool
|
||||
scoreThreshold float32
|
||||
nmsThreshold float32
|
||||
centerMu sync.Mutex
|
||||
centerCache map[anchorCacheKey][]float32
|
||||
}
|
||||
|
||||
var (
|
||||
onnxOnce sync.Once
|
||||
onnxInitErr error
|
||||
onnxExecutableVar = os.Executable
|
||||
)
|
||||
|
||||
// ensureONNXRuntime loads the ONNX runtime shared library and initializes the global environment.
|
||||
func ensureONNXRuntime(libraryPath string) error {
|
||||
onnxOnce.Do(func() {
|
||||
candidates := onnxSharedLibraryCandidates(libraryPath)
|
||||
var errs []string
|
||||
|
||||
for _, candidate := range candidates {
|
||||
onnxruntime.SetSharedLibraryPath(candidate)
|
||||
|
||||
if err := onnxruntime.InitializeEnvironment(); err != nil {
|
||||
// Collect errors so we can surface meaningful diagnostics when all options fail.
|
||||
errs = append(errs, fmt.Sprintf("%s (%v)", candidate, err))
|
||||
continue
|
||||
}
|
||||
|
||||
// Successfully initialized; stop retrying.
|
||||
onnxInitErr = nil
|
||||
return
|
||||
}
|
||||
|
||||
if len(errs) == 0 {
|
||||
onnxInitErr = errors.New("faces: no ONNX runtime library candidates")
|
||||
return
|
||||
}
|
||||
|
||||
onnxInitErr = fmt.Errorf("faces: failed to load ONNX runtime: %s", strings.Join(errs, "; "))
|
||||
})
|
||||
|
||||
return onnxInitErr
|
||||
}
|
||||
|
||||
// onnxSharedLibraryCandidates lists library paths to try when loading the ONNX runtime.
|
||||
func onnxSharedLibraryCandidates(explicit string) []string {
|
||||
appendUnique := func(list []string, seen map[string]struct{}, values ...string) []string {
|
||||
for _, value := range values {
|
||||
if value == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[value]; ok {
|
||||
continue
|
||||
}
|
||||
list = append(list, value)
|
||||
seen[value] = struct{}{}
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{})
|
||||
candidates := make([]string, 0, 8)
|
||||
candidates = appendUnique(candidates, seen, explicit)
|
||||
candidates = appendUnique(candidates, seen,
|
||||
"libonnxruntime.so",
|
||||
"libonnxruntime.so.1",
|
||||
"onnxruntime.so",
|
||||
)
|
||||
|
||||
if exePath, err := onnxExecutableVar(); err == nil {
|
||||
exeDir := filepath.Dir(exePath)
|
||||
rootDir := filepath.Dir(exeDir)
|
||||
|
||||
candidates = appendUnique(candidates, seen,
|
||||
filepath.Join(exeDir, "libonnxruntime.so"),
|
||||
filepath.Join(exeDir, "lib", "libonnxruntime.so"),
|
||||
)
|
||||
|
||||
if rootDir != "" && rootDir != "." && rootDir != exeDir {
|
||||
candidates = appendUnique(candidates, seen, filepath.Join(rootDir, "lib", "libonnxruntime.so"))
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// NewONNXEngine loads the SCRFD model and returns an ONNX-backed DetectionEngine.
|
||||
func NewONNXEngine(opts ONNXOptions) (DetectionEngine, error) {
|
||||
if opts.ModelPath == "" {
|
||||
return nil, fmt.Errorf("faces: missing ONNX model path")
|
||||
}
|
||||
|
||||
if _, err := os.Stat(opts.ModelPath); err != nil {
|
||||
return nil, fmt.Errorf("faces: %w", err)
|
||||
}
|
||||
|
||||
if opts.ScoreThreshold <= 0 {
|
||||
opts.ScoreThreshold = onnxDefaultScoreThreshold
|
||||
}
|
||||
|
||||
if opts.NMSThreshold <= 0 {
|
||||
opts.NMSThreshold = onnxDefaultNMSThreshold
|
||||
}
|
||||
|
||||
if err := ensureONNXRuntime(opts.LibraryPath); err != nil {
|
||||
return nil, fmt.Errorf("faces: %w", err)
|
||||
}
|
||||
|
||||
sessionOpts, err := onnxruntime.NewSessionOptions()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faces: %w", err)
|
||||
}
|
||||
defer sessionOpts.Destroy()
|
||||
|
||||
threads := opts.Threads
|
||||
if threads == 0 {
|
||||
threads = runtime.NumCPU() / 2
|
||||
if threads < 1 {
|
||||
threads = 1
|
||||
}
|
||||
}
|
||||
|
||||
if err := sessionOpts.SetIntraOpNumThreads(threads); err != nil {
|
||||
return nil, fmt.Errorf("faces: configure intra-op threads: %w", err)
|
||||
}
|
||||
|
||||
if err := sessionOpts.SetInterOpNumThreads(threads); err != nil {
|
||||
return nil, fmt.Errorf("faces: configure inter-op threads: %w", err)
|
||||
}
|
||||
|
||||
if err := sessionOpts.SetGraphOptimizationLevel(onnxruntime.GraphOptimizationLevelEnableAll); err != nil {
|
||||
return nil, fmt.Errorf("faces: optimize session graph: %w", err)
|
||||
}
|
||||
|
||||
inputInfos, outputInfos, err := onnxruntime.GetInputOutputInfoWithOptions(opts.ModelPath, sessionOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faces: load ONNX metadata: %w", err)
|
||||
}
|
||||
|
||||
if len(inputInfos) == 0 {
|
||||
return nil, fmt.Errorf("faces: ONNX model has no inputs")
|
||||
}
|
||||
|
||||
if len(outputInfos) == 0 {
|
||||
return nil, fmt.Errorf("faces: ONNX model has no outputs")
|
||||
}
|
||||
|
||||
inputName := inputInfos[0].Name
|
||||
inputDims := inputInfos[0].Dimensions
|
||||
|
||||
width := onnxDefaultInputSize
|
||||
height := onnxDefaultInputSize
|
||||
|
||||
if len(inputDims) >= 4 {
|
||||
if w := int(inputDims[len(inputDims)-1]); w > 0 {
|
||||
width = w
|
||||
}
|
||||
if h := int(inputDims[len(inputDims)-2]); h > 0 {
|
||||
height = h
|
||||
}
|
||||
}
|
||||
|
||||
outputNames := make([]string, len(outputInfos))
|
||||
for i, out := range outputInfos {
|
||||
outputNames[i] = out.Name
|
||||
}
|
||||
|
||||
fmc, numAnchors, useKps, batched, err := deriveONNXLayout(outputInfos)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
featStrides := stridesForFeatureMaps(fmc)
|
||||
|
||||
session, err := onnxruntime.NewDynamicAdvancedSession(opts.ModelPath, []string{inputName}, outputNames, sessionOpts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("faces: initialise ONNX session: %w", err)
|
||||
}
|
||||
|
||||
engine := &onnxEngine{
|
||||
session: session,
|
||||
inputName: inputName,
|
||||
outputNames: outputNames,
|
||||
inputWidth: width,
|
||||
inputHeight: height,
|
||||
featStrides: featStrides,
|
||||
numAnchors: numAnchors,
|
||||
useKps: useKps,
|
||||
batched: batched,
|
||||
scoreThreshold: opts.ScoreThreshold,
|
||||
nmsThreshold: opts.NMSThreshold,
|
||||
centerCache: make(map[anchorCacheKey][]float32),
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
// deriveONNXLayout infers the number of feature map chains, anchors, and output layout from the model outputs.
|
||||
func deriveONNXLayout(outputs []onnxruntime.InputOutputInfo) (fmc, anchors int, useKps, batched bool, err error) {
|
||||
outCount := len(outputs)
|
||||
|
||||
switch outCount {
|
||||
case 6:
|
||||
fmc = 3
|
||||
anchors = 2
|
||||
case 9:
|
||||
fmc = 3
|
||||
anchors = 2
|
||||
useKps = true
|
||||
case 10:
|
||||
fmc = 5
|
||||
anchors = 1
|
||||
case 15:
|
||||
fmc = 5
|
||||
anchors = 1
|
||||
useKps = true
|
||||
default:
|
||||
return 0, 0, false, false, fmt.Errorf("faces: unsupported ONNX output count %d", outCount)
|
||||
}
|
||||
|
||||
dims := outputs[0].Dimensions
|
||||
if len(dims) == 3 {
|
||||
batched = true
|
||||
}
|
||||
|
||||
return fmc, anchors, useKps, batched, nil
|
||||
}
|
||||
|
||||
// stridesForFeatureMaps returns SCRFD's default strides for the given number of feature maps.
|
||||
func stridesForFeatureMaps(fmc int) []int {
|
||||
if fmc == 5 {
|
||||
return []int{8, 16, 32, 64, 128}
|
||||
}
|
||||
|
||||
return []int{8, 16, 32}
|
||||
}
|
||||
|
||||
func (o *onnxEngine) Name() string {
|
||||
return EngineONNX
|
||||
}
|
||||
|
||||
func (o *onnxEngine) Close() error {
|
||||
if o.session != nil {
|
||||
if err := o.session.Destroy(); err != nil {
|
||||
return err
|
||||
}
|
||||
o.session = nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Detect identifies faces in the provided image using the ONNX runtime session.
|
||||
func (o *onnxEngine) Detect(fileName string, findLandmarks bool, minSize int) (Faces, error) {
|
||||
file, err := os.Open(fileName)
|
||||
if err != nil {
|
||||
return Faces{}, err
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
img, _, err := image.Decode(file)
|
||||
if err != nil {
|
||||
return Faces{}, err
|
||||
}
|
||||
|
||||
width := img.Bounds().Dx()
|
||||
height := img.Bounds().Dy()
|
||||
if width == 0 || height == 0 {
|
||||
return Faces{}, fmt.Errorf("faces: invalid image dimensions")
|
||||
}
|
||||
|
||||
blob, detScale, err := o.buildBlob(img)
|
||||
if err != nil {
|
||||
return Faces{}, err
|
||||
}
|
||||
|
||||
shape := onnxruntime.Shape{1, 3, int64(o.inputHeight), int64(o.inputWidth)}
|
||||
tensor, err := onnxruntime.NewTensor(shape, blob)
|
||||
if err != nil {
|
||||
return Faces{}, fmt.Errorf("faces: create tensor: %w", err)
|
||||
}
|
||||
defer tensor.Destroy()
|
||||
|
||||
inputs := []onnxruntime.Value{tensor}
|
||||
outputs := make([]onnxruntime.Value, len(o.outputNames))
|
||||
if err := o.session.Run(inputs, outputs); err != nil {
|
||||
return Faces{}, fmt.Errorf("faces: run session: %w", err)
|
||||
}
|
||||
for _, out := range outputs {
|
||||
if out != nil {
|
||||
defer out.Destroy()
|
||||
}
|
||||
}
|
||||
|
||||
detections, err := o.parseDetections(outputs, detScale, width, height)
|
||||
if err != nil {
|
||||
return Faces{}, err
|
||||
}
|
||||
|
||||
filtered := nonMaxSuppression(detections, o.nmsThreshold)
|
||||
result := make(Faces, 0, len(filtered))
|
||||
|
||||
for _, det := range filtered {
|
||||
faceWidth := det.x2 - det.x1
|
||||
faceHeight := det.y2 - det.y1
|
||||
size := int(math.Max(float64(faceWidth), float64(faceHeight)))
|
||||
if size < minSize {
|
||||
continue
|
||||
}
|
||||
|
||||
row := int((det.y1 + det.y2) * 0.5)
|
||||
col := int((det.x1 + det.x2) * 0.5)
|
||||
score := int(math.Round(float64(det.score * 100)))
|
||||
if score > 100 {
|
||||
score = 100
|
||||
} else if score < 0 {
|
||||
score = 0
|
||||
}
|
||||
|
||||
f := Face{
|
||||
Rows: height,
|
||||
Cols: width,
|
||||
Score: score,
|
||||
Area: NewArea("face", row, col, size),
|
||||
}
|
||||
|
||||
result.Append(f)
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// buildBlob normalises the input image into the tensor layout expected by SCRFD.
|
||||
func (o *onnxEngine) buildBlob(img image.Image) ([]float32, float32, error) {
|
||||
inputWidth := o.inputWidth
|
||||
inputHeight := o.inputHeight
|
||||
|
||||
if inputWidth < 1 {
|
||||
inputWidth = onnxDefaultInputSize
|
||||
}
|
||||
|
||||
if inputHeight < 1 {
|
||||
inputHeight = onnxDefaultInputSize
|
||||
}
|
||||
|
||||
bounds := img.Bounds()
|
||||
width := bounds.Dx()
|
||||
height := bounds.Dy()
|
||||
|
||||
if width == 0 || height == 0 {
|
||||
return nil, 0, fmt.Errorf("faces: invalid image dimensions")
|
||||
}
|
||||
|
||||
imRatio := float32(height) / float32(width)
|
||||
modelRatio := float32(inputHeight) / float32(inputWidth)
|
||||
|
||||
newHeight := inputHeight
|
||||
newWidth := inputWidth
|
||||
if imRatio > modelRatio {
|
||||
newHeight = inputHeight
|
||||
newWidth = int(float32(newHeight) / imRatio)
|
||||
} else {
|
||||
newWidth = inputWidth
|
||||
newHeight = int(float32(newWidth) * imRatio)
|
||||
}
|
||||
|
||||
if newWidth < 1 {
|
||||
newWidth = 1
|
||||
}
|
||||
if newHeight < 1 {
|
||||
newHeight = 1
|
||||
}
|
||||
|
||||
resized := imaging.Resize(img, newWidth, newHeight, imaging.Linear)
|
||||
|
||||
planeSize := inputWidth * inputHeight
|
||||
blob := make([]float32, planeSize*3)
|
||||
|
||||
for y := 0; y < inputHeight; y++ {
|
||||
for x := 0; x < inputWidth; x++ {
|
||||
idx := y*inputWidth + x
|
||||
var r, g, b float32
|
||||
if x < newWidth && y < newHeight {
|
||||
cr, cg, cb, _ := resized.At(x, y).RGBA()
|
||||
r = float32(uint8(cr >> 8))
|
||||
g = float32(uint8(cg >> 8))
|
||||
b = float32(uint8(cb >> 8))
|
||||
}
|
||||
|
||||
blob[idx] = (r - onnxInputMean) / onnxInputStd
|
||||
blob[idx+planeSize] = (g - onnxInputMean) / onnxInputStd
|
||||
blob[idx+planeSize*2] = (b - onnxInputMean) / onnxInputStd
|
||||
}
|
||||
}
|
||||
|
||||
detScale := float32(newHeight) / float32(height)
|
||||
|
||||
return blob, detScale, nil
|
||||
}
|
||||
|
||||
// parseDetections decodes model outputs into bounding boxes in the original image space.
|
||||
func (o *onnxEngine) parseDetections(values []onnxruntime.Value, detScale float32, origWidth, origHeight int) ([]onnxDetection, error) {
|
||||
fmc := len(o.featStrides)
|
||||
detections := make([]onnxDetection, 0, 32)
|
||||
|
||||
for level, stride := range o.featStrides {
|
||||
scoreTensor, ok := values[level].(*onnxruntime.Tensor[float32])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("faces: unexpected tensor type for scores")
|
||||
}
|
||||
|
||||
bboxTensor, ok := values[level+fmc].(*onnxruntime.Tensor[float32])
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("faces: unexpected tensor type for boxes")
|
||||
}
|
||||
|
||||
scores := scoreTensor.GetData()
|
||||
boxes := bboxTensor.GetData()
|
||||
|
||||
height := o.inputHeight / stride
|
||||
width := o.inputWidth / stride
|
||||
cells := height * width
|
||||
anchors := o.numAnchors
|
||||
expected := cells * anchors
|
||||
|
||||
switch {
|
||||
case len(scores) == expected:
|
||||
// already aligned
|
||||
case len(scores) == expected*2:
|
||||
trimmed := make([]float32, expected)
|
||||
copy(trimmed, scores[len(scores)-expected:])
|
||||
scores = trimmed
|
||||
default:
|
||||
return nil, fmt.Errorf("faces: unexpected score tensor size %d (expected %d)", len(scores), expected)
|
||||
}
|
||||
|
||||
if len(boxes) != expected*4 {
|
||||
return nil, fmt.Errorf("faces: mismatch between scores and boxes")
|
||||
}
|
||||
|
||||
centers := o.anchorCenters(height, width, stride, anchors)
|
||||
|
||||
for idx, score := range scores {
|
||||
if score < o.scoreThreshold {
|
||||
continue
|
||||
}
|
||||
|
||||
cx := centers[idx*2]
|
||||
cy := centers[idx*2+1]
|
||||
boxOffset := idx * 4
|
||||
left := boxes[boxOffset] * float32(stride)
|
||||
top := boxes[boxOffset+1] * float32(stride)
|
||||
right := boxes[boxOffset+2] * float32(stride)
|
||||
bottom := boxes[boxOffset+3] * float32(stride)
|
||||
|
||||
x1 := clampFloat32((cx-left)/detScale, 0, float32(origWidth))
|
||||
y1 := clampFloat32((cy-top)/detScale, 0, float32(origHeight))
|
||||
x2 := clampFloat32((cx+right)/detScale, 0, float32(origWidth))
|
||||
y2 := clampFloat32((cy+bottom)/detScale, 0, float32(origHeight))
|
||||
|
||||
if x2 <= x1 || y2 <= y1 {
|
||||
continue
|
||||
}
|
||||
|
||||
detections = append(detections, onnxDetection{
|
||||
x1: x1,
|
||||
y1: y1,
|
||||
x2: x2,
|
||||
y2: y2,
|
||||
score: score,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return detections, nil
|
||||
}
|
||||
|
||||
// anchorCenters returns cached anchor centers for the given feature map shape.
|
||||
func (o *onnxEngine) anchorCenters(height, width, stride, anchors int) []float32 {
|
||||
key := anchorCacheKey{height: height, width: width, stride: stride, anchors: anchors}
|
||||
|
||||
o.centerMu.Lock()
|
||||
cached, ok := o.centerCache[key]
|
||||
if ok {
|
||||
o.centerMu.Unlock()
|
||||
return cached
|
||||
}
|
||||
|
||||
centers := make([]float32, height*width*anchors*2)
|
||||
idx := 0
|
||||
for y := 0; y < height; y++ {
|
||||
cy := float32(y * stride)
|
||||
for x := 0; x < width; x++ {
|
||||
cx := float32(x * stride)
|
||||
for a := 0; a < anchors; a++ {
|
||||
centers[idx] = cx
|
||||
centers[idx+1] = cy
|
||||
idx += 2
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o.centerCache[key] = centers
|
||||
o.centerMu.Unlock()
|
||||
return centers
|
||||
}
|
||||
|
||||
// onnxDetection stores a single detection candidate in image coordinates.
|
||||
type onnxDetection struct {
|
||||
x1 float32
|
||||
y1 float32
|
||||
x2 float32
|
||||
y2 float32
|
||||
score float32
|
||||
}
|
||||
|
||||
// nonMaxSuppression filters overlapping detection boxes using IoU thresholding.
|
||||
func nonMaxSuppression(boxes []onnxDetection, threshold float32) []onnxDetection {
|
||||
if len(boxes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(boxes, func(i, j int) bool {
|
||||
return boxes[i].score > boxes[j].score
|
||||
})
|
||||
|
||||
picked := make([]onnxDetection, 0, len(boxes))
|
||||
suppressed := make([]bool, len(boxes))
|
||||
|
||||
for i := 0; i < len(boxes); i++ {
|
||||
if suppressed[i] {
|
||||
continue
|
||||
}
|
||||
|
||||
current := boxes[i]
|
||||
picked = append(picked, current)
|
||||
|
||||
for j := i + 1; j < len(boxes); j++ {
|
||||
if suppressed[j] {
|
||||
continue
|
||||
}
|
||||
|
||||
if iou(current, boxes[j]) > threshold {
|
||||
suppressed[j] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return picked
|
||||
}
|
||||
|
||||
// iou calculates the intersection-over-union score for two detections.
|
||||
func iou(a, b onnxDetection) float32 {
|
||||
x1 := float32(math.Max(float64(a.x1), float64(b.x1)))
|
||||
y1 := float32(math.Max(float64(a.y1), float64(b.y1)))
|
||||
x2 := float32(math.Min(float64(a.x2), float64(b.x2)))
|
||||
y2 := float32(math.Min(float64(a.y2), float64(b.y2)))
|
||||
|
||||
w := x2 - x1
|
||||
h := y2 - y1
|
||||
if w <= 0 || h <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
inter := w * h
|
||||
areaA := (a.x2 - a.x1) * (a.y2 - a.y1)
|
||||
areaB := (b.x2 - b.x1) * (b.y2 - b.y1)
|
||||
union := areaA + areaB - inter
|
||||
if union <= 0 {
|
||||
return 0
|
||||
}
|
||||
|
||||
return inter / union
|
||||
}
|
||||
|
||||
// clampFloat32 bounds v to the inclusive range [min, max].
|
||||
func clampFloat32(v, min, max float32) float32 {
|
||||
if v < min {
|
||||
return min
|
||||
}
|
||||
if v > max {
|
||||
return max
|
||||
}
|
||||
return v
|
||||
}
|
||||
78
internal/ai/face/engine_onnx_test.go
Normal file
78
internal/ai/face/engine_onnx_test.go
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
package face
|
||||
|
||||
import (
|
||||
"image"
|
||||
"image/color"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
onnxruntime "github.com/yalue/onnxruntime_go"
|
||||
)
|
||||
|
||||
// TestONNXSharedLibraryCandidates_Defaults verifies default search ordering when no explicit path is provided.
|
||||
func TestONNXSharedLibraryCandidates_Defaults(t *testing.T) {
|
||||
t.Cleanup(func() { onnxExecutableVar = os.Executable })
|
||||
onnxExecutableVar = func() (string, error) {
|
||||
return filepath.Join("/opt/photoprism", "bin", "photoprism"), nil
|
||||
}
|
||||
|
||||
candidates := onnxSharedLibraryCandidates("")
|
||||
require.NotEmpty(t, candidates)
|
||||
require.Equal(t, "libonnxruntime.so", candidates[0])
|
||||
require.Contains(t, candidates, filepath.Join("/opt/photoprism", "lib", "libonnxruntime.so"))
|
||||
}
|
||||
|
||||
// TestONNXSharedLibraryCandidates_ExplicitFirst ensures explicit paths remain the first candidate.
|
||||
func TestONNXSharedLibraryCandidates_ExplicitFirst(t *testing.T) {
|
||||
t.Cleanup(func() { onnxExecutableVar = os.Executable })
|
||||
onnxExecutableVar = func() (string, error) { return "/tmp/photoprism", nil }
|
||||
|
||||
explicit := "/custom/libonnxruntime.so"
|
||||
candidates := onnxSharedLibraryCandidates(explicit)
|
||||
require.NotEmpty(t, candidates)
|
||||
require.Equal(t, explicit, candidates[0])
|
||||
}
|
||||
|
||||
func TestDeriveONNXLayout(t *testing.T) {
|
||||
outputs := make([]onnxruntime.InputOutputInfo, 9)
|
||||
outputs[0] = onnxruntime.InputOutputInfo{Dimensions: onnxruntime.Shape{1, 3, 3}}
|
||||
|
||||
fmc, anchors, useKps, batched, err := deriveONNXLayout(outputs)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, fmc)
|
||||
require.Equal(t, 2, anchors)
|
||||
require.True(t, useKps)
|
||||
require.True(t, batched)
|
||||
|
||||
_, _, _, _, err = deriveONNXLayout(make([]onnxruntime.InputOutputInfo, 1))
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestStridesForFeatureMaps(t *testing.T) {
|
||||
require.Equal(t, []int{8, 16, 32, 64, 128}, stridesForFeatureMaps(5))
|
||||
require.Equal(t, []int{8, 16, 32}, stridesForFeatureMaps(3))
|
||||
}
|
||||
|
||||
func TestONNXEngineAnchorCentersCaches(t *testing.T) {
|
||||
engine := &onnxEngine{centerCache: make(map[anchorCacheKey][]float32)}
|
||||
centers1 := engine.anchorCenters(2, 2, 8, 2)
|
||||
require.Len(t, centers1, 16)
|
||||
centers2 := engine.anchorCenters(2, 2, 8, 2)
|
||||
// The cache should return the same backing array.
|
||||
require.Equal(t, ¢ers1[0], ¢ers2[0])
|
||||
}
|
||||
|
||||
func TestONNXEngineBuildBlob(t *testing.T) {
|
||||
engine := &onnxEngine{inputWidth: 4, inputHeight: 4}
|
||||
img := image.NewRGBA(image.Rect(0, 0, 1, 1))
|
||||
img.Set(0, 0, color.RGBA{R: 255, G: 0, B: 0, A: 255})
|
||||
|
||||
blob, scale, err := engine.buildBlob(img)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, blob, 4*4*3)
|
||||
require.InDelta(t, (255-onnxInputMean)/onnxInputStd, blob[0], 1e-3)
|
||||
require.InDelta(t, (0-onnxInputMean)/onnxInputStd, blob[16], 1e-3)
|
||||
require.Equal(t, float32(4), scale)
|
||||
}
|
||||
21
internal/ai/face/engine_test.go
Normal file
21
internal/ai/face/engine_test.go
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
package face
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestParseEngine(t *testing.T) {
|
||||
cases := map[string]EngineName{
|
||||
"": EngineAuto,
|
||||
"AUTO": EngineAuto,
|
||||
"pigo": EnginePigo,
|
||||
" PIGO ": EnginePigo,
|
||||
"onnx": EngineONNX,
|
||||
"OnNx": EngineONNX,
|
||||
"unknown": EngineAuto,
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
if got := ParseEngine(input); got != expected {
|
||||
t.Fatalf("ParseEngine(%q) = %q, expected %q", input, got, expected)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -15,7 +15,8 @@ func FilterModels(models []string, when RunType, allow func(ModelType, RunType)
|
|||
filtered := make([]string, 0, len(models))
|
||||
|
||||
for _, name := range models {
|
||||
modelType := ModelType(strings.TrimSpace(name))
|
||||
modelType := strings.TrimSpace(name)
|
||||
|
||||
if modelType == "" {
|
||||
continue
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,6 +49,10 @@ var FacesCommands = &cli.Command{
|
|||
Aliases: []string{"f"},
|
||||
Usage: "removes all people and faces",
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "engine",
|
||||
Usage: "regenerate markers using detection engine `NAME` (auto, pigo, onnx)",
|
||||
},
|
||||
},
|
||||
Action: facesResetAction,
|
||||
},
|
||||
|
|
@ -169,14 +173,21 @@ func facesResetAction(ctx *cli.Context) error {
|
|||
|
||||
w := get.Faces()
|
||||
|
||||
if err := w.Reset(); err != nil {
|
||||
return err
|
||||
} else {
|
||||
elapsed := time.Since(start)
|
||||
engine := strings.TrimSpace(ctx.String("engine"))
|
||||
|
||||
log.Infof("completed in %s", elapsed)
|
||||
if engine != "" {
|
||||
if err := w.ResetAndReindex(engine); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if err := w.Reset(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
elapsed := time.Since(start)
|
||||
log.Infof("completed in %s", elapsed)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -351,6 +351,15 @@ func (c *Config) Propagate() {
|
|||
face.ClusterDist = c.FaceClusterDist()
|
||||
face.MatchDist = c.FaceMatchDist()
|
||||
face.DetectionAngles = c.FaceAngles()
|
||||
if err := face.ConfigureEngine(face.EngineSettings{
|
||||
Name: c.FaceEngine(),
|
||||
ONNX: face.ONNXOptions{
|
||||
ModelPath: c.FaceEngineModelPath(),
|
||||
Threads: c.FaceEngineThreads(),
|
||||
},
|
||||
}); err != nil {
|
||||
log.Warnf("faces: %s (configure engine)", err)
|
||||
}
|
||||
|
||||
// Set default theme and locale.
|
||||
customize.DefaultTheme = c.DefaultTheme()
|
||||
|
|
|
|||
|
|
@ -2,8 +2,12 @@ package config
|
|||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
)
|
||||
|
||||
// FaceSize returns the face size threshold in pixels.
|
||||
|
|
@ -110,3 +114,140 @@ func (c *Config) FaceAngles() []float64 {
|
|||
|
||||
return angles
|
||||
}
|
||||
|
||||
// FaceEngine returns the configured face detection engine name.
|
||||
func (c *Config) FaceEngine() string {
|
||||
if c == nil {
|
||||
return face.EnginePigo
|
||||
} else if c.options.FaceEngine == face.EnginePigo || c.options.FaceEngine == face.EngineONNX {
|
||||
return c.options.FaceEngine
|
||||
}
|
||||
|
||||
desired := face.ParseEngine(c.options.FaceEngine)
|
||||
modelPath := c.FaceEngineModelPath()
|
||||
|
||||
if desired == face.EngineAuto {
|
||||
if modelPath != "" {
|
||||
if _, err := os.Stat(modelPath); err == nil {
|
||||
desired = face.EngineONNX
|
||||
} else {
|
||||
desired = face.EnginePigo
|
||||
}
|
||||
} else {
|
||||
desired = face.EnginePigo
|
||||
}
|
||||
|
||||
c.options.FaceEngine = desired
|
||||
}
|
||||
|
||||
return desired
|
||||
}
|
||||
|
||||
// FaceEngineRunType returns the configured run type for the face detection engine.
|
||||
func (c *Config) FaceEngineRunType() vision.RunType {
|
||||
if c == nil {
|
||||
return "auto"
|
||||
}
|
||||
|
||||
c.options.FaceEngineRun = vision.ParseRunType(c.options.FaceEngineRun)
|
||||
|
||||
if c.options.FaceEngineRun == vision.RunAuto {
|
||||
if c.FaceEngine() == face.EngineONNX && c.FaceEngineThreads() < 2 {
|
||||
c.options.FaceEngineRun = vision.RunOnDemand
|
||||
}
|
||||
}
|
||||
|
||||
if c.options.FaceEngineRun == vision.RunAuto {
|
||||
return "auto"
|
||||
}
|
||||
|
||||
return c.options.FaceEngineRun
|
||||
}
|
||||
|
||||
// FaceEngineShouldRun reports whether the face detection engine should execute in the
|
||||
// specified scheduling context.
|
||||
func (c *Config) FaceEngineShouldRun(when vision.RunType) bool {
|
||||
if c == nil || c.DisableFaces() {
|
||||
return false
|
||||
}
|
||||
|
||||
run := c.FaceEngineRunType()
|
||||
when = vision.ParseRunType(when)
|
||||
|
||||
switch run {
|
||||
case vision.RunNever:
|
||||
return false
|
||||
case vision.RunManual:
|
||||
return when == vision.RunManual
|
||||
case vision.RunAlways:
|
||||
return when != vision.RunNever
|
||||
case vision.RunNewlyIndexed:
|
||||
return when == vision.RunManual || when == vision.RunNewlyIndexed || when == vision.RunOnDemand
|
||||
case vision.RunOnDemand:
|
||||
return when == vision.RunAuto || when == vision.RunManual || when == vision.RunNewlyIndexed || when == vision.RunOnDemand || when == vision.RunOnSchedule
|
||||
case vision.RunOnSchedule:
|
||||
return when == vision.RunAuto || when == vision.RunManual || when == vision.RunOnSchedule || when == vision.RunOnDemand
|
||||
case vision.RunOnIndex:
|
||||
return when == vision.RunManual || when == vision.RunOnIndex
|
||||
case vision.RunAuto:
|
||||
fallthrough
|
||||
default:
|
||||
switch when {
|
||||
case vision.RunAuto, vision.RunManual, vision.RunOnDemand, vision.RunOnSchedule, vision.RunNewlyIndexed, vision.RunOnIndex:
|
||||
return true
|
||||
case vision.RunAlways:
|
||||
return true
|
||||
case vision.RunNever:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// FaceEngineRetry controls whether detection retries at a higher resolution should be performed.
|
||||
func (c *Config) FaceEngineRetry() bool {
|
||||
if c == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
return c.FaceEngine() == face.EnginePigo && c.IndexWorkers() > 2
|
||||
}
|
||||
|
||||
// FaceEngineThreads returns the configured thread count for ONNX inference.
|
||||
func (c *Config) FaceEngineThreads() int {
|
||||
if c == nil {
|
||||
return 1
|
||||
} else if c.options.FaceEngineThreads <= 0 {
|
||||
threads := runtime.NumCPU() / 2
|
||||
if threads < 1 {
|
||||
threads = 1
|
||||
}
|
||||
|
||||
c.options.FaceEngineThreads = threads
|
||||
|
||||
return threads
|
||||
}
|
||||
|
||||
return c.options.FaceEngineThreads
|
||||
}
|
||||
|
||||
// FaceEngineModelPath returns the absolute path to the bundled SCRFD ONNX detector.
|
||||
func (c *Config) FaceEngineModelPath() string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
dir := filepath.Join(c.ModelsPath(), "scrfs")
|
||||
primary := filepath.Join(dir, face.DefaultONNXModelFilename)
|
||||
if _, err := os.Stat(primary); err == nil {
|
||||
return primary
|
||||
}
|
||||
|
||||
legacy := filepath.Join(dir, "scrfd_500m_bnkps_shape640x640.onnx")
|
||||
if _, err := os.Stat(legacy); err == nil {
|
||||
return legacy
|
||||
}
|
||||
|
||||
return primary
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,11 +2,17 @@ package config
|
|||
|
||||
import (
|
||||
"math"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
)
|
||||
|
||||
func TestConfig_FaceSize(t *testing.T) {
|
||||
|
|
@ -91,3 +97,100 @@ func TestConfig_FaceAngles(t *testing.T) {
|
|||
c.options.FaceAngles = []float64{math.Pi + 0.1, math.NaN(), 4}
|
||||
assert.Equal(t, face.DefaultAngles, c.FaceAngles())
|
||||
}
|
||||
|
||||
func TestConfig_FaceEngineShouldRun(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
|
||||
assert.True(t, c.FaceEngineShouldRun(vision.RunOnIndex))
|
||||
assert.True(t, c.FaceEngineShouldRun(vision.RunNewlyIndexed))
|
||||
assert.True(t, c.FaceEngineShouldRun(vision.RunManual))
|
||||
|
||||
c.options.FaceEngineRun = string(vision.RunOnIndex)
|
||||
assert.True(t, c.FaceEngineShouldRun(vision.RunOnIndex))
|
||||
assert.False(t, c.FaceEngineShouldRun(vision.RunNewlyIndexed))
|
||||
|
||||
c.options.FaceEngineRun = string(vision.RunNever)
|
||||
assert.False(t, c.FaceEngineShouldRun(vision.RunOnIndex))
|
||||
assert.False(t, c.FaceEngineShouldRun(vision.RunNewlyIndexed))
|
||||
|
||||
c.options.FaceEngineRun = string(vision.RunManual)
|
||||
assert.True(t, c.FaceEngineShouldRun(vision.RunManual))
|
||||
assert.False(t, c.FaceEngineShouldRun(vision.RunOnDemand))
|
||||
|
||||
c.options.DisableFaces = true
|
||||
assert.False(t, c.FaceEngineShouldRun(vision.RunOnIndex))
|
||||
}
|
||||
|
||||
func TestConfig_FaceEngine(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
tempModels := t.TempDir()
|
||||
c.options.ModelsPath = tempModels
|
||||
c.options.FaceEngine = face.EnginePigo
|
||||
|
||||
assert.Equal(t, face.EnginePigo, c.FaceEngine())
|
||||
|
||||
modelDir := filepath.Join(tempModels, "scrfs")
|
||||
require.NoError(t, os.MkdirAll(modelDir, 0o755))
|
||||
modelFile := filepath.Join(modelDir, face.DefaultONNXModelFilename)
|
||||
require.NoError(t, os.WriteFile(modelFile, []byte("onnx"), 0o644))
|
||||
|
||||
c.options.FaceEngine = face.EngineAuto
|
||||
assert.Equal(t, face.EngineONNX, c.FaceEngine())
|
||||
|
||||
c.options.FaceEngine = face.EnginePigo
|
||||
assert.Equal(t, face.EnginePigo, c.FaceEngine())
|
||||
|
||||
c.options.FaceEngine = face.EngineONNX
|
||||
assert.Equal(t, face.EngineONNX, c.FaceEngine())
|
||||
}
|
||||
|
||||
func TestConfig_FaceEngineRunType(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
assert.Equal(t, "auto", c.FaceEngineRunType())
|
||||
assert.Equal(t, "", c.options.FaceEngineRun)
|
||||
|
||||
c.options.FaceEngineRun = vision.RunOnDemand
|
||||
assert.Equal(t, vision.RunOnDemand, c.FaceEngineRunType())
|
||||
|
||||
c.options.FaceEngineRun = vision.RunAuto
|
||||
c.options.FaceEngine = face.EngineONNX
|
||||
c.options.FaceEngineThreads = 1
|
||||
assert.Equal(t, "on-demand", c.FaceEngineRunType())
|
||||
assert.Equal(t, "on-demand", vision.ParseRunType(c.options.FaceEngineRun))
|
||||
|
||||
c.options.FaceEngineThreads = 4
|
||||
c.options.FaceEngineRun = vision.RunAuto
|
||||
assert.Equal(t, "auto", c.FaceEngineRunType())
|
||||
assert.Equal(t, "", c.options.FaceEngineRun)
|
||||
}
|
||||
|
||||
func TestConfig_FaceEngineRetry(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
assert.False(t, c.FaceEngineRetry())
|
||||
|
||||
c.options.FaceEngineRetry = false
|
||||
assert.False(t, c.FaceEngineRetry())
|
||||
}
|
||||
|
||||
func TestConfig_FaceEngineThreads(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
expected := runtime.NumCPU() / 2
|
||||
if expected < 1 {
|
||||
expected = 1
|
||||
}
|
||||
assert.Equal(t, expected, c.FaceEngineThreads())
|
||||
|
||||
c.options.FaceEngineThreads = 8
|
||||
assert.Equal(t, 8, c.FaceEngineThreads())
|
||||
}
|
||||
|
||||
func TestConfig_FaceEngineModelPath(t *testing.T) {
|
||||
c := NewConfig(CliTestContext())
|
||||
path := c.FaceEngineModelPath()
|
||||
assert.Contains(t, path, "scrfs")
|
||||
expected := filepath.Join(c.ModelsPath(), "scrfs", face.DefaultONNXModelFilename)
|
||||
if strings.HasSuffix(path, "scrfd_500m_bnkps_shape640x640.onnx") {
|
||||
expected = filepath.Join(c.ModelsPath(), "scrfs", "scrfd_500m_bnkps_shape640x640.onnx")
|
||||
}
|
||||
assert.Equal(t, expected, path)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1167,6 +1167,23 @@ var Flags = CliFlags{
|
|||
Usage: "flags newly added pictures as private if they might be offensive (requires TensorFlow)",
|
||||
EnvVars: EnvVars("DETECT_NSFW"),
|
||||
}}, {
|
||||
Flag: &cli.StringFlag{
|
||||
Name: "face-engine",
|
||||
Usage: "face detection engine `NAME` (auto, pigo, onnx)",
|
||||
Value: face.EngineAuto,
|
||||
EnvVars: EnvVars("FACE_ENGINE"),
|
||||
}}, {
|
||||
Flag: &cli.StringFlag{
|
||||
Name: "face-engine-run",
|
||||
Usage: "face detection run `MODE` (auto, never, manual, newly-indexed, on-demand, on-index, on-schedule, always)",
|
||||
Value: "auto",
|
||||
EnvVars: EnvVars("FACE_ENGINE_RUN"),
|
||||
}}, {
|
||||
Flag: &cli.IntFlag{
|
||||
Name: "face-engine-threads",
|
||||
Usage: "face detection thread `COUNT` (0 uses half the available CPU cores)",
|
||||
EnvVars: EnvVars("FACE_ENGINE_THREADS"),
|
||||
}}, {
|
||||
Flag: &cli.IntFlag{
|
||||
Name: "face-size",
|
||||
Usage: "minimum size of faces in `PIXELS` (20-10000)",
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ import (
|
|||
"github.com/urfave/cli/v2"
|
||||
"gopkg.in/yaml.v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
|
@ -229,6 +230,10 @@ type Options struct {
|
|||
VisionSchedule string `yaml:"VisionSchedule" json:"VisionSchedule" flag:"vision-schedule"`
|
||||
VisionFilter string `yaml:"VisionFilter" json:"VisionFilter" flag:"vision-filter"`
|
||||
DetectNSFW bool `yaml:"DetectNSFW" json:"DetectNSFW" flag:"detect-nsfw"`
|
||||
FaceEngine string `yaml:"FaceEngine" json:"-" flag:"face-engine"`
|
||||
FaceEngineRun string `yaml:"FaceEngineRun" json:"-" flag:"face-engine-run"`
|
||||
FaceEngineRetry bool `yaml:"-" json:"-" flag:"-"`
|
||||
FaceEngineThreads int `yaml:"FaceEngineThreads" json:"-" flag:"face-engine-threads"`
|
||||
FaceSize int `yaml:"-" json:"-" flag:"face-size"`
|
||||
FaceScore float64 `yaml:"-" json:"-" flag:"face-score"`
|
||||
FaceAngles []float64 `yaml:"-" json:"-" flag:"face-angle"`
|
||||
|
|
@ -250,7 +255,7 @@ type Options struct {
|
|||
// 2. ApplyCliContext: Which comes after Load and overrides
|
||||
// any previous options giving an option two override file configs through the CLI.
|
||||
func NewOptions(ctx *cli.Context) *Options {
|
||||
c := &Options{}
|
||||
c := &Options{FaceEngine: face.EngineAuto}
|
||||
|
||||
// Has context?
|
||||
if ctx == nil {
|
||||
|
|
|
|||
|
|
@ -283,6 +283,10 @@ func (c *Config) Report() (rows [][]string, cols []string) {
|
|||
{"detect-nsfw", fmt.Sprintf("%t", c.DetectNSFW())},
|
||||
|
||||
// Facial Recognition.
|
||||
{"face-engine", c.FaceEngine()},
|
||||
{"face-engine-run", c.FaceEngineRunType()},
|
||||
{"face-engine-retry", fmt.Sprintf("%t", c.FaceEngineRetry())},
|
||||
{"face-engine-threads", fmt.Sprintf("%d", c.FaceEngineThreads())},
|
||||
{"face-size", fmt.Sprintf("%d", c.FaceSize())},
|
||||
{"face-score", fmt.Sprintf("%f", c.FaceScore())},
|
||||
{"face-angle", fmt.Sprintf("%v", c.FaceAngles())},
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ var OptionsReportSections = []ReportSection{
|
|||
{Start: "PHOTOPRISM_THUMB_LIBRARY", Title: "Preview Images"},
|
||||
{Start: "PHOTOPRISM_JPEG_QUALITY", Title: "Image Quality"},
|
||||
{Start: "PHOTOPRISM_VISION_YAML", Title: "Computer Vision"},
|
||||
{Start: "PHOTOPRISM_FACE_SIZE", Title: "Face Recognition",
|
||||
{Start: "PHOTOPRISM_FACE_ENGINE", Title: "Face Recognition",
|
||||
Info: faceFlagsInfo},
|
||||
{Start: "PHOTOPRISM_PID_FILENAME", Title: "Daemon Mode",
|
||||
Info: "If you start the server as a *daemon* in the background, you can additionally specify a filename for the log and the process ID:"},
|
||||
|
|
@ -61,6 +61,7 @@ var YamlReportSections = []ReportSection{
|
|||
{Start: "ThumbLibrary", Title: "Preview Images"},
|
||||
{Start: "JpegQuality", Title: "Image Quality"},
|
||||
{Start: "VisionYaml", Title: "Computer Vision"},
|
||||
{Start: "FaceEngine", Title: "Face Recognition"},
|
||||
{Start: "PIDFilename", Title: "Daemon Mode",
|
||||
Info: "If you start the server as a *daemon* in the background, you can additionally specify a filename for the log and the process ID:"},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,10 +2,26 @@ package photoprism
|
|||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/dustin/go-humanize/english"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity/query"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
var runFacesReindex = func(conf *config.Config, opt IndexOptions) (fs.Done, int, error) {
|
||||
index := NewIndex(conf, NewConvert(conf), NewFiles(), NewPhotos())
|
||||
if index == nil {
|
||||
return nil, 0, fmt.Errorf("faces: index service unavailable")
|
||||
}
|
||||
|
||||
found, updated := index.Start(opt)
|
||||
return found, updated, nil
|
||||
}
|
||||
|
||||
// Reset removes automatically added face clusters, marker matches, and dangling subjects.
|
||||
func (w *Faces) Reset() (err error) {
|
||||
// Remove automatically added subject and face references from the markers table.
|
||||
|
|
@ -31,3 +47,54 @@ func (w *Faces) Reset() (err error) {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ResetAndReindex resets face data and optionally regenerates markers with the specified engine.
|
||||
|
||||
func (w *Faces) ResetAndReindex(engine string) error {
|
||||
trimmed := strings.TrimSpace(engine)
|
||||
lowered := strings.ToLower(trimmed)
|
||||
if lowered != "" {
|
||||
parsed := face.ParseEngine(lowered)
|
||||
if parsed == face.EngineAuto && !strings.EqualFold(trimmed, string(face.EngineAuto)) {
|
||||
return fmt.Errorf("faces: unsupported detection engine %q", engine)
|
||||
}
|
||||
}
|
||||
|
||||
if err := w.Reset(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if lowered == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.conf == nil {
|
||||
return fmt.Errorf("faces: configuration not available")
|
||||
}
|
||||
|
||||
engineName := face.ParseEngine(lowered)
|
||||
w.conf.Options().FaceEngine = engineName
|
||||
|
||||
if err := face.ConfigureEngine(face.EngineSettings{
|
||||
Name: w.conf.FaceEngine(),
|
||||
ONNX: face.ONNXOptions{
|
||||
ModelPath: w.conf.FaceEngineModelPath(),
|
||||
Threads: w.conf.FaceEngineThreads(),
|
||||
},
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
convert := w.conf.Settings().Index.Convert && w.conf.SidecarWritable()
|
||||
opt := IndexOptionsFacesOnly()
|
||||
opt.Convert = convert
|
||||
|
||||
found, updated, err := runFacesReindex(w.conf, opt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("faces: regenerated %s using %s engine (%s scanned)", english.Plural(updated, "file", "files"), w.conf.FaceEngine(), english.Plural(len(found), "file", "files"))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,7 +3,11 @@ package photoprism
|
|||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/pkg/fs"
|
||||
)
|
||||
|
||||
func TestFaces_Reset(t *testing.T) {
|
||||
|
|
@ -17,3 +21,34 @@ func TestFaces_Reset(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFaces_ResetAndReindex_InvalidEngine(t *testing.T) {
|
||||
c := config.TestConfig()
|
||||
m := NewFaces(c)
|
||||
|
||||
err := m.ResetAndReindex("invalid")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestFaces_ResetAndReindex_Pigo(t *testing.T) {
|
||||
defer func(prev func(*config.Config, IndexOptions) (fs.Done, int, error)) {
|
||||
runFacesReindex = prev
|
||||
}(runFacesReindex)
|
||||
|
||||
called := false
|
||||
var received IndexOptions
|
||||
runFacesReindex = func(conf *config.Config, opt IndexOptions) (fs.Done, int, error) {
|
||||
called = true
|
||||
received = opt
|
||||
return fs.Done{}, 0, nil
|
||||
}
|
||||
|
||||
c := config.TestConfig()
|
||||
m := NewFaces(c)
|
||||
|
||||
err := m.ResetAndReindex(face.EnginePigo)
|
||||
require.NoError(t, err)
|
||||
require.True(t, called)
|
||||
require.True(t, received.FacesOnly)
|
||||
require.Equal(t, face.EnginePigo, c.FaceEngine())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ func NewIndex(conf *config.Config, convert *Convert, files *Files, photos *Photo
|
|||
convert: convert,
|
||||
files: files,
|
||||
photos: photos,
|
||||
findFaces: !conf.DisableFaces(),
|
||||
findFaces: conf.FaceEngineShouldRun(vision.RunOnIndex),
|
||||
findLabels: conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunOnIndex),
|
||||
detectNsfw: conf.VisionModelShouldRun(vision.ModelTypeNsfw, vision.RunOnIndex),
|
||||
}
|
||||
|
|
@ -57,6 +57,21 @@ func NewIndex(conf *config.Config, convert *Convert, files *Files, photos *Photo
|
|||
return i
|
||||
}
|
||||
|
||||
// configureFaceDetection updates the face detection flag for a given indexing run.
|
||||
func (ind *Index) configureFaceDetection(o IndexOptions) {
|
||||
if ind == nil || ind.conf == nil {
|
||||
ind.findFaces = false
|
||||
return
|
||||
}
|
||||
|
||||
faceRun := vision.RunOnIndex
|
||||
if o.FacesOnly {
|
||||
faceRun = vision.RunManual
|
||||
}
|
||||
|
||||
ind.findFaces = ind.conf.FaceEngineShouldRun(faceRun)
|
||||
}
|
||||
|
||||
func (ind *Index) shouldFlagPrivate(labels classify.Labels) bool {
|
||||
if ind == nil || ind.conf == nil || !ind.conf.DetectNSFW() {
|
||||
return false
|
||||
|
|
@ -103,6 +118,8 @@ func (ind *Index) Start(o IndexOptions) (found fs.Done, updated int) {
|
|||
return found, updated
|
||||
}
|
||||
|
||||
ind.configureFaceDetection(o)
|
||||
|
||||
originalsPath := ind.originalsPath()
|
||||
optionsPath := filepath.Join(originalsPath, o.Path)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,26 +1,35 @@
|
|||
package photoprism
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/dustin/go-humanize/english"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/face"
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/thumb"
|
||||
"github.com/photoprism/photoprism/pkg/clean"
|
||||
)
|
||||
|
||||
// Faces finds faces in JPEG media files and returns them.
|
||||
func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
|
||||
// DetectFaces finds faces in JPEG media files and returns them.
|
||||
func DetectFaces(jpeg *MediaFile, expected int) (face.Faces, error) {
|
||||
if jpeg == nil {
|
||||
return face.Faces{}
|
||||
return face.Faces{}, fmt.Errorf("missing media file")
|
||||
}
|
||||
|
||||
engine := face.ActiveEngine()
|
||||
engineName := ""
|
||||
if engine != nil {
|
||||
engineName = engine.Name()
|
||||
}
|
||||
|
||||
var thumbSize thumb.Name
|
||||
|
||||
// Select best thumbnail depending on configured size.
|
||||
if Config().ThumbSizePrecached() < 1280 {
|
||||
if engineName == face.EngineONNX {
|
||||
thumbSize = thumb.Fit720
|
||||
} else if Config().ThumbSizePrecached() < 1280 {
|
||||
thumbSize = thumb.Fit720
|
||||
} else {
|
||||
thumbSize = thumb.Fit1280
|
||||
|
|
@ -30,23 +39,26 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
|
|||
|
||||
if err != nil {
|
||||
log.Debugf("vision: %s in %s (detect faces)", err, clean.Log(jpeg.BaseName()))
|
||||
return face.Faces{}
|
||||
return face.Faces{}, err
|
||||
}
|
||||
|
||||
if thumbName == "" {
|
||||
log.Debugf("vision: thumb %s not found in %s (detect faces)", thumbSize, clean.Log(jpeg.BaseName()))
|
||||
return face.Faces{}
|
||||
return face.Faces{}, fmt.Errorf("thumbnail %s not found", thumbSize)
|
||||
}
|
||||
|
||||
start := time.Now()
|
||||
var detectErr error
|
||||
allowRetry := Config().FaceEngineRetry() && (engineName == "" || engineName == face.EnginePigo)
|
||||
|
||||
faces, err := vision.Faces(thumbName, Config().FaceSize(), true, expected)
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("vision: %s in %s (detect faces)", err, clean.Log(jpeg.BaseName()))
|
||||
detectErr = err
|
||||
}
|
||||
|
||||
if thumbSize != thumb.Fit1280 {
|
||||
if allowRetry && thumbSize != thumb.Fit1280 {
|
||||
needRetry := len(faces) == 0
|
||||
|
||||
if !needRetry && expected > 0 && len(faces) < expected {
|
||||
|
|
@ -64,9 +76,11 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
|
|||
log.Debugf("vision: thumb %s not found in %s (detect faces @1280)", thumb.Fit1280, clean.Log(jpeg.BaseName()))
|
||||
} else if retryFaces, retryErr := vision.Faces(altThumb, Config().FaceSize(), true, expected); retryErr != nil {
|
||||
log.Debugf("vision: %s in %s (detect faces @1280)", retryErr, clean.Log(jpeg.BaseName()))
|
||||
detectErr = retryErr
|
||||
} else if len(retryFaces) > 0 {
|
||||
log.Debugf("vision: retry face detection for %s using %s", clean.Log(jpeg.BaseName()), thumb.Fit1280)
|
||||
faces = retryFaces
|
||||
detectErr = nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -75,5 +89,40 @@ func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
|
|||
log.Infof("vision: found %s in %s [%s]", english.Plural(l, "face", "faces"), clean.Log(jpeg.BaseName()), time.Since(start))
|
||||
}
|
||||
|
||||
return faces, detectErr
|
||||
}
|
||||
|
||||
// ApplyDetectedFaces persists detected faces on the given file and updates face counts.
|
||||
func ApplyDetectedFaces(file *entity.File, faces face.Faces) (saved bool, count int, err error) {
|
||||
if file == nil {
|
||||
return false, 0, fmt.Errorf("faces: file is nil")
|
||||
}
|
||||
|
||||
if len(faces) == 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
file.AddFaces(faces)
|
||||
|
||||
savedMarkers, saveErr := file.SaveMarkers()
|
||||
if saveErr != nil {
|
||||
return false, 0, saveErr
|
||||
}
|
||||
|
||||
if savedMarkers == 0 {
|
||||
return false, 0, nil
|
||||
}
|
||||
|
||||
count, updateErr := file.UpdatePhotoFaceCount()
|
||||
if updateErr != nil {
|
||||
return true, 0, updateErr
|
||||
}
|
||||
|
||||
return true, count, nil
|
||||
}
|
||||
|
||||
// Faces finds faces in JPEG media files and returns them.
|
||||
func (ind *Index) Faces(jpeg *MediaFile, expected int) face.Faces {
|
||||
faces, _ := DetectFaces(jpeg, expected)
|
||||
return faces
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ import (
|
|||
|
||||
"github.com/dustin/go-humanize/english"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
|
|
@ -102,3 +103,32 @@ func TestIndex_File(t *testing.T) {
|
|||
|
||||
assert.Equal(t, IndexFailed, err.Status)
|
||||
}
|
||||
|
||||
// TestIndexConfigureFaceDetectionFacesOnlyManual ensures faces-only runs override manual scheduling.
|
||||
func TestIndexConfigureFaceDetectionFacesOnlyManual(t *testing.T) {
|
||||
cfg := config.NewConfig(config.CliTestContext())
|
||||
cfg.Options().FaceEngineRun = string(vision.RunManual)
|
||||
|
||||
ind := NewIndex(cfg, nil, nil, nil)
|
||||
require.NotNil(t, ind)
|
||||
require.False(t, ind.findFaces)
|
||||
|
||||
opt := NewIndexOptions("", true, false, true, true, true)
|
||||
ind.configureFaceDetection(opt)
|
||||
|
||||
require.True(t, ind.findFaces)
|
||||
}
|
||||
|
||||
// TestIndexConfigureFaceDetectionFacesOnlyNever confirms the scheduler honors the "never" run mode.
|
||||
func TestIndexConfigureFaceDetectionFacesOnlyNever(t *testing.T) {
|
||||
cfg := config.NewConfig(config.CliTestContext())
|
||||
cfg.Options().FaceEngineRun = string(vision.RunNever)
|
||||
|
||||
ind := NewIndex(cfg, nil, nil, nil)
|
||||
require.NotNil(t, ind)
|
||||
|
||||
opt := NewIndexOptions("", true, false, true, true, true)
|
||||
ind.configureFaceDetection(opt)
|
||||
|
||||
require.False(t, ind.findFaces)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -51,16 +51,6 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
|||
// Check time when worker was last executed.
|
||||
updateIndex := force || mutex.MetaWorker.LastRun().Before(time.Now().Add(-1*entity.IndexUpdateInterval))
|
||||
|
||||
// Run faces worker if needed.
|
||||
if updateIndex || entity.UpdateFaces.Load() {
|
||||
log.Debugf("index: running face recognition")
|
||||
if faces := photoprism.NewFaces(w.conf); faces.Disabled() {
|
||||
log.Debugf("index: skipping face recognition")
|
||||
} else if facesErr := faces.Start(photoprism.FacesOptions{}); facesErr != nil {
|
||||
log.Warn(facesErr)
|
||||
}
|
||||
}
|
||||
|
||||
// Refresh index metadata.
|
||||
log.Debugf("index: updating metadata")
|
||||
|
||||
|
|
@ -70,10 +60,12 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
|||
limit := 1000
|
||||
offset := 0
|
||||
optimized := 0
|
||||
facesJobRequired := updateIndex || entity.UpdateFaces.Load()
|
||||
|
||||
labelsModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeLabels, vision.RunNewlyIndexed)
|
||||
captionModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeCaption, vision.RunNewlyIndexed)
|
||||
nsfwModelShouldRun := w.conf.VisionModelShouldRun(vision.ModelTypeNsfw, vision.RunNewlyIndexed)
|
||||
faceRunNewlyIndexed := w.conf.FaceEngineShouldRun(vision.RunNewlyIndexed)
|
||||
|
||||
if nsfwModelShouldRun {
|
||||
log.Debugf("index: cannot run %s model on %s", vision.ModelTypeNsfw, vision.RunNewlyIndexed)
|
||||
|
|
@ -106,9 +98,10 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
|||
generateLabels := labelsModelShouldRun && photo.ShouldGenerateLabels(false)
|
||||
generateCaption := captionModelShouldRun && photo.ShouldGenerateCaption(entity.SrcAuto, false)
|
||||
detectNsfw := w.conf.DetectNSFW() && !photo.PhotoPrivate
|
||||
runDetection := faceRunNewlyIndexed && photo.IsNewlyIndexed()
|
||||
|
||||
// If configured, generate metadata for newly indexed photos using external vision services.
|
||||
if photo.IsNewlyIndexed() && (generateLabels || generateCaption) {
|
||||
if photo.IsNewlyIndexed() && (runDetection || generateLabels || generateCaption) {
|
||||
primaryFile, fileErr := photo.PrimaryFile()
|
||||
|
||||
if fileErr != nil {
|
||||
|
|
@ -124,6 +117,24 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
|||
log.Debugf("index: could not open primary file %s (generate metadata)", clean.Error(mediaErr))
|
||||
}
|
||||
} else {
|
||||
if runDetection {
|
||||
if markers := primaryFile.Markers(); markers == nil {
|
||||
log.Errorf("index: failed loading markers for %s", logName)
|
||||
} else {
|
||||
expected := markers.DetectedFaceCount()
|
||||
faces, detectErr := photoprism.DetectFaces(mediaFile, expected)
|
||||
|
||||
if detectErr != nil {
|
||||
log.Debugf("vision: %s in %s (detect faces)", detectErr, clean.Log(mediaFile.BaseName()))
|
||||
} else if saved, count, applyErr := photoprism.ApplyDetectedFaces(primaryFile, faces); applyErr != nil {
|
||||
log.Warnf("index: %s in %s (save faces)", clean.Error(applyErr), logName)
|
||||
} else if saved {
|
||||
photo.PhotoFaces = count
|
||||
facesJobRequired = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate photo labels if needed.
|
||||
if generateLabels {
|
||||
if labels := mediaFile.GenerateLabels(entity.SrcAuto); len(labels) > 0 {
|
||||
|
|
@ -188,6 +199,7 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
|||
|
||||
// Only update index if necessary.
|
||||
if updateIndex {
|
||||
facesJobRequired = true
|
||||
// Set photo quality scores to -1 if files are missing.
|
||||
if err = query.FlagHiddenPhotos(); err != nil {
|
||||
log.Warnf("index: %s in optimization worker", err)
|
||||
|
|
@ -211,5 +223,14 @@ func (w *Meta) Start(delay, interval time.Duration, force bool) (err error) {
|
|||
}
|
||||
}
|
||||
|
||||
if facesJobRequired {
|
||||
log.Debugf("index: running face recognition")
|
||||
if faces := photoprism.NewFaces(w.conf); faces.Disabled() {
|
||||
log.Debugf("index: skipping face recognition")
|
||||
} else if facesErr := faces.Start(photoprism.FacesOptions{}); facesErr != nil {
|
||||
log.Warn(facesErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
|
|
@ -99,12 +99,17 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
|||
defer mutex.VisionWorker.Stop()
|
||||
|
||||
models = vision.FilterModels(models, runType, func(mt vision.ModelType, when vision.RunType) bool {
|
||||
if mt == vision.ModelTypeFace {
|
||||
return w.conf.FaceEngineShouldRun(when)
|
||||
}
|
||||
|
||||
return w.conf.VisionModelShouldRun(mt, when)
|
||||
})
|
||||
|
||||
updateLabels := slices.Contains(models, vision.ModelTypeLabels)
|
||||
updateNsfw := slices.Contains(models, vision.ModelTypeNsfw)
|
||||
updateCaptions := slices.Contains(models, vision.ModelTypeCaption)
|
||||
updateFaces := slices.Contains(models, vision.ModelTypeFace)
|
||||
|
||||
// Refresh index metadata.
|
||||
if n := len(models); n == 0 {
|
||||
|
|
@ -118,6 +123,7 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
|||
|
||||
// Check time when worker was last executed.
|
||||
updateIndex := false
|
||||
facesJobRequired := false
|
||||
|
||||
start := time.Now()
|
||||
done := make(map[string]bool)
|
||||
|
|
@ -181,7 +187,7 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
|||
generateCaptions := updateCaptions && m.ShouldGenerateCaption(customSrc, force)
|
||||
detectNsfw := updateNsfw && (!photo.PhotoPrivate || force)
|
||||
|
||||
if !(generateLabels || generateCaptions || detectNsfw) {
|
||||
if !(generateLabels || generateCaptions || detectNsfw || updateFaces) {
|
||||
continue
|
||||
}
|
||||
|
||||
|
|
@ -195,6 +201,28 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
|||
|
||||
changed := false
|
||||
|
||||
if updateFaces {
|
||||
if primaryFile, err := m.PrimaryFile(); err != nil {
|
||||
log.Debugf("vision: photo %s has invalid primary file (%s)", logName, clean.Error(err))
|
||||
} else if primaryFile == nil {
|
||||
log.Debugf("vision: missing primary file for %s", logName)
|
||||
} else if markers := primaryFile.Markers(); markers == nil {
|
||||
log.Errorf("vision: failed loading markers for %s", logName)
|
||||
} else {
|
||||
expected := markers.DetectedFaceCount()
|
||||
faces, detectErr := photoprism.DetectFaces(file, expected)
|
||||
if detectErr != nil {
|
||||
log.Debugf("vision: %s in %s (detect faces)", detectErr, clean.Log(file.BaseName()))
|
||||
} else if saved, count, applyErr := photoprism.ApplyDetectedFaces(primaryFile, faces); applyErr != nil {
|
||||
log.Warnf("vision: %s in %s (save faces)", clean.Error(applyErr), logName)
|
||||
} else if saved {
|
||||
m.PhotoFaces = count
|
||||
facesJobRequired = true
|
||||
changed = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Generate labels.
|
||||
if generateLabels {
|
||||
if labels := file.GenerateLabels(customSrc); len(labels) > 0 {
|
||||
|
|
@ -273,5 +301,14 @@ func (w *Vision) Start(filter string, count int, models []string, customSrc stri
|
|||
}
|
||||
}
|
||||
|
||||
if facesJobRequired {
|
||||
log.Debugf("vision: running face recognition")
|
||||
if faces := photoprism.NewFaces(w.conf); faces.Disabled() {
|
||||
log.Debugf("vision: skipping face recognition")
|
||||
} else if facesErr := faces.Start(photoprism.FacesOptions{}); facesErr != nil {
|
||||
log.Warn(facesErr)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
|
|
|||
3
scripts/dist/Makefile
vendored
3
scripts/dist/Makefile
vendored
|
|
@ -33,6 +33,9 @@ tensorflow-amd64-avx: install-tensorflow
|
|||
tensorflow-amd64-avx2: install-tensorflow
|
||||
install-tensorflow:
|
||||
/scripts/install-tensorflow.sh auto
|
||||
onnxruntime: install-onnx
|
||||
install-onnx:
|
||||
/scripts/install-onnx.sh
|
||||
tensorflow-gpu: install-tensorflow-gpu
|
||||
install-tensorflow-gpu:
|
||||
/scripts/install-tensorflow.sh gpu
|
||||
|
|
|
|||
126
scripts/dist/install-onnx.sh
vendored
Executable file
126
scripts/dist/install-onnx.sh
vendored
Executable file
|
|
@ -0,0 +1,126 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
ONNX_VERSION=${ONNX_VERSION:-1.22.0}
|
||||
TMPDIR=${TMPDIR:-/tmp}
|
||||
SYSTEM=$(uname -s)
|
||||
ARCH=${PHOTOPRISM_ARCH:-$(uname -m)}
|
||||
DESTDIR_ARG="${1:-/usr}"
|
||||
|
||||
if [[ ! -d "${DESTDIR_ARG}" ]]; then
|
||||
mkdir -p "${DESTDIR_ARG}"
|
||||
fi
|
||||
|
||||
DESTDIR=$(realpath "${DESTDIR_ARG}")
|
||||
|
||||
if [[ $(id -u) != 0 ]] && ([[ "${DESTDIR}" == "/usr" ]] || [[ "${DESTDIR}" == "/usr/local" ]]); then
|
||||
echo "Error: Run ${0##*/} as root to install in '${DESTDIR}'." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
mkdir -p "${DESTDIR}" "${TMPDIR}"
|
||||
|
||||
archive=""
|
||||
sha=""
|
||||
|
||||
case "${SYSTEM}" in
|
||||
Linux)
|
||||
case "${ARCH}" in
|
||||
amd64|AMD64|x86_64|x86-64)
|
||||
archive="onnxruntime-linux-x64-${ONNX_VERSION}.tgz"
|
||||
sha="8344d55f93d5bc5021ce342db50f62079daf39aaafb5d311a451846228be49b3"
|
||||
;;
|
||||
arm64|ARM64|aarch64)
|
||||
archive="onnxruntime-linux-aarch64-${ONNX_VERSION}.tgz"
|
||||
sha="bb76395092d150b52c7092dc6b8f2fe4d80f0f3bf0416d2f269193e347e24702"
|
||||
;;
|
||||
*)
|
||||
echo "Warning: ONNX Runtime is not provided for Linux/${ARCH}; skipping install." >&2
|
||||
exit 0
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
Darwin)
|
||||
case "${ARCH}" in
|
||||
arm64|ARM64|aarch64|x86_64|x86-64)
|
||||
archive="onnxruntime-osx-universal2-${ONNX_VERSION}.tgz"
|
||||
sha="cfa6f6584d87555ed9f6e7e8a000d3947554d589efe3723b8bfa358cd263d03c"
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported macOS architecture '${ARCH}'." >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
;;
|
||||
*)
|
||||
echo "Unsupported operating system '${SYSTEM}'." >&2
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
verify_sha() {
|
||||
local expected="$1"
|
||||
local file="$2"
|
||||
if command -v sha256sum >/dev/null 2>&1; then
|
||||
echo "${expected} ${file}" | sha256sum -c - >/dev/null
|
||||
else
|
||||
echo "${expected} ${file}" | shasum -a 256 -c - >/dev/null
|
||||
fi
|
||||
}
|
||||
|
||||
if [[ -z "${archive}" ]]; then
|
||||
echo "Could not determine ONNX Runtime archive." >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
url="https://github.com/microsoft/onnxruntime/releases/download/v${ONNX_VERSION}/${archive}"
|
||||
package_path="${TMPDIR}/${archive}"
|
||||
|
||||
if [[ -f "${package_path}" ]]; then
|
||||
if verify_sha "${sha}" "${package_path}"; then
|
||||
echo "Using cached archive ${package_path}."
|
||||
else
|
||||
echo "Cached archive ${package_path} failed checksum, re-downloading..."
|
||||
rm -f "${package_path}"
|
||||
fi
|
||||
fi
|
||||
|
||||
if [[ ! -f "${package_path}" ]]; then
|
||||
echo "Downloading ONNX Runtime ${ONNX_VERSION} (${archive})..."
|
||||
curl -fsSL --retry 3 --retry-delay 2 -o "${package_path}" "${url}"
|
||||
fi
|
||||
|
||||
echo "Verifying checksum..."
|
||||
verify_sha "${sha}" "${package_path}"
|
||||
|
||||
echo "Extracting to ${DESTDIR}..."
|
||||
tar --overwrite --mode=755 -C "${DESTDIR}" -xzf "${package_path}"
|
||||
|
||||
# Normalize layout: copy libraries into ${DESTDIR}/lib and remove extracted tree.
|
||||
output_lib_dir="${DESTDIR}/lib"
|
||||
mkdir -p "${output_lib_dir}"
|
||||
|
||||
for extracted in "${DESTDIR}/onnxruntime-linux-x64-${ONNX_VERSION}" "${DESTDIR}/onnxruntime-linux-aarch64-${ONNX_VERSION}" "${DESTDIR}/onnxruntime-osx-universal2-${ONNX_VERSION}"; do
|
||||
if [[ -d "${extracted}/lib" ]]; then
|
||||
find "${extracted}/lib" -maxdepth 1 -type f -name "libonnxruntime*.so*" -print0 | while IFS= read -r -d '' file; do
|
||||
cp -af "${file}" "${output_lib_dir}/"
|
||||
done
|
||||
# copy any symlinks as well to preserve SONAME links
|
||||
find "${extracted}/lib" -maxdepth 1 -type l -name "libonnxruntime*.so*" -print0 | while IFS= read -r -d '' link; do
|
||||
target=$(readlink "${link}")
|
||||
ln -sf "${target}" "${output_lib_dir}/$(basename "${link}")"
|
||||
done
|
||||
rm -rf "${extracted}"
|
||||
fi
|
||||
done
|
||||
|
||||
if [[ "${SYSTEM}" == "Linux" ]]; then
|
||||
if [[ "${DESTDIR}" == "/usr" || "${DESTDIR}" == "/usr/local" ]]; then
|
||||
ldconfig
|
||||
else
|
||||
ldconfig -n "${DESTDIR}/lib" >/dev/null 2>&1 || true
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "ONNX Runtime ${ONNX_VERSION} installed in '${DESTDIR}'."
|
||||
82
scripts/download-scrfs.sh
Executable file
82
scripts/download-scrfs.sh
Executable file
|
|
@ -0,0 +1,82 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
TODAY=$(date -u +%Y%m%d)
|
||||
|
||||
MODEL_SOURCE="scrfd_500m_bnkps_shape640x640.onnx"
|
||||
LOCAL_MODEL_NAME="scrfd.onnx"
|
||||
PRIMARY_URL="https://dl.photoprism.app/onnx/scrfd/${MODEL_SOURCE}?${TODAY}"
|
||||
FALLBACK_URL="https://raw.githubusercontent.com/laolaolulu/FaceTrain/master/model/scrfd/${MODEL_SOURCE}"
|
||||
MODEL_URL=${MODEL_URL:-"${PRIMARY_URL}"}
|
||||
MODELS_PATH="assets/models"
|
||||
MODEL_DIR="$MODELS_PATH/scrfs"
|
||||
MODEL_FILE="$MODEL_DIR/${LOCAL_MODEL_NAME}"
|
||||
MODEL_TMP="/tmp/photoprism/${MODEL_SOURCE}"
|
||||
MODEL_HASH="ae72185653e279aa2056b288662a19ec3519ced5426d2adeffbe058a86369a24 ${MODEL_TMP}"
|
||||
MODEL_VERSION="$MODEL_DIR/version.txt"
|
||||
MODEL_BACKUP="storage/backup/scrfs-${TODAY}"
|
||||
|
||||
mkdir -p /tmp/photoprism
|
||||
mkdir -p storage/backup
|
||||
mkdir -p "${MODEL_DIR}"
|
||||
|
||||
hash_file() {
|
||||
if command -v sha256sum >/dev/null 2>&1; then
|
||||
sha256sum "$1" | awk '{print $1}'
|
||||
else
|
||||
shasum -a 256 "$1" | awk '{print $1}'
|
||||
fi
|
||||
}
|
||||
|
||||
verify_hash() {
|
||||
local expected="$1"
|
||||
local file="$2"
|
||||
if command -v sha256sum >/dev/null 2>&1; then
|
||||
echo "${expected} ${file}" | sha256sum -c - >/dev/null
|
||||
else
|
||||
echo "${expected} ${file}" | shasum -a 256 -c - >/dev/null
|
||||
fi
|
||||
}
|
||||
|
||||
if [[ -f "${MODEL_FILE}" ]]; then
|
||||
CURRENT_HASH=$(hash_file "${MODEL_FILE}")
|
||||
if [[ "${CURRENT_HASH}" == ${MODEL_HASH%% *} ]]; then
|
||||
echo "SCRFD model already up to date."
|
||||
exit 0
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Downloading SCRFD detector from ${MODEL_URL}..."
|
||||
if ! curl -fsSL --retry 3 --retry-delay 2 -o "${MODEL_TMP}" "${MODEL_URL}"; then
|
||||
if [[ "${MODEL_URL}" != "${FALLBACK_URL}" ]]; then
|
||||
echo "Primary download failed, trying fallback..."
|
||||
MODEL_URL="${FALLBACK_URL}"
|
||||
MODEL_HASH="ae72185653e279aa2056b288662a19ec3519ced5426d2adeffbe058a86369a24 ${MODEL_TMP}"
|
||||
if ! curl -fsSL --retry 3 --retry-delay 2 -o "${MODEL_TMP}" "${MODEL_URL}"; then
|
||||
echo "Failed to download SCRFD detector." >&2
|
||||
exit 1
|
||||
fi
|
||||
else
|
||||
echo "Failed to download SCRFD detector." >&2
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Verifying checksum..."
|
||||
verify_hash "${MODEL_HASH%% *}" "${MODEL_TMP}"
|
||||
|
||||
if [[ -f "${MODEL_FILE}" ]]; then
|
||||
echo "Creating backup of existing detector at ${MODEL_BACKUP}"
|
||||
rm -rf "${MODEL_BACKUP}"
|
||||
mkdir -p "${MODEL_BACKUP}"
|
||||
mv "${MODEL_FILE}" "${MODEL_BACKUP}/"
|
||||
if [[ -f "${MODEL_VERSION}" ]]; then
|
||||
cp "${MODEL_VERSION}" "${MODEL_BACKUP}/"
|
||||
fi
|
||||
fi
|
||||
|
||||
mv "${MODEL_TMP}" "${MODEL_FILE}"
|
||||
echo "SCRFD ${TODAY} ${MODEL_HASH%% *} (${MODEL_SOURCE})" > "${MODEL_VERSION}"
|
||||
|
||||
echo "SCRFD detector installed in ${MODEL_DIR}."
|
||||
Loading…
Add table
Add a link
Reference in a new issue