AI: Generate Labels using the Ollama API #5232

Signed-off-by: Michael Mayer <michael@photoprism.app>
This commit is contained in:
Michael Mayer 2025-09-28 13:29:49 +02:00
parent cf06f52025
commit 8a7c61f467
6 changed files with 288 additions and 9 deletions

View file

@ -7,6 +7,7 @@ import (
"fmt"
"io"
"net/http"
"strings"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/pkg/clean"
@ -48,6 +49,10 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
return apiResponse, clientErr
}
defer func() {
_ = clientResp.Body.Close()
}()
apiResponse = &ApiResponse{}
// Parse and return response, or an error if the request failed.
@ -61,13 +66,17 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
}
case ApiFormatOllama:
ollamaResponse := &ApiResponseOllama{}
apiJson, apiErr := io.ReadAll(clientResp.Body)
if apiErr != nil {
return apiResponse, apiErr
}
if apiJson, apiErr := io.ReadAll(clientResp.Body); apiErr != nil {
return apiResponse, apiErr
} else if apiErr = json.Unmarshal(apiJson, ollamaResponse); apiErr != nil {
return apiResponse, apiErr
} else if clientResp.StatusCode >= 300 {
ollamaResponse, decodeErr := decodeOllamaResponse(apiJson)
if decodeErr != nil {
return apiResponse, decodeErr
}
if clientResp.StatusCode >= 300 {
log.Debugf("vision: %s (status code %d)", apiJson, clientResp.StatusCode)
}
@ -77,9 +86,41 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
Name: ollamaResponse.Model,
}
apiResponse.Result.Caption = &CaptionResult{
Text: ollamaResponse.Response,
Source: entity.SrcImage,
// Copy structured results when provided.
if len(ollamaResponse.Result.Labels) > 0 {
apiResponse.Result.Labels = append(apiResponse.Result.Labels, ollamaResponse.Result.Labels...)
}
parsedLabels := false
if len(apiResponse.Result.Labels) > 0 {
parsedLabels = true
}
if !parsedLabels {
if apiRequest.Format == FormatJSON {
if labels, parseErr := parseOllamaLabels(ollamaResponse.Response); parseErr != nil {
log.Debugf("vision: %s (parse ollama labels)", clean.Error(parseErr))
} else if len(labels) > 0 {
apiResponse.Result.Labels = append(apiResponse.Result.Labels, labels...)
parsedLabels = true
}
}
}
if parsedLabels {
for i := range apiResponse.Result.Labels {
if apiResponse.Result.Labels[i].Source == "" {
apiResponse.Result.Labels[i].Source = entity.SrcVision
}
}
} else {
if caption := strings.TrimSpace(ollamaResponse.Response); caption != "" {
apiResponse.Result.Caption = &CaptionResult{
Text: caption,
Source: entity.SrcImage,
}
}
}
default:
return apiResponse, fmt.Errorf("unsupported response format %s", clean.Log(apiRequest.ResponseFormat))
@ -87,3 +128,39 @@ func PerformApiRequest(apiRequest *ApiRequest, uri, method, key string) (apiResp
return apiResponse, nil
}
func decodeOllamaResponse(data []byte) (*ApiResponseOllama, error) {
resp := &ApiResponseOllama{}
dec := json.NewDecoder(bytes.NewReader(data))
for {
var chunk ApiResponseOllama
if err := dec.Decode(&chunk); err != nil {
if errors.Is(err, io.EOF) {
break
}
return nil, err
}
*resp = chunk
}
return resp, nil
}
func parseOllamaLabels(raw string) ([]LabelResult, error) {
raw = strings.TrimSpace(raw)
if raw == "" {
return nil, nil
}
var payload struct {
Labels []LabelResult `json:"labels"`
}
if err := json.Unmarshal([]byte(raw), &payload); err != nil {
return nil, err
}
return payload.Labels, nil
}

View file

@ -1,6 +1,9 @@
package vision
import (
"encoding/json"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
@ -39,3 +42,57 @@ func TestNewApiRequest(t *testing.T) {
}
})
}
func TestPerformApiRequestOllama(t *testing.T) {
t.Run("Labels", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var req ApiRequest
assert.NoError(t, json.NewDecoder(r.Body).Decode(&req))
assert.Equal(t, FormatJSON, req.Format)
assert.NoError(t, json.NewEncoder(w).Encode(ApiResponseOllama{
Model: "qwen2.5vl:latest",
Response: `{"labels":[{"name":"test","confidence":0.9,"topicality":0.8}]}`,
}))
}))
defer server.Close()
apiRequest := &ApiRequest{
Id: "test",
Model: "qwen2.5vl:latest",
Format: FormatJSON,
Images: []string{""},
ResponseFormat: ApiFormatOllama,
}
resp, err := PerformApiRequest(apiRequest, server.URL, http.MethodPost, "")
assert.NoError(t, err)
assert.Len(t, resp.Result.Labels, 1)
assert.Equal(t, "test", resp.Result.Labels[0].Name)
assert.Nil(t, resp.Result.Caption)
})
t.Run("CaptionFallback", func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
assert.NoError(t, json.NewEncoder(w).Encode(ApiResponseOllama{
Model: "qwen2.5vl:latest",
Response: "plain text",
}))
}))
defer server.Close()
apiRequest := &ApiRequest{
Id: "test2",
Model: "qwen2.5vl:latest",
Format: FormatJSON,
Images: []string{""},
ResponseFormat: ApiFormatOllama,
}
resp, err := PerformApiRequest(apiRequest, server.URL, http.MethodPost, "")
assert.NoError(t, err)
assert.Len(t, resp.Result.Labels, 0)
if assert.NotNil(t, resp.Result.Caption) {
assert.Equal(t, "plain text", resp.Result.Caption.Text)
}
})
}

View file

@ -21,6 +21,10 @@ import (
type Files = []string
const (
FormatJSON = "json"
)
// ApiRequestOptions represents additional model parameters listed in the documentation.
type ApiRequestOptions struct {
NumKeep int `yaml:"NumKeep,omitempty" json:"num_keep,omitempty"`

View file

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"sort"
"strings"
"github.com/photoprism/photoprism/internal/ai/classify"
"github.com/photoprism/photoprism/pkg/clean"
@ -30,6 +31,10 @@ func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.
return result, err
}
if format := model.GetFormat(); format != "" {
apiRequest.Format = format
}
switch model.Service.RequestFormat {
case ApiFormatOllama:
apiRequest.Model, _, _ = model.Model()
@ -45,6 +50,15 @@ func Labels(images Files, mediaSrc media.Src, labelSrc string) (result classify.
apiRequest.Prompt = model.Prompt
}
if schemaPrompt := model.SchemaInstructions(); schemaPrompt != "" {
prompt := strings.TrimSpace(apiRequest.Prompt)
if prompt != "" {
apiRequest.Prompt = fmt.Sprintf("%s\n\n%s", prompt, schemaPrompt)
} else {
apiRequest.Prompt = schemaPrompt
}
}
// Log JSON request data in trace mode.
apiRequest.WriteLog()

View file

@ -2,6 +2,7 @@ package vision
import (
"fmt"
"os"
"strings"
"sync"
@ -10,11 +11,14 @@ import (
"github.com/photoprism/photoprism/internal/ai/nsfw"
"github.com/photoprism/photoprism/internal/ai/tensorflow"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/service/http/scheme"
)
var modelMutex = sync.Mutex{}
const labelSchemaEnvVar = "PHOTOPRISM_VISION_LABEL_SCHEMA_FILE"
// Default model version strings.
var (
VersionLatest = "latest"
@ -30,6 +34,9 @@ type Model struct {
Version string `yaml:"Version,omitempty" json:"version,omitempty"`
System string `yaml:"System,omitempty" json:"system,omitempty"`
Prompt string `yaml:"Prompt,omitempty" json:"prompt,omitempty"`
Format string `yaml:"Format,omitempty" json:"format,omitempty"`
Schema string `yaml:"Schema,omitempty" json:"schema,omitempty"`
SchemaFile string `yaml:"SchemaFile,omitempty" json:"schemaFile,omitempty"`
Resolution int `yaml:"Resolution,omitempty" json:"resolution,omitempty"`
TensorFlow *tensorflow.ModelInfo `yaml:"TensorFlow,omitempty" json:"tensorflow,omitempty"`
Options *ApiRequestOptions `yaml:"Options,omitempty" json:"options,omitempty"`
@ -39,6 +46,8 @@ type Model struct {
classifyModel *classify.Model
faceModel *face.Model
nsfwModel *nsfw.Model
schemaOnce sync.Once
schema string
}
// Models represents a set of computer vision models.
@ -152,6 +161,19 @@ func (m *Model) GetSystemPrompt() string {
}
}
// GetFormat returns the configured response format or a sensible default.
func (m *Model) GetFormat() string {
if f := strings.TrimSpace(strings.ToLower(m.Format)); f != "" {
return f
}
if m.Type == ModelTypeLabels && m.EndpointResponseFormat() == ApiFormatOllama {
return FormatJSON
}
return ""
}
// GetOptions returns the API request options.
func (m *Model) GetOptions() *ApiRequestOptions {
if m.Options != nil {
@ -174,6 +196,56 @@ func (m *Model) GetOptions() *ApiRequestOptions {
}
}
// SchemaTemplate returns the model-specific JSON schema template, if any.
func (m *Model) SchemaTemplate() string {
m.schemaOnce.Do(func() {
var schema string
if m.Type == ModelTypeLabels {
if envFile := strings.TrimSpace(os.Getenv(labelSchemaEnvVar)); envFile != "" {
path := fs.Abs(envFile)
if path == "" {
path = envFile
}
if data, err := os.ReadFile(path); err != nil {
log.Warnf("vision: failed to read schema from %s (%s)", clean.Log(path), err)
} else {
schema = string(data)
}
}
}
if schema == "" && strings.TrimSpace(m.Schema) != "" {
schema = m.Schema
}
if schema == "" && strings.TrimSpace(m.SchemaFile) != "" {
path := fs.Abs(m.SchemaFile)
if path == "" {
path = m.SchemaFile
}
if data, err := os.ReadFile(path); err != nil {
log.Warnf("vision: failed to read schema from %s (%s)", clean.Log(path), err)
} else {
schema = string(data)
}
}
m.schema = strings.TrimSpace(schema)
})
return m.schema
}
// SchemaInstructions returns a helper string that can be appended to prompts.
func (m *Model) SchemaInstructions() string {
if schema := m.SchemaTemplate(); schema != "" {
return fmt.Sprintf("Return JSON that matches this schema:\n%s", schema)
}
return ""
}
// ClassifyModel returns the matching classify model instance, if any.
func (m *Model) ClassifyModel() *classify.Model {
// Use mutex to prevent models from being loaded and

View file

@ -2,10 +2,13 @@ package vision
import (
"net/http"
"os"
"path/filepath"
"testing"
"github.com/stretchr/testify/assert"
"github.com/photoprism/photoprism/pkg/fs"
"github.com/photoprism/photoprism/pkg/service/http/scheme"
)
@ -71,3 +74,55 @@ func TestParseTypes(t *testing.T) {
assert.Equal(t, ModelTypes{}, result)
})
}
func TestModelFormatAndSchema(t *testing.T) {
t.Run("DefaultOllamaFormat", func(t *testing.T) {
m := &Model{
Type: ModelTypeLabels,
Service: Service{
RequestFormat: ApiFormatOllama,
ResponseFormat: ApiFormatOllama,
},
}
assert.Equal(t, FormatJSON, m.GetFormat())
})
t.Run("InlineSchema", func(t *testing.T) {
schema := "{\n \"labels\": []\n}"
m := &Model{Schema: schema}
assert.Equal(t, schema, m.SchemaTemplate())
assert.Contains(t, m.SchemaInstructions(), "Return JSON")
})
t.Run("SchemaFileAndEnv", func(t *testing.T) {
tempDir := t.TempDir()
filePath := filepath.Join(tempDir, "schema.json")
content := "{\n \"labels\": [{\"name\": \"test\"}]\n}"
assert.NoError(t, os.WriteFile(filePath, []byte(content), fs.ModeConfigFile))
m := &Model{
Type: ModelTypeLabels,
SchemaFile: filePath,
}
// First read should use file content.
assert.Equal(t, content, m.SchemaTemplate())
// Reset and use env override with a different file.
otherFile := filepath.Join(tempDir, "schema-override.json")
otherContent := "{\n \"labels\": []\n, \"markers\": []\n}"
assert.NoError(t, os.WriteFile(otherFile, []byte(otherContent), fs.ModeConfigFile))
t.Setenv(labelSchemaEnvVar, otherFile)
m2 := &Model{Type: ModelTypeLabels}
assert.Equal(t, otherContent, m2.SchemaTemplate())
})
t.Run("FormatOverride", func(t *testing.T) {
m := &Model{Format: "JSON"}
assert.Equal(t, FormatJSON, m.GetFormat())
})
}