mirror of
https://github.com/photoprism/photoprism.git
synced 2026-01-23 02:24:24 +00:00
CLI: Refactor "photoprism vision" subcommands #5233
Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
parent
e21174c297
commit
84e11829dc
6 changed files with 219 additions and 55 deletions
|
|
@ -1,8 +1,37 @@
|
|||
package commands
|
||||
|
||||
import "github.com/urfave/cli/v2"
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/entity/search"
|
||||
)
|
||||
|
||||
// JsonFlag returns the shared CLI flag definition for JSON output across commands.
|
||||
func JsonFlag() *cli.BoolFlag {
|
||||
return &cli.BoolFlag{Name: "json", Aliases: []string{"j"}, Usage: "print machine-readable JSON"}
|
||||
}
|
||||
|
||||
// PicturesCountFlag returns a shared flag definition limiting how many pictures a batch operation processes.
|
||||
// Usage: commands from the vision or import tooling that need to cap result size per invocation.
|
||||
func PicturesCountFlag() *cli.IntFlag {
|
||||
return &cli.IntFlag{
|
||||
Name: "count",
|
||||
Aliases: []string{"n"},
|
||||
Usage: "maximum `NUMBER` of pictures to be processed",
|
||||
Value: search.MaxResults,
|
||||
}
|
||||
}
|
||||
|
||||
// VisionSourceFlag returns the CLI flag used to choose a metadata source for computer-vision commands.
|
||||
// Allowing only whitelisted aliases keeps CLI input aligned with entity.VisionSrcNames.
|
||||
func VisionSourceFlag() *cli.StringFlag {
|
||||
return &cli.StringFlag{
|
||||
Name: "source",
|
||||
Aliases: []string{"s"},
|
||||
Usage: fmt.Sprintf("custom data source `TYPE` (%s)", visionSourceUsage()),
|
||||
Value: entity.SrcImage,
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,8 +10,6 @@ import (
|
|||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/entity/search"
|
||||
"github.com/photoprism/photoprism/internal/workers"
|
||||
"github.com/photoprism/photoprism/pkg/txt"
|
||||
)
|
||||
|
|
@ -28,18 +26,8 @@ var VisionResetCommand = &cli.Command{
|
|||
Usage: "computer vision `MODELS` to reset, e.g. caption or labels",
|
||||
Value: "",
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "count",
|
||||
Aliases: []string{"n"},
|
||||
Usage: "maximum `NUMBER` of pictures to be processed",
|
||||
Value: search.MaxResults,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "source",
|
||||
Aliases: []string{"s"},
|
||||
Usage: "generated data source `TYPE` to reset, e.g. vision or ollama",
|
||||
Value: entity.SrcVision,
|
||||
},
|
||||
PicturesCountFlag(),
|
||||
VisionSourceFlag(),
|
||||
&cli.BoolFlag{
|
||||
Name: "yes",
|
||||
Aliases: []string{"y"},
|
||||
|
|
@ -61,6 +49,7 @@ func visionResetAction(ctx *cli.Context) error {
|
|||
}
|
||||
|
||||
selectedModels := make([]string, 0, 2)
|
||||
|
||||
if resetCaptions {
|
||||
selectedModels = append(selectedModels, vision.ModelTypeCaption)
|
||||
}
|
||||
|
|
@ -80,11 +69,17 @@ func visionResetAction(ctx *cli.Context) error {
|
|||
|
||||
worker := workers.NewVision(conf)
|
||||
filter := strings.TrimSpace(strings.Join(ctx.Args().Slice(), " "))
|
||||
source, err := sanitizeVisionSource(ctx.String("source"))
|
||||
|
||||
if err != nil {
|
||||
return cli.Exit(err.Error(), 1)
|
||||
}
|
||||
|
||||
return worker.Reset(
|
||||
filter,
|
||||
ctx.Int("count"),
|
||||
selectedModels,
|
||||
ctx.String("source"),
|
||||
string(source),
|
||||
)
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ import (
|
|||
|
||||
"github.com/photoprism/photoprism/internal/ai/vision"
|
||||
"github.com/photoprism/photoprism/internal/config"
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/internal/entity/search"
|
||||
"github.com/photoprism/photoprism/internal/workers"
|
||||
)
|
||||
|
||||
|
|
@ -24,18 +22,8 @@ var VisionRunCommand = &cli.Command{
|
|||
Usage: "computer vision `MODELS` to run, e.g. caption, labels, or nsfw",
|
||||
Value: "caption",
|
||||
},
|
||||
&cli.IntFlag{
|
||||
Name: "count",
|
||||
Aliases: []string{"n"},
|
||||
Usage: "maximum `NUMBER` of pictures to be processed",
|
||||
Value: search.MaxResults,
|
||||
},
|
||||
&cli.StringFlag{
|
||||
Name: "source",
|
||||
Aliases: []string{"s"},
|
||||
Value: entity.SrcImage,
|
||||
Usage: "custom data source `TYPE` e.g. default, image, meta, vision, or admin",
|
||||
},
|
||||
PicturesCountFlag(),
|
||||
VisionSourceFlag(),
|
||||
&cli.BoolFlag{
|
||||
Name: "force",
|
||||
Aliases: []string{"f"},
|
||||
|
|
@ -50,11 +38,17 @@ func visionRunAction(ctx *cli.Context) error {
|
|||
return CallWithDependencies(ctx, func(conf *config.Config) error {
|
||||
worker := workers.NewVision(conf)
|
||||
filter := strings.TrimSpace(strings.Join(ctx.Args().Slice(), " "))
|
||||
source, err := sanitizeVisionSource(ctx.String("source"))
|
||||
|
||||
if err != nil {
|
||||
return cli.Exit(err.Error(), 1)
|
||||
}
|
||||
|
||||
return worker.Start(
|
||||
filter,
|
||||
ctx.Int("count"),
|
||||
vision.ParseTypes(ctx.String("models")),
|
||||
ctx.String("source"),
|
||||
string(source),
|
||||
ctx.Bool("force"),
|
||||
)
|
||||
})
|
||||
|
|
|
|||
59
internal/commands/vision_sources.go
Normal file
59
internal/commands/vision_sources.go
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
"github.com/photoprism/photoprism/pkg/txt"
|
||||
)
|
||||
|
||||
var (
|
||||
visionSourceNames []string
|
||||
visionSourcesOnce sync.Once
|
||||
)
|
||||
|
||||
func initVisionSources() {
|
||||
visionSourcesOnce.Do(func() {
|
||||
namesSet := make(map[string]struct{}, len(entity.VisionSrcNames))
|
||||
|
||||
for alias := range entity.VisionSrcNames {
|
||||
normalized := strings.TrimSpace(alias)
|
||||
|
||||
if normalized == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := namesSet[normalized]; ok {
|
||||
continue
|
||||
}
|
||||
namesSet[normalized] = struct{}{}
|
||||
visionSourceNames = append(visionSourceNames, normalized)
|
||||
}
|
||||
|
||||
sort.Strings(visionSourceNames)
|
||||
})
|
||||
}
|
||||
|
||||
func sanitizeVisionSource(raw string) (entity.Src, error) {
|
||||
initVisionSources()
|
||||
|
||||
value := strings.ToLower(strings.TrimSpace(raw))
|
||||
if value == "" {
|
||||
return entity.SrcAuto, nil
|
||||
}
|
||||
|
||||
if src, ok := entity.VisionSrcNames[value]; ok {
|
||||
return src, nil
|
||||
}
|
||||
|
||||
allowed := append([]string(nil), visionSourceNames...)
|
||||
return "", fmt.Errorf("vision: unsupported source %q (allowed: %s)", raw, txt.JoinAnd(allowed))
|
||||
}
|
||||
|
||||
func visionSourceUsage() string {
|
||||
initVisionSources()
|
||||
return strings.Join(visionSourceNames, ", ")
|
||||
}
|
||||
44
internal/commands/vision_sources_test.go
Normal file
44
internal/commands/vision_sources_test.go
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
package commands
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/entity"
|
||||
)
|
||||
|
||||
func TestSanitizeVisionSource(t *testing.T) {
|
||||
cases := map[string]entity.Src{
|
||||
"": entity.SrcAuto,
|
||||
"auto": entity.SrcAuto,
|
||||
"AUTO": entity.SrcAuto,
|
||||
"default": entity.SrcDefault,
|
||||
"DEFAULT": entity.SrcDefault,
|
||||
"image": entity.SrcImage,
|
||||
"ollama": entity.SrcOllama,
|
||||
"openai": entity.SrcOpenAI,
|
||||
"vision": entity.SrcVision,
|
||||
}
|
||||
|
||||
for input, expected := range cases {
|
||||
result, err := sanitizeVisionSource(input)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, expected, result)
|
||||
}
|
||||
|
||||
if _, err := sanitizeVisionSource("meta"); err == nil {
|
||||
t.Fatalf("expected error for unsupported source")
|
||||
}
|
||||
}
|
||||
|
||||
func TestVisionSourceUsage(t *testing.T) {
|
||||
display := visionSourceUsage()
|
||||
|
||||
for _, name := range []string{"auto", "default", "image", "ollama", "openai", "vision"} {
|
||||
if !strings.Contains(display, name) {
|
||||
t.Fatalf("expected usage to list %s", name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -3,6 +3,7 @@ package entity
|
|||
import (
|
||||
"sort"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/photoprism/photoprism/internal/ai/classify"
|
||||
)
|
||||
|
|
@ -16,6 +17,9 @@ type Priority = int
|
|||
// Priorities maps source strings to their relative priorities.
|
||||
type Priorities map[Src]Priority
|
||||
|
||||
// SrcMap maps source names to sources.
|
||||
type SrcMap map[string]Src
|
||||
|
||||
// Supported metadata source strings.
|
||||
const (
|
||||
SrcAuto Src = classify.SrcAuto // Prio 1
|
||||
|
|
@ -79,31 +83,52 @@ var SrcPriority = Priorities{
|
|||
SrcAdmin: 128,
|
||||
}
|
||||
|
||||
// VisionSrcNames maps source names to the sources that can be used as arguments for computer vision commands.
|
||||
var VisionSrcNames = SrcMap{
|
||||
SrcAuto: SrcAuto,
|
||||
SrcString(SrcAuto): SrcAuto,
|
||||
SrcDefault: SrcDefault,
|
||||
SrcMarker: SrcMarker,
|
||||
SrcImage: SrcImage,
|
||||
SrcOllama: SrcOllama,
|
||||
SrcOpenAI: SrcOpenAI,
|
||||
SrcVision: SrcVision,
|
||||
}
|
||||
|
||||
// VisionSrc contains all the sources commonly used by computer vision models and services.
|
||||
var VisionSrc = []Src{
|
||||
SrcMarker,
|
||||
SrcImage,
|
||||
SrcOllama,
|
||||
SrcOpenAI,
|
||||
SrcVision,
|
||||
}
|
||||
|
||||
// SrcDesc maps source strings to their descriptions for documentation purposes.
|
||||
var SrcDesc = map[Src]string{
|
||||
SrcAuto: SrcString(SrcAuto),
|
||||
SrcDefault: "default",
|
||||
SrcEstimate: "estimated data",
|
||||
SrcFile: "filesystem metadata",
|
||||
SrcName: "file name",
|
||||
SrcYaml: "YAML sidecar file",
|
||||
SrcAuto: "Auto",
|
||||
SrcDefault: "Default",
|
||||
SrcEstimate: "Estimated",
|
||||
SrcFile: "File System",
|
||||
SrcName: "File Name",
|
||||
SrcYaml: "YAML Sidecar",
|
||||
SrcOIDC: "OpenID Connect (OIDC)",
|
||||
SrcLDAP: "LDAP / Active Directory",
|
||||
SrcLocation: "GPS position",
|
||||
SrcMarker: "face / object detection",
|
||||
SrcOllama: "Ollama",
|
||||
SrcOpenAI: "OpenAI",
|
||||
SrcImage: "computer vision (default)",
|
||||
SrcTitle: "picture title",
|
||||
SrcCaption: "picture caption",
|
||||
SrcSubject: "subject / person",
|
||||
SrcKeyword: "picture keywords",
|
||||
SrcMeta: "embedded metadata",
|
||||
SrcXmp: "XMP sidecar file",
|
||||
SrcBatch: "batch edit",
|
||||
SrcVision: "computer vision (manual)",
|
||||
SrcManual: "manually changed",
|
||||
SrcAdmin: "overrides manual changes",
|
||||
SrcLocation: "GPS Position",
|
||||
SrcMarker: "Object Detection",
|
||||
SrcImage: "Computer Vision (default)",
|
||||
SrcOllama: "Computer Vision (Ollama)",
|
||||
SrcOpenAI: "Computer Vision (OpenAI)",
|
||||
SrcTitle: "Picture Title",
|
||||
SrcCaption: "Picture Caption",
|
||||
SrcSubject: "Person",
|
||||
SrcKeyword: "Picture Keywords",
|
||||
SrcMeta: "Embedded Metadata",
|
||||
SrcXmp: "XMP Sidecar",
|
||||
SrcBatch: "Batch Edit",
|
||||
SrcVision: "Computer Vision (manual)",
|
||||
SrcManual: "Edited Manually",
|
||||
SrcAdmin: "Admin Override",
|
||||
}
|
||||
|
||||
// Report returns a metadata sources documentation table.
|
||||
|
|
@ -111,15 +136,33 @@ func (p Priorities) Report() (rows [][]string, cols []string) {
|
|||
cols = []string{"Source", "Priority", "Description"}
|
||||
|
||||
keys := make([]string, 0, len(SrcPriority))
|
||||
|
||||
for s := range SrcPriority {
|
||||
keys = append(keys, s)
|
||||
}
|
||||
|
||||
sort.Slice(keys, func(i, j int) bool {
|
||||
pi, pj := SrcPriority[keys[i]], SrcPriority[keys[j]]
|
||||
if pi == pj {
|
||||
return keys[i] < keys[j]
|
||||
|
||||
if pi != pj {
|
||||
return pi < pj
|
||||
}
|
||||
return pi < pj
|
||||
|
||||
di := strings.ToLower(SrcDesc[keys[i]])
|
||||
if di == "" {
|
||||
di = strings.ToLower(keys[i])
|
||||
}
|
||||
|
||||
dj := strings.ToLower(SrcDesc[keys[j]])
|
||||
if dj == "" {
|
||||
dj = strings.ToLower(keys[j])
|
||||
}
|
||||
|
||||
if di != dj {
|
||||
return di < dj
|
||||
}
|
||||
|
||||
return strings.ToLower(keys[i]) < strings.ToLower(keys[j])
|
||||
})
|
||||
|
||||
rows = make([][]string, len(keys))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue